JX通信社は Go Conference 2021 Autumn に協賛します!【11/13(土)開催】

こんにちは! CDO(開発担当役員)の小笠原(@yamitzky)です。

Python のイメージが強い(?)JX 通信社ですが、実は Go に関してもかなり積極的に利用しています。Go のコミュニティ発展を目的として、Go Conference 2021 Autumn にブロンズプランでの協賛を行っています。

Go Conference 2021 Autumn は、今週土曜日 2021/11/13(土) 10:00 〜 20:00 で開催 されますので、ぜひ奮ってご参加ください(参加費無料)。

gocon.connpass.com

JX 通信社と Go

AI リスク情報SaaS「FASTALERT」や、ニュース速報アプリ「NewsDigest」のバックエンドは Python や Go で開発されています。

機能開発だけでなく、VaultTerraform など、Go で開発された OSS も活用しています。

また、僕自身も、今年の ISUCON は同僚と Go で参戦し、本選に進むことができました 💪

NewsDigest での使い方

NewsDigest の API は gqlgen という Go 製の GraphQL ライブラリを使って開発されています。

NewsDigest は速報の強いニュースアプリです。地震速報などの災害や、重大なニュース速報などは事前に予測することができないため、予期できないトラフィックのバーストを捌く必要があります。gqlgen は他言語を含めた GraphQL ライブラリの中でも高速なので、ベンチマーク検証の上、技術選定しました。

参考) Go の GraphQL API のパフォーマンス改善のために分散トレーシングを導入した話 - JX通信社エンジニアブログ

FASTALERT での使い方

FASTALERT のバックエンドでは、API に限らず、バッチ処理、イベント駆動型の処理など、幅広い箇所で Go が活用されています。

JX 通信社では、機能開発を行う際、マイクロサービス的にシステム分割をしながら開発しています。元々は Python によるシステム(マイクロサービス)が多かったのですが、新規システムの開発時に積極的に Go を選定したり、既存システムの置き換えなども行っています。

参考) Goの並行処理について - JX通信社エンジニアブログ

採用情報

JX 通信社では、FASTALERT や NewsDigest を進化させていくバックエンドエンジニアを募集しています! Python が書けない方でも大丈夫です。

もし少しでも「話を聞いてみたい」という方がいたら、小笠原まで Twitter で DM いただくか、Wantedlyコーポレートサイトなどでお声がけください。

open.talentio.com

それでは、明日は Go Conference 2021 Autumn を楽しみましょうー!

属人化しがちなR&Dをチーム開発するためのJX通信社での工夫

こんにちは!JX通信社でMLエンジニアのファンヨンテです。私は自分だけでなくMLチームの成果を最大化するために日々全力を尽くしています!

JX通信社のMLチームでは人的リソースを最大限活用するため "力を使うべき所にのみ注力しよう!" をスローガンに徹底的に効率化しています。 今回はちゅらでーた様と弊社の共同勉強会で私が発表した内容をより掘り下げてお伝えできればと思います。

本内容については

ちゅらデータさんとの共同勉強会にて発表しております!

speakerdeck.com

動画を見たい方はこちら

を御覧ください〜

R&Dタスクの属人化について

f:id:yoooongtae:20211015004452j:plain
図1 アプリ開発におけるチーム開発(左)とR&Dチームで発生しがちな属人化した、タスクの進め方

弊社はNewsDigestを始めとしたアプリを開発しており、アプリ開発の場ではチームの皆が一丸となり、アプリ開発という一つの目標に向かって協力し合いながら進めていきます。(図1左)

一方、R&Dのタスクでは、1タスクを一人が担当しているといったケースが多いのではないでしょうか?(図1右)

f:id:yoooongtae:20211015004718p:plain
図2 R&Dチームが属人化する理由と問題点

R&Dタスクが属人化してしまう理由は様々あると思いますが、その問題点はメンバー間の協力が非常に難しいことにあります。

それゆえに、だれかと同じ苦労を違う人が繰り返す。。。といったことが頻発してしまいます。これでは、 "力を使うべき所にのみ注力しよう!"という我々の信念からずれてしまいます。

個人の能力をフルに発揮しながらも、チームの協力を最大化し無駄な所に時間を割かないための我々JX通信社での工夫を紹介します。

R&Dをチームで行うためのJX通信社の工夫

コードのテンプレート化

f:id:yoooongtae:20211015004824j:plain

f:id:yoooongtae:20211015004829j:plain
図3 コードのテンプレート化について

R&Dのチーム内で協力が生まれにくい大きな原因は、他メンバーのコードを読み理解することが面倒なことだと思います。個々のメンバーが規則なく、自由にコーディングを行うとこの面倒さが発生します。この状況を回避するためにJX通信社では、機械学習モデルの学習時とデプロイ時にそれぞれ利用できるテンプレートを開発し利用しています。

テンプレートを作成した目的は以下のとおりです。

  1. コードの書き方に適度な矯正をすることで、可読性を増やすこと
  2. Poetryを用いて学習環境を管理することで、学習を引き継ぐ時などに学習環境を気にしなくても良くなること
  3. 後述するMLflow等の便利系のTipsを予め仕込むことで、初学者でも"いい感じに"学習できること

テンプレートには我々の信念である "力を使うべき所にのみ注力しよう!"と同じコンセプトで作られたPytorch Lightningを用いて開発されています。(私がPytorch Lightningヘビーユーザーであり2021年11月17日のMLOps勉強会にてPytorch Lightningについて発表します。もしよろしければ、来てください!)

このテンプレートコードはJXの社員・インターンにより随時アップデートされており、誰かが一度体験した苦労を他メンバーが体験しないようにしています。

MLflow + App Engineを用いた実験の一元管理

f:id:yoooongtae:20211015005604j:plain
図4 JX通信社におけるmlflowの使い方

実験した結果を実験者本人しか知らない状況を作ってしまうと、実験の引き継いだ後に同じ実験をしてしまったり、実験の工夫が結果に及ぼす影響を定量的に共有ができなかったりと、"力を使うべき所にのみ注力しよう!"ができません! そこでJX通信社では皆が行った実験結果を1元管理できる仕組みを作成しました。Pytorch Lightningを用いたテンプレートを用いて学習しているので、実験管理は数行のコードの追加で行うことができます。弊社では実験管理のOSSとして有名なMLflowを用いて行っています。

App EngineにMLflowサーバーを構築しており、学習時に結果とモデルがアップロードされるようテンプレートコードに組み込んであります。したがって、テンプレートコードを利用すると、(実験者は意識しなくても、google colabを含む)どのサーバーで学習しても学習状況とモデルがGCSにアップロードされ、JX通信社のすべてのメンバーがすべての実験内容を把握できるようになっています。

学習状況がSlackに通知されるシステム

f:id:yoooongtae:20211015004216j:plain
図5 Slackに学習状況が通知されるシステムについて

学習の引き継ぎや、"モデルA →モデルB"のようなパイプライン状のモデルの学習を複数メンバーで分担して行う場合、他メンバーの学習の進捗状況を知りたい時が頻繁にあります。直接、進捗状況をメンバーに聞くのもありですが、そのメンバーの集中を妨げることになります。

そこでJX通信社では学習の進捗情報と精度などの結果をSlackに自動で投稿してくれる昨日を開発しテンプレートに仕込んでます。

したがって、JXの皆がテンプレートを利用し学習してくれると、自動でSlackに学習状況が共有されることになります!

まとめ

今回の記事では、MLチームにおいて"力を使うべき所にのみ注力しよう!"を達成するにチームメンバーが互いに協力しあえる環境・システム構築についてお話しました。

"力を使うべき所にのみ注力しよう!"を達成するためにはシステムで解決できる部分と解決できない部分があると思います。今回は、システムで解決できると思った部分にフォーカスしてお話しましたが、システムで解決できない部分(価値のAIをそもそも作っているのか?メンバー間のコミュニケーションはうまく取れているのか?等)についても様々な工夫をとっており、これについてもいつか記事を書ければと思ってます!

JX通信社のMLチームは、ML以外のチームとも連携し合いながら、これからもどんどん成長していきたいと思ってます!

我々とともに挑戦する仲間を求めています

我々とともに成長しながら、より良い社会のためのMLを開発したい仲間を社員・インターン問わず積極的に募集しています!また、MLエンジニアはもちろん、あらゆる職種のエンジニアを求めています!

正社員、インターン、おためし入社などなど!ほんの少しでも興味を持たれた方はこちらを覗いてみてください!

PyCon JP 2021 でのミニゲームをGoogle Cloud RunとFirebase Hostingで作った話

JX通信社はPyCon JP 2021のゴールドスポンサーとしてスポンサーさせていただきました。 今回弊社平瀬さんがLocastの内容で登壇されました!内容はこちらからご覧ください。

speakerdeck.com

その際、会社のブースではPythonを使ったミニゲームを実施していました。

PyCon JP 2021について

f:id:jx_k_watanabe:20211019211553p:plain

PythonのカンファレンスであるPyCon JP 2021はオンラインとオフラインでのハイブリッド開催となりました。スポンサーブースはオンラインのみで、Discordのチャンネルに参加者の方に入っていただき、そこで会社の紹介や、参加者の方と交流するという流れになっていました。画像はブースで待機している図です。

f:id:jx_k_watanabe:20211019210537p:plain:w300

(JX通信社のブースです)

その中でJX通信社ではミニゲームを開催し、来場者の方にゲームを遊んでいただき、抽選で景品を差し上げることにしました。 オンラインブースに来てもらうためのコンテンツを用意すると、当日来てくれた方と話すきっかけになったり、宣伝もしやすかったので用意して良かったです。

f:id:jx_k_watanabe:20211019210616p:plain:w400

ゲームの内容

ゲームを作るに当たり、UI部分まで作るのは大変そうだったので、小規模に、かつPythonを使ってゲームを楽しめるようにできないか考えた結果、お客さんはエンジニアなのでターミナルからアクセスしてもらう形式にしました。

Responseはすべてテキストで返却して、回答者に答えをPostしていただく形式になっています。

f:id:jx_k_watanabe:20211019224510p:plain

(ちなみにResponseをapplication/jsonではなくtext/plainにすることでターミナル上で改行がきちんと反映されるというハックがあります)

また、回答していただけた方にアンケートを実施していたのですが、回答が全て好評でよかったです。

f:id:jx_k_watanabe:20211019223735p:plain:w500

要件

今回設定したのは以下の要件でした。

  • APIだけで完結させる
  • お客さんのHTTPClientで答えをPostしてもらう

システム構成

f:id:jx_k_watanabe:20211019210706p:plain

ソースコードはこちらで公開しています。

FastAPI(ゲーム部分)

API部分はFastAPIで作成しました。テキスト・問題・解答はyamlファイルに入れておき、ゲームの進行度に応じて適切な内容を返すようにしたので、ロジック部分はほぼありません。

f:id:jx_k_watanabe:20211019210729p:plain:w300

FastAPIは基本JSONレスポンスなので、文章を出す際に改行コードが出てしまい、若干見づらかったためテキスト形式でレスポンスを返すようにしました。

工夫した点1: Cloud Runからのレスポンスタイムが遅い

Cloud Runは、通常コールドスタートといって、リクエストが来ない間はインスタンスを立ち上げない状態のままにしておき料金を抑えてくれるという特徴があります。しかし、リクエストが来る最初のレスポンスは時間がかかってしまう問題があり、コールドスタートのままだとリクエストから5秒以上かかってしまうこともありました。ゲームは(一般的なAPIエンドポイントと比較すると)リクエスト数が少なくなるため、ユーザーがアクセスするときに時間がかかってしまうことは避けたかったので、Cloud Runのオプションで最低1つのインスタンスを立ち上げておくようにしました。

このようにデプロイ時にプロパティを設定することで、最低インスタンス数を指定することができます。

gcloud run deploy game --source . --min-instances 1

工夫した点2: ドメイン名を読みやすいようにしたい

当初はCloud Runが自動生成していたURLを利用していたのですが、実際にゲームを参加していただく方に分かりづらいドメイン名になってしまっていたため(https://game-mbqu6va7zq-an.a.run.app/ のようなドメインが生成されます)、ある程度読みやすいドメインがほしいなと思いました。

そこで、Firebase hostingを使って、ゲームに必要なドメインを作成することにしました。デフォルトだと、プロジェクト名そのままのドメインが生成されてしまうため、カスタムドメインを作成します。 Firebaseの画面から新規でウェブのアプリを作成します。その際に、Firebase Hostingの設定もできるため、ここで追加します。

f:id:jx_k_watanabe:20211019211204p:plain

その後、チュートリアルに従ってFirebaseの設定をします。加えて、Firebase HostingからCloud Runに送信するための設定を記述します。以下すべてのトラフィックをCloud Runに送信するための設定です。また、生成されたindex.htmlは削除しましょう。

"hosting": {
  // ...
  "rewrites": [ {
    "source": "**",
    "run": {
      "serviceId": "game", // ここにはCloud Runのプロジェクト名を指定します
      "region": "asia-northeast1"
    }
  } ]
}

おまけ

今回PyConJPの開催期間2日稼働させた結果、150円前後で抑えることができました。(Cloud Runの最低インスタンスを抑えることでもう少し安くできたかもしれないですが)

f:id:jx_k_watanabe:20211019211252p:plain:w500

まとめ

今回もスポンサーとして参加できてよかったです。

JX通信社はPythonistaを募集しております!

open.talentio.com

GCPをフル活用して東京五輪の2週間で約5000万ツイートをさばいた話

はじめまして。JX通信社でデータアナリストをしている @nrtaking です。

弊社では、7/23〜8/8に行われた東京オリンピック、8/25〜9/5に行われた東京パラリンピックにあわせて関連した日本語ツイートを全量収集し、Twitter Japanなど各社に提供していました。

内容に関する簡単な分析についてはプレスリリースでお伝えしているので、そちらもあわせてご覧ください。

prtimes.jp

実はこのツイート収集システムは、2週間ほどでほぼゼロから立ち上げたものでした。 今回は五輪関連のツイート収集を支えた技術について紹介します。

叶えたかった要件

  • 五輪に関するツイートを、NTTデータの提供するAPIからストリームで受け取り続ける
  • ツイート量などの統計情報やRTが多いツイート情報をダッシュボードの形で見ることができる
  • 上記を(ほぼ)リアルタイムで実現できる

実はこの取り組みにあたり、システム全体を一から構築する必要がありました。

提供が決まったとき、データ提供までのタイムリミットは約2週間。開催までは否定的な意見が目立っていたとはいえ、オリンピックが実際に始まると盛り上がることは容易に想像できました。 フルスクラッチで全部を作ることは早々に諦め、GCPのマネージドサービスをできるだけ活用する方向に切り替えました。

全体のアーキテクチャ

使ったサービス

  • Compute Engine
  • Cloud Pub/Sub
  • Cloud Dataflow
  • Cloud Storage
  • Cloud Composer
  • BigQuery
  • Google データポータル

f:id:nrtaking:20210901185815p:plain
ざっくりした構成

それぞれの役割

Compute Engine

  • ツイートを収集し続ける役割。
  • 弊社で提供しているプロダクトFASTALERT(ファストアラート)で同じような収集システムが存在していたので、そのコードを再利用しています。
  • といっても、コード自体は大したものではなく、全部で100〜200行程度。
  • 収集システムは Docker 化しておいて、インスタンス上で Docker コンテナを立ち上げています。Compute Engine には簡単に Docker コンテナをデプロイできる機能があるので、これを活用しました。

f:id:nrtaking:20210906173151p:plain
1秒間でも大量のツイートが取り込まれていきます

Cloud Pub/Sub

  • GCE で動いてるシステムからツイートを一つずつ受け止め、Cloud Dataflowに一つずつ流してあげる役割。
  • ここのリクエスト数のメトリクスを Datadog で監視してあげることで、 GCE がツイートを収集し続けられているか、死活監視していました。
    f:id:nrtaking:20210906173353p:plain
    この「リクエスト数」を死活監視に使ってました

Cloud Dataflow

  • Pub/Sub から受け取ったツイートを、5分おきに Cloud Storage に吐き出す役割。
  • Dataflow には Pub/Sub から MongoDB など、さまざまなテンプレートが用意されています。今回は Pub/Sub から Cloud Storage に吐き出すテンプレートを使用しました。
  • ここから BigQuery に直接吐き出す選択肢もありましたが、データのスキーマ設定で沼ったので諦めました。
  • Cloud Storage に吐き出す時は JSONL の形式にしてました。

Cloud Storage & BigQuery

  • 社内のデータ基盤では Cloud Storage から BigQuery にロードする構成を取っており、同様の構成としました。
  • Airflow で動かすDAGの定義のために50行くらいコードを書きましたが、これも社内のコードを再利用する形。
  • 社内のデータ基盤についてはこちらも見てみてください!

tech.jxpress.net

  • 5分おきにロードされるよう設定しましたが、重複排除の仕組みを作るのが少し面倒でした。
  • 新しく Cloud Composer のクラスタを立てると維持費が高くなってしまうため、先述したデータ基盤のクラスターに相乗りすることで維持費を節約しました。

Google データポータル

  • 簡単な可視化であれば、これで十分です。
  • 今回は特に、ダッシュボードの利用者も Google Workspace を使っていたため、重宝しました。
  • フロントの開発が必要なかったのが非常によい。

f:id:nrtaking:20210831185027p:plain
こんな感じのアウトプットです

気になる点① 処理性能と安定性

もちろん時間帯ごとにツイート数に差があるのですが、

  • 平均すると1秒あたり40ツイート程度を処理し続けました。
  • 結果的に、 五輪の2週間で約5000万ツイートを処理 することができました。

死活監視も仕込んでましたがアラートが来ることもなく、安定性もバッチリでした。

気になる点② 値段

実際に動かす前は結構気にしていた点で、動かし始めた当初は費用がかさんでいたのですが、いろんな工夫をした結果、1日あたりの費用は数十ドル程度までおさえることができました。

f:id:nrtaking:20210908212021p:plain
こういうリソースにお金がかかっています

  • Compute Engine のインスタンスタイプを可能な限り下げる
  • 特にお金がかかる Cloud Composer は既存プロジェクトのものを使う

といった工夫をしています。特に後者の工夫がなければ、費用は跳ね上がっていたと思います。

それでもなお課金額が多いサービスは

  • Cloud Monitoring (Compute Engine のロギング)
  • Cloud Dataflow

の2つです。ツイートごとにログを吐き出す形にしていたため、前者は流量に比例するものですが、後者はCPUやストレージが常に必要になる以上、その動作時間に対して常に費用がかかり続けます。

Cloud Dataflow の実体は Compute Engine 上にあり、設定時に割り当てるインスタンスタイプを選ぶことができるのですが、 g1-small など、共有コアを使う(安価な)インスタンスタイプを割り当てられないことも費用がかさむ要因です。

立ち上げてみて

2週間という短い期間にもかかわらず、社内のリソースとGCPのマネージドサービスを使うことで、大量のデータを処理できるシステムを作ることができました。

GCPは偉大です!

BERTの推論速度を最大10倍にしてデプロイした話とそのTips

背景

はじめまして、JX通信社でインターンをしている原田です。

近年深層学習ではモデルが肥大化する傾向にあります。2020年にopen aiが示したScaling Laws([2001.08361] Scaling Laws for Neural Language Models) の衝撃は記憶に新しく、MLP-Mixerが示したように、モデルを大きくすればAttention構造やCNNでさえも不必要という説もあります。([2105.01601] MLP-Mixer: An all-MLP Architecture for Vision

しかし大きな深層学習モデルを利用しようとすると、しばしば以下のような問題に悩まされます。

  • 推論速度が問題でプロダクトに実装不可能
  • GPU/TPUはコスト上厳しい
  • プロダクトの性質上バッチ処理が不可能(効率的にGPU/TPUが利用できない)

例えばJX通信社の強みは「速報性」にあるため、バッチ処理が困難であり、効率的なGPU/TPU利用が困難です。

しかし、機械学習モデルの精度はプロダクトのUXと直結するため、「なんとかCPU上で大きなモデルを高速に推論させたい」というモチベーションが発生します。

本記事は以上のような背景から大きなNLPモデルの代表格であるBERTを利用して各高速化手法を検証します。 さらに多くの高速化手法では推論速度と精度のトレードオフが存在し、そのトレードオフに注目して検証を行います。

実際に自分は下記で紹介する方法を組み合わせた結果、BERTの推論速度を最大約10倍まで向上させ、高速に動作させることに成功しました!

まとめ

今回検証した各高速化手法の各評価は以下になります。 (☆ > ◎ > ○ > △ の順で良い)

f:id:haraso1130:20210824183731p:plain
各手法のまとめ

ただし、タスクによって各手法の有効性が大きく変わるので実際に高速化を図る際には、その都度丁寧な検証が必要です。

各手法の説明と実装コード

以下から簡単に各高速化手法の概要と実装コードを解説します。

  • pruning, quantization, distillation, torchscriptはNLP以外でも利用可能な手法
  • max_lengthはNLPモデルであれば利用可能な手法です
  • 動的なmax_lengthはバッチサイズ==1で推論するときに利用可能な手法です。

quantization(量子化)

量子化とは、浮動小数点精度よりも低いビット幅で計算を行ったり、テンソルを格納したりする技術のことです。float32からint8へ変換することが一般的です。

ここではpytorch公式を参考にしました。

pytorch.org

Pytorchでは以下の三種類の量子化が用意されており、今回は最も簡単なdynamic quantizationを学習済みモデルに適応します。

  • dynamic quantization(動的量子化)...weightsのみ量子化し、活性化はfloatで読み書きを行う。学習済みモデルにそのまま適応し、計算を行う。
  • static quantization(動的量子化)...weightsと活性化を両方量子化する。学習後にキャリブレーションが必要である。
  • quantization aware training ...weightsと活性化を両方量子化する。トレーニング中から量子化をおこなう。

実装コードは以下になります。 以下のコードでは、BERTのnn.Linearの重みをfloat32→int8に変換しています。

def quantize_transform(model: nn.Module) -> nn.Module::
  model = torch.quantization.quantize_dynamic(
    model, {torch.nn.Linear}, dtype=torch.qint8
  )
  return model

distillation(蒸留)

蒸留は大きなモデルを教師モデルとし、教師モデルより小さなモデルを作成する手法です。 特にBERTの蒸留版モデルはDistilBERT(https://arxiv.org/pdf/1910.01108.pdf) として紹介されています。

BERT-baseはtransformerを12層利用していますが、DistilBERTはその半分の6層のtransformerを持った構造になっています。

また、損失関数は以下の三つから構成されており、解釈として「masked language task(単語穴埋め問題) をこなしながら、教師モデルと近い出力と重みを獲得する」と捉えることができます。

  • BERTのoutputとの近さ
  • masked language taskでの損失
  • BERTのパラメータとのコサイン類似度

今回の実験ではバンダイナムコが公開している日本語版distillbertモデルを利用しました。 https://huggingface.co/bandainamco-mirai/distilbert-base-japanese

huggingfaceのtransformersを利用することでとても簡単に使うことができます。

from transformers import AutoTokenizer, AutoModel
  
tokenizer = AutoTokenizer.from_pretrained("bandainamco-mirai/distilbert-base-japanese")

model = AutoModel.from_pretrained("bandainamco-mirai/distilbert-base-japanese")

pruning(剪定)

モデルの重みの一定割合で0にする手法で、モデルをスパースにすることができます。

ここでもpytorch公式のtutorialに沿って実装します。

pytorch.org

どの重みを剪定するかはさまざまな研究がありますが、ここでは上記tutorialで紹介されていたL1ノルム基準で削る手法を用いました。絶対値が小さい重みは重要度が低いと考えられるため0にしてしまうという発想はとても直感的です。

実装コードは以下になります。

import torch.nn.utils.prune as prune

PRUNE_RATE = 0.2

def prune_transform(model: nn.Module) -> nn.Module:
  for name, module in model.named_modules():
    if isinstance(module, torch.nn.Linear):
        prune.l1_unstructured(module, name='weight', amount=PRUNE_RATE)
        prune.remove(module, "weight")
  return model

上記のコードではモデル中のnn.Linearの重みのうち、絶対値が小さいものから20%を0に置き換えるという処理になります。

今回は複数のPRUNE_RATEで推論速度と精度の変化を実験しました。

torchscript(Jit)

TorchScriptは、PyTorchのコードからシリアライズ可能で最適化可能なモデルを作成する手法です。Python以外のC++等のランタイムで実行可能になります。

Pytorhはdefine by run方式を採用しており、動的に計算グラフを作成します。学習時には非常に有用なこの形式ですが、プロダクション上の推論時における恩恵はほとんどありません。

そこで、先にデータを流してコンパイルしてしましまおう(実行時コンパイラを使おう)というのが大まかな発想です。

より詳細な解説は以下の記事が非常にわかりやすいです。

towardsdatascience.com

簡単に解説すると、 - Torchscriptは中間表現コード - この中間表現は内部的に最適化されており、実行時に pytorchの実行時コンパイラであるPyTorch JIT compilationを利用する。 - PyTorch JIT compilationはpythonランタイムから独立しており、実行時の情報を用いて中間表現を最適化する

実装コードは以下になります。 torchscriptにはtraceとscriptの二つの作成方法がありますが、ここでは後からでも簡単に作成できるtraceを用います。

def torchscript_transform(model):
  model = torch.jit.trace(model, (SANPLE_INTPUT))
  return model

max_length

inputのmax_lengthを制限して入力データを軽くします。 transformersで前処理を行う場合、以下のような実装になります。

from transformers import BertTokenizer

MAX_LENGTH = 512

tokenizer = BertTokenizer.from_pretrained("hoge_pretrain")

data = tokenizer.encode_plus(
            TEXT,
            add_special_tokens=True,
            max_length=MAX_LENGTH,
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        )

do_not_pad

この手法はbatch_size==1で推論する場合に利用可能な手法です。

通常batch推論をするために入力データのpaddingが必要ですが、batch_size==1の状況下ではパディングを行わずに推論することができます。

実装は以下になります。padding引数に'do_not_pad'を設定するだけです。

from transformers import BertTokenizer

tokenizer = BertTokenizer.from_pretrained("hoge_pretrain")

data = tokenizer.encode_plus(
            TEXT,
            add_special_tokens=True,
            max_length=512,
            padding="do_not_pad",
            truncation=True,
            return_tensors="pt",
        )

実験方法

今回の実験は精度と速度のトレードオフを測定することが主眼であるため、丁寧に精度の調査を行います。

環境

実行環境はgoogle colabで統一してあります。

Dataset

後述しますが、データセットによって有効な手法が変わるため特性が異なる複数のサンプルタスクを用意しました。

  • 一文が長いデータセット(livedoorトピック分類)
  • 一文が短いデータセット(twitter感情分類、ポジネガの2値分類で検証)

modelについて

精度評価方法

  • まず8:2でtrain/testに分割
  • trainのみを利用し、5fold stratified cross validation(全ての実験でfoldは固定)でモデルを学習
  • 5つのモデルでそれぞれtestに対して推論、averageしたものをtestの予測値とする。
  • cvとtestのacc & f1 macroで比較

速度評価方法

  • testセットからランダムに500個のデータをサンプリングし(全ての実験で共通)、batch_size==1で推論
  • 各データに対する推論時間の平均値と標準偏差で評価

結果

まず、各手法に対するtest scoreと速度のplotは以下のようになりました。 グラフの見方ですが、以下の通りです。

  • 一番左がベースライン
  • 赤と黄色のバーは精度を表しており上方向の方が良い
  • 青い点は推論時間で下方向の方が良い
  • エラーバーは標準偏差

twitter感情分類

f:id:haraso1130:20210824180416p:plain
twitterデータに対する精度と速度

livedoorトピック分類

f:id:haraso1130:20210824180443p:plain
livedoorデータに対する精度と速度

詳細な結果は以下になります。

twitter感情分類

手法 cv acc (f1-macro) test acc (f1-macro) 平均推論速度(s) 標準偏差(s)
BASELINE 0.8295 (0.8193) 0.8363 (0.8256) 0.2150 0.0050
quantization 0.8223 (0.8092) 0.8283 (0.8150) 0.1700 0.0048
distillation 0.8388 (0.8313) 0.8292 (0.8220) 0.1547 0.0076
max_length:64 0.8212 (0.8103) 0.8250 (0.8138) 0.1156 0.0036
do_not_pad 0.8295 (0.8193) 0.8363 (0.8256) 0.0987 0.0290
torchscript 0.8295 (0.8193) 0.8363 (0.8256) 0.1847 0.0080
pruning: 0.2 0.8327 (0.8226) 0.8283 (0.8173) 0.2124 0.0043
pruning: 0.4 0.8095 (0.7972) 0.8229 (0.8100) 0.1925 0.0041
pruning: 0.6 0.7097 (0.6787) 0.7597 (0.7198) 0.1925 0.0044
pruning: 0.8 0.5809 (0.5024) 0.6220 (0.3834) 0.1912 0.0046

livedoorトピック分類

手法 cv acc (f1-macro) test acc (f1-macro) 平均推論速度(s) 標準偏差(s)
BASELINE 0.9238 (0.9180) 0.9348 (0.9285) 0.7500 0.0079
quantization 0.9022 (0.8962) 0.9246 (0.9199) 0.6565 0.0068
distillation 0.8581 (0.8494) 0.8723 (0.8646) 0.5128 0.0079
max_length:256 0.8691 (0.8630) 0.8676 (0.8605) 0.4511 0.0062
do_not_pad 0.9238 (0.9180) 0.9348 (0.9285) 0.7012 0.0926
torchscript 0.9238 (0.9180) 0.9348 (0.9285) 0.7222 0.0083
pruning: 0.2 0.9204 (0.9144) 0.9355 (0.9302) 0.7633 0.0083
pruning: 0.4 0.8674 (0.8624) 0.8900 (0.8846) 0.7682 0.0084
pruning: 0.6 0.1973 (0.1176) 0.2057 (0.1025) 0.7496 0.1045
pruning: 0.8 0.1360 (0.0950) 0.1140 (0.0227) 0.7287 0.0075

考察

それぞれの手法についてより性能をわかりやすく表示するためBASEの精度と速度を1とし、各手法の性能を考察していきます。

twitter感情分類

f:id:haraso1130:20210824180325p:plain
Twitterデータに対する相対的な精度と速度

livedoorトピック分類

f:id:haraso1130:20210824180428p:plain
livedoorデータに対する相対的な精度と速度

quantization(量子化)

どちらのタスクにおいても殆ど精度を落とさずに推論時間を10~20%ほど削減することが可能です。 実装も容易であるため、高速化の際にはまず試してみたい手法です。

distillation(蒸留)

精度面ではタスクによって大きく結果が異なることがわかります。 twitterデータに対しては殆ど精度低下が見れらませんが、livedoorデータに対してはある程度の精度低下が認められます。

推論時間については約30%ほど削減できており、タスクによっては非常に有効な選択肢になり得ます。

max_length

どちらのタスクでも推論時間を40%~45%ほど削減できており、高速化において最も安定して寄与したといえます。

非常にインパクトが大きくセンシティブなパラメータであるため、ある程度速度が求められるシチュエーションの場合、まず初めにチューニングすべきパラメータです。

do_not_pad

この手法はデータセットによって大きく効果が異なる結果となりましたが、精度は不変であるため、バッチ処理が不可能な状況下では積極的に利用すべきです。

特に最大文字数が少なく、文字数の分散が大きいと考えられるツイッターデータセットではdo_not_padの影響は大きく、葯50%の推論時間をセーブすることができました。

torchscript

精度を落とさずに、少しではありますが推論速度を向上させることができます。

また、torchscriptはその他にも多くのメリットを有しており(Python以外のランタイムで実行可能、推論時にネットワークの定義が不要など)、プロダクションにデプロイする際はONNX等と並ぶ選択肢となります。

Pruning

今回の実験ではかなり微妙な結果でした。

twittterデータセットではpruning:0.4で10%ほどの推論時間削減を達成しましたが、その他の手法のトレードオフと比較するとコストパフォーマンスが低い印象です。その他の手法を全て適応した後、それでも高速化が必要ならば検討する、といったものになるでしょう。

また、livedoorデータセットにおいてはまさかの低速化に寄与する結果となってしまいました。

まとめ

本記事ではNLPモデルを高速にCPU上で動作させるため、各高速化手法について検証してきました。 タスクによって各手法のパフォーマンスが大きく異なるため、必要な精度と速度を見極めた後、最適な高速化手法の組み合わせを模索することが重要です。

その他にも有効な高速化手法があれば教えてくださると幸いです。

参考