Hi, I have the following code in Python:
class LayerLSTM(nn.Module):
def __init__(self):
super(LayerLSTM, self).__init__()
self.rnns = nn.LSTM(156, 512, 1, batch_first=False)
def forward(self, x, hx, cx):
x, (hx, cx) = self.rnns(x, (hx,cx))
return x
model = LayerLSTM()
with torch.no_grad():
x = torch.randn(1, 1, 156)
hx = torch.randn(1, 1, 512)
cx = torch.randn(1, 1, 512)
out = model(x, hx, cx)
torch.onnx.export(model, (x, hx, cx), '/home/samuels/DVS_Original/deep-stabilization/dvs/onnx_model/sample_lstm.onnx', verbose=True, input_names=['x', 'hx', 'cx'], output_names=['output'])
# sess = onnxruntime.InferenceSession('/home/samuels/DVS_Original/deep-stabilization/dvs/onnx_model/sample_lstm.onnx')
# out_on = sess.run(None, {'x': x.cpu().numpy(), 'hx': hx.numpy(), 'cx': cx.numpy()})
# print(out.numpy() - out_on[0])
net = cv2.dnn.readNetFromONNX('/home/samuels/DVS_Original/deep-stabilization/dvs/onnx_model/sample_lstm.onnx')
As you see, in the commented lines, I have tested the model in ONNX, and it works as well as the one in Pytorch.
When I try to import the ONNX model in OpenCV, I am facing the following errors:
[ERROR:0] global /tmp/pip-req-build-w88qv8vs/opencv/modules/dnn/src/onnx/onnx_importer.cpp (718) handleNode DNN/ONNX: ERROR during processing node with 7 inputs and 3 outputs: [LSTM]:(28)
Traceback (most recent call last):
File "/home/samuels/DVS_Original/deep-stabilization_original/dvs/sample_lstm.py", line 29, in <module>
net = cv2.dnn.readNetFromONNX('/home/samuels/DVS_Original/deep-stabilization/dvs/onnx_model/sample_lstm.onnx')
cv2.error: OpenCV(4.5.4) /tmp/pip-req-build-w88qv8vs/opencv/modules/dnn/src/onnx/onnx_importer.cpp:739: error: (-2:Unspecified error) in function 'handleNode'
> Node [LSTM]:(28) parse error: OpenCV(4.5.4) /tmp/pip-req-build-w88qv8vs/opencv/modules/dnn/src/onnx/onnx_importer.cpp:463: error: (-5:Bad argument) Blob hx not found in const blobs in function 'getBlob'
Is it a bug? What am I doing wrong on passing the hx input? Do you guys have any suggestions?
I am using OpenCV version: 4.5.4.60, and my PyTorch is 1.11