/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.sql.calcite.utils;

import java.util.ArrayList;
import java.util.List;
import javax.annotation.Nullable;
import org.apache.calcite.plan.RelOptTable;
import org.apache.calcite.rel.RelHomogeneousShuttle;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.RelShuttle;
import org.apache.calcite.rel.core.TableScan;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexVisitor;
import org.apache.calcite.rex.RexVisitorImpl;
import org.apache.calcite.rex.RexWindowBound;
import org.apache.calcite.rex.RexWindowBounds;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.util.Util;
import org.opensearch.sql.ast.AbstractNodeVisitor;
import org.opensearch.sql.ast.Node;
import org.opensearch.sql.ast.expression.IntervalUnit;
import org.opensearch.sql.ast.expression.SpanUnit;
import org.opensearch.sql.ast.expression.WindowBound;
import org.opensearch.sql.ast.expression.WindowFrame;
import org.opensearch.sql.ast.tree.Relation;
import org.opensearch.sql.ast.tree.UnresolvedPlan;
import org.opensearch.sql.calcite.CalcitePlanContext;
import org.opensearch.sql.expression.function.BuiltinFunctionName;
import org.opensearch.sql.expression.function.PPLFuncImpTable;
import shaded.com.google.common.collect.ImmutableList;

public interface PlanUtils {
    public static final String ROW_NUMBER_COLUMN_NAME = "_row_number_";
    public static final String ROW_NUMBER_COLUMN_NAME_MAIN = "_row_number_main_";
    public static final String ROW_NUMBER_COLUMN_NAME_SUBSEARCH = "_row_number_subsearch_";

    public static SpanUnit intervalUnitToSpanUnit(IntervalUnit unit) {
        return switch (unit) {
            case IntervalUnit.MICROSECOND -> SpanUnit.MILLISECOND;
            case IntervalUnit.SECOND -> SpanUnit.SECOND;
            case IntervalUnit.MINUTE -> SpanUnit.MINUTE;
            case IntervalUnit.HOUR -> SpanUnit.HOUR;
            case IntervalUnit.DAY -> SpanUnit.DAY;
            case IntervalUnit.WEEK -> SpanUnit.WEEK;
            case IntervalUnit.MONTH -> SpanUnit.MONTH;
            case IntervalUnit.QUARTER -> SpanUnit.QUARTER;
            case IntervalUnit.YEAR -> SpanUnit.YEAR;
            case IntervalUnit.UNKNOWN -> SpanUnit.UNKNOWN;
            default -> throw new UnsupportedOperationException("Unsupported interval unit: " + String.valueOf((Object)unit));
        };
    }

    public static RexNode makeOver(CalcitePlanContext context, BuiltinFunctionName functionName, RexNode field, List<RexNode> argList, List<RexNode> partitions, List<RexNode> orderKeys, @Nullable WindowFrame windowFrame) {
        if (windowFrame == null) {
            windowFrame = WindowFrame.rowsUnbounded();
        }
        boolean rows = windowFrame.getType() == WindowFrame.FrameType.ROWS;
        RexWindowBound lowerBound = PlanUtils.convert(context, windowFrame.getLower());
        RexWindowBound upperBound = PlanUtils.convert(context, windowFrame.getUpper());
        switch (functionName) {
            case AVG: {
                return context.relBuilder.call((SqlOperator)SqlStdOperatorTable.DIVIDE, new RexNode[]{PlanUtils.sumOver(context, field, partitions, rows, lowerBound, upperBound), context.relBuilder.cast(PlanUtils.countOver(context, field, partitions, rows, lowerBound, upperBound), SqlTypeName.DOUBLE)});
            }
            case STDDEV_POP: {
                return PlanUtils.variance(context, field, partitions, rows, lowerBound, upperBound, true, true);
            }
            case STDDEV_SAMP: {
                return PlanUtils.variance(context, field, partitions, rows, lowerBound, upperBound, false, true);
            }
            case VARPOP: {
                return PlanUtils.variance(context, field, partitions, rows, lowerBound, upperBound, true, false);
            }
            case VARSAMP: {
                return PlanUtils.variance(context, field, partitions, rows, lowerBound, upperBound, false, false);
            }
            case ROW_NUMBER: {
                return PlanUtils.withOver(context.relBuilder.aggregateCall((SqlAggFunction)SqlStdOperatorTable.ROW_NUMBER, new RexNode[0]), partitions, orderKeys, true, lowerBound, upperBound);
            }
            case NTH_VALUE: {
                return PlanUtils.withOver(context.relBuilder.aggregateCall(SqlStdOperatorTable.NTH_VALUE, new RexNode[]{field, argList.get(0)}), partitions, orderKeys, true, lowerBound, upperBound);
            }
        }
        return PlanUtils.withOver(PlanUtils.makeAggCall(context, functionName, false, field, argList), partitions, orderKeys, rows, lowerBound, upperBound);
    }

    private static RexNode sumOver(CalcitePlanContext ctx, RexNode operation, List<RexNode> partitions, boolean rows, RexWindowBound lowerBound, RexWindowBound upperBound) {
        return PlanUtils.withOver(ctx.relBuilder.sum(operation), partitions, List.of(), rows, lowerBound, upperBound);
    }

    private static RexNode countOver(CalcitePlanContext ctx, RexNode operation, List<RexNode> partitions, boolean rows, RexWindowBound lowerBound, RexWindowBound upperBound) {
        return PlanUtils.withOver(ctx.relBuilder.count((Iterable)ImmutableList.of((Object)operation)), partitions, List.of(), rows, lowerBound, upperBound);
    }

    private static RexNode withOver(RelBuilder.AggCall aggCall, List<RexNode> partitions, List<RexNode> orderKeys, boolean rows, RexWindowBound lowerBound, RexWindowBound upperBound) {
        return ((RelBuilder.OverCall)aggCall.over().partitionBy(partitions).orderBy(orderKeys).let(c -> rows ? c.rowsBetween(lowerBound, upperBound) : c.rangeBetween(lowerBound, upperBound))).toRex();
    }

    private static RexNode variance(CalcitePlanContext ctx, RexNode operator, List<RexNode> partitions, boolean rows, RexWindowBound lowerBound, RexWindowBound upperBound, boolean biased, boolean sqrt) {
        RexNode div;
        RexNode denominator;
        RexNode argSquared = ctx.relBuilder.call((SqlOperator)SqlStdOperatorTable.MULTIPLY, new RexNode[]{operator, operator});
        RexNode sumArgSquared = PlanUtils.sumOver(ctx, argSquared, partitions, rows, lowerBound, upperBound);
        RexNode sum = PlanUtils.sumOver(ctx, operator, partitions, rows, lowerBound, upperBound);
        RexNode sumSquared = ctx.relBuilder.call((SqlOperator)SqlStdOperatorTable.MULTIPLY, new RexNode[]{sum, sum});
        RexNode count = PlanUtils.countOver(ctx, operator, partitions, rows, lowerBound, upperBound);
        RexNode countCast = ctx.relBuilder.cast(count, SqlTypeName.DOUBLE);
        RexNode avgSumSquared = ctx.relBuilder.call((SqlOperator)SqlStdOperatorTable.DIVIDE, new RexNode[]{sumSquared, countCast});
        RexNode diff = ctx.relBuilder.call((SqlOperator)SqlStdOperatorTable.MINUS, new RexNode[]{sumArgSquared, avgSumSquared});
        if (biased) {
            denominator = countCast;
        } else {
            RexLiteral one = ctx.relBuilder.literal((Object)1);
            denominator = ctx.relBuilder.call((SqlOperator)SqlStdOperatorTable.MINUS, new RexNode[]{countCast, one});
        }
        RexNode result = div = ctx.relBuilder.call((SqlOperator)SqlStdOperatorTable.DIVIDE, new RexNode[]{diff, denominator});
        if (sqrt) {
            RexLiteral half = ctx.relBuilder.literal((Object)0.5);
            result = ctx.relBuilder.call((SqlOperator)SqlStdOperatorTable.POWER, new RexNode[]{div, half});
        }
        return result;
    }

    public static RexWindowBound convert(CalcitePlanContext context, WindowBound windowBound) {
        if (windowBound instanceof WindowBound.UnboundedWindowBound) {
            WindowBound.UnboundedWindowBound unbounded = (WindowBound.UnboundedWindowBound)windowBound;
            if (unbounded.isPreceding()) {
                return RexWindowBounds.UNBOUNDED_PRECEDING;
            }
            return RexWindowBounds.UNBOUNDED_FOLLOWING;
        }
        if (windowBound instanceof WindowBound.CurrentRowWindowBound) {
            WindowBound.CurrentRowWindowBound current = (WindowBound.CurrentRowWindowBound)windowBound;
            return RexWindowBounds.CURRENT_ROW;
        }
        if (windowBound instanceof WindowBound.OffSetWindowBound) {
            WindowBound.OffSetWindowBound offset = (WindowBound.OffSetWindowBound)windowBound;
            if (offset.isPreceding()) {
                return RexWindowBounds.preceding((RexNode)context.relBuilder.literal((Object)offset.getOffset()));
            }
            return RexWindowBounds.following((RexNode)context.relBuilder.literal((Object)offset.getOffset()));
        }
        throw new UnsupportedOperationException("Unexpected window bound: " + String.valueOf(windowBound));
    }

    public static RelBuilder.AggCall makeAggCall(CalcitePlanContext context, BuiltinFunctionName functionName, boolean distinct, RexNode field, List<RexNode> argList) {
        return PPLFuncImpTable.INSTANCE.resolveAgg(functionName, distinct, field, argList, context);
    }

    public static List<RexInputRef> getInputRefs(RexNode node) {
        final ArrayList<RexInputRef> inputRefs = new ArrayList<RexInputRef>();
        node.accept((RexVisitor)new RexVisitorImpl<Void>(true){

            public Void visitInputRef(RexInputRef inputRef) {
                if (!inputRefs.contains(inputRef)) {
                    inputRefs.add(inputRef);
                }
                return null;
            }
        });
        return inputRefs;
    }

    public static List<RexInputRef> getInputRefs(List<RexNode> nodes) {
        return nodes.stream().flatMap(node -> PlanUtils.getInputRefs(node).stream()).toList();
    }

    public static List<RexInputRef> getInputRefsFromAggCall(List<RelBuilder.AggCall> aggCalls) {
        return aggCalls.stream().map(RelBuilder.AggCall::over).map(RelBuilder.OverCall::toRex).flatMap(rex -> PlanUtils.getInputRefs(rex).stream()).toList();
    }

    public static UnresolvedPlan getRelation(UnresolvedPlan node) {
        AbstractNodeVisitor<Relation, Object> relationVisitor = new AbstractNodeVisitor<Relation, Object>(){

            @Override
            public Relation visitRelation(Relation node, Object context) {
                return node;
            }
        };
        return node.getChild().getFirst().accept(relationVisitor, null);
    }

    public static RelOptTable findTable(RelNode root) {
        try {
            RelHomogeneousShuttle visitor = new RelHomogeneousShuttle(){

                public RelNode visit(TableScan scan) {
                    RelOptTable scanTable = scan.getTable();
                    throw new Util.FoundOne((Object)scanTable);
                }
            };
            root.accept((RelShuttle)visitor);
            return null;
        }
        catch (Util.FoundOne e) {
            Util.swallow((Throwable)e, null);
            return (RelOptTable)e.getNode();
        }
    }

    public static void transformPlanToAttachChild(UnresolvedPlan node, final UnresolvedPlan child) {
        AbstractNodeVisitor<Void, Object> leafVisitor = new AbstractNodeVisitor<Void, Object>(){

            @Override
            public Void visitChildren(Node node, Object context) {
                if (node.getChild() == null || node.getChild().isEmpty()) {
                    ((UnresolvedPlan)node).attach(child);
                } else {
                    node.getChild().forEach(child -> child.accept(this, context));
                }
                return null;
            }
        };
        node.accept(leafVisitor, null);
    }

    public static RexNode derefMapCall(RexNode rexNode) {
        RexCall call;
        if (rexNode instanceof RexCall && (call = (RexCall)rexNode).getOperator() == SqlStdOperatorTable.MAP_VALUE_CONSTRUCTOR) {
            return (RexNode)call.getOperands().get(1);
        }
        return rexNode;
    }
}

