少数ショット学習における SetFit によるテキスト分類

少数ショット学習における SetFit によるテキスト分類

翻訳者 |陳俊

レビュー | Chonglou

この記事では、「少量学習」の概念を紹介しテキスト分類で広く使用されているSetFit法に焦点を当てます。

従来の機械学習 (ML)

教師あり機械学習では、正確な予測を行う能力を磨くために、大規模なデータセットを使用してモデルをトレーニングしますトレーニングプロセスが完了したら、テストデータを使用してモデルの予測結果を取得できます。しかし、この従来の教師あり学習アプローチには、エラーのない大量のトレーニング データ セットが必要になるという重大な欠点があります。しかし、すべての分野でこのようなエラーのないデータセットを提供できるわけではありません。そこで、「少数サンプル学習」という概念が生まれました。

Sentence Transformer の微調整( SetFit )について詳しく説明する前に、自然言語処理 ( NLP )の重要な側面である「少量学習」について簡単に確認しておく必要があります

少数ショット学習

少数ショット学習とは、限られたトレーニング データ セットを使用してモデルをトレーニングすることを意味します。モデルは、サポート セットと呼ばれるこれらの小さなセットから知識を獲得できます。このタイプの学習は、トレーニング データ内の類似点と相違点を識別するために、少数ショット モデルを学習させることを目的としています。たとえば、モデルに特定の画像を猫か犬に分類するように指示するのではなく、さまざまな動物の共通点と相違点を理解するように指示します。ご覧のとおり、このアプローチは入力データの類似点と相違点を理解することに重点を置いています。そのため、メタ学習や学習のための学習も呼ばれます

少数ショット学習のサポート セットは、 k方向nショット学習とも呼ばれることに留意してください。ここで、「 k 」はサポートセット内のカテゴリの数を表します。たとえば、バイナリ分類では、 k2に等しくなります。一方、「 n 」はサポート セット内の各クラスで利用可能なサンプルの数を表します。たとえば、陽性クラスに10 個のデータ ポイントがあり陰性クラスに10 個のデータ ポイントがある場合、 n は10になります。要約すると、このサポート セットは双方向の10ショット学習として説明できます。

少数ショット学習の基本を理解したので、 SetFit の使用方法を簡単に学習し、それを e コマース データセットのテキスト分類に適用してみましょう。

SetFitアーキテクチャ

Hugging Faceと Intel Labs のチーム共同で開発したSetFit は少数ショットの写真分類用のオープンソース ツールです。 SetFit に関する包括的な情報は、プロジェクト リポジトリ リンク (https://github.com/huggingface/setfit?ref=hackernoon.com) で参照できます。

出力の場合、 SetFit は顧客レビュー ( CR ) 感情分析データセットからカテゴリごとに 8 つの注釈付き例のみを使用します。結果は、3,000 個の例の完全なトレーニング セットで微調整されたRoBERTa Largeの結果に匹敵しますサイズの点では、わずかに最適化されたRoBERTaモデルはSetFitモデルの 3 倍の大きさであることは強調する価値があります。次の図は SetFit アーキテクチャを示しています。

画像ソース: https://www.sbert.net/docs/training/overview.html?ref=hackernoon.com

SetFitによる高速学習

SetFitのトレーニング速度は非常に速く、効率的です。 GPT-3T-FEWなどの大型モデル比較しても、その性能は非常に競争力があります次の図を参照してください。

SetFitとT-Few 3Bモデルの比較

下の図に示すように、 SetFit は、Few-Shot 学習においてRoBERTaよりも優れています

SetFit と RoBERT の比較、画像ソース: https://huggingface.co/blog/setfit?ref=hackernoon.com

データセット

以下では、書籍、アパレルとアクセサリー、電子機器、家庭用品の 4 つの異なるカテゴリで構成される独自の e コマース データセットを使用します。このデータセットの主な目的は、電子商取引 Web サイトの製品説明を指定されたラベルに分類することです。

少数ショットのトレーニング アプローチを容易にするために、4 つのカテゴリのそれぞれから 8 つのサンプルを選択し、合計32 個トレーニング サンプルを作成します。残りのサンプルはテスト用に保管されます。簡単に言うと、ここで使用するサポート セットは4 8ショットの学習です。次の図は、カスタム e コマース データセットの例を示しています。

カスタム e コマース データセット サンプル

テキストデータをさまざまなベクトル埋め込みに変換するために、all-mpnet-base-v2というSentence Transformersの事前トレーニング済みモデルを採用しています。このモデルは、入力テキストに対して768次元のベクトル埋め込みを生成できます

以下のコマンドに示すように、 conda環境 (オープンソースのパッケージ管理システムおよび環境管理システム)に必要なパッケージをインストールして、 SetFitの実装を開始します。

 !pip3 install SetFit !pip3 install sklearn !pip3 install transformers !pip3 install sentence-transformers

パッケージをインストールしたら、次のコードを使用してデータセットを読み込むことができます。

 from datasets import load_dataset dataset = load_dataset('csv', data_files={ "train": 'E_Commerce_Dataset_Train.csv', "test": 'E_Commerce_Dataset_Test.csv' })

トレーニングサンプルとテストサンプルの数を確認するには、下の図を参照してください。

トレーニングおよびテストデータ

テキスト ラベルをエンコードされたラベルに変換するには、 sklearnパッケージLabelEncoderを使用します

 from sklearn.preprocessing import LabelEncoder le = LabelEncoder()

LabelEncoderを使用して、トレーニング データセットとテスト データセットをエンコードし、エンコードされたラベルをデータセットの「ラベル」列に追加します。次のコードを参照してください:

 Encoded_Product = le.fit_transform(dataset["train"]['Label']) dataset["train"] = dataset["train"].remove_columns("Label").add_column("Label", Encoded_Product).cast(dataset["train"].features) Encoded_Product = le.fit_transform(dataset["test"]['Label']) dataset["test"] = dataset["test"].remove_columns("Label").add_column("Label", Encoded_Product).cast(dataset["test"].features)

次に、 SetFitモデルと sentence-transformers モデルを初期化します

 from setfit import SetFitModel, SetFitTrainer from sentence_transformers.losses import CosineSimilarityLoss model_id = "sentence-transformers/all-mpnet-base-v2" model = SetFitModel.from_pretrained(model_id) trainer = SetFitTrainer( model=model, train_dataset=dataset["train"], eval_dataset=dataset["test"], loss_class=CosineSimilarityLoss, metric="accuracy", batch_size=64, num_iteratinotallow=20, num_epochs=2, column_mapping={"Text": "text", "Label": "label"} )

両方のモデルを初期化したら、トレーニング手順を呼び出すことができます。

 trainer.train()

2 回のトレーニング エポックを完了したらeval_datasetでトレーニング済みモデルを評価します

 trainer.evaluate()

テストの結果、トレーニング済みモデルの最高精度は87.5%でした 87.5%という精度は高くありません、結局のところ、私たちのモデルはトレーニングに32 個のサンプルしか使用しませんでした。つまり、データセットのサイズが限られていることを考慮すると、テスト データセットで87.5%の精度を達成することは、実はかなり印象的です。

さらに、 SetFit はトレーニング済みのモデルをローカル ストレージに保存し、後でディスクからロードして将来の予測に使用することもできます。

 trainer.model._save_pretrained(save_directory="SetFit_ECommerce_Output/") model=SetFitModel.from_pretrained("SetFit_ECommerce_Output/", local_files_notallow=True)

次のコードは、新しいデータに基づく予測結果を示しています。

 input = ["Campus Sutra Men's Sports Jersey T-Shirt Cool-Gear: Our Proprietary Moisture Management technology. Helps to absorb and evaporate sweat quickly. Keeps you Cool & Dry. Ultra-Fresh: Fabrics treated with Ultra-Fresh Antimicrobial Technology. Ultra-Fresh is a trademark of (TRA) Inc, Ontario, Canada. Keeps you odour free."] output = model(input)

予測出力は 1 であり、ラベルのLabelEncoded値は「衣類とアクセサリー」であることがわかります従来の AI モデルでは、安定したレベルの出力を実現するために、大量のトレーニング リソース (時間とデータを含む) が必要になります。それらと比較すると、私たちのモデルは正確かつ効率的です。

この時点で、基本的には「少量学習」の概念と、テキスト分類などのアプリケーションでSetFit を使用する方法を習得できたと思います。もちろん、より深い理解を得るためには、実際のシナリオを選択し、データセットを作成し、対応するコードを記述し、プロセスをゼロショット学習とワンショット学習に拡張することを強くお勧めします。

翻訳者紹介

51CTO コミュニティの編集者である Julian Chen 氏は、IT プロジェクトの実装で 10 年以上の経験があります。社内外のリソースとリスクの管理に長けており、ネットワークと情報セキュリティの知識と経験の普及に重点を置いています。

原題:テキスト分類のための SetFit による Few-Shot 学習の習得、著者: Shyam Ganesh S)


<<: 

>>:  マルチモーダル生成AIの深掘り

ブログ    

推薦する

AIと機械学習、5G、IoTは2021年に重要な技術となる

IEEEは、米国、英国、中国、インド、ブラジルの最高情報責任者(CIO)と最高技術責任者(CTO)を...

...

IDC: 生成型AIへの世界的な支出は2027年に1,430億ドルに達する

IDC は最近、世界中の企業による生成 AI サービス、ソフトウェア、インフラストラクチャへの支出が...

1000ステップ未満の微調整で、LLaMAコンテキストは32Kに拡張されました。これは、Tian Yuandongチームの最新の研究です。

誰もが独自の大規模モデルをアップグレードして反復し続けるにつれて、コンテキスト ウィンドウを処理する...

CNNとRNNの比較と組み合わせ

CNNとRNNはディープラーニングのほぼ半分を占めているので、この記事ではCNN+RNNとさまざまな...

人工知能に関するあまり知られていない3つの事実!古代中国にロボットは存在したのでしょうか?

時代の発展とテクノロジーの進歩に伴い、人工知能の分野も革新を繰り返しています。しかし、この神秘的な業...

アルゴリズムがバグをキャッチ:ディープラーニングとコンピュータービジョンが昆虫学を変える

[[390223]]導入コンピュータ アルゴリズムは、ソフトウェア プログラムのバグを検出するのに役...

フェデレーテッドラーニングも安全ではないのでしょうか? Nvidiaの研究は「プライバシーフリー」データを使用して元の画像を直接再構築します

フェデレーテッド ラーニングは、データがローカルの場所から出ないようにするプライバシー保護戦略により...

...

AI は無限であり、あなたの声によって動かされます。マイクロソフトは慈善団体や業界のパートナーと協力し、テクノロジーで愛を育むお手伝いをします。

12月2日、マイクロソフトと周迅のAI音声紅丹丹慈善プロジェクトの発起人である魯音源文化伝承社は、...

Redis のソースコードを読んで、キャッシュ除去アルゴリズム W-TinyLFU を学びましょう

[[433812]]この記事は董澤潤氏が執筆したWeChat公開アカウント「董澤潤の技術ノート」から...

...

プログラミング啓蒙ロボット、本物の人形か、それとも本当の物語か?

[[255856]]画像ソース @Visual China人工知能の普及により、中国の親たちの不安...

移動ロボットとは何ですか?また、どのように分類されますか?

移動ロボットは、作業を自動的に行う機械装置です。センサー、遠隔操作者、自動制御移動搬送機などから構成...