リスク検知SaaSを支えるマルチモーダル・マルチタスクなExplainable AI

皆様こんにちは!JX通信社で機械学習エンジニアを担っているファンヨンテです。

弊社提供のビックデータ リスク情報サービスFASTALERTでは、Deep Learningを使ってSNSの投稿をリアルタイムに解析し、火事や事故などのリスク情報の検知を行っています。

SNSの投稿には、文字だけでなく、画像、動画などの情報も含まれているため、SNS解析にはよくマルチモーダルなAIモデルが用いられます。今回は「SNS の投稿からのリスク情報の判定」というタスクをテーマに、マルチモーダルなAIモデルの判定根拠の可視化や、精度を上げるための工夫などをご紹介します。

FASTALERT(ファストアラート)について

「FASTALERT」は、SNSをはじめとする各種ビッグデータから、AIがリスク情報を検知・配信するビックデータ リスク情報サービスです。報道に必要不可欠なツールとしてNHKと全ての民放キー局、全ての一般紙に採用されるなど、国内の大半の報道機関に浸透しています。また、最近では防災やBCP、障害監視やサプライチェーンのリスク管理など広範なニーズに対応する情報ツールとして、政府・自治体やインフラ企業をはじめとする幅広い業種の顧客に導入されています。URL:https://fastalert.jp

f:id:yoooongtae:20220408110109p:plain

図1FASTALARTについて

リスク検知タスクの課題

SNSからリスク情報を検知するというタスクの性質上、AIモデルには以下の4点の性質が求められます。

  1. 画像やテキスト等あらゆる情報を統合して判定したい!(マルチモーダル)
  2. リスク度だけでなく、リスク情報のカテゴリー(災害、事件、事故など)まで予測したい!
  3. リスク情報の検知漏れが起こってはならない!
  4. モデルに説明可能性をもたせたい!

マルチモーダルでマルチタスクができるAIのモデルアーキテクチャ

冒頭にも書いた通り、SNSの投稿にはテキストや画像、動画などの情報が含まれています。

また、FASTALERTでは、リスク情報を検知して配信するだけでなく、カテゴリー(災害、事件、事故などの区分)まで解析して、フィルタリングできるようにしています。リスク度の判定+リスクのカテゴリー予測を1つのモデルでマルチタスクで予想させると、ビジネス的な要件を満たすだけでなく、リスクの検知力自体を上げる効果もあります

まとめると、SNS からのリスク検知タスクに対しては、

  • 入力データ
    • テキスト
    • 画像
  • 出力データ
    • リスク検知
    • カテゴリー(災害、事件、事故などの区分)

が可能な”マルチモーダル””マルチタスク”ができるAI】が最適です。

AIの説明可能性

リスク情報を検知するというAIの性質上、AIの判断根拠を可視化することで、人間による監視や改善に役立てることができます。しかし、一般に公開されている多くのAIは”ブラックボックス”と言われ、判断根拠をうまく示すことができません。したがって、今回は判断根拠を示せるAIを開発し、実験を行いました。

AIの判断根拠を理解するためには、”挙動の理解”または”仕組みの理解”2つの方針があります。

”挙動の理解”は何かを変化させたときの出力の変化を見ることで、判断根拠を理解する方針であり、Grad_CAM, LIME, SHAP等が有名な方法です。しかし、マルチモーダルなモデルだと根拠の解釈が直感的ではないことや、計算時間が余計にかかり、リアルタイム性が失われてしまうことなどから、実装を見送りました。

”仕組みの理解”はAIモデルの仕組みを理解する事で判断根拠を見出す方針であり、Attentionが有名な方法です。Attentionを用いると、

  • 注目箇所が可視化されるので解釈があまり困らない
  • forwardの1回の計算で完結するので、計算時間が余計にかからない

のメリットがあるため、今回はAttentionを用いてAIが着目した部分を明瞭化し、判断根拠の可視化を行うことにしました。

実験

リスク検知力について

今回作成したモデルでは、全リスク情報ツイートの内99%を検知できるような再現率を達成することができました。弊社のFASTALERTでは、AIだけでなく最終的には有人による24時間体制の監視も加えて、リスク情報を漏れなく素早く検知し、より高品質な配信を実現しています。

モデルの判断根拠について

リスク情報を検知する判断根拠となる箇所を、わかりやすく可視化することで、さらに改善のための指標を得やすいようにしました。

以下に、火事、水害、横転事故のTweetを解析した例をあげました。画像は左がオリジナルで右がAIの注目箇所を表示しています。赤みが強いほど、AIがより注目したことを意味しています。

※ 以下の事例は実際のツイートではなく、ブログでの公開用に擬似的に作成したデータです(学習自体は実際のツイートで行っています)

f:id:yoooongtae:20220408110826p:plain
図2 火事についてのツイートを模したデータ。https://www.flickr.com/photos/sbeebe/7974757077から引用(licence: CC BY 2.0)。

f:id:yoooongtae:20220408111122p:plain
図3 水害についてのツイートを模したデータ。https://pxhere.com/en/photo/852904から引用(licence: CC 0.0)

f:id:yoooongtae:20220408111350p:plain
図4 横転事故についてのツイートを模したデータ。https://pxhere.com/en/photo/619517から引用(licence: CC BY 0.0)

これら結果を眺めていると、AIは各単語の関連性や、画像の情報と総合して判断箇所を注目していることがわかります。例えば、図2の火事の情報では、テキストは”消防車”と”火事”が、画像は炎の領域が注目されていることがわかります。この結果と図4の横転事故の情報を見比べてみましょう。すると”消防車”は、横転事故のときはあまり注目度が高くないのが見て取れます。図3の大雨の情報と、図4の横転事故を見比べても”道”という単語の注目度は変化してます。更に、興味深い点として、図4の横転事故のテキストの中で、最初に現れる”事故”は注目されていますが、後半の”事故”には注目されていないことがわかります。つまり、同じ文章内でも、重要度の区別をAIができていることを意味しています。

このように単純な単語のマッチングでは対応できない、複雑な言語の表現に対応した,マルチモーダルで説明可能なAIを作成することができました。

補足1:その他の前処理の工夫

リスク情報を漏らさないことを目的として、前処理をいくつか行ってます。その一つとして、絵文字の扱いを紹介します。

SNSにはテキスト情報の中に絵文字が含まれることがあります。自然言語処理を行う際にはノイズになるため絵文字は削除されることも多いのですが、リスク情報を検知する上では絵文字は重要な情報となりえます。 例えば以下のようなテキスト情報はリスク情報を含む可能性があります。

  • 東京⚡️やばい → 落雷・豪雨
  • 神田駅前🧑‍✈️👩‍✈️いっぱいいる。なにごと? → 事件
  • 東名高速で🚕 が🔥 → 交通事故

詳細な技術の説明は省きますがFASTALARTのAIはこれらの絵文字の情報も取りこぼすことなくリスク検知に役立てることで、情報の検知漏れがないようにしています。

補足2:チームでR&Dに取り組む工夫

JX通信社のMLチームにはインターン生がおり、今回の実験に関しても、インターンの小川遼人さんにサポートしていただきました。

過去にもエンジニアブログでご紹介しましたが、Pytorch Lightningをベースとした、テンプレートコードを用いて学習を行うことで、チームで効率よく実験ができるようにしています。

興味があれば以下のブログも読んてみてください tech.jxpress.net

tech.jxpress.net

我々とともに挑戦する仲間を求めています

我々とともに成長しながら、より良い社会のためのMLを開発したい仲間を社員・インターン問わず積極的に募集しています!また、MLエンジニアはもちろん、あらゆる職種のエンジニアを求めています!

正社員、インターン、そして副業・復業として体験的に働くことで、当社のカルチャーや働き方等を知っていただいたうえで正式入社を見極めていただくことが可能な「おためし入社」制度などもあります!ほんの少しでも興味を持たれた方はこちらを覗いてみてください!

slack-goとZapierで障害対応初動を自動化した話

f:id:TatchNicolas:20220316183046p:plain

サーバサイド開発やインフラ周りをいじっているたっち(TatchNicolas)です。今回はJX通信社における障害対応フローの改善について書きます。

はじめに

TL;DR;

  • slack-goとZapierを組み合わせて、障害対応時の提携作業を自動化するツールを作った
  • 「自動化しよう」という意見がでやすい場自体を仕組みとして整備したことが改善のきっかけになった

今回の話の背景

少し前に、ビープラウドさんとのイベントにて、JX通信社NewsDigestチームのCI/CDおよび障害対応についてお話しさせていただきました。*1

そのなかで、「障害対応時に専用のSlackチャンネルを都度作成していること」「Notionテンプレを使って情報の整理をしていること」*2を紹介しました。イベント後も何か障害やヒヤリハットが起こると少しずつフローが改善されていき、またNewsDigest以外のプロダクトチームでも同様のフローが採用されたりと自然と社内に浸透していきました。

JX通信社のポストモーテムでは、「発生した障害に対する技術的な振り返り、再発防止」の他に「対応のマニュアル、対応中のメンバーの振る舞い」に対する振り返りをする項目を明示的に設けています。その結果、「この手順、自動化したいな」という声が挙がったことから、ZapierとGolangを使って障害対応の初動を自動化する仕組みを作りました

f:id:TatchNicolas:20220316175706p:plain:w400
ポストモーテムで自動化の案がでたときのメモ

何をつくったか

任意のチャンネルで @incident-bot 障害発生 <プロダクト略称> <障害のひとこと説明> のような形式でSlackに書き込むと、

  • 障害対応Slackチャンネル作成(#incident-yyyy-mm-dd-<プロダクト略称>-<障害のひとこと説明>)
  • 情報をまとめるNotionページの作成
  • 関係者への連絡(然るべき常設Slackチャンネルへ新規作成した障害対応slackチャンネルとNotionページを投下)

を自動で瞬時に行ってくれます。

f:id:TatchNicolas:20220315230506p:plain:w400 f:id:TatchNicolas:20220315230514p:plain:w600

こういった作業は障害発生時のような急いでいる時ほど速く確実に行いたいので、自動化にはもってこいですね。

どう作ったか

はじめは、以前社内勉強会*3でも取り上げた https://github.com/slack-go/slack を使って適当な場所へデプロイしてサクッと完成させようと思ったのですが、Notion APIを使ってみたところ執筆時点でページ単位のAPIキーの発行に対応しておらず、あまり気軽にキーを発行するわけにもいきませんでした。

そこで社内で相談したところ「Zapier越しになら記録もいい感じに残るし使ってOK」との助言をもらったので、「じゃあいっそのことZapierで完結させるか!チャチャっと終わらせたいし!」と作り始めました。

Zapierの困りごと

Zapierは、SlackもNotionもネイティブに対応しており、Slack連携でチャンネル作成やメッセージ送信、Notion連携でページの作成などの処理を簡単に作ることができます。なのでポチポチしていくだけで簡単に今回の用件を満たすものが作れそうだと思ったのですが、Slackチャンネル作成の機能が日本語に対応していなかったので、 冒頭の例のようなチャンネルを作ろうとすると #incident-yyyy-mm-dd-nd-_ のように日本語がアンダースコアで置き換えられてしまいました。

これにより困ることが2つあって、一つは急いでいるときにチャンネル名から何が起こっているのかを特定できないことです。障害対応中はアラートを流しているチャンネルや修正のためのCI/CDの状態を流しているチャンネルなど、いろいろなチャンネルを行き来します。なので日本語でチャンネルを検索できた方が便利ですし、チャンネルが作成されたことを通知された人も何が起こっているのかパッとみてイメージしやすいので、日本語チャンネル名対応は諦めたくありませんでした。

もう一つは、あまり考えたくないですが同じ日に同じプロダクトで障害が複数発生*4した場合、日本語部分がアンダースコアになってしまうとチャンネル名が被ってしまいます。するとチャンネル作成がエラーになってしまいますし、ランダムな接尾辞などを付与するのもなんだかイケてないので困りました。

そこで「チャンネルを作る」というアクションは今後も変わることが少ないと想定されるので、チャンネル作成の部分は当初の予定通りGolangで開発したbotに任せ、それ以外をZapierに分担させることにしました。

JXならではの工夫

JX通信社は、私が所属しているプロダクトであるニュースアプリのNewsDigestのほか、To BサービスのFASTALERT、KAIZODEがあり、それぞれが共通して使う社内基盤であるXWireというシステムも存在します。

また弊社の文化として、各プロダクトのチームは技術選定やインフラ管理などを自分たちの意思と責任で実践しており、たとえば「本番デプロイのためのAWSの権限をどう管理・取得するか」のような本当に会社全体で守るべきルール以外はかなり大きな裁量がプロダクト開発チームに与えられています

一方で、他プロダクトの良いと思った開発手法やツールなどは積極的にお互いに真似していく文化もあります。私はこれを個人的に「ゆるやかで自然発生的な標準化」と呼んでいます。

なので、「障害が発生したときの動きを自動化したい」といっても、プロダクトごとに以下の設定が異なりますし、今後も違いが出てくるかもしれません。

  • 障害対応時に作成するNotionページにどんな情報を含めるか、どんなTODOを含めるか*5
  • 障害対応Slackチャンネルに誰を招待するか
  • 障害の発生を誰に通知するか
  • その他、「追加で○○がしたい」etc...

そこで、プロダクトごとに異なる部分をZapierのPath*6を使って表現することにしました。

結果的に、ツールの細かな挙動もZapierの管理画面からポチポチと変更できるので、「障害対応チャンネルに招待する人を変えたい」「通知先を変えたい」などの変更も気軽にできるようになりました

まとめ

ノーコードでポチポチ気軽に変更できる部分と、コードを書いて作る部分を組み合わせて、「ゆるやかで自然発生的な標準化」の文化を活かしつつも、共通する部分を自動化して運用負荷を下げようと試みました。ノーコードのツールを使うのは初めてで、変更差分の管理やコメントを入れにくいなど慣れない部分もありましたが、Zapierが思ったよりも高機能で驚きました。

今後使われていく中で機能を落としたり追加したり改修が増えてくると、「やっぱり普通にbotに寄せよう」「いやZapier使い倒そう」など作り方自体を変えていくかもしれませんが、現時点では「制約のなかで欲しいものをチャチャっと作る」を実現するのにちょうどいい塩梅にできたかなと思います。幸い、今回の仕組みを作ったあとでこれが必要になるようなトラブルはまだ発生していないので、実際に使ってもらいながら改善していきたいと思います。

今回のツールは障害発生時に使われるものなので出番が少ないほど嬉しい性質のものではありますが、こういった改善のアイディアが出てくるのはポストモーテムで考える範囲を広げて「人の動き方をもっと良くするにはどうすればよいか」について議論することを明示的に組み込んだ一つの成果ではないかと思います。

障害対応フローについては他社さんでも色々な工夫をされていると思うので、「こういう事例しってるよ」「ウチではこんなことしてるよ」など知見がありましたら、ぜひ教えていただければ嬉しいです。

*1:https://speakerdeck.com/tatchnicolas/cdtozhang-hai-dui-ying

*2:詳しくは↑の註釈のURLから資料を見ていただきたいですが、アラートチャンネルでそのままコミュニケーションを始めるとアラートと人間の会話が混ざってしまいますし、Slackだけでは情報の整理という意味で不足を感じたのでSlack+Notionを障害対応時のツールとして両方使っています

*3:https://tech.jxpress.net/entry/slack-app-101

*4:対応フローを整備したときに指針として「false positiveには寛大になろう」と明文化して障害発生を宣言することのハードルを下げているため、実際は何もなくともチャンネル名が重複する可能性を考慮しました

*5:障害対応ページをまとめるページもプロダクトごとに分かれているため、挿入先データベースの指定も異なります。

*6:https://zapier.com/help/create/customize/add-branching-logic-to-zaps-with-paths

AWS 上のシステムでリージョン切り替えの避難訓練を年末にやってみた

f:id:nsmr_jx:20220104095954j:plain あけましておめでとうございます。 サーバーサイドエンジニアの @kimihiro_n です。

今日はAWSに載っているシステムの避難訓練を実施したことについて書いてみようと思います。

弊社が提供している FASTALERT というサービスでは、全国の災害や事件などを検知して報道機関や自治体、インフラを支える企業などにリスク情報として提供しています。 リスク情報を提供するという性質上、情報検知の素早さや網羅性に加えて「システムの可用性」も重要なサービスの要素となっています。

FASTALERT の多くのシステムは AWS の東京リージョンで動いており、複数データセンターを活用した冗長化(マルチAZ)がされています。 しかし、例えば大規模地震のような広域かつ被害の大きい災害の場合、東京リージョン全体にわたって問題が発生する可能性があります。 首都直下型の大きな地震は今後30年以内に70%の確率で発生すると予想されており*1東京リージョンのみに頼らない構成が必要であると考えています。

また災害に限らず、AWSの特定のサービスで障害が発生した場合でも、別のリージョンへ切り替えることで素早く復旧できるケースもあるため、マルチリージョンでシステムを構成しておくことは可用性を高めてくれます。

マルチリージョン化の進め方

AWS のマルチリージョン化については、主要なコンポーネントから少しずつ進めていきました。 FASTALERT のシステムはマイクロサービス的に機能ごとに複数のコンポーネントに分かれており、止まってしまうと困る部分を優先的にリージョン単位で冗長化していきました。

実際の作業としては、リージョン間でネットワークの相互疎通できるよう土台を整え、それからデータベースやアプリケーションの冗長化に手をつけていきました。

要となるデータベースは Amazon Auroraを利用しているので、グローバルデータベース機能が利用できました。 グローバルデータベースを活用すると、データベースのリードレプリカが別リージョンへ簡単に作成できます。 また東京リージョンで障害が発生した際には、フェイルオーバー操作によって書き込み用の Writer へと昇格することが可能です。

コンポーネントの冗長化にあたっては、費用の面もあるのでホットスタンバイするものとコールドスタンバイにしておくものなど適宜使い分けるようにしています。 データベースなどリージョン切り替えに時間を要するものはホットスタンバイとして常時動かしておき、ECSで動くサービスなど素早く起動して利用できるものは停止した状態で冗長化しています。

システム自体の冗長化に加えて、欠かせないのが移行手順のドキュメント化です。 東京リージョンがまるごと利用困難になるレベルの災害を想定すると、東京にいるエンジニアも対応できない状況になっている可能性が高いです。 こうした非常時に、リモートで動けるエンジニアへバトンタッチ出来るようしっかりとしたドキュメントを用意しておく必要があります。

避難訓練の実施

個々のコンポーネントの冗長構成と動作確認は適宜行っていたのですが、全体を通して「東京リージョンの機能がほぼ使えなくなった」シナリオでの検証はやったことがありませんでした。 シナリオを想定してみることで、どの順序で何をすべきかや、どれくらい時間がかかるかなどを洗い出すことが出来ます。

ちょうど年末年始の休みで開発のキリがいいタイミングがあったため、避難訓練を実施してみることにしました。

ちなみにAWS のマネジメントコンソールは使える仮定でやっていました。実際になってみないと何が使えて何が使えないのかは分からないですが、何もかも使えないことに対処しようとすると無限に工数と費用がかかってしまうため、出来る範囲で対処できるケースを増やしていくのが大事だと思います。

避難訓練は、メンバーの1人に画面共有してもらい、マニュアル通りにリージョンの切り替え作業を行ってもらう形で実施しました。 他のメンバーは手順が正しいことを確認しつつ、サービスの状態を監視したり、どのタイミングで何を操作したかの記録を取っていきました。 マニュアルで分かりづらい点や不具合、エラーがでた箇所なども適宜メモしています。

避難訓練を実施してみて

元の状態に戻すところまで含めて1時間ぐらいで終わるかなと思っていたのですが、実際は2時間かかってしまいました

時間がかかった要因としては、マニュアル通りにうまくいかなかったことがあったことが挙げられます。 ドキュメントを作成してからシステム自体に変更が入りそのままでは動かなくなってしまっているケースや、ターゲットのリージョンで操作すべき項目を東京リージョン側で操作してしまったケースなどがありました。 練習なので落ち着いて調査出来ましたが、実際に障害が発生しているときだと大変です。 事前に不備を訓練で洗い出すことが出来たのはよかったです。

またこうすればもっと省力化できそう、みたいなアイデアも出てきたので避難訓練を実施した価値は十分あったかなと思います。 大規模な災害は起こってほしくないですが、万が一の際でも安定して提供できるサービスを作っていきたいですね。

爆速開発を目指して NewsDigest を Flutter にリプレイスします

f:id:jazzsasori:20211128104040p:plain
爆速開発を目指して NewsDigest を Flutter にリプレイスします

JX通信社 Engineering Manager の @jazzsasori です。
最近アークナイツというソシャゲに課金してしまいましたが妻には内緒にしています。

弊社は NewsDigest という無料ニュースアプリを運営しています。
NewsDigest は記者が業務で愛用するほど、その圧倒的スピードに強みがある速報アプリです。また、一般的なニュース分野での速報に加えて、報道はされにくいが個人にとって価値の高い情報も to B 向けのリスク情報SaaS である FASTALERT と連携して即時に伝える、社会派ニュースアプリです。
現在 (2021/11) 500万ダウンロードを突破しており、今後もさらにユーザーを伸ばそうとしています。

なぜリプレイスを行うのか

サービスとしては 2015年 にストアで公開されたので今日現在 (2021年12月) アプリとしてはもうすぐ8年目となります (すごい)。
iOS でいくと Swift 1 → 2 → 3 → 4 という移行も乗り越えてきました。

技術的負債という課題

7年以上運用していると、どれだけ気をつけていてもいわゆる「技術的負債」が溜まってしまいました。
時が経つにつれて技術的負債の課題は深刻なものになっていきました。
例えばある機能を改修する際、

  • 技術的負債となっている部分が解決すれば1日で終わるタスクに2日かかる
  • 簡単に終わると思っていた改修箇所が実は同じ変更内容を2箇所に適用しなければならなかった

など、よくある技術的負債による工数の肥大化が発生していました。

※ NewsDigest はさまざまな立場で多くのエンジニアの方に改修していただいた歴史があります。
技術的負債に関して否定的なことも記述はしておりますが、関わっていただいたメンバーの方に最大限の尊敬と感謝の念を込めて書いています。

爆速開発できるコードベースをつくりたい

技術的負債の側面に加え、我々が描く未来をより速く実現したいと考えています。
例えば弊社の提供する最新感染状況マップ・感染者数情報 は多くの方に価値を感じていただき、テレビ番組などでは多く特集いただきました。
このようなユーザーにとって価値の高い情報をお届けしつつ、多くの方にご利用いただくことにより、たくさんのフィードバックをいただき、さらにサービスを改善していきたいと考えています。

我々は技術的負債の解消・今後の開発速度の向上という二つの観点から Flutter によるフルリプレイスを行う という選択をしました。

Flutter を選択した背景

NewsDigest という大規模でかつユーザーの多いニュースアプリをリプレイスする技術選定としてはさまざまな可能性を模索しました。

  • 継続してSwift / Kotlin で開発
  • React Native
  • Kotlin Multi Platform
  • ...

特に React Native に関しては弊社の Frontend では React を多く使うこともあり、相性がよいのではないか、という議論もありました。

会社としても大きな投資となるため、経営層・マネージャー層で議論を進めていたところ、弊社のエンジニアから「NewsDigest をもし Flutter で書き直すならこんな感じかなと思って書いてみました」という連絡がありました。リポジトリを覗いてみるとけっこう書き進めてくれていて...というストーリーがあり、 チームメンバーからのボトムアップで Flutter という技術選択をしました。

Flutter でどうリプレイスを進めているのか

今回のテーマとしては「技術的負債の解消」という大きなテーマがあります。
今後の開発速度を爆速にするため、なるべく技術的負債が溜まりにくくするため、アーキテクチャにはこだわっています。

具体的には、まず大きな考え方として Clean Architecture の考え方を参考にしつつ Onion Architecture をベースとしたアーキテクチャを選択しています。例えば社内のドキュメントには SOLID 原則 にのっとってコードを書くように案内するようなものもあります。

また、FAT な Widget を作らないようにすることも重要です。
(私も iOS で 2000行を超える FatViewController を書いたこともあります...)
こちらは Atomic Design の考え方にのっとって Widget を設計しています。
何をもって molecules か、何をもって organisms か、というのは難しい問題ですが、考え方をすり合わせながら・ドキュメント化しながらコーディングルールを明確化していっているのが現状です。

ステート管理については Riverpod を採用しています。

今後どう進めていくのか

NewsDigest の規模のアプリをフルリプレイスするのには工数がかかります。
数ヶ月の工数をかけて出来るだけ正しい形、今後の開発速度が爆速になるよう + 負債が溜まりにくくなるような形でリプレイスを完遂しようとしています。
エンジニアだけではなく、セールスメンバーにも理解いただきつつチーム一丸となってリリース目標を立てて進めています。

正直なところ忙しいプロジェクトではあります。
一方で多くのユーザーの方にご利用いただいているニュースアプリを根本から改善する、という挑戦的な面白いプロジェクトでもあり、メンバーにはモチベーション高く挑んでいただいています。  

読んでいただいている Flutter エンジニアの方、ぜひ一度弊社のお話を聞いていただけないでしょうか?
フリーランス、正社員問わず募集しています。
こちらからご応募可能なのでぜひお気軽にご連絡ください 🙇

open.talentio.com


NewsDigest を使ってみたいという方は下記よりダウンロードください 🙇‍♂️
iPhone 版 :
app.adjust.com

Android 版: app.adjust.com

ヘビーユーザーが解説するPyTorch Lightning

こんにちは!私はファンヨンテと申します!JX通信社で機械学習エンジニアを行っております! 私はPyTorch Lightningを初めて使ったときの便利さに感動した以来、PyTorch Lightningのヘビーユーザーです! この解説記事ベビーユーザーの私が皆様にPyTorch Lightningを知っていただき、利用のきっかけになってほしいと思って公開しています!

今回の解説記事のサンプルコードはこちらにあります。ぜひ、実際のコードを手にとって体験しPyTorch Lightningの素晴らしさに触れてみてください!

この記事内容は13回のMLOps勉強会で発表しました! speakerdeck.com

読者の対象

  • PyTorch Lightningを使ってみたいが、最初の初め方がわからない人
  • PyTorch をある程度知っている方

PyTorch Vs PyTorch Lightning

PyTorch について

PyTorchはDeep Learningを実行する時に用いられる代表的なフレームワークですが、自由度が高く、個々人で自由な記述が可能です。これはメリットでもありますが、可読性が非常に低くなるリスクがあります。また、学習時のモデルや結果の保存など、サイエンスの側面とは異なるコード書くタイミングが多く、コーディングに多くの時間を取られます。したがって、PyTorchで機械学習を行うなら、データサイエンティストにはサイエンスの能力とコーディングの2つの能力が求められます。

f:id:yoooongtae:20211112121029p:plain
図1 PyTorchを用いた時に、MLエンジニアが行う業務

PyTorch Lightningについて

一方、データサイエンティストが力を存分に発揮し、これまで想像されなかった革新的なことをしたいなら、コーディングの部分よりサイエンスの領域に力を注ぐべきだと思います。

PyTorch Lightningを用いると、コーディングにかける時間を最小限にし、データサイエンティストがサイエンスの領域に全力を捧げられるようになります。

f:id:yoooongtae:20211112121449p:plain
図2 PyTorch Lightningを用いるとサイエンスの部分に注力できる

JX通信社でPyTorch Lightningを採用した理由

JX通信社のMLチームでは"力を使うべき場所に注力しよう"ということを理念を大事にしており、その工夫の一つとして学習・デプロイのテンプレートコードを作成し、運用しています。そのテンプレートはPyTorch Lightningをベースに記述されています。詳しくはこちらのブログをご参照ください。

PyTorch Lightningとは

PyTorch Lightningは"データサイエンティストはサイエンスに力を捧ぐべきで、コーディングは最小限の労力で"をコンセプトに作成された、PyTorch のラッパーです。PyTorch Lightningは後述するように、コードの書き方に指定があり、サイエンスに関わる部分(モデルのアーキテクチャ、学習方法、前処理の方法など)をメインに書きます。このコードの書き方の適度な矯正はチーム間での可読性を上げる効果もあります!また、GPU、TPUでの学習や、モデルと結果の保存など複雑なコーディンに関わる部分は数行書き足すだけで、実行することができるようになっています。

f:id:yoooongtae:20211112121356p:plain
図3PyTorch LightningのHPより引用

PyTorch Lightningの柔軟性について

この様に聞くと、柔軟性が欠けており、できない実験があるのでは?と感じる方もいると思います。しかし、柔軟性を最大化することがPyTorch Lightningの哲学として紹介されており、ほとんどすべての実験を再現することが可能です。

f:id:yoooongtae:20211112121738p:plain
図4Lightning Design Philosophy Githubより引用

またPyTorch Lightningをずっと使ってきたの私個人的な意見としては、こちらのブログにも記述されている通り、業務においてPyTorch Lightningの記述で困ったことはほとんどなく、メリットのみを享受しています。(しかし、半教師あり学習をPyTorch Lightningで記述しようとした際、Optimizerの記述に沼ってしまい、余計な時間と労力を費やしたことが一度あります。)

PyTorch Lightning の書き方

サイエンスに関わる部分

PyTorch Lightningに書き方のフレームワークがあり、そこを埋めていくように書くと記述しました。このページでは、PyTorch Lightningの書き方を、PyTorch との差分という意識で記述させていただきます。このムービーは非常にわかりやすいムービーですので、こちらのムービーをご覧になっていただいてから、下の章を読んでみてください。 Google Colabで動作するサンプルコードはこちらにあります。ブログの内容と全てが一対一対応してませんが、実際ご自身の手で動かしてみてください!

PyTorch LightningはDataに関わる部分(DatasetやDataloader等)と学習に関わる部分(モデルの構造、学習ループの書き方)に大きく区分されます。

Dataに関する部分(LightningDataModule)

PyTorch では学習に用いるデータはtorch.utils.dataDataLoaderDatasetを用いて記述されます。PyTorch Ligthningでは、LightningDataModuleを継承したクラスを作成し、その中のtrain_dataloader()と名のついた関数にtrainに用いるDataloaderをreturnする必要があります。validation, testも同様です(図5)。ここで注意すべきなのは関数名は変更してはなりません。この様に関数名が完全に決まっているので、他のチームメンバーがコードを読むときにどのデータがどこに利用されているのかわかりやすくなります。 LightningDataModuleには、このページで紹介した関数以外にも多くの関数が定義されています。LightningDataModuleについての詳細は公式のドキュメントをご覧ください.

f:id:yoooongtae:20211115112938p:plain
図5LightningDataModuleについて

学習に関する部分(LightningModule)

PyTorch では、torch.nnのmoduleを継承したクラスでモデルの構造を決定し、その後、学習ループをtrain.pyなどにベタがきすることが多いと思います。PyTorch Lightningでは、モデルの構造と学習ループでの計算挙動をLightningModuleで定義します。

図6はLightningModuleに必要な最低限のコードを示しています。まず、initでモデルのパーツを定義し、forwordでモデルの構造(計算の流れ)を定義します(ここまではPyTorchと同じ)。

次にPyTorch Lightningでは、学習ループの各ステップでどの様に計算してほしいかを記述します。図6では、training_step()という関数が定義されていますが、この関数は学習ループのミニバッチが与えられたときに、どの様に計算するのか記述してます。ここで、return のdict中の"loss"を鍵とする値を基にLightningの方で逆伝播や最適化が行われます。このときに用いられるOptimizerやschedulerはconfigure_optimizers()に定義する必要があります。

図6では、training_step(学習のミニバッチに対する処理)だけが書かれてますが、training_epoch_endvalidation_step, validation_epoch_end, on_fit_startなど、学習ループの各ステップで行ってほしいことの記述が可能になっています。

また、PyTorchでは学習ループの記述の際、データやモデルをCPUやGPUに移動する必要があり、.to(device)が至る所に記述する必要があると思います。PyTorch Lightningを用いているとto(device)の記述しなくても、Lightningが自動で学習に用いるデバイスにデータを移動してくれます。LightningModuleについての他の関数を含む詳細はこちらを御覧ください。

f:id:yoooongtae:20211115113357p:plain
図6 LightningModuleについて

学習の初め方

上で定義した、LightningDataModuleLightningModuleを用いて学習するには、Trainerを定義し、trainer.fitで学習されます(図7)。

f:id:yoooongtae:20211115113500p:plain
図7Trainerで学習が開始される。右の図は実際のコード下の図は学習が開始されたときのコンソールの表示

サイエンスに関わるコードの書き方まとめ

このページではサイエンスに関わるコードの記述法について記述してきました。PyTorch Lightningを利用していると、上記で解説したサイエンス部分の実装に費やす時間が多くなります! 一方、ここまで読まれた方は以下の疑問を思ったと思います。

  • 結果の表示や保存はどうするの?
  • モデルの保存方法は?
  • 学習のデバイスをどの様に指定するのか? こちらは、Lightningを用いていると、数行の追加で終了します。これがLighthningの強みだと思います。詳細は次章以降のページにて解説します!

サイエンス以外の部分のコードの書き方(結果の表示と保存について)

前章ではPytorch Lightningのサイエンスに関わる部分のコードの記述方について解説しました。

ここでは、結果の表示法と保存法について記述します。

前回の解説で、Lightningでは学習ループの段階ごとの計算挙動を記述すると記述しました。このときに見たい/保存したい結果があるとき、self.log()に見たいメトリクスを入れると、結果が記録されます (図8)。それぞれの引数の意味は以下のとおりです。

self.log(
        "test_loss", #metiricsの名前
        loss,#metiricsの値
        prog_bar=True,#プログレスバーに表示するか?
        logger=True,#結果を保存するのか?
        on_epoch=True,#1epoch中の結果を累積した値を利用するのか?
        on_step=True,#1stepの結果を利用するのか?
        )

f:id:yoooongtae:20211115123647p:plain
図8LightningModuleでのログの取り方について

TorchMetricsの利用

TorchMetricsのライブラリを利用すると結果を非常に簡単に扱うことができます。TorchMetricsを用いることで、AccuracyやRecallなど一般的な指標が簡潔に書けるだけでなく、MetiricsCollectionを用いると複数の指標を一度に定義し(図9左)、計算することができます(図9右)。

f:id:yoooongtae:20211115123725p:plain
図9TorchMetricsに利用法

結果の表示

self.log()のprog_bar = Trueに設定したMetricsはプログレスバーに表示されます (図10)。

f:id:yoooongtae:20211115123816p:plain
図10 self.logを入れたMetricsがプログレスバーの下に結果が表示される

結果の保存方法

self.log()のlogger = TrueにしたMetricsの保存はどの様になるのでしょうか?

LightningのTrainerを定義するときにloggerを定義すると、その形式にあったログが自動で保存されます。このloggerにはTensorBoardだけでなく、MLFlow, Cometなど有名な実験管理ライブラリに対応していますので、自身の好きなライブラリを用いることができます。

f:id:yoooongtae:20211115123901p:plain
図11loggerについて。loggerにTensorBoardを定義すると(左上)、自動で結果がTensorBoardの形式で保存される(右)。

JX通信社ではGoogle App EngineにデプロイされたMLFlowに結果が保存される様にしています。詳細はこちらを御覧ください。

サイエンス以外の部分のコードの書き方(callback : モデルの保存とearly stopping)

前章は結果の表示と保存方法について解説しました。 このページではモデルの保存方法とEarly stoppingの方法について説明します。 そのためにはLightningのCallbackと呼ばれる機能について理解する必要があるため、先にcallbackについての解説を行った後にモデルの保存方法とEarly stoppingの方法について説明します。

callbackについて

PyTorch Lightning にはCallbackというmoduleがあります。Callbackを用いると、LightningModuleで定義した学習ループの計算挙動に新たな動作を差し込むことができます。 例として、図12の左の図の様に定義したcallbacksをtrainerに入れると、ミニバッチ学習が始めに"train_batch is starting"、終わりに"train_batch is completed"と表示されるようになります。 ですので、学習で普遍的に用いられる便利系のコードをcallbacksで定義しておくことで、LightningModule内の学習ループの定義は簡潔になります。

f:id:yoooongtae:20211115124839p:plain
図12 callbackとは

JX通信社でのcallbackの利用

JX通信社ではWebHookを通じて学習の進捗状況をSlackに知らせるcallbacksを作成し運用しています。詳しくはこちらを御覧ください!

f:id:yoooongtae:20211115124921p:plain
図13JX通信社でのcallbacksの利用。(https://tech.jxpress.net/entry/2021/10/27/160154より引用)

モデルの保存とearly stopping

Lightningから提供されるBuilt-in callbackを利用することもできます。

ここにはModelCheckpoint, EarlyStoppingなどの便利なcallbackがあります。

ModelCheckpointは指定した指標が最も高い/低いモデルを保存することができるcallbackです。このcallbackの素晴らしいポイントは監視対象を選択し、その値が最大/ 最小のときのモデルを保存するといったベストモデルを保存することができます(図3の場合、valid_lossが最小になるモデルが保存される)。 EarlyStoppingは文字通りearly stoppingのためのcallbackであり、監視対象の値が改良しない場合、学習がストップしてくれます。この2つのcallbacksをTrainerの引数に入れ込むだけで、モデルの保存とearly stoppingを簡単に実装することが可能になります。

f:id:yoooongtae:20211115125127p:plain
図14Lightningが提供するカスタムBuilt in callback(ModelCheckpoint, EarlyStopping)

その他、Built in callbackについては公式のドキュメントをご覧ください。

サイエンス以外の部分のコードの書き方(GPU /TPU /IPUでの学習)

GPU, TPUで学習速度を上げる

Deep learningのモデルはcpuで学習させるには非常に時間がかかります。一方GPU, TPUを用いると、学習期間が数倍~数十倍程度加速されるため、GPU, TPUで学習することが一般的だと考えられます (図15)。 一方PyTorchを用いてGPU, TPUで学習させることを考えると、至る所に.to(divice)を書く必要があり、可読性が下がるだけでなく、不必要なエラーが発生する原因になります。

f:id:yoooongtae:20211115184604p:plain
図15GPU, TPUの利用は学習速度を加速させる。このブログより引用

LightningでのGPU, TPUで利用法

LightningModuleではto(device)のコードを一切書くことなく学習ループを実装します。では、Lightningではどの様にGPU/TPU/IPUで学習することができるのでしょうか? 答えは非常に簡単で、Trainerのgpus/tpu_cores/ipusで利用するGPU/TPU/IPUを指定すると指定したデバイスで学習が開始されます (図16)。

ここで、例えば、gpusの引数にlist型を入れると、学習に用いるGPUのデバイスを詳細に指定することができます。一つの要素しか無いlistを入れるとGPU1台での学習になりますが、複数の要素のlistを入れる、または2以上のintを入れると複数GPUでの並列学習が可能になります。より詳細は公式のドキュメントを御覧ください。

f:id:yoooongtae:20211115182808p:plain
図16学習するデバイスを変更するためにはTrainerの引数を指定するだけで良い

ColabでのTPU利用について

google colabでTPUの利用したい場合は、以下のコマンドを打ってxlaをインストールする必要があります。

!pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.9-cp37-cp37m-linux_x86_64.whl

(ここで、torch_xla-1.9の1.9は利用しているTorch のVersionを指定してください)

その後Tranerのtpu_coresの引数を指定することで、TPUでの学習が開始されます。

PyTorch Lightningより上位のラッパー (PyTorch Lightning Flash)

これまで紹介してきたとおりPyTorch LightningはPyTorchのラッパーでコーディングにかける時間を最小限にし、サイエンスにかける時間を最大化してくれました。

ここではPyTorch Lightningのより高位ラッパーであるPyTorch Lightning Flashについてご紹介させていただきます!

PyTorch Lightning Flashは高レベルのAIフレームワークであり、このライブラリを用いることで、図17のような典型的なタスクであれば、十数行のコードで解決することができます。また、PyTorch Lightningのラッパーであるため、Pytorch Lightningで利用できた機能(callbackや, loggingなど)をそのまま利用することができます。

サンプルコードの前半にPyTorch Lightningで行った画像分類タスクを後半のPyTorch Lightning Flashで実装しています!興味がある方はこちらも触ってみてください。

Lightning Flashができるタスクについてはこちらを御覧ください。

f:id:yoooongtae:20211115184143p:plain
図17Lightning Flashから引用

PyTorch Lightning Flashの個人的な利用方法

PyTorch Lightning Flashは数行のコードで機械学習タスクが完結する反面、柔軟性に乏しく、タスク目的に合わせた工夫が困難です。

したがって、個人的にはPyTorch Lightning Flashでベースラインを構築し、精度が十分ならそのまま利用します。一方、タスクに合わせた工夫が求められたときにはPyTorch Lightningのテンプレートコードを基に実験を行います。PyTorchの利用は最終手段だと思っていて、Lightning単体で実装できない最先端の工夫をしたい時に利用すると思います!

f:id:yoooongtae:20211115190731p:plain
図18 Pytorch のラッパーの個人的な利用方法

我々とともに挑戦する仲間を求めています

我々とともに成長しながら、より良い社会のためのMLを開発したい仲間を社員・インターン問わず積極的に募集しています!また、MLエンジニアはもちろん、あらゆる職種のエンジニアを求めています!

正社員、インターン、おためし入社などなど!ほんの少しでも興味を持たれた方はこちらを覗いてみてください!

日本でPyTorch Lightningを盛り上げていきたい!

個人的に、日本でPyTorch Lightningを盛り上げていきたい!と思っています。

PyTorch Lightningをこれまで使っていた方、これから使いたいと思っている方、質問がある方!積極的に私のtwitter, またはFacebookにご連絡ください!