この記事の主な内容は、TensorFlow で RNN のいくつかの構造を実装する方法です。
1. シングルステップRNNの学習: RNNCell TensorFlow の RNN について学習したい場合、まず最初に「RNNCell」について学習する必要があります。これは、TensorFlow で RNN を実装するための基本単位です。各 RNNCell には call メソッドがあり、次のように使用されます: (output, next_state) = call(input, state)。 写真があればもっと分かりやすいかもしれません。初期状態 h0 と入力 x1 があり、call(x1, h0) を呼び出すと (output1, h1) が得られるとします。 もう一度call(x2, h1)を呼び出すと、(output2, h2)が返されます。 つまり、RNNCell 呼び出しメソッドの各呼び出しは、時間を「 1 ステップ進める」ことに相当します。これが RNNCell の基本的な機能です。 コード実装では、RNNCell は単なる抽象クラスです。これを使用する場合は、その 2 つのサブクラスである BasicRNNCell と BasicLSTMCell を使用します。名前が示すように、前者は RNN の基本クラスであり、後者は LSTM の基本クラスです。ここでは、ソースコードの実装を読むことをお勧めします。最初からすべてを読む必要はありません。RNNCell、BasicRNNCell、BasicLSTMCell の 3 つのクラスのコメントを見るだけで、それらの機能を理解できるはずです。 call メソッドに加えて、RNNCell にはさらに 2 つの重要なクラス属性があります。
前者は隠れ層のサイズであり、後者は出力のサイズです。たとえば、通常は計算のためにバッチをモデルに送信します。入力データの形状が (batch_size、input_size) であると仮定すると、計算中に取得される隠し層の状態は (batch_size、state_size) となり、出力は (batch_size、output_size) となります。 次のコードで確認できます (次のコードは TensorFlow*** のバージョン 1.2 に基づいていることに注意してください)。
BasicLSTMCell の場合、状況は少し異なります。LSTM には 2 つの隠し状態 h と c があり、対応する隠し層はタプルであり、それぞれが (batch_size、state_size) の形状であるためです。
2. 一度に複数のステップを実行する方法を学ぶ: tf.nn.dynamic_rnn 基本的な RNNCell には明らかな問題があります。単一の RNNCell の場合、その呼び出し関数を使用して操作を実行すると、シーケンス時間で 1 ステップしか前進しません。たとえば、x1 と h0 を使用して h1 を取得し、x2 と h1 を使用して h2 を取得します。この場合、シーケンスの長さが 10 であれば、呼び出し関数を 10 回呼び出す必要があり、面倒です。この目的のために、TensorFlow は tf.nn.dynamic_rnn 関数を提供します。これは、呼び出し関数を n 回呼び出すことと同等です。つまり、{h1,h2…,hn}は{h0,x1,x2,….,xn}を通じて直接取得できます。 具体的には、入力データの形式が (batch_size、time_steps、input_size) であるとします。ここで、time_steps はシーケンス自体の長さを表します。たとえば、Char RNN では、長さ 10 の文は、time_steps が 10 に相当します。まず、input_size は、単一の時間次元における入力データの単一シーケンスの固有の長さを表します。さらに、RNNCell を定義し、RNNCell の呼び出し関数を time_steps 回呼び出しました。対応するコードは次のとおりです。
この時点で得られる出力は、time_steps ステップのすべての出力です。その形状は (batch_size、time_steps、cell.output_size) です。 state は最後のステップの隠し状態であり、その形状は (batch_size、cell.state_size) です。 さらに理解を深めるには、tf.nn.dynamic_rnn のドキュメントを読むことをお勧めします。 3. RNNCell をスタックする方法を学ぶ: MultiRNNCell 多くの場合、単層 RNN の機能には制限があり、多層 RNN が必要になります。 x を RNN の最初のレイヤーに入力すると、隠れ状態 h が得られます。この隠れ状態は、RNN の 2 番目のレイヤーの入力に相当します。RNN の 2 番目のレイヤーの隠れ状態は、RNN の 3 番目のレイヤーの入力に相当します。 TensorFlow では、tf.nn.rnn_cell.MultiRNNCell 関数を使用して RNNCell をスタックできます。対応するサンプル プログラムは次のとおりです。
MultiRNNCell を通じて取得されるセルは新しいものではありません。実際には RNNCell のサブクラスなので、call メソッド、state_size および output_size 属性も持っています。 tf.nn.dynamic_rnn を使用して一度に複数のステップを実行することも可能です。 MutiRNNCell の機能をさらに理解するには、MutiRNNCell ソース コード内のコメントを読むことをお勧めします。 4. 潜在的な落とし穴1: 出力の説明 古典的な RNN 構造には、次のような図があります。 上記のコードでは、call または dynamic_rnn 関数を呼び出した後に取得される出力の導入を意図的に無視しているようです。上の図を TensorFlow の BasicRNNCell と比較してください。 h は BasicRNNCell の state_size に対応します。では、y は BasicRNNCell の output_size に対応しているのでしょうか? 答えは「いいえ」です。 ソース コードで BasicRNNCell の呼び出し関数の実装を見つけます。
「return output, output」という文は、BasicRNNCell では出力が実際には隠し状態の値と同じであることを示しています。したがって、図の実際の出力 y を取得するには、出力に対して追加の変換を定義する必要があります。出力と隠し状態は同じものなので、BasicRNNCell では state_size は常に output_size と等しくなります。 TensorFlow は、できるだけ簡潔にするために BasicRNNCell を定義しているため、出力パラメータは省略されています。ここでは、図の元の RNN 定義とどのように関連し、どのように異なるかを明確にする必要があります。 BasicLSTMCell の呼び出し関数定義 (関数の最初の数行) を見てみましょう。
self._state_is_tuple == True の場合のみに注意する必要があります。 self._state_is_tuple == False の場合は将来非推奨になる予定です。返される隠し状態は new_c と new_h の組み合わせであり、出力は new_h のみになります。分類問題を扱っている場合は、最適な分類確率出力を得るために、new_h に別の Softmax レイヤーを追加する必要もあります。 詳細を理解するには、ソース コードの実装を自分で確認することをお勧めします。 5. 潜在的な落とし穴2: バージョンの問題によるエラー 先ほど RNN のスタッキングについて説明したとき、使用したコードは次のとおりです。
このコードは TensorFlow 1.2 では正常に動作します。しかし、以前のバージョン (およびインターネット上の多くの関連チュートリアル) では、実装は次のようになります。
TensorFlow 1.2 で元の方法で定義すると、エラーが発生します。 6. 実践プロジェクト: Char RNN 上記の内容は、実は TensorFlow で RNN を実装するための基礎知識です。このとき、プロジェクトを使用して練習し、定着させることをお勧めします。ここでは、Char RNN プロジェクトが特にお勧めです。このプロジェクトは、古典的な RNN 構造に対応しています。これを実装するために使用される TensorFlow 関数は、上記のものです。プロジェクト自体は非常に興味深く、テキスト生成に使用できます。これは基本的に、ディープラーニングを使用して詩や歌詞を書くときによく見られるものです。 Char RNN の実装はすでにたくさんあり、Github で見つけることができます。参考までに、ここで実装も作成しました。プロジェクトのアドレスは、hzy46/Char-RNN-TensorFlow です。 中国語をサポートするために、主にコードに埋め込みレイヤーを追加しました。また、コード構造を再編成し、API を最新の TensorFlow 1.2 バージョンに変更しました。 このプロジェクトを使用して詩を書くことができます (次の詩は自動的に生成されます)。 誰でもこの場所がどのような場所か見ることができます。 一夜にして山へ行き、一夜にして山川へ帰る。 山のそよ風が春の草を緑に染め、秋の水が夜に深い音を立てます。 なぜ私たちは会うのでしょうか?それは私たちが古くからの友人だからです。 どうして私たちは二度と会えないのだろう? どうして川辺で会えるのだろう? 雲の中に木の葉が生え、竹の堂からは春風が吹き抜けます。 私があなたを訪ねても、それはあなたの心の中にはないでしょう。 コードを生成することもできます:
英語の生成は問題ありません(シェイクスピアのテキストを使用してトレーニング済み): 発売: 彼には形があまりにも歪んでいた、あなたは彼女が 彼女に聞くと、私たちは何を言うだろう 主とすべての反則とあまりにも、言うだろう、 私たちはここに平和と共にいます。 パリナ: なぜ、あなたは呼吸したり敬礼したりしなければならないのですか? 私は彼を満足させすぎた 私はキャンプルスです。 ***、想像力が豊かなら、もっと面白いこともできます。たとえば、有名なオンライン小説「Battle Through the Heavens」を使用して RNN モデルをトレーニングすると、次のテキストを生成できます。 これを聞いて、シャオ・ヤンはびっくりして、隣にいる灰色のローブを着た青年に視線を向け、それから老人に視線を向けた。そこには、巨大な石の台座の上に巨大な穴があり、そこからいくつかの黒い光の柱が出ていた。空からは巨大な黒いニシキヘビと非常に恐ろしいオーラが噴き出していた。そして、何人かの目には、稲妻のようにそれらの姿が現れた。その魂には、多くの強い人々の感覚があった。彼らの前には、それらの姿は黒い影のようだった。その目の中で、この巨大な空間に、彼らは広がっていた... 「ここは闘尊レベルだが、何をしても動けない。あいつらはここのためにこんなことができる。ここには何か異常があるかもしれないし、他人の魂を渡すこともできない。だから、この強者を呑み込む天蛇に渡すわけにはいかない。今度は、俺たちの力で倒せる……」 「ここの人たちも、魂宮の強者たちと戦えるほどだ」 シャオ・ヤンの目には一瞬の恐怖が浮かび、それから彼は微笑み、そして冷たい叫び声を上げた。彼の後ろにいた魂宮の達人たちはシャオ・ヤンに向かって叫んだ。空から体が飛び出し、恐ろしいエネルギーが空から降り注いだ。 "笑う!" なかなか楽しいので、日本語生成などもやってみました。 7. LSTMCellの完全版を学ぶ 上記では、BasicRNNCell と BasicLSTMCell の基本バージョンのみを説明しました。 TensorFlow には「完全な」 LSTM である LSTMCell もあります。この LSTM の完全バージョンでは、ピープホールを定義し、出力投影層を追加し、LSTM 忘却ユニットのバイアスを設定できます。使用方法については、ソース コードを参照してください。 8. 最新のSeq2Seq APIを学ぶ Google は、TensorFlow バージョン 1.2 で Seq2Seq API を更新しました (1.3.0 rc バージョンがリリースされ、正式バージョンも間もなくリリースされるようです。更新が非常に速いです)。この API を使用すると、Seq2Seq モデルで Encoder と Decoder を手動で定義する必要がなくなります。さらに、バージョン 1.2 の新しいデータ読み取り方法である Datasets にも対応しています。使用方法については、こちらのドキュメントをお読みください。 IX. 結論 ***簡単にまとめると、この記事では、学習の順序、起こりうる落とし穴、ソースコード分析、サンプルプロジェクト hzy46/Char-RNN-TensorFlow など、TensorFlow RNN 実装を学習するための詳細なパスを提供します。皆様のお役に立てれば幸いです。 |
>>: TensorFlow の基礎から実践まで: 交通標識分類ニューラル ネットワークの作成方法を段階的に学習します
人工知能 (AI) とデジタル病理学は、特に通信分野において医療業界に革命をもたらすと期待されていま...
2020年は自動運転業界が徐々に安定する年だ。ウェイモなどの巨大企業が商業化の模索を開始し、テスラ...
サイエンス フィクションや大衆文化では、人工知能 (AI) 技術に関する大胆な予測や説明がよく取り上...
[[342573]]研究室の菌類1928 年、スコットランドの研究者アレクサンダー・フレミングが休暇...
AdobeやCelsysなどのソフトウェア企業は近年、デジタルデザインソフトウェアに人工知能機能を追...
音声とテキストの両方における自然言語処理 (NLP) の改善は、主流のテクノロジーの進歩に役立ちます...
最近、P2Pプラットフォームが頻繁に崩壊していることから、インターネット金融プラットフォームの長期的...
[[440742]]この記事はLeiphone.comから転載したものです。転載する場合は、Leip...
エッジデバイスとコンピューティングにおける AI アプリケーションが未来である理由は何でしょうか?変...
シンプルな Java 暗号化アルゴリズムは次のとおりです。厳密に言えば、BASE64 は暗号化アルゴ...
Chris Betz 氏は、サイバーセキュリティにおける GenAI の役割について恐れたり、過度に...