/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.modality.cv.translator;

import ai.djl.modality.cv.output.BoundingBox;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.modality.cv.output.Mask;
import ai.djl.modality.cv.output.Rectangle;
import ai.djl.modality.cv.translator.YoloV5Translator;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.DataType;
import ai.djl.translate.TranslatorContext;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;

public class YoloSegmentationTranslator
extends YoloV5Translator {
    private static final int[] AXIS_0 = new int[]{0};
    private static final int[] AXIS_1 = new int[]{1};
    private float threshold;
    private float nmsThreshold;

    public YoloSegmentationTranslator(Builder builder) {
        super(builder);
        this.threshold = builder.threshold;
        this.nmsThreshold = builder.nmsThreshold;
    }

    @Override
    public DetectedObjects processOutput(TranslatorContext ctx, NDList list) {
        NDArray pred = (NDArray)list.get(0);
        NDArray protos = (NDArray)list.get(1);
        int maskIndex = this.classes.size() + 4;
        NDArray candidates = pred.get("4:" + maskIndex, new Object[0]).max(AXIS_0).gt(Float.valueOf(this.threshold));
        pred = pred.transpose();
        NDArray sub = pred.get("..., :4", new Object[0]);
        sub = this.xywh2xyxy(sub);
        pred = sub.concat(pred.get("..., 4:", new Object[0]), -1);
        pred = pred.get(candidates);
        NDList split = pred.split(new long[]{4L, maskIndex}, 1);
        NDArray box = (NDArray)split.get(0);
        int numBox = Math.toIntExact(box.getShape().get(0));
        float[] buf = box.toFloatArray();
        float[] confidences = ((NDArray)split.get(1)).max(AXIS_1).toFloatArray();
        long[] ids = ((NDArray)split.get(1)).argMax(1).toLongArray();
        ArrayList<Rectangle> boxes = new ArrayList<Rectangle>(numBox);
        ArrayList<Double> scores = new ArrayList<Double>(numBox);
        for (int i = 0; i < numBox; ++i) {
            float xPos = buf[i * 4];
            float yPos = buf[i * 4 + 1];
            float w = buf[i * 4 + 2] - xPos;
            float h = buf[i * 4 + 3] - yPos;
            Rectangle rect = new Rectangle(xPos, yPos, w, h);
            boxes.add(rect);
            scores.add(Double.valueOf(confidences[i]));
        }
        List<Integer> nms = Rectangle.nms(boxes, scores, this.nmsThreshold);
        long[] idx = nms.stream().mapToLong(Integer::longValue).toArray();
        NDArray selected = box.getManager().create(idx);
        NDArray masks = ((NDArray)split.get(2)).get(selected);
        int maskW = Math.toIntExact(protos.getShape().get(2));
        int maskH = Math.toIntExact(protos.getShape().get(1));
        protos = protos.reshape(32L, (long)maskH * (long)maskW);
        masks = masks.matMul(protos).reshape(nms.size(), maskH, maskW).gt(Float.valueOf(0.0f)).toType(DataType.FLOAT32, true);
        float[] maskArray = masks.toFloatArray();
        box = box.get(selected);
        buf = box.toFloatArray();
        ArrayList<String> retClasses = new ArrayList<String>();
        ArrayList<Double> retProbs = new ArrayList<Double>();
        ArrayList<BoundingBox> retBB = new ArrayList<BoundingBox>();
        for (int i = 0; i < idx.length; ++i) {
            float x = buf[i * 4] / (float)this.width;
            float y = buf[i * 4 + 1] / (float)this.height;
            float w = buf[i * 4 + 2] / (float)this.width - x;
            float h = buf[i * 4 + 3] / (float)this.width - y;
            int id = nms.get(i);
            retClasses.add((String)this.classes.get((int)ids[id]));
            retProbs.add(Double.valueOf(confidences[id]));
            float[][] maskFloat = new float[maskH][maskW];
            for (int j = 0; j < maskH; ++j) {
                System.arraycopy(maskArray, j * maskW, maskFloat[j], 0, maskW);
            }
            Mask bb = new Mask(x, y, w, h, maskFloat, true);
            retBB.add(bb);
        }
        return new DetectedObjects(retClasses, retProbs, retBB);
    }

    private NDArray xywh2xyxy(NDArray array) {
        NDArray xy = array.get("..., :2", new Object[0]);
        NDArray wh = array.get("..., 2:", new Object[0]).div(2);
        return xy.sub(wh).concat(xy.add(wh), -1);
    }

    public static Builder builder() {
        return new Builder();
    }

    public static Builder builder(Map<String, ?> arguments) {
        Builder builder = new Builder();
        builder.configPreProcess(arguments);
        builder.configPostProcess(arguments);
        return builder;
    }

    public static class Builder
    extends YoloV5Translator.Builder {
        Builder() {
        }

        @Override
        protected Builder self() {
            return this;
        }

        @Override
        public YoloSegmentationTranslator build() {
            this.validate();
            return new YoloSegmentationTranslator(this);
        }
    }
}

