Import LSTM from ONNX to OpenCV

Hi, I have the following code in Python:

class LayerLSTM(nn.Module):
    def __init__(self):
        super(LayerLSTM, self).__init__()
        self.rnns = nn.LSTM(156, 512, 1, batch_first=False)

    def forward(self, x, hx, cx):
        x, (hx, cx) = self.rnns(x, (hx,cx))
        return x

model = LayerLSTM()

with torch.no_grad():
    x = torch.randn(1, 1, 156)
    hx = torch.randn(1, 1, 512)
    cx = torch.randn(1, 1, 512)
    out = model(x, hx, cx)
torch.onnx.export(model, (x, hx, cx), '/home/samuels/DVS_Original/deep-stabilization/dvs/onnx_model/sample_lstm.onnx', verbose=True, input_names=['x', 'hx', 'cx'], output_names=['output'])
# sess = onnxruntime.InferenceSession('/home/samuels/DVS_Original/deep-stabilization/dvs/onnx_model/sample_lstm.onnx')
# out_on = sess.run(None, {'x': x.cpu().numpy(), 'hx': hx.numpy(), 'cx': cx.numpy()})
# print(out.numpy() - out_on[0])

net = cv2.dnn.readNetFromONNX('/home/samuels/DVS_Original/deep-stabilization/dvs/onnx_model/sample_lstm.onnx')

As you see, in the commented lines, I have tested the model in ONNX, and it works as well as the one in Pytorch.
When I try to import the ONNX model in OpenCV, I am facing the following errors:

[ERROR:0] global /tmp/pip-req-build-w88qv8vs/opencv/modules/dnn/src/onnx/onnx_importer.cpp (718) handleNode DNN/ONNX: ERROR during processing node with 7 inputs and 3 outputs: [LSTM]:(28)
Traceback (most recent call last):
  File "/home/samuels/DVS_Original/deep-stabilization_original/dvs/sample_lstm.py", line 29, in <module>
    net = cv2.dnn.readNetFromONNX('/home/samuels/DVS_Original/deep-stabilization/dvs/onnx_model/sample_lstm.onnx')
cv2.error: OpenCV(4.5.4) /tmp/pip-req-build-w88qv8vs/opencv/modules/dnn/src/onnx/onnx_importer.cpp:739: error: (-2:Unspecified error) in function 'handleNode'
> Node [LSTM]:(28) parse error: OpenCV(4.5.4) /tmp/pip-req-build-w88qv8vs/opencv/modules/dnn/src/onnx/onnx_importer.cpp:463: error: (-5:Bad argument) Blob hx not found in const blobs in function 'getBlob'

Is it a bug? What am I doing wrong on passing the hx input? Do you guys have any suggestions?

I am using OpenCV version: 4.5.4.60, and my PyTorch is 1.11

potentially a bug or not yet implemented.

you should check the issues on opencv’s github. if it hasn’t been reported yet, create an issue that includes steps to reproduce (code and data).

1 Like