BERTの推論速度を最大10倍にしてデプロイした話とそのTips

背景

はじめまして、JX通信社でインターンをしている原田です。

近年深層学習ではモデルが肥大化する傾向にあります。2020年にopen aiが示したScaling Laws([2001.08361] Scaling Laws for Neural Language Models) の衝撃は記憶に新しく、MLP-Mixerが示したように、モデルを大きくすればAttention構造やCNNでさえも不必要という説もあります。([2105.01601] MLP-Mixer: An all-MLP Architecture for Vision

しかし大きな深層学習モデルを利用しようとすると、しばしば以下のような問題に悩まされます。

  • 推論速度が問題でプロダクトに実装不可能
  • GPU/TPUはコスト上厳しい
  • プロダクトの性質上バッチ処理が不可能(効率的にGPU/TPUが利用できない)

例えばJX通信社の強みは「速報性」にあるため、バッチ処理が困難であり、効率的なGPU/TPU利用が困難です。

しかし、機械学習モデルの精度はプロダクトのUXと直結するため、「なんとかCPU上で大きなモデルを高速に推論させたい」というモチベーションが発生します。

本記事は以上のような背景から大きなNLPモデルの代表格であるBERTを利用して各高速化手法を検証します。 さらに多くの高速化手法では推論速度と精度のトレードオフが存在し、そのトレードオフに注目して検証を行います。

実際に自分は下記で紹介する方法を組み合わせた結果、BERTの推論速度を最大約10倍まで向上させ、高速に動作させることに成功しました!

まとめ

今回検証した各高速化手法の各評価は以下になります。 (☆ > ◎ > ○ > △ の順で良い)

f:id:haraso1130:20210824183731p:plain
各手法のまとめ

ただし、タスクによって各手法の有効性が大きく変わるので実際に高速化を図る際には、その都度丁寧な検証が必要です。

各手法の説明と実装コード

以下から簡単に各高速化手法の概要と実装コードを解説します。

  • pruning, quantization, distillation, torchscriptはNLP以外でも利用可能な手法
  • max_lengthはNLPモデルであれば利用可能な手法です
  • 動的なmax_lengthはバッチサイズ==1で推論するときに利用可能な手法です。

quantization(量子化)

量子化とは、浮動小数点精度よりも低いビット幅で計算を行ったり、テンソルを格納したりする技術のことです。float32からint8へ変換することが一般的です。

ここではpytorch公式を参考にしました。

pytorch.org

Pytorchでは以下の三種類の量子化が用意されており、今回は最も簡単なdynamic quantizationを学習済みモデルに適応します。

  • dynamic quantization(動的量子化)...weightsのみ量子化し、活性化はfloatで読み書きを行う。学習済みモデルにそのまま適応し、計算を行う。
  • static quantization(動的量子化)...weightsと活性化を両方量子化する。学習後にキャリブレーションが必要である。
  • quantization aware training ...weightsと活性化を両方量子化する。トレーニング中から量子化をおこなう。

実装コードは以下になります。 以下のコードでは、BERTのnn.Linearの重みをfloat32→int8に変換しています。

def quantize_transform(model: nn.Module) -> nn.Module::
  model = torch.quantization.quantize_dynamic(
    model, {torch.nn.Linear}, dtype=torch.qint8
  )
  return model

distillation(蒸留)

蒸留は大きなモデルを教師モデルとし、教師モデルより小さなモデルを作成する手法です。 特にBERTの蒸留版モデルはDistilBERT(https://arxiv.org/pdf/1910.01108.pdf) として紹介されています。

BERT-baseはtransformerを12層利用していますが、DistilBERTはその半分の6層のtransformerを持った構造になっています。

また、損失関数は以下の三つから構成されており、解釈として「masked language task(単語穴埋め問題) をこなしながら、教師モデルと近い出力と重みを獲得する」と捉えることができます。

  • BERTのoutputとの近さ
  • masked language taskでの損失
  • BERTのパラメータとのコサイン類似度

今回の実験ではバンダイナムコが公開している日本語版distillbertモデルを利用しました。 https://huggingface.co/bandainamco-mirai/distilbert-base-japanese

huggingfaceのtransformersを利用することでとても簡単に使うことができます。

from transformers import AutoTokenizer, AutoModel
  
tokenizer = AutoTokenizer.from_pretrained("bandainamco-mirai/distilbert-base-japanese")

model = AutoModel.from_pretrained("bandainamco-mirai/distilbert-base-japanese")

pruning(剪定)

モデルの重みの一定割合で0にする手法で、モデルをスパースにすることができます。

ここでもpytorch公式のtutorialに沿って実装します。

pytorch.org

どの重みを剪定するかはさまざまな研究がありますが、ここでは上記tutorialで紹介されていたL1ノルム基準で削る手法を用いました。絶対値が小さい重みは重要度が低いと考えられるため0にしてしまうという発想はとても直感的です。

実装コードは以下になります。

import torch.nn.utils.prune as prune

PRUNE_RATE = 0.2

def prune_transform(model: nn.Module) -> nn.Module:
  for name, module in model.named_modules():
    if isinstance(module, torch.nn.Linear):
        prune.l1_unstructured(module, name='weight', amount=PRUNE_RATE)
        prune.remove(module, "weight")
  return model

上記のコードではモデル中のnn.Linearの重みのうち、絶対値が小さいものから20%を0に置き換えるという処理になります。

今回は複数のPRUNE_RATEで推論速度と精度の変化を実験しました。

torchscript(Jit)

TorchScriptは、PyTorchのコードからシリアライズ可能で最適化可能なモデルを作成する手法です。Python以外のC++等のランタイムで実行可能になります。

Pytorhはdefine by run方式を採用しており、動的に計算グラフを作成します。学習時には非常に有用なこの形式ですが、プロダクション上の推論時における恩恵はほとんどありません。

そこで、先にデータを流してコンパイルしてしましまおう(実行時コンパイラを使おう)というのが大まかな発想です。

より詳細な解説は以下の記事が非常にわかりやすいです。

towardsdatascience.com

簡単に解説すると、 - Torchscriptは中間表現コード - この中間表現は内部的に最適化されており、実行時に pytorchの実行時コンパイラであるPyTorch JIT compilationを利用する。 - PyTorch JIT compilationはpythonランタイムから独立しており、実行時の情報を用いて中間表現を最適化する

実装コードは以下になります。 torchscriptにはtraceとscriptの二つの作成方法がありますが、ここでは後からでも簡単に作成できるtraceを用います。

def torchscript_transform(model):
  model = torch.jit.trace(model, (SANPLE_INTPUT))
  return model

max_length

inputのmax_lengthを制限して入力データを軽くします。 transformersで前処理を行う場合、以下のような実装になります。

from transformers import BertTokenizer

MAX_LENGTH = 512

tokenizer = BertTokenizer.from_pretrained("hoge_pretrain")

data = tokenizer.encode_plus(
            TEXT,
            add_special_tokens=True,
            max_length=MAX_LENGTH,
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        )

do_not_pad

この手法はbatch_size==1で推論する場合に利用可能な手法です。

通常batch推論をするために入力データのpaddingが必要ですが、batch_size==1の状況下ではパディングを行わずに推論することができます。

実装は以下になります。padding引数に'do_not_pad'を設定するだけです。

from transformers import BertTokenizer

tokenizer = BertTokenizer.from_pretrained("hoge_pretrain")

data = tokenizer.encode_plus(
            TEXT,
            add_special_tokens=True,
            max_length=512,
            padding="do_not_pad",
            truncation=True,
            return_tensors="pt",
        )

実験方法

今回の実験は精度と速度のトレードオフを測定することが主眼であるため、丁寧に精度の調査を行います。

環境

実行環境はgoogle colabで統一してあります。

Dataset

後述しますが、データセットによって有効な手法が変わるため特性が異なる複数のサンプルタスクを用意しました。

  • 一文が長いデータセット(livedoorトピック分類)
  • 一文が短いデータセット(twitter感情分類、ポジネガの2値分類で検証)

modelについて

精度評価方法

  • まず8:2でtrain/testに分割
  • trainのみを利用し、5fold stratified cross validation(全ての実験でfoldは固定)でモデルを学習
  • 5つのモデルでそれぞれtestに対して推論、averageしたものをtestの予測値とする。
  • cvとtestのacc & f1 macroで比較

速度評価方法

  • testセットからランダムに500個のデータをサンプリングし(全ての実験で共通)、batch_size==1で推論
  • 各データに対する推論時間の平均値と標準偏差で評価

結果

まず、各手法に対するtest scoreと速度のplotは以下のようになりました。 グラフの見方ですが、以下の通りです。

  • 一番左がベースライン
  • 赤と黄色のバーは精度を表しており上方向の方が良い
  • 青い点は推論時間で下方向の方が良い
  • エラーバーは標準偏差

twitter感情分類

f:id:haraso1130:20210824180416p:plain
twitterデータに対する精度と速度

livedoorトピック分類

f:id:haraso1130:20210824180443p:plain
livedoorデータに対する精度と速度

詳細な結果は以下になります。

twitter感情分類

手法 cv acc (f1-macro) test acc (f1-macro) 平均推論速度(s) 標準偏差(s)
BASELINE 0.8295 (0.8193) 0.8363 (0.8256) 0.2150 0.0050
quantization 0.8223 (0.8092) 0.8283 (0.8150) 0.1700 0.0048
distillation 0.8388 (0.8313) 0.8292 (0.8220) 0.1547 0.0076
max_length:64 0.8212 (0.8103) 0.8250 (0.8138) 0.1156 0.0036
do_not_pad 0.8295 (0.8193) 0.8363 (0.8256) 0.0987 0.0290
torchscript 0.8295 (0.8193) 0.8363 (0.8256) 0.1847 0.0080
pruning: 0.2 0.8327 (0.8226) 0.8283 (0.8173) 0.2124 0.0043
pruning: 0.4 0.8095 (0.7972) 0.8229 (0.8100) 0.1925 0.0041
pruning: 0.6 0.7097 (0.6787) 0.7597 (0.7198) 0.1925 0.0044
pruning: 0.8 0.5809 (0.5024) 0.6220 (0.3834) 0.1912 0.0046

livedoorトピック分類

手法 cv acc (f1-macro) test acc (f1-macro) 平均推論速度(s) 標準偏差(s)
BASELINE 0.9238 (0.9180) 0.9348 (0.9285) 0.7500 0.0079
quantization 0.9022 (0.8962) 0.9246 (0.9199) 0.6565 0.0068
distillation 0.8581 (0.8494) 0.8723 (0.8646) 0.5128 0.0079
max_length:256 0.8691 (0.8630) 0.8676 (0.8605) 0.4511 0.0062
do_not_pad 0.9238 (0.9180) 0.9348 (0.9285) 0.7012 0.0926
torchscript 0.9238 (0.9180) 0.9348 (0.9285) 0.7222 0.0083
pruning: 0.2 0.9204 (0.9144) 0.9355 (0.9302) 0.7633 0.0083
pruning: 0.4 0.8674 (0.8624) 0.8900 (0.8846) 0.7682 0.0084
pruning: 0.6 0.1973 (0.1176) 0.2057 (0.1025) 0.7496 0.1045
pruning: 0.8 0.1360 (0.0950) 0.1140 (0.0227) 0.7287 0.0075

考察

それぞれの手法についてより性能をわかりやすく表示するためBASEの精度と速度を1とし、各手法の性能を考察していきます。

twitter感情分類

f:id:haraso1130:20210824180325p:plain
Twitterデータに対する相対的な精度と速度

livedoorトピック分類

f:id:haraso1130:20210824180428p:plain
livedoorデータに対する相対的な精度と速度

quantization(量子化)

どちらのタスクにおいても殆ど精度を落とさずに推論時間を10~20%ほど削減することが可能です。 実装も容易であるため、高速化の際にはまず試してみたい手法です。

distillation(蒸留)

精度面ではタスクによって大きく結果が異なることがわかります。 twitterデータに対しては殆ど精度低下が見れらませんが、livedoorデータに対してはある程度の精度低下が認められます。

推論時間については約30%ほど削減できており、タスクによっては非常に有効な選択肢になり得ます。

max_length

どちらのタスクでも推論時間を40%~45%ほど削減できており、高速化において最も安定して寄与したといえます。

非常にインパクトが大きくセンシティブなパラメータであるため、ある程度速度が求められるシチュエーションの場合、まず初めにチューニングすべきパラメータです。

do_not_pad

この手法はデータセットによって大きく効果が異なる結果となりましたが、精度は不変であるため、バッチ処理が不可能な状況下では積極的に利用すべきです。

特に最大文字数が少なく、文字数の分散が大きいと考えられるツイッターデータセットではdo_not_padの影響は大きく、葯50%の推論時間をセーブすることができました。

torchscript

精度を落とさずに、少しではありますが推論速度を向上させることができます。

また、torchscriptはその他にも多くのメリットを有しており(Python以外のランタイムで実行可能、推論時にネットワークの定義が不要など)、プロダクションにデプロイする際はONNX等と並ぶ選択肢となります。

Pruning

今回の実験ではかなり微妙な結果でした。

twittterデータセットではpruning:0.4で10%ほどの推論時間削減を達成しましたが、その他の手法のトレードオフと比較するとコストパフォーマンスが低い印象です。その他の手法を全て適応した後、それでも高速化が必要ならば検討する、といったものになるでしょう。

また、livedoorデータセットにおいてはまさかの低速化に寄与する結果となってしまいました。

まとめ

本記事ではNLPモデルを高速にCPU上で動作させるため、各高速化手法について検証してきました。 タスクによって各手法のパフォーマンスが大きく異なるため、必要な精度と速度を見極めた後、最適な高速化手法の組み合わせを模索することが重要です。

その他にも有効な高速化手法があれば教えてくださると幸いです。

参考