I have a YOLO model extracted in ONNX format and I am trying to use it for inference in Java. I need to get the bounding box and class prediction information.
While the model is being read with no errors, I am getting that the rows and cols are -1. I would appreciate it if anyone could help with the issue.
Mat img = Imgcodecs.imread("x0.png");
Net net = Dnn.readNetFromONNX("best.onnx");
Mat blob = Dnn.blobFromImage(img, 1/255.0, new Size(224,224), new Scalar(0,0,0),true, false);
// Set the input to the network
net.setInput(blob);
List<Mat> outputs = new ArrayList<Mat>();
net.forward(outputs, net.getUnconnectedOutLayersNames());
System.out.println("Number of output layers: " + outputs.size());
for (Mat output : outputs) {
int rows = output.rows();
int cols = output.cols();
System.out.println("Rows in output: " + rows);
System.out.println("Columns in output: " + cols);
if (rows <= 0) {
System.out.println("Empty output, skipping.");
continue; // Skip empty outputs
}
for (int i = 0; i < rows; i++) {
Mat row = output.row(i);
if (cols < 5) {
System.out.println("Insufficient columns in row, skipping.");
continue; // Skip invalid rows
}
// Extract bounding box and confidence
double[] data = new double[cols];
row.get(0, 0, data);
double confidence = data[4];
if (confidence < 0.5) {
System.out.println("Low confidence (" + confidence + "), skipping.");
continue; // Skip low-confidence detections
}
double centerX = data[0];
double centerY = data[1];
double width = data[2];
double height = data[3];
int absCenterX = (int) (centerX * img.width());
int absCenterY = (int) (centerY * img.height());
int absWidth = (int) (width * img.width());
int absHeight = (int) (height * img.height());
// Compute top-left corner
int left = absCenterX - absWidth / 2;
int top = absCenterY - absHeight / 2;
// Extract class scores
double maxScore = -1;
int classId = -1;
for (int j = 5; j < cols; j++) {
double score = data[j];
if (score > maxScore) {
maxScore = score;
classId = j - 5; // Assuming class labels start from column 5
}
}
// Print or store the bounding box and class
Rect boundingBox = new Rect(left, top, absWidth, absHeight);
System.out.println("Bounding Box: " + boundingBox);
System.out.println("Class ID: " + classId);
}
}