TENSORFLOW に基づく中国語テキスト分類のための CNN と RNN

TENSORFLOW に基づく中国語テキスト分類のための CNN と RNN

[[211015]]

現在、TensorFlow のメジャーバージョンは 1.3 にアップグレードされ、多くのネットワーク層のより高度なカプセル化と実装が実現され、さらに Keras などの優れた高レベルフレームワークも統合され、使いやすさが大幅に向上しました。初期の基盤コードと比較すると、今日の実装はより簡潔でエレガントになっています。

この記事は、中国語データセットでの TensorFlow の簡略化された実装です。文字レベルの CNN と RNN を使用して中国語のテキストを分類し、良好な結果を達成しています。

データセット

この記事では、Tsinghua NLP Group が提供する THUCNews ニューステキスト分類データセットのサブセットを使用します (元のデータセットには約 740,000 件のドキュメントが含まれており、トレーニングには長い時間がかかります)。 THUCTC: 効率的な中国語テキスト分類ツールキットからデータセットをダウンロードしてください。データ プロバイダーのオープン ソース契約に従ってください。

このトレーニングでは 10 個のカテゴリが使用され、各カテゴリには 6,500 個のデータ ポイントがありました。

カテゴリーは次のとおりです。

スポーツ、金融、不動産、住宅、教育、テクノロジー、ファッション、時事問題、ゲーム、エンターテイメント

このサブセットはここからダウンロードできます: リンク: http://pan.baidu.com/s/1bpq9Eub パスワード: ycyw

データセットは次のように分割されます。

  • トレーニングセット: 5000*10
  • 検証セット: 500*10
  • テストセット: 1000*10

元のデータセットからサブセットを生成するプロセスについては、ヘルパーの下の 2 つのスクリプトを参照してください。このうち、copy_data.sh は各カテゴリから 6500 個のファイルをコピーするために使用され、cnews_group.py は複数のファイルを 1 つのファイルに統合するために使用されています。ファイルを実行すると、次の 3 つのデータ ファイルが取得されます。

  • cnews.train.txt: トレーニング セット (50,000 項目)
  • cnews.val.txt: 検証セット (5000 項目)
  • cnews.test.txt: テストセット (10000 項目)

前処理

data/cnews_loader.py はデータ前処理ファイルです。

  • read_file(): ファイルデータを読み取ります。
  • build_vocab(): 文字レベルの表現を使用して語彙を構築します。この関数は語彙を保存し、毎回繰り返し処理されないようにしています。
  • read_vocab(): 前のステップで保存された語彙を読み取り、それを {word:id} 表現に変換します。
  • read_category(): カテゴリディレクトリを修正し、{category: id} 表現に変換します。
  • to_words(): id で表されるデータをテキストに変換します。
  • preocess_file(): データセットをテキストから固定長の ID シーケンス表現に変換します。
  • batch_iter(): ニューラル ネットワークのトレーニング用にシャッフルされたデータ バッチを準備します。

データの前処理後のデータ形式は次のようになります。

CNN 畳み込みニューラルネットワーク

構成項目

CNN の設定可能なパラメータは、以下の cnn_model.py に示されています。

  1. クラス TCNNConfig(オブジェクト):
  2. CNN 構成パラメータ  
  3.  
  4. embedding_dim = 64 # 単語ベクトルの次元
  5. seq_length = 600 # シーケンスの長さ
  6. num_classes = 10 # カテゴリの数
  7. num_filters = 128 # 畳み込みカーネルの数
  8. kernel_size = 5 # 畳み込みカーネルのサイズ
  9. vocab_size = 5000 # 小さな語彙
  10.  
  11. hidden_​​dim = 128 #完全結合層ニューロン
  12.  
  13. dropout_keep_prob = 0.5 # ドロップアウト保持率
  14. learning_rate = 1e-3 # 学習率
  15.  
  16. batch_size = 64 # バッチあたりのトレーニングサイズ
  17. num_epochs = 10 # 反復回数の合計
  18.  
  19. print_per_batch = 100 # 数ラウンドごとに結果を出力します
  20. save_per_batch = 10 # テンソルボードに保存するラウンド数

CNNモデル

詳細については、cnn_model.py の実装を参照してください。

一般的な構造は次のとおりです。

トレーニングと検証

トレーニングを開始するには、python run_cnn.py train を実行します。

以前にトレーニングしたことがある場合は、TensorBoard で複数のトレーニング結果が重複しないように、tensorboard/textcnn を削除してください。

  1. CNN モデルを構成しています...
  2. TensorBoardSaver を設定しています...
  3. トレーニングおよび検証データを読み込んでいます...
  4. 使用時間: 0:00:14
  5. トレーニング評価...
  6. エポック: 1
  7. 反復: 0、列車損失: 2.3、列車精度: 10.94%、値損失: 2.3、値精度: 8.92%、時間: 0:00:01 *
  8. 反復: 100、列車損失: 0.88、列車精度: 73.44%、値損失: 1.2、値精度: 68.46%、時間: 0:00:04 *
  9. 反復: 200、列車損失: 0.38、列車精度: 92.19%、値損失: 0.75、値精度: 77.32%、時間: 0:00:07 *
  10. 反復: 300、列車損失: 0.22、列車精度: 92.19%、値損失: 0.46、値精度: 87.08%、時間: 0:00:09 *
  11. 反復: 400、列車損失: 0.24、列車精度: 90.62%、値損失: 0.4、値精度: 88.62%、時間: 0:00:12 *
  12. 反復: 500、列車損失: 0.16、列車精度: 96.88%、値損失: 0.36、値精度: 90.38%、時間: 0:00:15 *
  13. 反復: 600、列車損失: 0.084、列車精度: 96.88%、値損失: 0.35、値精度: 91.36%、時間: 0:00:17 *
  14. 反復: 700、列車損失: 0.21、列車精度: 93.75%、値損失: 0.26、値精度: 92.58%、時間: 0:00:20 *
  15. エポック: 2
  16. 反復: 800、列車損失: 0.07、列車精度: 98.44%、値損失: 0.24、値精度: 94.12%、時間: 0:00:23 *
  17. 反復: 900、列車損失: 0.092、列車精度: 96.88%、値損失: 0.27、値精度: 92.86%、時間: 0:00:25
  18. 反復: 1000、列車損失: 0.17、列車精度: 95.31%、値損失: 0.28、値精度: 92.82%、時間: 0:00:28
  19. 反復: 1100、列車損失: 0.2、列車精度: 93.75%、値損失: 0.23、値精度: 93.26%、時間: 0:00:31
  20. 反復: 1200、列車損失: 0.081、列車精度: 98.44%、値損失: 0.25、値精度: 92.96%、時間: 0:00:33
  21. 反復: 1300、列車損失: 0.052、列車精度: 100.00%、値損失: 0.24、値精度: 93.58%、時間: 0:00:36
  22. 反復: 1400、列車損失: 0.1、列車精度: 95.31%、値損失: 0.22、値精度: 94.12%、時間: 0:00:39
  23. 反復: 1500、列車損失: 0.12、列車精度: 98.44%、値損失: 0.23、値精度: 93.58%、時間: 0:00:41
  24. エポック: 3
  25. 反復: 1600、列車損失: 0.1、列車精度: 96.88%、値損失: 0.26、値精度: 92.34%、時間: 0:00:44
  26. 反復: 1700、列車損失: 0.018、列車精度: 100.00%、値損失: 0.22、値精度: 93.46%、時間: 0:00:47
  27. 反復: 1800、列車損失: 0.036、列車精度: 100.00%、値損失: 0.28、値精度: 92.72%、時間: 0:00:50
  28. 長時間最適化されず自動停止します...

検証セットでの最良の結果は 94.12% で、アルゴリズムはわずか 3 回の反復後に停止しました。

精度と誤差は図に示されています。

テスト

テスト セットをテストするには、python run_cnn.py test を実行します。

  1. CNN モデルを構成しています...
  2. テストデータを読み込んでいます...
  3. テスト中...
  4. テスト損失: 0.14、テスト精度: 96.04%
  5. 精度、再現率 F1 スコア...
  6. 精度再現率 F1スコア サポート
  7.  
  8. スポーツ 0.99 0.99 0.99 1000
  9. 金融 0.96 0.99 0.97 1000
  10. 不動産 1.00 1.00 1.00 1000
  11. ホーム 0.95 0.91 0.93 1000
  12. 教育 0.95 0.89 0.92 1000
  13. テクノロジー 0.94 0.97 0.95 1000
  14. ファッション 0.95 0.97 0.96 1000
  15. 時事 0.94 0.94 0.94 1000
  16. ゲーム 0.97 0.96 0.97 1000
  17. エンターテイメント 0.95 0.98 0.97 1000
  18.  
  19. 平均/ 合計 0.96 0.96 0.96 10000
  20.  
  21. 混同マトリックス...
  22. [[991 0 0 0 2 1 0 4 1 1]
  23. [ 0 992 0 0 2 1 0 5 0 0 ]
  24. [ 0 1 996 0 1 1 0 0 0 1 ]
  25. [ 0 14 0 912 7 15 9 29 3 11 ]
  26. [ 2 9 0 12 892 22 18 21 10 14 ]
  27. [ 0 0 0 10 1 968 4 3 12 2 ]
  28. [ 1 0 0 9 4 4 971 0 2 9]
  29. [ 1 16 0 4 18 12 1 941 1 6 ]
  30. [ 2 4 1 5 4 5 10 1 962 6 ]
  31. [ 1 0 1 6 4 3 5 0 1 979]]
  32. 使用時間: 0:00:05

テスト セットの精度は 96.04% に達し、各カテゴリの精度、再現率、f1 スコアは 0.9 を超えました。

混同行列からも分類効果が非常に優れていることがわかります。

RNN リカレント ニューラル ネットワーク

構成項目

RNN の設定可能なパラメータは、rnn_model.py に以下のように示されています。

  1. クラスTRNNConfig(オブジェクト):
  2. "" "RNN 構成パラメータ" ""  
  3.  
  4. # モデルパラメータ
  5. embedding_dim = 64 # 単語ベクトルの次元
  6. seq_length = 600 # シーケンスの長さ
  7. num_classes = 10 # カテゴリの数
  8. vocab_size = 5000 # 小さな語彙
  9.  
  10. num_layers = 2 # 隠し層の数
  11. hidden_​​dim = 128 # 隠れ層ニューロン
  12. rnn = 'gru' # lstm または gru
  13.  
  14. dropout_keep_prob = 0.8 # ドロップアウト保持率
  15. learning_rate = 1e-3 # 学習率
  16.  
  17. batch_size = 128 # バッチあたりのトレーニングサイズ
  18. num_epochs = 10 # 反復回数の合計
  19.  
  20. print_per_batch = 100 # 数ラウンドごとに結果を出力します
  21. save_per_batch = 10 # テンソルボードに保存するラウンド数

RNN モデル

詳細については、rnn_model.py の実装を参照してください。

一般的な構造は次のとおりです。

トレーニングと検証

この部分のコードは run_cnn.py と非常に似ていますが、モデルといくつかのディレクトリのみを少し変更する必要があります。

トレーニングを開始するには、python run_rnn.py train を実行します。

以前にトレーニングしたことがある場合は、TensorBoard で複数のトレーニング結果が重複しないように、tensorboard/textrnn を削除してください。

  1. RNN モデルを構成しています...
  2. TensorBoardSaver を設定しています...
  3. トレーニングおよび検証データを読み込んでいます...
  4. 使用時間: 0:00:14
  5. トレーニング評価...
  6. エポック: 1
  7. 反復: 0、列車損失: 2.3、列車精度: 8.59%、値損失: 2.3、値精度: 11.96%、時間: 0:00:08 *
  8. 反復: 100、列車損失: 0.95、列車精度: 64.06%、値損失: 1.3、値精度: 53.06%、時間: 0:01:15 *
  9. 反復: 200、列車損失: 0.61、列車精度: 79.69%、値損失: 0.94、値精度: 69.88%、時間: 0:02:22 *
  10. 反復: 300、列車損失: 0.49、列車精度: 85.16%、値損失: 0.63、値精度: 81.44%、時間: 0:03:29 *
  11. エポック: 2
  12. 反復: 400、列車損失: 0.23、列車精度: 92.97%、値損失: 0.6、値精度: 82.86%、時間: 0:04:36 *
  13. 反復: 500、列車損失: 0.27、列車精度: 92.97%、値損失: 0.47、値精度: 86.72%、時間: 0:05:43 *
  14. 反復: 600、列車損失: 0.13、列車精度: 98.44%、値損失: 0.43、値精度: 87.46%、時間: 0:06:50 *
  15. 反復: 700、列車損失: 0.24、列車精度: 91.41%、値損失: 0.46、値精度: 87.12%、時間: 0:07:57
  16. エポック: 3
  17. 反復: 800、列車損失: 0.11、列車精度: 96.09%、値損失: 0.49、値精度: 87.02%、時間: 0:09:03
  18. 反復: 900、列車損失: 0.15、列車精度: 96.09%、値損失: 0.55、値精度: 85.86%、時間: 0:10:10
  19. 反復: 1000、列車損失: 0.17、列車精度: 96.09%、値損失: 0.43、値精度: 89.44%、時間: 0:11:18 *
  20. 反復: 1100、列車損失: 0.25、列車精度: 93.75%、値損失: 0.42、値精度: 88.98%、時間: 0:12:25
  21. エポック: 4
  22. 反復: 1200、列車損失: 0.14、列車精度: 96.09%、値損失: 0.39、値精度: 89.82%、時間: 0:13:32 *
  23. 反復: 1300、列車損失: 0.2、列車精度: 96.09%、値損失: 0.43、値精度: 88.68%、時間: 0:14:38
  24. 反復: 1400、列車損失: 0.012、列車精度: 100.00%、値損失: 0.37、値精度: 90.58%、時間: 0:15:45 *
  25. 反復: 1500、列車損失: 0.15、列車精度: 96.88%、値損失: 0.39、値精度: 90.58%、時間: 0:16:52
  26. エポック: 5
  27. 反復: 1600、列車損失: 0.075、列車精度: 97.66%、値損失: 0.41、値精度: 89.90%、時間: 0:17:59
  28. 反復: 1700、列車損失: 0.042、列車精度: 98.44%、値損失: 0.41、値精度: 90.08%、時間: 0:19:06
  29. 反復: 1800、列車損失: 0.08、列車精度: 97.66%、値損失: 0.38、値精度: 91.36%、時間: 0:20:13 *
  30. 反復: 1900、列車損失: 0.089、列車精度: 98.44%、値損失: 0.39、値精度: 90.18%、時間: 0:21:20
  31. エポック: 6
  32. 反復: 2000、列車損失: 0.092、列車精度: 96.88%、値損失: 0.36、値精度: 91.42%、時間: 0:22:27 *
  33. 反復: 2100、列車損失: 0.062、列車精度: 98.44%、値損失: 0.39、値精度: 90.56%、時間: 0:23:34
  34. 反復: 2200、列車損失: 0.053、列車精度: 98.44%、値損失: 0.39、値精度: 90.02%、時間: 0:24:41
  35. 反復: 2300、列車損失: 0.12、列車精度: 96.09%、値損失: 0.37、値精度: 90.84%、時間: 0:25:48
  36. エポック: 7
  37. 反復: 2400、列車損失: 0.014、列車精度: 100.00%、値損失: 0.41、値精度: 90.38%、時間: 0:26:55
  38. 反復: 2500、列車損失: 0.14、列車精度: 96.88%、値損失: 0.37、値精度: 91.22%、時間: 0:28:01
  39. 反復: 2600、列車損失: 0.11、列車精度: 96.88%、値損失: 0.43、値精度: 89.76%、時間: 0:29:08
  40. 反復: 2700、列車損失: 0.089、列車精度: 97.66%、値損失: 0.37、値精度: 91.18%、時間: 0:30:15
  41. エポック: 8
  42. 反復: 2800、列車損失: 0.0081、列車精度: 100.00%、値損失: 0.44、値精度: 90.66%、時間: 0:31:22
  43. 反復: 2900、列車損失: 0.017、列車精度: 100.00%、値損失: 0.44、値精度: 89.62%、時間: 0:32:29
  44. 反復: 3000、列車損失: 0.061、列車精度: 96.88%、値損失: 0.43、値精度: 90.04%、時間: 0:33:36
  45. 長時間最適化されず自動停止します...

検証セットでの最高の結果は 91.42% で、8 ラウンドの反復後に停止しました。速度は CNN よりもはるかに遅いです。

精度と誤差は図に示されています。

テスト

python run_rnn.py test を実行して、テスト セットでテストを実行します。

  1. テスト中...
  2. テスト損失: 0.21、テスト精度: 94.22%
  3. 精度、再現率 F1 スコア...
  4. 精度再現率 F1スコア サポート
  5.  
  6. スポーツ 0.99 0.99 0.99 1000
  7. 金融 0.91 0.99 0.95 1000
  8. 不動産 1.00 1.00 1.00 1000
  9. ホーム 0.97 0.73 0.83 1000
  10. 教育 0.91 0.92 0.91 1000
  11. テクノロジー 0.93 0.96 0.94 1000
  12. ファッション 0.89 0.97 0.93 1000
  13. 時事 0.93 0.93 0.93 1000
  14. ゲーム 0.95 0.97 0.96 1000
  15. エンターテイメント 0.97 0.96 0.97 1000
  16.  
  17. 平均/ 合計 0.94 0.94 0.94 10000
  18.  
  19. 混同マトリックス...
  20. [[988 0 0 0 4 0 2 0 5 1]
  21. [ 0 9 9 0 1 1 1 1 0 6 0 0 ]
  22. [ 0 2 996 1 1 0 0 0 0 0 ]
  23. [ 2 71 1 731 51 20 88 28 3 5 ]
  24. [ 1 3 0 7 918 23 4 31 9 4 ]
  25. [ 1 3 0 3 0 964 3 5 21 0 ]
  26. [ 1 0 1 7 1 3 972 0 6 9]
  27. [ 0 16 0 0 22 26 0 931 2 3]
  28. [ 2 3 0 0 2 2 12 0 972 7 ]
  29. [ 0 3 1 1 7 3 11 5 9 960]]
  30. 使用時間: 0:00:33

テスト セットの精度は 94.22% に達し、ホーム カテゴリを除く各カテゴリの精度、再現率、f1 スコアは 0.9 を超えました。

混同行列から、分類効果が非常に優れていることがわかります。

2 つのモデルを比較すると、家庭用家具の分類のパフォーマンスを除いて、他のカテゴリでの RNN のパフォーマンスは CNN とそれほど変わらないことがわかります。

パラメータをさらに調整することで、より良い結果を得ることもできます。

<<:  ディープラーニングを使用してNBAの試合結果を予測する

>>:  ウェブデザインに人工知能を活用する10の方法

ブログ    
ブログ    
ブログ    
ブログ    

推薦する

IBM、AI導入を加速しAIの透明性を向上するオープンプラットフォームを発表

[[247168]]最近、IBM は、AI アプリケーションがどのように意思決定を行うかを説明する際...

大規模言語モデルの量子化手法の比較: GPTQ、GGUF、AWQ

大規模言語モデル (LLM) は過去 1 年間で急速に進化しており、この記事では (量子化) へのい...

...

AI とデジタル病理学は医療通信をどのように改善できるのでしょうか?

人工知能 (AI) とデジタル病理学は、特に通信分野において医療業界に革命をもたらすと期待されていま...

...

AIが医療をどう変えるか リアルタイムのデータ分析は医療にとって重要

科学者たちは、人工知能が多くの分野で人間を日常的な作業から解放できると信じています。ヘルスケアはこう...

集団雷雨!自動化された攻撃により、主要な言語モデルを1分で脱獄できる

大規模な言語モデル アプリケーションが直面する 2 つの主要なセキュリティ上の脅威は、トレーニング ...

...

顔スキャン決済は問題多し、アマゾンは「手のひら」スキャンを選択し無人スーパーで正式に商品化

さあ、手払いについて学んでみましょう〜アマゾンはこのほど、自社が開発した手のひら認識技術「Amazo...

モビリティの未来:スマート、持続可能、効率的

[[348989]] COVID-19のロックダウンの緩和により多くの社会的要因が浮き彫りになりまし...

人工知能がファッションデザインと生産を変革

人工知能とロボット工学がファッション業界に変化をもたらしています。市場分析からカスタムデザイン、無駄...

...

VB.NET バブルソートアルゴリズムの詳細な説明

VB.NET を学習する場合、中国語の情報が非常に少なく、大多数のプログラマーのニーズを満たすのが難...

人工知能は裁判所によって特許発明者とみなされるでしょうか?

人工知能(AI)は、新薬の発見から新しい数学の問題の解決まで、あらゆることを人間が行うのに役立ってお...

チャットボットのさまざまな種類について学ぶ

チャットボットの種類は、提供されるさまざまな機能と応答に使用する方法によって決まります。チャットボッ...