/*
 * Decompiled with CFR 0.152.
 */
package org.apache.kylin.rec.util;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.LinkedList;
import java.util.Map;
import lombok.Generated;
import org.apache.calcite.sql.SqlBasicCall;
import org.apache.calcite.sql.SqlCall;
import org.apache.calcite.sql.SqlIdentifier;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlLiteral;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlNodeList;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.SqlSelect;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.util.SqlVisitor;
import org.apache.kylin.common.KylinConfig;
import org.apache.kylin.common.util.Pair;
import org.apache.kylin.metadata.model.tool.CalciteParser;
import org.apache.kylin.metadata.project.NProjectManager;
import org.apache.kylin.query.IQueryTransformer;
import org.apache.kylin.query.util.AbstractSqlVisitor;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class OptimizeTransformer
implements IQueryTransformer {
    @Generated
    private static final Logger log = LoggerFactory.getLogger(OptimizeTransformer.class);

    public String transform(String originSql, String project, String defaultSchema) {
        try {
            KylinConfig config = NProjectManager.getProjectConfig((String)project);
            String sql = originSql;
            long start = System.currentTimeMillis();
            for (int i = 0; i < config.getOptimizeTransformerMaxIterations(); ++i) {
                SqlNode node = CalciteParser.parse((String)sql);
                ConditionCounter counter = new ConditionCounter(originSql);
                MergeOrExprAndCollectInExprMatcher matcher = new MergeOrExprAndCollectInExprMatcher(originSql);
                if (counter.isBelowThreshold(node) || !matcher.isMergeableOrCollapsible(node)) break;
                sql = this.updateCondition(sql, matcher.getToUpdateNodes());
            }
            log.debug("OptimizeTransformer cost: {} ms, the transformed sql is:\n{}", (Object)(System.currentTimeMillis() - start), (Object)sql);
            return sql;
        }
        catch (Exception e) {
            log.warn("Something unexpected in OptimizeTransformer, return original query", (Throwable)e);
            return originSql;
        }
    }

    private String updateCondition(String originSql, Map<SqlBasicCall, Action> toUpdateNodes) {
        ArrayList<SqlBasicCall> allNodes = new ArrayList<SqlBasicCall>(toUpdateNodes.keySet());
        HashMap nodePositions = new HashMap();
        allNodes.forEach(node -> nodePositions.put(node, CalciteParser.getReplacePos((SqlNode)node, (String)originSql)));
        allNodes.sort((o1, o2) -> (Integer)((Pair)nodePositions.get(o2)).getFirst() - (Integer)((Pair)nodePositions.get(o1)).getFirst());
        String sql = originSql + " ";
        for (SqlBasicCall call : allNodes) {
            Pair pos = (Pair)nodePositions.get(call);
            if (toUpdateNodes.get(call) == Action.DELETE) {
                sql = this.dropOrNode(sql, call, (Pair<Integer, Integer>)pos);
                continue;
            }
            if (toUpdateNodes.get(call) == Action.MODIFY) {
                sql = this.collapseNode(originSql, sql, call, (Pair<Integer, Integer>)pos);
                continue;
            }
            throw new IllegalStateException("Not support action: " + (Object)((Object)toUpdateNodes.get(call)));
        }
        return sql.trim();
    }

    private String dropOrNode(String sql, SqlBasicCall call, Pair<Integer, Integer> pos) {
        String left = sql.substring(0, (Integer)pos.getFirst()).trim();
        String right = sql.substring((Integer)pos.getSecond()).trim();
        while (left.endsWith("(") && right.startsWith(")")) {
            left = left.substring(0, left.length() - 1).trim();
            right = right.substring(1).trim();
        }
        if (left.length() > 2 && left.substring(left.length() - 2).equalsIgnoreCase("or")) {
            left = left.substring(0, left.length() - 2).trim();
        } else if (right.length() > 2 && right.substring(0, 2).equalsIgnoreCase("or")) {
            right = right.substring(2).trim();
        } else {
            throw new IllegalStateException("SqlNode could not be delete form sql: " + call);
        }
        return left + " " + right;
    }

    private String collapseNode(String originSql, String sql, SqlBasicCall in, Pair<Integer, Integer> pos) {
        assert (in.getKind() == SqlKind.IN);
        assert (in.operand(1) instanceof SqlNodeList);
        LinkedHashSet<String> newCommaList = new LinkedHashSet<String>();
        for (SqlNode para : ((SqlNodeList)in.operand(1)).getList()) {
            if (!newCommaList.isEmpty() && para instanceof SqlLiteral) continue;
            newCommaList.add(this.nodeToString(para, originSql));
        }
        String left = sql.substring(0, (Integer)pos.getFirst()).trim();
        String right = sql.substring((Integer)pos.getSecond()).trim();
        if (newCommaList.size() == 1) {
            return left + " " + this.nodeToString(in.operand(0), originSql) + " = " + newCommaList.toArray()[0] + " " + right;
        }
        return left + " " + this.nodeToString(in.operand(0), originSql) + " IN (" + String.join((CharSequence)", ", newCommaList) + ") " + right;
    }

    private String nodeToString(SqlNode node, String originSql) {
        Pair pos = CalciteParser.getReplacePos((SqlNode)node, (String)originSql);
        return originSql.substring((Integer)pos.getFirst(), (Integer)pos.getSecond());
    }

    static class MergeOrExprAndCollectInExprMatcher
    extends AbstractSqlVisitor {
        private final Map<SqlBasicCall, Action> toUpdateNodes = new HashMap<SqlBasicCall, Action>();

        protected MergeOrExprAndCollectInExprMatcher(String originSql) {
            super(originSql);
        }

        private boolean isMergeableOrCollapsible(SqlNode sql) {
            sql.accept((SqlVisitor)this);
            return !this.toUpdateNodes.isEmpty();
        }

        private Map<SqlBasicCall, Action> getToUpdateNodes() {
            return this.toUpdateNodes;
        }

        protected void visitInSqlWhere(SqlNode node) {
            LinkedList<SqlNode> conditions = new LinkedList<SqlNode>();
            conditions.add(node);
            while (!conditions.isEmpty()) {
                SqlNode cond = (SqlNode)conditions.poll();
                if (!(cond instanceof SqlBasicCall)) continue;
                SqlBasicCall call = (SqlBasicCall)cond;
                SqlKind kind = call.getOperator().getKind();
                if (kind == SqlKind.IN) {
                    this.collectInCommaList(call);
                    continue;
                }
                if (kind == SqlKind.AND) {
                    conditions.addAll(call.getOperandList());
                    continue;
                }
                if (kind == SqlKind.OR) {
                    this.mergeOrCondition(call);
                    conditions.addAll(call.getOperandList());
                    continue;
                }
                call.getOperator().acceptCall((SqlVisitor)this, (SqlCall)call);
            }
        }

        private void collectInCommaList(SqlBasicCall in) {
            assert (in.getKind() == SqlKind.IN);
            if (in.operand(1) instanceof SqlNodeList) {
                this.toUpdateNodes.put(in, Action.MODIFY);
            } else if (in.operand(1) instanceof SqlSelect) {
                in.operand(1).accept((SqlVisitor)this);
            } else {
                throw new IllegalStateException("Unsupported sql syntax");
            }
        }

        private void mergeOrCondition(SqlBasicCall topOr) {
            while (true) {
                SqlNode downRight;
                SqlNode topLeft = topOr.operand(0);
                SqlNode topRight = topOr.operand(1);
                if (topLeft.getKind() != SqlKind.OR) {
                    SqlNode tmp = topLeft;
                    topLeft = topRight;
                    topRight = tmp;
                }
                if (topLeft.getKind() != SqlKind.OR) {
                    return;
                }
                SqlBasicCall downOr = (SqlBasicCall)topLeft;
                SqlNode downLeft = downOr.operand(0);
                if (this.isMergeable(downLeft, downRight = downOr.operand(1))) {
                    assert (downLeft instanceof SqlBasicCall);
                    assert (downRight instanceof SqlBasicCall);
                    this.mergeToLeft((SqlBasicCall)downLeft, (SqlBasicCall)downRight);
                    topOr.setOperand(0, downLeft);
                    topOr.setOperand(1, topRight);
                    this.toUpdateNodes.put((SqlBasicCall)downLeft, Action.MODIFY);
                    this.toUpdateNodes.put((SqlBasicCall)downRight, Action.DELETE);
                    continue;
                }
                if (this.isMergeable(downLeft, topRight)) {
                    assert (downLeft instanceof SqlBasicCall);
                    assert (topRight instanceof SqlBasicCall);
                    this.mergeToLeft((SqlBasicCall)topRight, (SqlBasicCall)downLeft);
                    topOr.setOperand(0, topRight);
                    topOr.setOperand(1, downRight);
                    this.toUpdateNodes.put((SqlBasicCall)topRight, Action.MODIFY);
                    this.toUpdateNodes.put((SqlBasicCall)downLeft, Action.DELETE);
                    continue;
                }
                if (!this.isMergeable(topRight, downRight)) break;
                assert (topRight instanceof SqlBasicCall);
                assert (downRight instanceof SqlBasicCall);
                this.mergeToLeft((SqlBasicCall)topRight, (SqlBasicCall)downRight);
                topOr.setOperand(0, topRight);
                topOr.setOperand(1, downLeft);
                this.toUpdateNodes.put((SqlBasicCall)topRight, Action.MODIFY);
                this.toUpdateNodes.put((SqlBasicCall)downRight, Action.DELETE);
            }
        }

        private boolean inCommaListOrIsEqual(SqlNode node) {
            return node.getKind() == SqlKind.IN && ((SqlBasicCall)node).operand(1) instanceof SqlNodeList || node.getKind() == SqlKind.EQUALS;
        }

        private void mvIdentifierToLeft(SqlBasicCall call) {
            if (call.getKind() == SqlKind.EQUALS && !(call.operand(0) instanceof SqlIdentifier) && call.operand(1) instanceof SqlIdentifier) {
                SqlNode tmp = call.operand(0);
                call.setOperand(0, call.operand(1));
                call.setOperand(1, tmp);
            }
        }

        private boolean isMergeable(SqlNode cond1, SqlNode cond2) {
            if (!this.inCommaListOrIsEqual(cond1) || !this.inCommaListOrIsEqual(cond2)) {
                return false;
            }
            assert (cond1 instanceof SqlBasicCall);
            assert (cond2 instanceof SqlBasicCall);
            this.mvIdentifierToLeft((SqlBasicCall)cond1);
            this.mvIdentifierToLeft((SqlBasicCall)cond2);
            return ((SqlBasicCall)cond1).operand(0).toString().equals(((SqlBasicCall)cond2).operand(0).toString());
        }

        private void mergeToLeft(SqlBasicCall leftCall, SqlBasicCall rightCall) {
            if (leftCall.getKind() == SqlKind.EQUALS) {
                leftCall.setOperator((SqlOperator)SqlStdOperatorTable.IN);
                if (rightCall.getKind() == SqlKind.EQUALS) {
                    leftCall.setOperand(1, (SqlNode)new SqlNodeList(Arrays.asList(leftCall.operand(1), rightCall.operand(1)), leftCall.operand(1).getParserPosition()));
                } else if (rightCall.getKind() == SqlKind.IN) {
                    SqlNodeList nodeList = (SqlNodeList)rightCall.operand(1);
                    nodeList.add(leftCall.operand(1));
                    leftCall.setOperand(1, (SqlNode)nodeList);
                }
            } else if (leftCall.getKind() == SqlKind.IN) {
                if (rightCall.getKind() == SqlKind.EQUALS) {
                    ((SqlNodeList)leftCall.operand(1)).add(rightCall.operand(1));
                } else if (rightCall.getKind() == SqlKind.IN) {
                    ((SqlNodeList)leftCall.operand(1)).getList().addAll(((SqlNodeList)rightCall.operand(1)).getList());
                }
            }
        }
    }

    static class ConditionCounter
    extends AbstractSqlVisitor {
        int orConditionCnt = 0;

        protected ConditionCounter(String originSql) {
            super(originSql);
        }

        private boolean isBelowThreshold(SqlNode sql) {
            sql.accept((SqlVisitor)this);
            return this.orConditionCnt < KylinConfig.getInstanceFromEnv().getOptimizeTransformerConditionCountThreshold();
        }

        protected void visitInSqlWhere(SqlNode node) {
            LinkedList<SqlNode> conditions = new LinkedList<SqlNode>();
            conditions.add(node);
            while (!conditions.isEmpty()) {
                SqlNode cond = (SqlNode)conditions.poll();
                if (!(cond instanceof SqlBasicCall)) continue;
                SqlBasicCall call = (SqlBasicCall)cond;
                SqlKind kind = call.getOperator().getKind();
                if (kind == SqlKind.AND) {
                    conditions.addAll(call.getOperandList());
                    continue;
                }
                if (kind == SqlKind.OR) {
                    conditions.addAll(call.getOperandList());
                    ++this.orConditionCnt;
                    continue;
                }
                call.getOperator().acceptCall((SqlVisitor)this, (SqlCall)call);
            }
        }
    }

    private static enum Action {
        MODIFY,
        DELETE;

    }
}

