takminの書きっぱなし備忘録 @はてなブログ

主にコンピュータビジョンなど技術について、たまに自分自身のことや思いついたことなど

PyTorchで学習したモデルをOpenCVで使う

以前、Keras+Tensorflowで学習したモデルをOpenCVで推論に使用する方法について解説したことがありました

 

www.slideshare.net

OpenCVにはDNNモジュールという畳み込みニューラルネットワークを使用するための機能があります。ただこれは主に推論用で、学習のためには別のディープラーニングフレームワークで作成したモデルを別途読み込む必要があります。

OpenCVはTensorflowやCaffe等いくつかのフレームワークをサポートしているのですが、前回は初学者にも使いやすいだろうという理由でKears+Tensorflowのモデルを選択しました。なお、OpenCVはTorchはサポートしてますがPyTorchはサポートしてませんでした。

 

しかしながら、OpenCVはバージョン4.0以降のONNXのサポートにより、様々なディープラーニングフレームワークに対応できるようになったため、今回PyTorchで作成したモデルも実際に読み込めるか試してみることにしました。

尚、わざわざOpenCV上でCNNを行う動機ですが、OpenCVで開発したコードの一部にCNNを使用したいケースや、C++で推論を行うことで高速化したいときに、追加ライブラリなしに使用できるのは魅力ではないかと思います。

今回もKerasを使用した時と同様、MNISTデータセットに対してLeNetを使用して手書き数字認識を行うという例で説明します。

 

PyTorchで作成したモデルをOpenCVで使用する手順は以下の通りです。

  1. PyTorchで学習用ネットワークを構築/学習し、学習結果のパラメータを保存
  2. PyTorchで推論用ネットワークを構築し、学習したパラメータを読み込み
  3. 推論ネットワークをONNXフォーマットで保存
  4. OpenCVでONNXファイルを読み込み  

では、順をおって要点のみ解説します。

尚、今回使用したサンプルコードはすべて以下にアップしてありますので、詳しくはこちらを直接ご覧ください。

https://github.com/takmin/PyTorch2OpenCV_sample

 

1. PyTorchで学習

  PyTorchでLeNetを構築して、トレーニングを行いました。 MNISTで学習するサンプルは以下においてあります。

https://github.com/takmin/PyTorch2OpenCV_sample/blob/master/train_LeNet.py

学習コードの詳細については他に参考になるサイトは山ほどあるので、ここではいちいち説明しませんが、学習モデルについては以下のような、単純な畳み込み層とPooling層、およびDropout層で構築しました。

class LeNet5(nn.Module):
    def __init__(self, input_size):
        super(LeNet5, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 5, 1)
        self.conv2 = nn.Conv2d(32, 64, 5, 1)
        self.dropout = nn.Dropout2d(0.2)
        fc1_h = int(input_size[0] / 4 - 3)
        fc1_w = int(input_size[1] / 4 - 3)
        self.fc1 = nn.Linear(fc1_h * fc1_w * 64, 1024)
        self.fc2 = nn.Linear(1024, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.max_pool2d(x, 2)
        x = self.conv2(x)
        x = F.max_pool2d(x, 2)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output

また学習したパラメータは以下のように保存します。

torch.save(model.state_dict(), "mnist_cnn.pt")

 

2. PyTorchで推論用モデルを構築

推論用モデルの構築と、ONNXでの保存について、詳細は以下のコードを参照してください。

https://github.com/takmin/PyTorch2OpenCV_sample/blob/master/save_LeNet_ONNX.py

推論用モデルは以下のように学習モデルから、推論時には使用しないDropoutの層を除去した形になります。

class LeNet5_Infer(nn.Module):
    def __init__(self, input_size):
        super(LeNet5_Infer, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 5, 1)
        self.conv2 = nn.Conv2d(32, 64, 5, 1)
        fc1_h = int(input_size[0] / 4 - 3)
        fc1_w = int(input_size[1] / 4 - 3)
        self.fc1 = nn.Linear(fc1_h * fc1_w * 64, 1024)
        self.fc2 = nn.Linear(1024, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.max_pool2d(x, 2)
        x = self.conv2(x)
        x = F.max_pool2d(x, 2)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output

この推論用モデルに、学習したパラメータを読み込みます。

    model = LeNet5_Infer([28,28])
    model.load_state_dict(torch.load("mnist_cnn.pt"))

3. 推論モデルをONNXフォーマットで保存

ONNXで保存する方法の詳細については、以下の公式サイトを参照してください。

torch.onnx — PyTorch 1.7.0 documentation

ここでは以下のような形で保存しています。

    # Input to the model
    x = torch.randn(1, 1, 28, 28)
    torch_out = model(x)

    # Export the model as onnx (lenet5.onnx)
    torch.onnx.export(model,             # model being run
                        x,               # model input (or a tuple for multiple inputs)
                        "lenet5.onnx",   # where to save the model (can be a file or file-like object)
                        export_params=True,        # store the trained parameter weights inside the model file
                        opset_version=10,          # the ONNX version to export the model to
                        do_constant_folding=True,  # whether to execute constant folding for optimization
                        input_names = ['input'],   # the model's input names
                        output_names = ['output'], # the model's output names
                        dynamic_axes={'input' : {0 : 'batch_size'},    # variable length axes
                                    'output' : {0 : 'batch_size'}})

xという仮の入力を乱数生成しているのは、フォーマットのValidationのためらしいです。

4. OpenCVでONNXファイルを読み込み

OpenCVC++)でONNXファイルを読み込み、推論まで行うコードの詳細はこちらを参照してください。

https://github.com/takmin/PyTorch2OpenCV_sample/blob/master/opencv_LeNet.cpp

ONNXのコードは以下のようにdnn::readNetを使用すれば、自動でONNXフォーマットとして読み込んでくれます。

dnn::Net net = dnn::readNet("lenet5.onnx");

明示的にONNXフォーマットとして読み込みたいときは、

dnn::Net net = dnn::readNetFromONNX("lenet5.onnx");

としても読み込むことができます。

以上で、OpenCVからPyTorchで学習したモデルが読み込めます。


PyTorchで学習したモデルをOpenCVで使う (Custom Layer編) - takminの書きっぱなし備忘録 @はてなブログへ続く