import {
    Tensor,
    Tensor2D,
    Tensor4D,
    div,
    maximum,
    range,
    stack,
    sub,
    tensor,
    image as tfImage,
} from '@tensorflow/tfjs';
import { FabricObject, Point, Polygon } from 'fabric';
import { generateId } from '@autoixpert/lib/generate-id';
import { DEFAULT_FABRIC_CONTROL_OPTIONS } from 'src/app/reports/details/photos/photo-editor/fabric-js-custom-types';
import { determineFabricObjectBlurFactor } from '../../libraries/fabric/ts/object-blur-factor.utils';
import { convexHull } from './convex-hull.utils';
import {
    DetectionBoxesModelOutput,
    DetectionClassesModelOutput,
    DetectionMaskModelOutput,
    DetectionScoresModelOutput,
    ImageInfoModelOutput,
    LicensePlateRedactionImageMaskInfo,
    LicensePlateRedactionImageMasks,
    LicensePlateRedactionModelOutput,
    LicensePlateRedactionModelOutputUnion,
    LicensePlateRedactionPolygon,
    NumDetectionsModelOutput,
    SerializedTensor,
} from './license-plate-redaction-model.interfaces';

export const MODEL_NUM_BOXES = 100;
export const MODEL_IMAGE_SIZE = 512;

function isTensor({
    tensor,
    shape,
    dtype,
}: {
    tensor: LicensePlateRedactionModelOutputUnion;
    shape: number[];
    dtype: 'float32' | 'int32';
}): boolean {
    const shapeMatches =
        tensor.shape.length === shape.length && shape.every((dim, index) => dim === tensor.shape[index]);
    if (!shapeMatches) return false;

    const dtypeMatches = tensor.dtype === dtype;
    if (!dtypeMatches) return false;

    return true;
}

function findTensor<Tensor>({
    predictions,
    criteria,
}: {
    predictions: LicensePlateRedactionModelOutputUnion[];
    criteria: {
        shape: number[];
        dtype: 'float32' | 'int32';
    };
}): Tensor {
    return predictions.find((tensor) => isTensor({ tensor, shape: criteria.shape, dtype: criteria.dtype })) as Tensor;
}

/**
 * Debugging function to log the output of the license plate redaction model.
 */
export function logLicensePlateRedactionModelOutput({ output }: { output: LicensePlateRedactionModelOutput }) {
    console.log({
        imageInfo: output.imageInfo,
        imageInfoData: output.imageInfo.dataSync(),
        numDetections: output.numDetections,
        numDetectionsData: output.numDetections.dataSync(),
        detectionBoxes: output.detectionBoxes,
        detectionBoxesData: output.detectionBoxes.dataSync(),
        detectionClasses: output.detectionClasses,
        detectionClassesData: output.detectionClasses.dataSync(),
        detectionScores: output.detectionScores,
        detectionScoresData: output.detectionScores.dataSync(),
        detectionMasks: output.detectionMasks,
        detectionMasksData: output.detectionMasks.dataSync(),
    });
}

export function parseLicensePlateRedactionModelOutput(
    predictions: LicensePlateRedactionModelOutputUnion[],
): LicensePlateRedactionModelOutput | undefined {
    const imageInfo = findTensor<ImageInfoModelOutput>({
        predictions,
        criteria: { shape: [1, 4, 2], dtype: 'float32' },
    });

    const numDetections = findTensor<NumDetectionsModelOutput>({
        predictions,
        criteria: { shape: [1], dtype: 'int32' },
    });

    const detectionBoxes = findTensor<DetectionBoxesModelOutput>({
        predictions,
        criteria: { shape: [1, MODEL_NUM_BOXES, 4], dtype: 'float32' },
    });

    const detectionClasses = findTensor<DetectionClassesModelOutput>({
        predictions,
        criteria: { shape: [1, MODEL_NUM_BOXES], dtype: 'int32' },
    });

    const detectionScores = findTensor<DetectionScoresModelOutput>({
        predictions,
        criteria: { shape: [1, MODEL_NUM_BOXES], dtype: 'float32' },
    });
    predictions.find((tensor) => isTensor({ tensor, shape: [1, MODEL_NUM_BOXES, 28, 28], dtype: 'float32' }));

    const detectionMasks = findTensor<DetectionMaskModelOutput>({
        predictions,
        criteria: { shape: [1, MODEL_NUM_BOXES, 28, 28], dtype: 'float32' },
    });

    if (!imageInfo || !numDetections || !detectionBoxes || !detectionClasses || !detectionScores || !detectionMasks) {
        console.error('Cannot parse license plate redaction model output!');
        return undefined;
    }

    return {
        imageInfo,
        numDetections,
        detectionBoxes,
        detectionClasses,
        detectionScores,
        detectionMasks,
    };
}

/**
 * Reframe the image corners ([0, 0, 1, 1]) to be relative to boxes.
 * The local coordinate frame of each box is assumed to be relative to
 * its own for corners.
 * Args:
 *   boxes: A float tensor of [num_boxes, 4] of (ymin, xmin, ymax, xmax)
 *     coordinates in relative coordinate space of each bounding box.
 * @return reframed_boxes: Reframes boxes with same shape as input.
 */
function reframeImageCornersRelativeToBoxes({ detectionBoxes }: { detectionBoxes: Tensor }) {
    const ymin = detectionBoxes.slice([0, 0], [-1, 1]);
    const xmin = detectionBoxes.slice([0, 1], [-1, 1]);
    const ymax = detectionBoxes.slice([0, 2], [-1, 1]);
    const xmax = detectionBoxes.slice([0, 3], [-1, 1]);

    const height = maximum(sub(ymax, ymin), 1e-4);
    const width = maximum(sub(xmax, xmin), 1e-4);

    const yminOut = div(sub(0, ymin), height);
    const xminOut = div(sub(0, xmin), width);
    const ymaxOut = div(sub(1, ymin), height);
    const xmaxOut = div(sub(1, xmin), width);

    return stack([yminOut, xminOut, ymaxOut, xmaxOut], 1) as Tensor2D;
}
export function reframeBoxMasksToImageMasks({ output }: { output: LicensePlateRedactionModelOutput }): {
    imageMasks: LicensePlateRedactionImageMasks;
} {
    const { detectionBoxes, detectionMasks } = output;

    let detectionBoxesNormalized = div(detectionBoxes, MODEL_IMAGE_SIZE);
    detectionBoxesNormalized = detectionBoxesNormalized.squeeze([0]);

    const detectionBoxesNormalizedReframed = reframeImageCornersRelativeToBoxes({
        detectionBoxes: detectionBoxesNormalized,
    }).squeeze<Tensor2D>([2]);

    const detectionMaskImage: Tensor4D = detectionMasks.squeeze([0]).expandDims(3);
    const imageMasks = tfImage.cropAndResize(
        detectionMaskImage,
        detectionBoxesNormalizedReframed,
        range(0, MODEL_NUM_BOXES, 1, 'int32'),
        [MODEL_IMAGE_SIZE, MODEL_IMAGE_SIZE],
        'bilinear',
        0,
    );

    imageMasks.cast(detectionMasks.dtype);
    return { imageMasks };
}

export async function filterLicensePlateRedactionImageMasks({
    imageMasks,
    output,
    boxScoreTreshold,
    findAtLeastOneLicensePlateAboveScoreThreshold,
}: {
    imageMasks: LicensePlateRedactionImageMasks;
    output: LicensePlateRedactionModelOutput;

    /** Threshold for the detection score of the box */
    boxScoreTreshold: number;
    findAtLeastOneLicensePlateAboveScoreThreshold?: number;
}): Promise<{ selectedImageMasks: LicensePlateRedactionImageMaskInfo[] }> {
    const { numDetections, detectionScores, detectionClasses } = output;
    const numDetectionsData = await numDetections.array();
    const detectionScoresData = await detectionScores.array();

    const imageMaskInfos: LicensePlateRedactionImageMaskInfo[] = [];
    for (let index = 0; index < numDetectionsData[0]; index++) {
        const score = detectionScoresData[0][index];

        // Check if the class is a license plate (1)
        const detectionClass = detectionClasses.slice([0, index], [1, 1]).dataSync()[0];
        if (detectionClass !== 1) continue;

        // Image mask has shape [100, 512, 512, 1]
        imageMaskInfos.push({
            mask: imageMasks.slice([index, 0, 0, 0], [1, MODEL_IMAGE_SIZE, MODEL_IMAGE_SIZE, 1]).squeeze([0]),
            score,
        });
    }

    // Check if the detection score is above the threshold
    const selectedImageMasks = imageMaskInfos.filter(({ score }) => score > boxScoreTreshold);

    // If we want to find at least one license plate, and none was found, return the one with the highest score
    if (findAtLeastOneLicensePlateAboveScoreThreshold && selectedImageMasks.length === 0) {
        const highestScore = imageMaskInfos.reduce((acc, { score }) => (score > acc ? score : acc), 0);
        if (highestScore >= findAtLeastOneLicensePlateAboveScoreThreshold) {
            selectedImageMasks.push(imageMaskInfos.find(({ score }) => score === highestScore));
        }
    }

    return { selectedImageMasks };
}

export function getLicensePlateRedactionImageMaskPolygon({
    imageMaskInfo,
    maskValueThreshold,
    colinearNeighborDetectionTolerance,
}: {
    imageMaskInfo: LicensePlateRedactionImageMaskInfo;
    maskValueThreshold: number;
    /** Tolerance for the convex hull (graham scan algorithm) = ignore distance between points */
    colinearNeighborDetectionTolerance: number;
}): { polygon: LicensePlateRedactionPolygon } {
    // Convert image mask with shape [512, 512, 1] to [x, y][]
    const imageMaskData: number[][][] = imageMaskInfo.mask.arraySync();
    const pixels = imageMaskData.reduce(
        (acc, row, y) => {
            row.forEach((value, x) => {
                if (value[0] > maskValueThreshold) {
                    acc.push({ x, y });
                }
            });
            return acc;
        },
        [] as { x: number; y: number }[],
    );

    const points = convexHull(
        pixels.map(({ x, y }) => [x, y]),
        colinearNeighborDetectionTolerance,
    );

    return {
        polygon: {
            score: imageMaskInfo.score,
            points: points.map(([x, y]) => ({ x, y })),
        },
    };
}
export function mergeOverlappingLicensePlateRedactionPolygons({
    polygons,
}: {
    polygons: LicensePlateRedactionPolygon[];
}): LicensePlateRedactionPolygon[] {
    const newPolygons: LicensePlateRedactionPolygon[] = [];

    for (const polygon of polygons) {
        let foundOverlap = false;
        for (const newPolygon of newPolygons) {
            if (doPolygonsOverlap(polygon, newPolygon)) {
                foundOverlap = true;
                newPolygon.points = convexHull([...polygon.points, ...newPolygon.points].map(({ x, y }) => [x, y])).map(
                    ([x, y]) => ({ x, y }),
                );

                break;
            }
        }

        if (!foundOverlap) {
            newPolygons.push(polygon);
        }
    }

    return newPolygons;
}

function doPolygonsOverlap(polygon1: LicensePlateRedactionPolygon, polygon2: LicensePlateRedactionPolygon): boolean {
    for (const point of polygon1.points) {
        if (isPointInPolygon(point, polygon2)) {
            return true;
        }
    }

    for (const point of polygon2.points) {
        if (isPointInPolygon(point, polygon1)) {
            return true;
        }
    }

    return false;
}

function isPointInPolygon(point: { x: number; y: number }, polygon: LicensePlateRedactionPolygon): boolean {
    let isInside = false;
    for (let i = 0, j = polygon.points.length - 1; i < polygon.points.length; j = i++) {
        const xi = polygon.points[i].x;
        const yi = polygon.points[i].y;
        const xj = polygon.points[j].x;
        const yj = polygon.points[j].y;

        const intersect = yi > point.y !== yj > point.y && point.x < ((xj - xi) * (point.y - yi)) / (yj - yi) + xi;
        if (intersect) {
            isInside = !isInside;
        }
    }

    return isInside;
}

export function scaleLicensePlateRedactionPolygonToDimensions({
    polygon,
    width,
    height,
}: {
    polygon: LicensePlateRedactionPolygon;
    width: number;
    height: number;
}): LicensePlateRedactionPolygon {
    const points = polygon.points.map(({ x, y }) => ({
        x: (x / MODEL_IMAGE_SIZE) * width,
        y: (y / MODEL_IMAGE_SIZE) * height,
    }));

    return {
        ...polygon,
        points,
    };
}

export function createLicensePlateRedactionFabricPolygon({
    redaction,
    scaleX,
    scaleY,
    color,
}: {
    redaction: LicensePlateRedactionPolygon;
    scaleX: number;
    scaleY: number;
    color?: string;
}): Polygon {
    // Scale points to current image size
    const points = redaction.points.map((point) => {
        const x = point.x * scaleX;
        const y = point.y * scaleY;
        // console.log({ scaleX, scaleY, x, y });
        return new Point(x, y);
    });

    const polygon = new Polygon(points, {
        fill: color,
        data: {
            axType: 'automaticLicensePlateRedaction',
        },
        ...DEFAULT_FABRIC_CONTROL_OPTIONS,
    });

    if (!color) {
        polygon.data ??= {};
        polygon.data.axId ??= generateId();
        polygon.data.axBlurFactor = determineFabricObjectBlurFactor({ object: polygon });
    }

    return polygon;
}

export function hasLicensePlateRedactionPolygonAlreadyBeenAdded({
    objects,
    polygon,
}: {
    objects: FabricObject[];
    polygon: Polygon;
}): boolean {
    return objects.some((object) => {
        if (object.type?.toLowerCase() !== 'polygon') return false;

        const points = (object as Polygon).points as Point[];
        if (points.length !== polygon.points.length) return false;

        return points.every((point, index) => {
            const { x, y } = polygon.points[index];
            return point.x === x && point.y === y;
        });
    });
}

export async function serializeTensor(tensor: Tensor): Promise<SerializedTensor> {
    return {
        shape: tensor.shape,
        data: Array.from(await tensor.data()),
        dtype: tensor.dtype,
    };
}

export function deserializeTensor(serializedTensor: SerializedTensor): Tensor {
    return tensor(serializedTensor.data, serializedTensor.shape, serializedTensor.dtype);
}
