TENSORFLOW に基づく中国語テキスト分類のための CNN と RNN

TENSORFLOW に基づく中国語テキスト分類のための CNN と RNN

[[211015]]

現在、TensorFlow のメジャーバージョンは 1.3 にアップグレードされ、多くのネットワーク層のより高度なカプセル化と実装が実現され、さらに Keras などの優れた高レベルフレームワークも統合され、使いやすさが大幅に向上しました。初期の基盤コードと比較すると、今日の実装はより簡潔でエレガントになっています。

この記事は、中国語データセットでの TensorFlow の簡略化された実装です。文字レベルの CNN と RNN を使用して中国語のテキストを分類し、良好な結果を達成しています。

データセット

この記事では、Tsinghua NLP Group が提供する THUCNews ニューステキスト分類データセットのサブセットを使用します (元のデータセットには約 740,000 件のドキュメントが含まれており、トレーニングには長い時間がかかります)。 THUCTC: 効率的な中国語テキスト分類ツールキットからデータセットをダウンロードしてください。データ プロバイダーのオープン ソース契約に従ってください。

このトレーニングでは 10 個のカテゴリが使用され、各カテゴリには 6,500 個のデータ ポイントがありました。

カテゴリーは次のとおりです。

スポーツ、金融、不動産、住宅、教育、テクノロジー、ファッション、時事問題、ゲーム、エンターテイメント

このサブセットはここからダウンロードできます: リンク: http://pan.baidu.com/s/1bpq9Eub パスワード: ycyw

データセットは次のように分割されます。

  • トレーニングセット: 5000*10
  • 検証セット: 500*10
  • テストセット: 1000*10

元のデータセットからサブセットを生成するプロセスについては、ヘルパーの下の 2 つのスクリプトを参照してください。このうち、copy_data.sh は各カテゴリから 6500 個のファイルをコピーするために使用され、cnews_group.py は複数のファイルを 1 つのファイルに統合するために使用されています。ファイルを実行すると、次の 3 つのデータ ファイルが取得されます。

  • cnews.train.txt: トレーニング セット (50,000 項目)
  • cnews.val.txt: 検証セット (5000 項目)
  • cnews.test.txt: テストセット (10000 項目)

前処理

data/cnews_loader.py はデータ前処理ファイルです。

  • read_file(): ファイルデータを読み取ります。
  • build_vocab(): 文字レベルの表現を使用して語彙を構築します。この関数は語彙を保存し、毎回繰り返し処理されないようにしています。
  • read_vocab(): 前のステップで保存された語彙を読み取り、それを {word:id} 表現に変換します。
  • read_category(): カテゴリディレクトリを修正し、{category: id} 表現に変換します。
  • to_words(): id で表されるデータをテキストに変換します。
  • preocess_file(): データセットをテキストから固定長の ID シーケンス表現に変換します。
  • batch_iter(): ニューラル ネットワークのトレーニング用にシャッフルされたデータ バッチを準備します。

データの前処理後のデータ形式は次のようになります。

CNN 畳み込みニューラルネットワーク

構成項目

CNN の設定可能なパラメータは、以下の cnn_model.py に示されています。

  1. クラス TCNNConfig(オブジェクト):
  2. CNN 構成パラメータ  
  3.  
  4. embedding_dim = 64 # 単語ベクトルの次元
  5. seq_length = 600 # シーケンスの長さ
  6. num_classes = 10 # カテゴリの数
  7. num_filters = 128 # 畳み込みカーネルの数
  8. kernel_size = 5 # 畳み込みカーネルのサイズ
  9. vocab_size = 5000 # 小さな語彙
  10.  
  11. hidden_​​dim = 128 #完全結合層ニューロン
  12.  
  13. dropout_keep_prob = 0.5 # ドロップアウト保持率
  14. learning_rate = 1e-3 # 学習率
  15.  
  16. batch_size = 64 # バッチあたりのトレーニングサイズ
  17. num_epochs = 10 # 反復回数の合計
  18.  
  19. print_per_batch = 100 # 数ラウンドごとに結果を出力します
  20. save_per_batch = 10 # テンソルボードに保存するラウンド数

CNNモデル

詳細については、cnn_model.py の実装を参照してください。

一般的な構造は次のとおりです。

トレーニングと検証

トレーニングを開始するには、python run_cnn.py train を実行します。

以前にトレーニングしたことがある場合は、TensorBoard で複数のトレーニング結果が重複しないように、tensorboard/textcnn を削除してください。

  1. CNN モデルを構成しています...
  2. TensorBoardSaver を設定しています...
  3. トレーニングおよび検証データを読み込んでいます...
  4. 使用時間: 0:00:14
  5. トレーニング評価...
  6. エポック: 1
  7. 反復: 0、列車損失: 2.3、列車精度: 10.94%、値損失: 2.3、値精度: 8.92%、時間: 0:00:01 *
  8. 反復: 100、列車損失: 0.88、列車精度: 73.44%、値損失: 1.2、値精度: 68.46%、時間: 0:00:04 *
  9. 反復: 200、列車損失: 0.38、列車精度: 92.19%、値損失: 0.75、値精度: 77.32%、時間: 0:00:07 *
  10. 反復: 300、列車損失: 0.22、列車精度: 92.19%、値損失: 0.46、値精度: 87.08%、時間: 0:00:09 *
  11. 反復: 400、列車損失: 0.24、列車精度: 90.62%、値損失: 0.4、値精度: 88.62%、時間: 0:00:12 *
  12. 反復: 500、列車損失: 0.16、列車精度: 96.88%、値損失: 0.36、値精度: 90.38%、時間: 0:00:15 *
  13. 反復: 600、列車損失: 0.084、列車精度: 96.88%、値損失: 0.35、値精度: 91.36%、時間: 0:00:17 *
  14. 反復: 700、列車損失: 0.21、列車精度: 93.75%、値損失: 0.26、値精度: 92.58%、時間: 0:00:20 *
  15. エポック: 2
  16. 反復: 800、列車損失: 0.07、列車精度: 98.44%、値損失: 0.24、値精度: 94.12%、時間: 0:00:23 *
  17. 反復: 900、列車損失: 0.092、列車精度: 96.88%、値損失: 0.27、値精度: 92.86%、時間: 0:00:25
  18. 反復: 1000、列車損失: 0.17、列車精度: 95.31%、値損失: 0.28、値精度: 92.82%、時間: 0:00:28
  19. 反復: 1100、列車損失: 0.2、列車精度: 93.75%、値損失: 0.23、値精度: 93.26%、時間: 0:00:31
  20. 反復: 1200、列車損失: 0.081、列車精度: 98.44%、値損失: 0.25、値精度: 92.96%、時間: 0:00:33
  21. 反復: 1300、列車損失: 0.052、列車精度: 100.00%、値損失: 0.24、値精度: 93.58%、時間: 0:00:36
  22. 反復: 1400、列車損失: 0.1、列車精度: 95.31%、値損失: 0.22、値精度: 94.12%、時間: 0:00:39
  23. 反復: 1500、列車損失: 0.12、列車精度: 98.44%、値損失: 0.23、値精度: 93.58%、時間: 0:00:41
  24. エポック: 3
  25. 反復: 1600、列車損失: 0.1、列車精度: 96.88%、値損失: 0.26、値精度: 92.34%、時間: 0:00:44
  26. 反復: 1700、列車損失: 0.018、列車精度: 100.00%、値損失: 0.22、値精度: 93.46%、時間: 0:00:47
  27. 反復: 1800、列車損失: 0.036、列車精度: 100.00%、値損失: 0.28、値精度: 92.72%、時間: 0:00:50
  28. 長時間最適化されず自動停止します...

検証セットでの最良の結果は 94.12% で、アルゴリズムはわずか 3 回の反復後に停止しました。

精度と誤差は図に示されています。

テスト

テスト セットをテストするには、python run_cnn.py test を実行します。

  1. CNN モデルを構成しています...
  2. テストデータを読み込んでいます...
  3. テスト中...
  4. テスト損失: 0.14、テスト精度: 96.04%
  5. 精度、再現率 F1 スコア...
  6. 精度再現率 F1スコア サポート
  7.  
  8. スポーツ 0.99 0.99 0.99 1000
  9. 金融 0.96 0.99 0.97 1000
  10. 不動産 1.00 1.00 1.00 1000
  11. ホーム 0.95 0.91 0.93 1000
  12. 教育 0.95 0.89 0.92 1000
  13. テクノロジー 0.94 0.97 0.95 1000
  14. ファッション 0.95 0.97 0.96 1000
  15. 時事 0.94 0.94 0.94 1000
  16. ゲーム 0.97 0.96 0.97 1000
  17. エンターテイメント 0.95 0.98 0.97 1000
  18.  
  19. 平均/ 合計 0.96 0.96 0.96 10000
  20.  
  21. 混同マトリックス...
  22. [[991 0 0 0 2 1 0 4 1 1]
  23. [ 0 992 0 0 2 1 0 5 0 0 ]
  24. [ 0 1 996 0 1 1 0 0 0 1 ]
  25. [ 0 14 0 912 7 15 9 29 3 11 ]
  26. [ 2 9 0 12 892 22 18 21 10 14 ]
  27. [ 0 0 0 10 1 968 4 3 12 2 ]
  28. [ 1 0 0 9 4 4 971 0 2 9]
  29. [ 1 16 0 4 18 12 1 941 1 6 ]
  30. [ 2 4 1 5 4 5 10 1 962 6 ]
  31. [ 1 0 1 6 4 3 5 0 1 979]]
  32. 使用時間: 0:00:05

テスト セットの精度は 96.04% に達し、各カテゴリの精度、再現率、f1 スコアは 0.9 を超えました。

混同行列からも分類効果が非常に優れていることがわかります。

RNN リカレント ニューラル ネットワーク

構成項目

RNN の設定可能なパラメータは、rnn_model.py に以下のように示されています。

  1. クラスTRNNConfig(オブジェクト):
  2. "" "RNN 構成パラメータ" ""  
  3.  
  4. # モデルパラメータ
  5. embedding_dim = 64 # 単語ベクトルの次元
  6. seq_length = 600 # シーケンスの長さ
  7. num_classes = 10 # カテゴリの数
  8. vocab_size = 5000 # 小さな語彙
  9.  
  10. num_layers = 2 # 隠し層の数
  11. hidden_​​dim = 128 # 隠れ層ニューロン
  12. rnn = 'gru' # lstm または gru
  13.  
  14. dropout_keep_prob = 0.8 # ドロップアウト保持率
  15. learning_rate = 1e-3 # 学習率
  16.  
  17. batch_size = 128 # バッチあたりのトレーニングサイズ
  18. num_epochs = 10 # 反復回数の合計
  19.  
  20. print_per_batch = 100 # 数ラウンドごとに結果を出力します
  21. save_per_batch = 10 # テンソルボードに保存するラウンド数

RNN モデル

詳細については、rnn_model.py の実装を参照してください。

一般的な構造は次のとおりです。

トレーニングと検証

この部分のコードは run_cnn.py と非常に似ていますが、モデルといくつかのディレクトリのみを少し変更する必要があります。

トレーニングを開始するには、python run_rnn.py train を実行します。

以前にトレーニングしたことがある場合は、TensorBoard で複数のトレーニング結果が重複しないように、tensorboard/textrnn を削除してください。

  1. RNN モデルを構成しています...
  2. TensorBoardSaver を設定しています...
  3. トレーニングおよび検証データを読み込んでいます...
  4. 使用時間: 0:00:14
  5. トレーニング評価...
  6. エポック: 1
  7. 反復: 0、列車損失: 2.3、列車精度: 8.59%、値損失: 2.3、値精度: 11.96%、時間: 0:00:08 *
  8. 反復: 100、列車損失: 0.95、列車精度: 64.06%、値損失: 1.3、値精度: 53.06%、時間: 0:01:15 *
  9. 反復: 200、列車損失: 0.61、列車精度: 79.69%、値損失: 0.94、値精度: 69.88%、時間: 0:02:22 *
  10. 反復: 300、列車損失: 0.49、列車精度: 85.16%、値損失: 0.63、値精度: 81.44%、時間: 0:03:29 *
  11. エポック: 2
  12. 反復: 400、列車損失: 0.23、列車精度: 92.97%、値損失: 0.6、値精度: 82.86%、時間: 0:04:36 *
  13. 反復: 500、列車損失: 0.27、列車精度: 92.97%、値損失: 0.47、値精度: 86.72%、時間: 0:05:43 *
  14. 反復: 600、列車損失: 0.13、列車精度: 98.44%、値損失: 0.43、値精度: 87.46%、時間: 0:06:50 *
  15. 反復: 700、列車損失: 0.24、列車精度: 91.41%、値損失: 0.46、値精度: 87.12%、時間: 0:07:57
  16. エポック: 3
  17. 反復: 800、列車損失: 0.11、列車精度: 96.09%、値損失: 0.49、値精度: 87.02%、時間: 0:09:03
  18. 反復: 900、列車損失: 0.15、列車精度: 96.09%、値損失: 0.55、値精度: 85.86%、時間: 0:10:10
  19. 反復: 1000、列車損失: 0.17、列車精度: 96.09%、値損失: 0.43、値精度: 89.44%、時間: 0:11:18 *
  20. 反復: 1100、列車損失: 0.25、列車精度: 93.75%、値損失: 0.42、値精度: 88.98%、時間: 0:12:25
  21. エポック: 4
  22. 反復: 1200、列車損失: 0.14、列車精度: 96.09%、値損失: 0.39、値精度: 89.82%、時間: 0:13:32 *
  23. 反復: 1300、列車損失: 0.2、列車精度: 96.09%、値損失: 0.43、値精度: 88.68%、時間: 0:14:38
  24. 反復: 1400、列車損失: 0.012、列車精度: 100.00%、値損失: 0.37、値精度: 90.58%、時間: 0:15:45 *
  25. 反復: 1500、列車損失: 0.15、列車精度: 96.88%、値損失: 0.39、値精度: 90.58%、時間: 0:16:52
  26. エポック: 5
  27. 反復: 1600、列車損失: 0.075、列車精度: 97.66%、値損失: 0.41、値精度: 89.90%、時間: 0:17:59
  28. 反復: 1700、列車損失: 0.042、列車精度: 98.44%、値損失: 0.41、値精度: 90.08%、時間: 0:19:06
  29. 反復: 1800、列車損失: 0.08、列車精度: 97.66%、値損失: 0.38、値精度: 91.36%、時間: 0:20:13 *
  30. 反復: 1900、列車損失: 0.089、列車精度: 98.44%、値損失: 0.39、値精度: 90.18%、時間: 0:21:20
  31. エポック: 6
  32. 反復: 2000、列車損失: 0.092、列車精度: 96.88%、値損失: 0.36、値精度: 91.42%、時間: 0:22:27 *
  33. 反復: 2100、列車損失: 0.062、列車精度: 98.44%、値損失: 0.39、値精度: 90.56%、時間: 0:23:34
  34. 反復: 2200、列車損失: 0.053、列車精度: 98.44%、値損失: 0.39、値精度: 90.02%、時間: 0:24:41
  35. 反復: 2300、列車損失: 0.12、列車精度: 96.09%、値損失: 0.37、値精度: 90.84%、時間: 0:25:48
  36. エポック: 7
  37. 反復: 2400、列車損失: 0.014、列車精度: 100.00%、値損失: 0.41、値精度: 90.38%、時間: 0:26:55
  38. 反復: 2500、列車損失: 0.14、列車精度: 96.88%、値損失: 0.37、値精度: 91.22%、時間: 0:28:01
  39. 反復: 2600、列車損失: 0.11、列車精度: 96.88%、値損失: 0.43、値精度: 89.76%、時間: 0:29:08
  40. 反復: 2700、列車損失: 0.089、列車精度: 97.66%、値損失: 0.37、値精度: 91.18%、時間: 0:30:15
  41. エポック: 8
  42. 反復: 2800、列車損失: 0.0081、列車精度: 100.00%、値損失: 0.44、値精度: 90.66%、時間: 0:31:22
  43. 反復: 2900、列車損失: 0.017、列車精度: 100.00%、値損失: 0.44、値精度: 89.62%、時間: 0:32:29
  44. 反復: 3000、列車損失: 0.061、列車精度: 96.88%、値損失: 0.43、値精度: 90.04%、時間: 0:33:36
  45. 長時間最適化されず自動停止します...

検証セットでの最高の結果は 91.42% で、8 ラウンドの反復後に停止しました。速度は CNN よりもはるかに遅いです。

精度と誤差は図に示されています。

テスト

python run_rnn.py test を実行して、テスト セットでテストを実行します。

  1. テスト中...
  2. テスト損失: 0.21、テスト精度: 94.22%
  3. 精度、再現率 F1 スコア...
  4. 精度再現率 F1スコア サポート
  5.  
  6. スポーツ 0.99 0.99 0.99 1000
  7. 金融 0.91 0.99 0.95 1000
  8. 不動産 1.00 1.00 1.00 1000
  9. ホーム 0.97 0.73 0.83 1000
  10. 教育 0.91 0.92 0.91 1000
  11. テクノロジー 0.93 0.96 0.94 1000
  12. ファッション 0.89 0.97 0.93 1000
  13. 時事 0.93 0.93 0.93 1000
  14. ゲーム 0.95 0.97 0.96 1000
  15. エンターテイメント 0.97 0.96 0.97 1000
  16.  
  17. 平均/ 合計 0.94 0.94 0.94 10000
  18.  
  19. 混同マトリックス...
  20. [[988 0 0 0 4 0 2 0 5 1]
  21. [ 0 9 9 0 1 1 1 1 0 6 0 0 ]
  22. [ 0 2 996 1 1 0 0 0 0 0 ]
  23. [ 2 71 1 731 51 20 88 28 3 5 ]
  24. [ 1 3 0 7 918 23 4 31 9 4 ]
  25. [ 1 3 0 3 0 964 3 5 21 0 ]
  26. [ 1 0 1 7 1 3 972 0 6 9]
  27. [ 0 16 0 0 22 26 0 931 2 3]
  28. [ 2 3 0 0 2 2 12 0 972 7 ]
  29. [ 0 3 1 1 7 3 11 5 9 960]]
  30. 使用時間: 0:00:33

テスト セットの精度は 94.22% に達し、ホーム カテゴリを除く各カテゴリの精度、再現率、f1 スコアは 0.9 を超えました。

混同行列から、分類効果が非常に優れていることがわかります。

2 つのモデルを比較すると、家庭用家具の分類のパフォーマンスを除いて、他のカテゴリでの RNN のパフォーマンスは CNN とそれほど変わらないことがわかります。

パラメータをさらに調整することで、より良い結果を得ることもできます。

<<:  ディープラーニングを使用してNBAの試合結果を予測する

>>:  ウェブデザインに人工知能を活用する10の方法

ブログ    

推薦する

AIとRPA:両者の連携方法と、ビジネスに両方が必要な理由

ゴールドマン・サックスのレポートによると、AI は世界の労働生産性を年間 1% 以上向上させ、202...

...

IDC: 高速サーバー市場は2023年上半期に31億ドルに達し、GPUサーバーが依然として主流となる

10月9日、IDCコンサルティングの公式WeChatアカウントによると、IDCは本日「中国半期加速コ...

テスラ、マイクロソフト、グーグル、アップルなどを含む1,000件以上の「AIロールオーバー」事件が発生しています。

この記事はLeiphone.comから転載したものです。転載する場合は、Leiphone.com公式...

TensorFlow を通じてディープラーニング アルゴリズムを実装し、企業の実務に適用する方法

この記事は、Caiyun Technology のトップ ビッグ データ サイエンティストである Z...

エッジAI: ディープラーニングをより効率的にする方法

人工知能 (AI) は今日の産業情勢を変えています。 エンタープライズ ソフトウェアから機械の自動化...

AIが伝染病と闘う: 時折の恥ずかしさの裏に究極の防壁が現れる

人類と新型コロナウイルスとの戦いは今も続いていますが、この間、さまざまな「人工知能+」アプリケーショ...

AIシステムが初めて真の自律プログラミングを実現:遺伝的アルゴリズムを使用して初心者プログラマーを上回る

編集者注:この記事は、WeChatのパブリックアカウント「New Intelligence」(ID:...

...

...

分散システム設計のための負荷分散アルゴリズム

概要分散システムの設計では、通常、サービスはクラスターに展開されます。クラスター内の複数のノードが同...

...

生成AIは高価すぎるため、マイクロソフトやグーグルのような大手テクノロジー企業でさえも導入できない

テクノロジー企業は、AI がビジネスメモを書いたり、コンピューターコードを作成したりできると宣伝して...

ついに誰かがROSロボットオペレーティングシステムをわかりやすく説明しました

この記事はWeChatの公開アカウント「Big Data DT」から転載したもので、著者はZhang...