CNNを使った分類問題の判断根拠(画像編)

この記事は JX通信社Advent Calendar の 12 日目です。

FASTALERT チーム機械学習エンジニアの mapler です。FASTALERT の機械学習とサーバーサイドの開発をしています。

FASTALERT(ファストアラート)は、SNSから事件・事故・災害等の緊急情報を検知し、配信する緊急情報配信サービスです。その処理の中でも画像認識は欠かせない存在です。

今回お話しするのは CNN (Convolutional Neural Network,または畳み込みニューラルネットワーク) というニューラルネットのモデルです。CNN は行列の空間情報を捉えるため、特に画像認識分野では非常に有効な手法です。

convolution
from Performing Convolution Operations

CNN の解釈性

ニューラルネットワークはとても有効な一方、その根拠が解釈しにくいとよく言われています。FASTALERT でも、ニュースの価値を判別するアルゴリズムを改善するためにモデルの判定結果を解釈することは重要です。

CNN もしくは深層学習の解釈性について、icoxfog417さんディープラーニングの判断根拠を理解する手法 でたくさんの研究や手法が紹介されています。

今回は Grad-CAM という判定根拠の可視化方法について実験してみようと思います。

Grad-CAM の仕組み

p
from Introduction to Deep Learning: What Are Convolutional Neural Networks?

上の図に、入力画像の受容野(Receptive field)がいくつかの Convolution + ReLU + Pooling 層を通って、一次元ベクトルに Flatten される直前まで、位置が変わらないことがわかります。(左上の部分が複数の Convlution 層を通った後の出力でも左上にあります)

A number of previous works have asserted that deeper representations in a CNN capture higher-level visual constructs [5, 35]. Furthermore, convolutional features naturally retain spatial information which is lost in fully-connected layers, so we can expect the last convolutional layers to have the best compromise between high-level semantics and detailed spatial information.

from: Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization

(拙訳)たくさんの研究によって、より深い層ではより高度、豊富な特徴が捉えられると示されています。しかし、Flatten によって、空間情報は分類の fully-connected 層で失われます。最後の CNN 層は分類特徴と空間情報を両方持つ、可視化に最も利用できる層になります。

Grad-CAM はこの最後の CNN 層の勾配を利用して、どのニューラルのどの部分が出力のどの分類に一番貢献したかを計算します。

CAM
CAM from Learning Deep Features for Discriminative Localization

上の図は CAM という Grad-CAM が登場する前の CNN 根拠可視化手法です。

CAM は Grad-CAM と異なり、勾配を利用するのではなく、CNN 層の後の Fully-Connected 層と一つの GAP(Global Average Pooling)に入れ替えています。この GAP は(豊富な特徴情報を持っている)最後の CNN の出力の特徴図(Feature Map)を Pooling して、分類のクラスとマッピングします(Class Activation Mapping)。

上の図で犬(Australian terrier)を示す w_2(赤い四角)と w_n(緑の四角)の特徴図の重みは w_1(青い四角)より強いのがわかります。(逆に人間を判定する場合、w_1 の重みは強くなるでしょう。)

f:id:maplerme:20181212104555p:plain
from: Learning Deep Features for Discriminative Localization
こうやって重み w を付けて特徴図の加重合計(Weighted Sum)の結果、できた図は Class Activation Map(CAM)となります。 一方、Grad-CAM は GAP 層の入れ替え不要で、逆伝播の時の勾配を利用して、特徴図の重みを実現しています。
f:id:maplerme:20181212104717p:plain
from: Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization
上の式の \alpha^{c}_k は CNN の出力特徴図(Feature Map)A^{k} がクラス c に判定される確率の y^{c} に対する偏微分、もしくは勾配(gradient)となります。

ここの勾配はニューラルネットワークの逆伝播で計算され、特徴図 A^{k}_{ij} の中の i、j 位置のピクセルの変化に対し、クラス c に判定される確率の影響を表しています。そして、この勾配はちょうど CAM の重みと同じになることを論文の中でも証明しています。(Grad-CAM は汎用化した CAM だと論文の作者は言ってます。)

f:id:maplerme:20181212104824p:plain
from: Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization
CAM と同様に、重みの α と特徴図の A を加重合計して、重みの代わりに勾配(Gradient)を使った CAM (すなわち Grad-CAM)ができます。ちなみに、クラス判定にマイナス影響をする部分を"非表示"にするため、活性化関数 ReLU をかけています。

grad-cam
Grad-CAM Flow from http://gradcam.cloudcv.org/

Grad-CAM を火災画像で試してみる

モデルの作成:

学習データ:

火災 の画像約5000枚。
火災ではない 画像同じ約5000枚。

モデル:

f:id:maplerme:20181212105046p:plain
ResNet34 from https://arxiv.org/abs/1512.03385 (rotated)

ResNet34 の ImageNet の Pre-Train モデルを利用して、三番目のブロックから再学習します。最初の出力層は二項分類にします。

(学習バッチは PyTorch で実装していますが、省略させていただきます。)

評価

学習結果は以下になります

>> print(metrics.classification_report(gts, predict_labels))
(threshold = 0.5)
              precision    recall  f1-score   support

       False       0.78      0.57      0.66      5605
        True       0.66      0.84      0.74      5605

Confusion Matrix: f:id:maplerme:20181212105332p:plain

火災ではない画像が間違えて火災に判定されたことが結構多いです。True/False どっちも良い精度とは言えません。これから Grad-CAM でどこで間違えているのかを可視化してみましょう。

Grad-CAM を実装

PyTorch で実装となります。

class GradCAM:
    def __init__(self, model, feature_layer):
        self.model = model
        self.feature_layer = feature_layer
        self.model.eval()
        self.feature_grad = None
        self.feature_map = None
        self.hooks = []

        # 最終層逆伝播時の勾配を記録する
        def save_feature_grad(module, in_grad, out_grad):
            self.feature_grad = out_grad[0]
        self.hooks.append(self.feature_layer.register_backward_hook(save_feature_grad))

        # 最終層の出力 Feature Map を記録する
        def save_feature_map(module, inp, outp):
            self.feature_map = outp[0]
        self.hooks.append(self.feature_layer.register_forward_hook(save_feature_map))

    def forward(self, x):
        return self.model(x)

    def backward_on_target(self, output, target):
        self.model.zero_grad()
        one_hot_output = torch.zeros([1, output.size()[-1]])
        one_hot_output[0][target] = 1
        output.backward(gradient=one_hot_output, retain_graph=True)

    def clear_hook(self):
        for hook in self.hooks:
            hook.remove()

こちらは PyTorch の register_forward_hookregister_backward_hook メソッドで最終の CNN 層の出力(Feature Map)と逆伝播時の勾配(Gradient)を記録します。

画像を Grad-CAM に入れて可視化までの実装

まずはモデルをロードする。

image_model_path = "./fire.model"
image_model_save_point = torch.load(image_model_path)
image_model = models.resnet34(pretrained=False, num_classes=2)  # モデルを定義
image_model.load_state_dict(image_model_save_point['state_dict'])  # 保存したパラメータをモデルにロードする
image_model.eval()
id_to_label = {
    0: 'other',
    1: 'fire'
}

Grad-CAM class にモデルを代入するかたちになります。

grad_cam = GradCAM(model=image_model, feature_layer=list(image_model.layer4.modules())[-1])

PyTorch の ResNet モデルの layer4 は最後のブロックで、その最後の module (最終の CNN 層)を取得して、GradCAMの feature_layer に渡します。

画像を開いて前処理:

from PIL import Image
from torchvision.transforms.functional import to_pil_image

VISUALIZE_SIZE = (224, 224)  # 可視化する時に使うサイズ。PyTorch ResNet の Pre-Train モデルのデフォルト入力サイズを使います

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

image_transform = transforms.Compose([
        transforms.Resize(VISUALIZE_SIZE),
        transforms.ToTensor(),
        normalize])

path = "./fire.jpg"
image = Image.open(path)
image.thumbnail(VISUALIZE_SIZE, Image.ANTIALIAS)
display(image)

# save image origin size
image_orig_size = image.size # (W, H)

img_tensor = image_transform(image)
img_tensor = img_tensor.unsqueeze(0)

画像を Grad-CAM に入れる

model_output = grad_cam.forward(img_tensor)
target = model_output.argmax(1).item()

予測された class を取得して、逆伝播にいれる

grad_cam.backward_on_target(model_output, target)

最終層の勾配と出力を取得して、hooks をクリア

import numpy as np
# Get feature gradient
feature_grad = grad_cam.feature_grad.data.numpy()[0]
# Get weights from gradient
weights = np.mean(feature_grad, axis=(1, 2))  # Take averages for each gradient
# Get features outputs
feature_map = grad_cam.feature_map.data.numpy()
grad_cam.clear_hook()

勾配(重み weights)と出力の特徴図(Feature Map)の加重合計で CAM を算出して、ReLU を通します

# Get cam
cam = np.sum((weights * feature_map.T), axis=2).T
cam = np.maximum(cam, 0)  # apply ReLU to cam

CAM を可視化するために、resize して正規化

import cv2
cam = cv2.resize(cam, VISUALIZE_SIZE)
cam = (cam - np.min(cam)) / (np.max(cam) - np.min(cam))  # Normalize between 0-1
cam = np.uint8(cam * 255)  # Scale between 0-255 to visualize

元画像に CAM を合成

activation_heatmap = np.expand_dims(cam, axis=0).transpose(1,2,0)
org_img = np.asarray(image.resize(VISUALIZE_SIZE))
img_with_heatmap = np.multiply(np.float32(activation_heatmap), np.float32(org_img))
img_with_heatmap = img_with_heatmap / np.max(img_with_heatmap)
org_img = cv2.resize(org_img, image_orig_size)

可視化

import matplotlib.pyplot as plt
plt.figure(figsize=(20,10))
plt.subplot(1,2,1)
plt.imshow(org_img)
plt.subplot(1,2,2)
plt.imshow(cv2.resize(np.uint8(255 * img_with_heatmap), image_orig_size))
plt.show()

実際の画像を入れてみる

Flickr にある Commercial use allowed の火災写真をモデルに入れてみます。

まずは正解例から:

f:id:maplerme:20181212105805p:plain
Image Source: https://flic.kr/p/Pf9dW3
(other: 0.0100, fire: 0.9900)

この写真は 0.99 で正しく火災写真と判定されました。Grad-CAM の結果も正しく火災の場所を特定できていると思います。(消防員も特定してほしかった、、、後、この写真は火災訓練っぽいので、本当は 火災ではない が正解かもしれません。)

もう一つ正解例:

f:id:maplerme:20181212105840p:plain
Image Source: https://flic.kr/p/M9wgsU
(other: 0.2345, fire: 0.7655)

モデルは煙をうまく見つけています。

悪い例をも見てみましょう:

f:id:maplerme:20181212105855p:plain
Image Source: https://flic.kr/p/29d3vTz
(other: 0.7869, fire: 0.2131)

Grad-CAM は正しそうに炎に特定したが、判定結果をみたら 0.7869 で other と判定されました。つまり、学習したモデルはこの炎で写真は火災写真ではないと判定しました、、、(たしかに炎がすごすぎで、フェイクっぽいかもしれないですね)

ちなみに強制で 火災 にしてみたらどうなるでしょうか?

# target = model_output.argmax(1).item()  # 予測値をコメントアウトして target を 1 に指定して逆伝播させる
target = 1
grad_cam.backward_on_target(model_output, target)

f:id:maplerme:20181212105936p:plain (other: 0.7869, fire: 0.2131)

消防車の部分を見ていました!

※ 上記のソースコードは https://github.com/mapler/gradcam-pytorch においてあります。

まとめ

今回は Grad-CAM という手法で CNN が画像のどこを見て判定しているかを可視化してみました。Grad-CAM を利用したモデルを可視化することによって、モデルが何を学習したか、何を学習不足なのかがわかるので、実業務の中でモデルのチューニング、学習データの選別などの領域で活用できます。

FASTALERT が扱うような SNS の投稿には、一般的に、画像だけでなくテキストも含まれていますが、このような自然言語の分類タスクに関しても TextCNN などの CNN を使った先行研究が存在しています。次回は、GradCAM を使った TextCNN の可視化を紹介したいと思います。

References


JX通信社で一緒に働いてくださる機械学習エンジニアを絶賛募集中です。
アプリエンジニア、サーバサイドエンジニアも募集しています。
まずは話を聞くだけでも構いませんので、気軽にご連絡ください

www.wantedly.com