カテゴリー【Python】
【Python】Spacyを使用して文章から出発地と目的地を抜き出す
POSTED BY
2024-07-20
2024-07-20
対話ロボットを作成するとき、ユーザーの入力文中に含まれる特定の単語を抜き出したい。
大久保駅から吉祥寺駅までの行き方を教えてください。
という文章を解析して、
出発地: 大久保駅, 目的地: 吉祥寺駅
と抽出する。これには自然言語処理ライブラリがいくつかあるがSpacyというものをトレーニングして使ってみる。
環境の整備
mkdir spacy_test cd spacy_test python -m venv source venv/bin/activate pip install spacy ginza ja-ginza python -m spacy download ja_core_news_sm
日本語のSpacyモデルをダウンロードしている。
モデルを作成してトレーニング
Python | spacy_train.py | GitHub Source |
import spacy from spacy.training import Example from spacy.tokens import DocBin # サンプルデータセットの作成 train_data = [ ("東京駅から大阪駅までの行き方を教えてください。", {"entities": [(0, 3, "DEPARTURE"), (5, 8, "DESTINATION")]}), ("渋谷駅から新宿駅への行き方を知りたいです。", {"entities": [(0, 3, "DEPARTURE"), (5, 8, "DESTINATION")]}), ("名古屋駅から京都駅までの最短ルートは?", {"entities": [(0, 4, "DEPARTURE"), (6, 8, "DESTINATION")]}), # さらに多くのデータを追加 ] # 日本語モデルのロード nlp = spacy.blank("ja") # NERパイプラインの追加 ner = nlp.add_pipe("ner") # 新しいラベルの追加 for _, annotations in train_data: for ent in annotations.get("entities"): ner.add_label(ent[2]) # トレーニングデータの変換 def convert_data(data): db = DocBin() for text, annot in data: doc = nlp.make_doc(text) ents = [] for start, end, label in annot["entities"]: span = doc.char_span(start, end, label=label) if span is not None: ents.append(span) doc.ents = ents db.add(doc) return db # トレーニングデータの準備 train_db = convert_data(train_data) train_db.to_disk("./train.spacy") # トレーニングの実行 nlp.begin_training() for itn in range(30): # エポック数を増やしてみる losses = {} examples = [] for text, annotations in train_data: doc = nlp.make_doc(text) example = Example.from_dict(doc, annotations) examples.append(example) nlp.update(examples, drop=0.35, losses=losses) print(f"Epoch {itn + 1} - Losses: {losses}") # トレーニング済みモデルの保存 nlp.to_disk("./ner_model") print("Model trained and saved.")
とりあえず3行ほどの行き方を尋ねるサンプル文を作成し、文中に含まれる駅名の位置を出発地と目的地に分けてラベリングしたものをトレーニングに入れ込む。
python spacy_train.py Epoch 1 - Losses: {'ner': 35.0999950170517} Epoch 2 - Losses: {'ner': 34.133325695991516} Epoch 3 - Losses: {'ner': 32.990808963775635} Epoch 4 - Losses: {'ner': 31.658286452293396} Epoch 5 - Losses: {'ner': 29.982840061187744} Epoch 6 - Losses: {'ner': 27.79977011680603} Epoch 7 - Losses: {'ner': 25.32037329673767} Epoch 8 - Losses: {'ner': 21.030838012695312} Epoch 9 - Losses: {'ner': 19.09130835533142} Epoch 10 - Losses: {'ner': 16.101310461759567} Epoch 11 - Losses: {'ner': 11.924833744764328} Epoch 12 - Losses: {'ner': 10.00987720862031} Epoch 13 - Losses: {'ner': 9.994316773489118} Epoch 14 - Losses: {'ner': 9.19360821461305} Epoch 15 - Losses: {'ner': 9.794028195785359} Epoch 16 - Losses: {'ner': 8.14222964423243} Epoch 17 - Losses: {'ner': 8.681412859354168} Epoch 18 - Losses: {'ner': 8.006356921629049} Epoch 19 - Losses: {'ner': 7.872147118701832} Epoch 20 - Losses: {'ner': 6.413324706401909} Epoch 21 - Losses: {'ner': 6.0774027310544625} Epoch 22 - Losses: {'ner': 5.448554979404435} Epoch 23 - Losses: {'ner': 33.613754868507385} Epoch 24 - Losses: {'ner': 11.632518198341131} Epoch 25 - Losses: {'ner': 14.990542967570946} Epoch 26 - Losses: {'ner': 13.025165064726025} Epoch 27 - Losses: {'ner': 13.370349167846143} Epoch 28 - Losses: {'ner': 13.174954702844843} Epoch 29 - Losses: {'ner': 12.035214286064729} Epoch 30 - Losses: {'ner': 10.9738067840226} Model trained and saved.
作成したモデルをロードしてテスト文から抽出させる
Python | spacy_exec.py | GitHub Source |
import spacy # トレーニング済みモデルの読み込み nlp = spacy.load("./ner_model") # ユーザーの入力に対する出発地と目的地の抽出 def extract_entities(text): doc = nlp(text) departure = None destination = None for ent in doc.ents: if ent.label_ == "DEPARTURE": departure = ent.text elif ent.label_ == "DESTINATION": destination = ent.text return departure, destination # テスト用の入力 user_input = "大久保駅から吉祥寺駅までの行き方を教えてください。" departure, destination = extract_entities(user_input) print(f"出発地: {departure}, 目的地: {destination}")
python spacy_exec.py 出発地: 大久保駅, 目的地: 吉祥寺駅
トレーニングデータには大久保も吉祥寺も登場していないにかかわらず、出発地と目的地と判定して抜き出しに成功。
このあたりはトレーニングのクオリティと時の運にもよるので、ミスを無くすにはより大量のデータが必要になってくるだろう。
Android
iPhone/iPad
Flutter
MacOS
Windows
Debian
Ubuntu
CentOS
FreeBSD
RaspberryPI
HTML/CSS
C/C++
PHP
Java
JavaScript
Node.js
Swift
Python
MatLab
Amazon/AWS
CORESERVER
Google
仮想通貨
LINE
OpenAI/ChatGPT
IBM Watson
Microsoft Azure
Xcode
VMware
MySQL
PostgreSQL
Redis
Groonga
Git/GitHub
Apache
nginx
Postfix
SendGrid
Hackintosh
Hardware
Fate/Grand Order
ウマ娘
将棋
ドラレコ
※本記事は当サイト管理人の個人的な備忘録です。本記事の参照又は付随ソースコード利用後にいかなる損害が発生しても当サイト及び管理人は一切責任を負いません。
※本記事内容の無断転載を禁じます。
※本記事内容の無断転載を禁じます。
【WEBMASTER/管理人】
自営業プログラマーです。お仕事ください!ご連絡は以下アドレスまでお願いします★
【キーワード検索】
【最近の記事】【全部の記事】
CORESERVER v1プランからさくらインターネットスタンダートプランへ引っ越しメモさくらインターネットでPython MecabをCGIから使う
さくらインターネットのPHPでAnalytics-G4 APIを使う
インクルードパスの調べ方
【Git】特定ファイルを除外する.gitignore
【Ubuntu/Debian】NVIDIA関係のドライバを自動アップデートさせない
【Python】Spacyを使用して文章から出発地と目的地を抜き出す
HomeBrewでApache2を入れて自動起動つきで動かしPHPモジュールと連携する
macOSに標準付属のApacheを自動起動つきで動かす
HomeBrewでPostgreSQLを入れて自動起動つきで動かす
【人気の記事】【全部の記事】
【Windows10】リモートデスクトップ間のコピー&ペーストができなくなった場合の対処法Windows版Google Driveが使用中と言われアンインストールできない場合
【C/C++】小数点以下の切り捨て・切り上げ・四捨五入
進研ゼミチャレンジタッチをAndroid端末化する
Windows11+WSL2でUbuntuを使う【2】ブリッジ接続+固定IPの設定
Googleスプレッドシートで図形をコピーして使いまわすには
【Linux】iconv/libiconvをソースコードからインストール
【Apache】サーバーに同時接続可能なクライアント数を調整する
Pythonで処理にかかった時間を計測するには
Windows11のコマンドプロンプトでテキストをコピーする
【カテゴリーリンク】
Android
iPhone/iPad
Flutter
MacOS
Windows
Debian
Ubuntu
CentOS
FreeBSD
RaspberryPI
HTML/CSS
C/C++
PHP
Java
JavaScript
Node.js
Swift
Python
MatLab
Amazon/AWS
CORESERVER
Google
仮想通貨
LINE
OpenAI/ChatGPT
IBM Watson
Microsoft Azure
Xcode
VMware
MySQL
PostgreSQL
Redis
Groonga
Git/GitHub
Apache
nginx
Postfix
SendGrid
Hackintosh
Hardware
Fate/Grand Order
ウマ娘
将棋
ドラレコ