import { FC, useCallback, useEffect, useRef } from "react";
import * as d3 from "d3";
import { useTranslation } from "react-i18next";
import { SummaryWithTotals } from "../contexts/DataContext";
import { Box } from "@mui/material";

interface WaterfallChartProps {
    data: SummaryWithTotals;
    translationPrefix: string;
    chartColors: string[];
    overrideColors?: { [key: string]: string[] };
    showLegend: boolean;
    rotateXaxis: boolean;
    chartTextColor: string;
    yAxisTitle: string;
}

const WaterfallChart: FC<WaterfallChartProps> = ({
    data,
    translationPrefix,
    chartColors,
    overrideColors,
    showLegend,
    rotateXaxis,
    chartTextColor,
    yAxisTitle,
}) => {
    const svgRef = useRef<SVGSVGElement | null>(null);
    const { t } = useTranslation();
    const getStackedData = useCallback(
        (data: SummaryWithTotals) => {
            let start = 0;
            const stackedData = data.data.map((d, i) => {
                const title = t(`${translationPrefix}.${d.title}`);
                const values = d.values.map((v, i) => {
                    v = v ?? 0; // default v to 0 if undefined
                    const result = {
                        start: start,
                        height: v,
                        key: data.columns[i],
                    };
                    start += v;
                    return result;
                });
                return { values, title };
            });

            start = 0;
            stackedData.push({
                values: data.columnTotals.map((ct, i) => {
                    ct = ct ?? 0; // default ct to 0 if undefined
                    const result = {
                        start: start,
                        height: ct,
                        key: data.columns[i],
                    };
                    start += ct;
                    return result;
                }),
                title: "Total",
            });
            return stackedData;
        },
        [t, translationPrefix],
    );
    useEffect(() => {
        if (!svgRef.current) return;

        while (svgRef.current.lastElementChild) {
            svgRef.current.removeChild(svgRef.current.lastElementChild);
        }
        const svg = d3.select(svgRef.current);

        // Extract data
        const { columns, data: chartData } = data;

        // Define chart dimensions
        const chartwidth = 450;
        const chartheight = 250;
        const margin = { top: 20, right: 30, bottom: 40, left: 50 };
        const barsArea = {
            width: chartwidth - margin.left - margin.right,
            height: chartheight - margin.top - margin.bottom,
        };

        const categories = chartData.map((d) => t(`${translationPrefix}.${d.title}`));
        categories.push("Total");
        // Create scales
        const xScale = d3
            .scaleBand()
            .domain(categories)
            .range([0, chartwidth - margin.right - margin.left])
            .padding(0.1);

        // Only use the margins here to detemermine the total size of the range, otherwise you can't use the scale to calculate height
        const yScale = d3
            .scaleLinear()
            .domain([0, data.total ?? 0])
            .nice()
            .range([0, chartheight - margin.bottom - margin.top]);

        // Create color scale for stacked bars
        const colorScale = d3.scaleOrdinal(columns, chartColors);
        const overrideScales = Object.fromEntries(
            Object.entries(overrideColors || {}).map((entry) => {
                const [key, colors] = entry;
                return [key, d3.scaleOrdinal(columns, colors)];
            }),
        );

        // Add y-axis
        svg.append("g")
            .attr("class", "y-grid")
            .attr("transform", `translate(${margin.left},${margin.top})`)
            .call(
                d3
                    .axisLeft(yScale)
                    .tickSizeInner(-barsArea.width)
                    .tickFormat(() => ""), // Hide tick values with an empty string format
            )
            .selectAll("line")
            .style("opacity", 0.1);

        // Add a specific vertical text label next to the y-axis with a 90-degree rotation
        svg.append("text")
            .attr("x", margin.top - 105)
            .attr("y", margin.left - 10)
            .attr("dy", "0.31em")
            .style("text-anchor", "end")
            .text(yAxisTitle)
            .style("font-size", "10px")
            .style("fill", "black")
            .attr("transform", "rotate(-90)"); // Rotate the text by 90 degrees

        const stackedData = getStackedData(data);

        // Create the bars
        svg.append("g")
            .attr("class", "bars")
            .attr("transform", `translate(${margin.left},${margin.top})`)
            .selectAll(".bar-group")
            .data(stackedData)
            .enter()
            .append("g")
            .attr("transform", (d) => `translate(${xScale(d.title)},0)`)
            .selectAll("rect")
            .data((d) => d.values.map((v) => ({ ...v, title: d.title })))
            .enter()
            .append("g")
            .attr("transform", (d) => `translate(0, ${barsArea.height - yScale(d.start + d.height)})`)
            .append("rect")
            .attr("fill", (d) => (d.title in overrideScales ? overrideScales[d.title](d.key) : colorScale(d.key)))
            .attr("height", (d) => yScale(d.height))
            .attr("width", xScale.bandwidth());

        //Create Bar Values
        svg.append("g")
            .attr("class", "bar-values")
            .attr("transform", `translate(${margin.left},${margin.top})`)
            .selectAll(".bar-group")
            .data(stackedData)
            .enter()
            .append("g")
            .attr("transform", (d) => `translate(${(xScale(d.title) || 0) + xScale.bandwidth() / 2},0)`)
            .selectAll("rect")
            .data((d) => d.values.map((v) => ({ ...v, title: d.title })))
            .enter()
            .append("g")
            .attr("transform", (d) => `translate(0, ${barsArea.height - yScale(d.start + d.height)})`)
            .attr("width", xScale.bandwidth())
            .append("text")
            .text((d) => (yScale(d.height) < 10 ? "" : Math.round(d.height)))
            .attr("y", (d) => yScale(d.height) / 2)
            .attr("dy", "0.31em")
            .attr("text-anchor", "middle")
            .style("font-size", "12px")
            .style("fill", chartTextColor); // Adjust the fill color of the labels

        // Add x-axis
        svg.append("g")
            .attr("class", "x-axis")
            .attr("transform", `translate(${margin.left},${chartheight - margin.bottom})`)
            .call(d3.axisBottom(xScale).tickSizeInner(0))
            .selectAll(".tick text") // Select all x-axis text labels
            .style("font-size", "9px")
            .style("text-anchor", "middle") // Set text anchor to the end (right)
            .attr("transform", rotateXaxis ? "translate(14,10) rotate(15)" : "translate(0,5)"); // Rotate text labels by -60 degrees
        // Add legend
        if (showLegend) {
            const legendItems = svg
                .append("g")
                .attr("class", "legend")
                .attr("transform", `translate(${chartwidth - 300},${margin.top - 13})`)
                .selectAll(".legend-item")
                .data(showLegend ? columns : [""])
                .enter()
                .append("g")
                .attr("class", "legend-item")
                .attr("transform", (d, i) => `translate(${i * 80},0)`);

            const radius = 7;
            legendItems
                .append("circle")
                .attr("r", radius)
                .attr("cy", radius / 2)
                .attr("fill", (d) => colorScale(d));

            legendItems
                .append("text")
                .attr("x", 15)
                .attr("dy", "0.31em")
                .text((d) => t(`${translationPrefix}.${d}`))
                .style("font-size", "14px")
                .attr("alignment-baseline", "middle");
        }
        svg.append("g")
            .attr("class", "tooltip")
            .attr("transform", `translate(${margin.left},${margin.top})`)
            .selectAll(".bar-group")
            .data(stackedData)
            .enter()
            .append("g")
            .attr("transform", (d) => `translate(${(xScale(d.title) || 0) + xScale.bandwidth() / 2},0)`)
            .selectAll("rect")
            .data((d) => d.values.map((v) => ({ ...v, title: d.title })))
            .enter()
            .append("g")
            .attr("transform", (d) => `translate(0, ${barsArea.height - yScale(d.start + d.height)})`)
            .attr("width", xScale.bandwidth())
            .append("text")
            .text((d) => Math.round(d.height)) //tooltip text
            .attr("y", (d) => yScale(d.height) / 2)
            .attr("dy", "0.31em")
            .style("opacity", "0")
            .attr("text-anchor", "middle")
            .style("font-size", "12px")
            .style("fill", "black") // Adjust the fill color of the labels
            .on("mouseover", function () {
                d3.select(this).style("opacity", "1");
            })
            .on("mouseout", function () {
                d3.select(this).transition().duration(1200).style("opacity", "0");
            });
    }, [chartColors, chartTextColor, data, rotateXaxis, showLegend, t, translationPrefix, overrideColors, getStackedData, yAxisTitle]);

    return (
        <Box position="relative">
            <svg ref={svgRef} width={450} height={250}>
                {/* Add labels, titles, etc. if needed */}
            </svg>
        </Box>
    );
};

export default WaterfallChart;
