この記事は JX通信社Advent Calendar の 12 日目です。
FASTALERT チーム機械学習エンジニアの mapler です。FASTALERT の機械学習とサーバーサイドの開発をしています。
FASTALERT(ファストアラート)は、SNSから事件・事故・災害等の緊急情報を検知し、配信する緊急情報配信サービスです。その処理の中でも画像認識は欠かせない存在です。
今回お話しするのは CNN (Convolutional Neural Network,または畳み込みニューラルネットワーク) というニューラルネットのモデルです。CNN は行列の空間情報を捉えるため、特に画像認識分野では非常に有効な手法です。
CNN の解釈性
ニューラルネットワークはとても有効な一方、その根拠が解釈しにくいとよく言われています。FASTALERT でも、ニュースの価値を判別するアルゴリズムを改善するためにモデルの判定結果を解釈することは重要です。
CNN もしくは深層学習の解釈性について、icoxfog417さん の ディープラーニングの判断根拠を理解する手法 でたくさんの研究や手法が紹介されています。
今回は Grad-CAM という判定根拠の可視化方法について実験してみようと思います。
Grad-CAM の仕組み
上の図に、入力画像の受容野(Receptive field)がいくつかの Convolution + ReLU + Pooling 層を通って、一次元ベクトルに Flatten される直前まで、位置が変わらないことがわかります。(左上の部分が複数の Convlution 層を通った後の出力でも左上にあります)
A number of previous works have asserted that deeper representations in a CNN capture higher-level visual constructs [5, 35]. Furthermore, convolutional features naturally retain spatial information which is lost in fully-connected layers, so we can expect the last convolutional layers to have the best compromise between high-level semantics and detailed spatial information.
from: Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization
(拙訳)たくさんの研究によって、より深い層ではより高度、豊富な特徴が捉えられると示されています。しかし、Flatten によって、空間情報は分類の fully-connected 層で失われます。最後の CNN 層は分類特徴と空間情報を両方持つ、可視化に最も利用できる層になります。
Grad-CAM はこの最後の CNN 層の勾配を利用して、どのニューラルのどの部分が出力のどの分類に一番貢献したかを計算します。
上の図は CAM という Grad-CAM が登場する前の CNN 根拠可視化手法です。
CAM は Grad-CAM と異なり、勾配を利用するのではなく、CNN 層の後の Fully-Connected 層と一つの GAP(Global Average Pooling)に入れ替えています。この GAP は(豊富な特徴情報を持っている)最後の CNN の出力の特徴図(Feature Map)を Pooling して、分類のクラスとマッピングします(Class Activation Mapping)。
上の図で犬(Australian terrier)を示す (赤い四角)と
(緑の四角)の特徴図の重みは
(青い四角)より強いのがわかります。(逆に人間を判定する場合、
の重みは強くなるでしょう。)
は CNN の出力特徴図(Feature Map)
がクラス c に判定される確率の
に対する偏微分、もしくは勾配(gradient)となります。
ここの勾配はニューラルネットワークの逆伝播で計算され、特徴図 の中の i、j 位置のピクセルの変化に対し、クラス c に判定される確率の影響を表しています。そして、この勾配はちょうど CAM の重みと同じになることを論文の中でも証明しています。(Grad-CAM は汎用化した CAM だと論文の作者は言ってます。)
Grad-CAM を火災画像で試してみる
モデルの作成:
学習データ:
火災
の画像約5000枚。
火災ではない
画像同じ約5000枚。
モデル:
ResNet34 の ImageNet の Pre-Train モデルを利用して、三番目のブロックから再学習します。最初の出力層は二項分類にします。
(学習バッチは PyTorch で実装していますが、省略させていただきます。)
評価
学習結果は以下になります
>> print(metrics.classification_report(gts, predict_labels)) (threshold = 0.5) precision recall f1-score support False 0.78 0.57 0.66 5605 True 0.66 0.84 0.74 5605
Confusion Matrix:
火災ではない画像が間違えて火災に判定されたことが結構多いです。True/False どっちも良い精度とは言えません。これから Grad-CAM でどこで間違えているのかを可視化してみましょう。
Grad-CAM を実装
PyTorch で実装となります。
class GradCAM: def __init__(self, model, feature_layer): self.model = model self.feature_layer = feature_layer self.model.eval() self.feature_grad = None self.feature_map = None self.hooks = [] # 最終層逆伝播時の勾配を記録する def save_feature_grad(module, in_grad, out_grad): self.feature_grad = out_grad[0] self.hooks.append(self.feature_layer.register_backward_hook(save_feature_grad)) # 最終層の出力 Feature Map を記録する def save_feature_map(module, inp, outp): self.feature_map = outp[0] self.hooks.append(self.feature_layer.register_forward_hook(save_feature_map)) def forward(self, x): return self.model(x) def backward_on_target(self, output, target): self.model.zero_grad() one_hot_output = torch.zeros([1, output.size()[-1]]) one_hot_output[0][target] = 1 output.backward(gradient=one_hot_output, retain_graph=True) def clear_hook(self): for hook in self.hooks: hook.remove()
こちらは PyTorch の register_forward_hook
と register_backward_hook
メソッドで最終の CNN 層の出力(Feature Map)と逆伝播時の勾配(Gradient)を記録します。
画像を Grad-CAM に入れて可視化までの実装
まずはモデルをロードする。
image_model_path = "./fire.model" image_model_save_point = torch.load(image_model_path) image_model = models.resnet34(pretrained=False, num_classes=2) # モデルを定義 image_model.load_state_dict(image_model_save_point['state_dict']) # 保存したパラメータをモデルにロードする image_model.eval() id_to_label = { 0: 'other', 1: 'fire' }
Grad-CAM class にモデルを代入するかたちになります。
grad_cam = GradCAM(model=image_model, feature_layer=list(image_model.layer4.modules())[-1])
PyTorch の ResNet モデルの layer4 は最後のブロックで、その最後の module (最終の CNN 層)を取得して、GradCAMの feature_layer
に渡します。
画像を開いて前処理:
from PIL import Image from torchvision.transforms.functional import to_pil_image VISUALIZE_SIZE = (224, 224) # 可視化する時に使うサイズ。PyTorch ResNet の Pre-Train モデルのデフォルト入力サイズを使います normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) image_transform = transforms.Compose([ transforms.Resize(VISUALIZE_SIZE), transforms.ToTensor(), normalize]) path = "./fire.jpg" image = Image.open(path) image.thumbnail(VISUALIZE_SIZE, Image.ANTIALIAS) display(image) # save image origin size image_orig_size = image.size # (W, H) img_tensor = image_transform(image) img_tensor = img_tensor.unsqueeze(0)
画像を Grad-CAM に入れる
model_output = grad_cam.forward(img_tensor)
target = model_output.argmax(1).item()
予測された class を取得して、逆伝播にいれる
grad_cam.backward_on_target(model_output, target)
最終層の勾配と出力を取得して、hooks をクリア
import numpy as np # Get feature gradient feature_grad = grad_cam.feature_grad.data.numpy()[0] # Get weights from gradient weights = np.mean(feature_grad, axis=(1, 2)) # Take averages for each gradient # Get features outputs feature_map = grad_cam.feature_map.data.numpy() grad_cam.clear_hook()
勾配(重み weights)と出力の特徴図(Feature Map)の加重合計で CAM を算出して、ReLU を通します
# Get cam cam = np.sum((weights * feature_map.T), axis=2).T cam = np.maximum(cam, 0) # apply ReLU to cam
CAM を可視化するために、resize して正規化
import cv2 cam = cv2.resize(cam, VISUALIZE_SIZE) cam = (cam - np.min(cam)) / (np.max(cam) - np.min(cam)) # Normalize between 0-1 cam = np.uint8(cam * 255) # Scale between 0-255 to visualize
元画像に CAM を合成
activation_heatmap = np.expand_dims(cam, axis=0).transpose(1,2,0) org_img = np.asarray(image.resize(VISUALIZE_SIZE)) img_with_heatmap = np.multiply(np.float32(activation_heatmap), np.float32(org_img)) img_with_heatmap = img_with_heatmap / np.max(img_with_heatmap) org_img = cv2.resize(org_img, image_orig_size)
可視化
import matplotlib.pyplot as plt plt.figure(figsize=(20,10)) plt.subplot(1,2,1) plt.imshow(org_img) plt.subplot(1,2,2) plt.imshow(cv2.resize(np.uint8(255 * img_with_heatmap), image_orig_size)) plt.show()
実際の画像を入れてみる
Flickr にある Commercial use allowed の火災写真をモデルに入れてみます。
まずは正解例から:
この写真は 0.99 で正しく火災写真と判定されました。Grad-CAM の結果も正しく火災の場所を特定できていると思います。(消防員も特定してほしかった、、、後、この写真は火災訓練っぽいので、本当は 火災ではない
が正解かもしれません。)
もう一つ正解例:
モデルは煙をうまく見つけています。
悪い例をも見てみましょう:
Grad-CAM は正しそうに炎に特定したが、判定結果をみたら 0.7869 で other
と判定されました。つまり、学習したモデルはこの炎で写真は火災写真ではないと判定しました、、、(たしかに炎がすごすぎで、フェイクっぽいかもしれないですね)
ちなみに強制で 火災
にしてみたらどうなるでしょうか?
# target = model_output.argmax(1).item() # 予測値をコメントアウトして target を 1 に指定して逆伝播させる target = 1 grad_cam.backward_on_target(model_output, target)
(other: 0.7869, fire: 0.2131)
消防車の部分を見ていました!
※ 上記のソースコードは https://github.com/mapler/gradcam-pytorch においてあります。
まとめ
今回は Grad-CAM という手法で CNN が画像のどこを見て判定しているかを可視化してみました。Grad-CAM を利用したモデルを可視化することによって、モデルが何を学習したか、何を学習不足なのかがわかるので、実業務の中でモデルのチューニング、学習データの選別などの領域で活用できます。
FASTALERT が扱うような SNS の投稿には、一般的に、画像だけでなくテキストも含まれていますが、このような自然言語の分類タスクに関しても TextCNN などの CNN を使った先行研究が存在しています。次回は、GradCAM を使った TextCNN の可視化を紹介したいと思います。
References
- Learning Deep Features for Discriminative Localization
- Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization
- ディープラーニングの判断根拠を理解する手法
- 深層学習は画像のどこを見ている!? CNNで「お好み焼き」と「ピザ」の違いを検証
- https://github.com/jacobgil/pytorch-grad-cam
- https://github.com/GunhoChoi/Grad-CAM-Pytorch
JX通信社で一緒に働いてくださる機械学習エンジニアを絶賛募集中です。
アプリエンジニア、サーバサイドエンジニアも募集しています。
まずは話を聞くだけでも構いませんので、気軽にご連絡ください