/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.engine.algorithms.agent;

import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.atomic.AtomicReference;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import lombok.Generated;
import org.apache.commons.text.StringSubstitutor;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.core.common.Strings;
import org.opensearch.ml.common.agent.MLAgent;
import org.opensearch.ml.common.agent.MLToolSpec;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.common.spi.tools.Tool;
import org.opensearch.ml.common.utils.StringUtils;
import org.opensearch.ml.engine.memory.ConversationIndexMemory;

public class AgentUtils {
    @Generated
    private static final Logger log = LogManager.getLogger(AgentUtils.class);
    public static final String SELECTED_TOOLS = "selected_tools";
    public static final String PROMPT_PREFIX = "prompt.prefix";
    public static final String PROMPT_SUFFIX = "prompt.suffix";
    public static final String RESPONSE_FORMAT_INSTRUCTION = "prompt.format_instruction";
    public static final String TOOL_RESPONSE = "prompt.tool_response";
    public static final String PROMPT_CHAT_HISTORY_PREFIX = "prompt.chat_history_prefix";
    public static final String DISABLE_TRACE = "disable_trace";
    public static final String VERBOSE = "verbose";
    public static final String LLM_GEN_INPUT = "llm_generated_input";
    public static List<String> MODEL_RESPONSE_PATTERNS = List.of("\\{\\s*(\"(thought|action|action_input|final_answer)\"\\s*:\\s*\".*?\"\\s*,?\\s*)+\\}");

    public static String addExamplesToPrompt(Map<String, String> parameters, String prompt) {
        HashMap<String, String> examplesMap = new HashMap<String, String>();
        if (parameters.containsKey("examples")) {
            String examples = parameters.get("examples");
            List exampleList = (List)StringUtils.gson.fromJson(examples, List.class);
            StringBuilder exampleBuilder = new StringBuilder();
            exampleBuilder.append("EXAMPLES\n--------\n");
            String examplesPrefix = Optional.ofNullable(parameters.get("examples.prefix")).orElse("You should follow and learn from examples defined in <examples>: \n<examples>\n");
            String examplesSuffix = Optional.ofNullable(parameters.get("examples.suffix")).orElse("</examples>\n");
            exampleBuilder.append(examplesPrefix);
            String examplePrefix = Optional.ofNullable(parameters.get("examples.example.prefix")).orElse("<example>\n");
            String exampleSuffix = Optional.ofNullable(parameters.get("examples.example.suffix")).orElse("\n</example>\n");
            for (String example : exampleList) {
                exampleBuilder.append(examplePrefix).append(example).append(exampleSuffix);
            }
            exampleBuilder.append(examplesSuffix);
            examplesMap.put("examples", exampleBuilder.toString());
        } else {
            examplesMap.put("examples", "");
        }
        StringSubstitutor substitutor = new StringSubstitutor(examplesMap, "${parameters.", "}");
        return substitutor.replace(prompt);
    }

    public static String addPrefixSuffixToPrompt(Map<String, String> parameters, String prompt) {
        HashMap<String, String> prefixMap = new HashMap<String, String>();
        String prefix = parameters.getOrDefault(PROMPT_PREFIX, "");
        String suffix = parameters.getOrDefault(PROMPT_SUFFIX, "");
        prefixMap.put(PROMPT_PREFIX, prefix);
        prefixMap.put(PROMPT_SUFFIX, suffix);
        StringSubstitutor substitutor = new StringSubstitutor(prefixMap, "${parameters.", "}");
        return substitutor.replace(prompt);
    }

    public static String addToolsToPrompt(Map<String, Tool> tools, Map<String, String> parameters, List<String> inputTools, String prompt) {
        StringBuilder toolsBuilder = new StringBuilder();
        StringBuilder toolNamesBuilder = new StringBuilder();
        String toolsPrefix = Optional.ofNullable(parameters.get("agent.tools.prefix")).orElse("You have access to the following tools defined in <tools>: \n<tools>\n");
        String toolsSuffix = Optional.ofNullable(parameters.get("agent.tools.suffix")).orElse("</tools>\n");
        String toolPrefix = Optional.ofNullable(parameters.get("agent.tools.tool.prefix")).orElse("<tool>\n");
        String toolSuffix = Optional.ofNullable(parameters.get("agent.tools.tool.suffix")).orElse("\n</tool>\n");
        toolsBuilder.append(toolsPrefix);
        for (String toolName : inputTools) {
            if (!tools.containsKey(toolName)) {
                throw new IllegalArgumentException("Tool [" + toolName + "] not registered for model");
            }
            toolsBuilder.append(toolPrefix).append(toolName).append(": ").append(tools.get(toolName).getDescription()).append(toolSuffix);
            toolNamesBuilder.append(toolName).append(", ");
        }
        toolsBuilder.append(toolsSuffix);
        HashMap<String, String> toolsPromptMap = new HashMap<String, String>();
        toolsPromptMap.put("tool_descriptions", toolsBuilder.toString());
        toolsPromptMap.put("tool_names", toolNamesBuilder.substring(0, toolNamesBuilder.length() - 1));
        if (parameters.containsKey("tool_descriptions")) {
            toolsPromptMap.put("tool_descriptions", parameters.get("tool_descriptions"));
        }
        if (parameters.containsKey("tool_names")) {
            toolsPromptMap.put("tool_names", parameters.get("tool_names"));
        }
        StringSubstitutor substitutor = new StringSubstitutor(toolsPromptMap, "${parameters.", "}");
        return substitutor.replace(prompt);
    }

    public static String addIndicesToPrompt(Map<String, String> parameters, String prompt) {
        HashMap<String, String> indicesMap = new HashMap<String, String>();
        if (parameters.containsKey("opensearch_indices")) {
            String indices = parameters.get("opensearch_indices");
            List indicesList = (List)StringUtils.gson.fromJson(indices, List.class);
            StringBuilder indicesBuilder = new StringBuilder();
            String indicesPrefix = Optional.ofNullable(parameters.get("opensearch_indices.prefix")).orElse("You have access to the following OpenSearch Index defined in <opensearch_indexes>: \n<opensearch_indexes>\n");
            String indicesSuffix = Optional.ofNullable(parameters.get("opensearch_indices.suffix")).orElse("</opensearch_indexes>\n");
            String indexPrefix = Optional.ofNullable(parameters.get("opensearch_indices.index.prefix")).orElse("<index>\n");
            String indexSuffix = Optional.ofNullable(parameters.get("opensearch_indices.index.suffix")).orElse("\n</index>\n");
            indicesBuilder.append(indicesPrefix);
            for (String e : indicesList) {
                indicesBuilder.append(indexPrefix).append(e).append(indexSuffix);
            }
            indicesBuilder.append(indicesSuffix);
            indicesMap.put("opensearch_indices", indicesBuilder.toString());
        } else {
            indicesMap.put("opensearch_indices", "");
        }
        StringSubstitutor substitutor = new StringSubstitutor(indicesMap, "${parameters.", "}");
        return substitutor.replace(prompt);
    }

    public static String addChatHistoryToPrompt(Map<String, String> parameters, String prompt) {
        HashMap<String, String> chatHistoryMap = new HashMap<String, String>();
        String chatHistory = parameters.getOrDefault("chat_history", "");
        chatHistoryMap.put("chat_history", chatHistory);
        parameters.put("chat_history", chatHistory);
        StringSubstitutor substitutor = new StringSubstitutor(chatHistoryMap, "${parameters.", "}");
        return substitutor.replace(prompt);
    }

    public static String addContextToPrompt(Map<String, String> parameters, String prompt) {
        HashMap<String, String> contextMap = new HashMap<String, String>();
        contextMap.put("context", parameters.getOrDefault("context", ""));
        parameters.put("context", (String)contextMap.get("context"));
        if (!contextMap.isEmpty()) {
            StringSubstitutor substitutor = new StringSubstitutor(contextMap, "${parameters.", "}");
            return substitutor.replace(prompt);
        }
        return prompt;
    }

    public static String extractModelResponseJson(String text) {
        return AgentUtils.extractModelResponseJson(text, null);
    }

    public static Map<String, String> parseLLMOutput(ModelTensorOutput tmpModelTensorOutput, List<String> llmResponsePatterns, Set<String> inputTools) {
        HashMap<String, String> modelOutput = new HashMap<String, String>();
        Map dataAsMap = ((ModelTensor)((ModelTensors)tmpModelTensorOutput.getMlModelOutputs().get(0)).getMlModelTensors().get(0)).getDataAsMap();
        if (dataAsMap.size() == 1 && dataAsMap.containsKey("response")) {
            String llmReasoningResponse = (String)dataAsMap.get("response");
            String thoughtResponse = null;
            try {
                thoughtResponse = AgentUtils.extractModelResponseJson(llmReasoningResponse, llmResponsePatterns);
                modelOutput.put("thought_response", thoughtResponse);
            }
            catch (IllegalArgumentException e) {
                modelOutput.put("thought_response", llmReasoningResponse);
                thoughtResponse = llmReasoningResponse;
            }
            AgentUtils.parseThoughtResponse(modelOutput, thoughtResponse);
        } else {
            AgentUtils.extractParams(modelOutput, dataAsMap, "thought");
            AgentUtils.extractParams(modelOutput, dataAsMap, "action");
            AgentUtils.extractParams(modelOutput, dataAsMap, "action_input");
            AgentUtils.extractParams(modelOutput, dataAsMap, "final_answer");
            try {
                modelOutput.put("thought_response", StringUtils.toJson((Object)dataAsMap));
            }
            catch (Exception e) {
                log.warn("Failed to parse model response", (Throwable)e);
            }
        }
        String action = (String)modelOutput.get("action");
        if (action != null) {
            String matchedTool = AgentUtils.getMatchedTool(inputTools, action);
            if (matchedTool != null) {
                modelOutput.put("action", matchedTool);
            } else {
                modelOutput.remove("action");
            }
        }
        if (!modelOutput.containsKey("action") && !modelOutput.containsKey("final_answer")) {
            modelOutput.put("final_answer", (String)modelOutput.get("thought_response"));
        }
        return modelOutput;
    }

    public static String getMatchedTool(Collection<String> tools, String action) {
        for (String tool : tools) {
            if (!action.toLowerCase(Locale.ROOT).contains(tool.toLowerCase(Locale.ROOT))) continue;
            return tool;
        }
        return null;
    }

    public static void extractParams(Map<String, String> modelOutput, Map<String, ?> dataAsMap, String paramName) {
        if (dataAsMap.containsKey(paramName)) {
            modelOutput.put(paramName, StringUtils.toJson(dataAsMap.get(paramName)));
        }
    }

    public static String extractModelResponseJson(String text, List<String> llmResponsePatterns) {
        if (text.contains("```json") && (text = text.substring(text.indexOf("```json") + "```json".length())).contains("```")) {
            text = text.substring(0, text.lastIndexOf("```"));
        }
        if (StringUtils.isJson((String)(text = text.trim()))) {
            return text;
        }
        String matchedPart = null;
        if (llmResponsePatterns != null && (matchedPart = AgentUtils.findMatchedPart(text, llmResponsePatterns)) != null) {
            return matchedPart;
        }
        matchedPart = AgentUtils.findMatchedPart(text, MODEL_RESPONSE_PATTERNS);
        if (matchedPart != null) {
            return matchedPart;
        }
        throw new IllegalArgumentException("Model output is invalid");
    }

    public static void parseThoughtResponse(Map<String, String> modelOutput, String thoughtResponse) {
        if (thoughtResponse != null) {
            if (StringUtils.isJson((String)thoughtResponse)) {
                modelOutput.putAll(StringUtils.getParameterMap((Map)((Map)StringUtils.gson.fromJson(thoughtResponse, Map.class))));
            } else {
                String thought = AgentUtils.extractThought(thoughtResponse);
                String action = AgentUtils.extractAction(thoughtResponse);
                String actionInput = AgentUtils.extractActionInput(thoughtResponse);
                String finalAnswer = AgentUtils.extractFinalAnswer(thoughtResponse);
                if (thought != null) {
                    modelOutput.put("thought", thought);
                }
                if (action != null) {
                    modelOutput.put("action", action);
                }
                if (actionInput != null) {
                    modelOutput.put("action_input", actionInput);
                }
                if (finalAnswer != null) {
                    modelOutput.put("final_answer", finalAnswer);
                }
            }
        }
    }

    public static String extractFinalAnswer(String text) {
        String pattern;
        Pattern jsonBlockPattern;
        Matcher jsonBlockMatcher;
        String result = null;
        if (text.contains("\"final_answer\"") && (jsonBlockMatcher = (jsonBlockPattern = Pattern.compile(pattern = "\"final_answer\"\\s*:\\s*\"(.*)\"", 32)).matcher(text)).find()) {
            result = jsonBlockMatcher.group(1);
        }
        return result;
    }

    public static String extractThought(String text) {
        String pattern;
        Pattern jsonBlockPattern;
        Matcher jsonBlockMatcher;
        String result = null;
        if (text.contains("\"thought\"") && (jsonBlockMatcher = (jsonBlockPattern = Pattern.compile(pattern = "\"thought\"\\s*:\\s*\"(.*?)\"\\s*,\\s*[\"final_answer\"|\"action\"]", 32)).matcher(text)).find()) {
            result = jsonBlockMatcher.group(1);
        }
        return result;
    }

    public static String extractAction(String text) {
        String pattern;
        Pattern jsonBlockPattern;
        Matcher jsonBlockMatcher;
        String result = null;
        if (text.contains("\"action\"") && (jsonBlockMatcher = (jsonBlockPattern = Pattern.compile(pattern = "\"action\"\\s*:\\s*\"(.*?)(?:\"action_input\"|$)", 32)).matcher(text)).find()) {
            result = jsonBlockMatcher.group(1);
        }
        return result;
    }

    public static String extractActionInput(String text) {
        String pattern;
        Pattern jsonBlockPattern;
        Matcher jsonBlockMatcher;
        String result = null;
        if (text.contains("\"action_input\"") && (jsonBlockMatcher = (jsonBlockPattern = Pattern.compile(pattern = "\"action_input\"\\s*:\\s*\"((?:[^\\\"]|\\\")*)\"", 32)).matcher(text)).find()) {
            result = jsonBlockMatcher.group(1);
            result = result.replace("\\\"", "\"");
        }
        return result;
    }

    public static String findMatchedPart(String text, List<String> patternList) {
        for (String p : patternList) {
            Pattern pattern = Pattern.compile(p);
            Matcher matcher = pattern.matcher(text);
            if (!matcher.find()) continue;
            return matcher.group();
        }
        return null;
    }

    public static String outputToOutputString(Object output) throws PrivilegedActionException {
        ModelTensor outputModel;
        String outputString = output instanceof ModelTensorOutput ? ((outputModel = (ModelTensor)((ModelTensors)((ModelTensorOutput)output).getMlModelOutputs().get(0)).getMlModelTensors().get(0)).getDataAsMap() != null ? AccessController.doPrivileged(() -> StringUtils.gson.toJson((Object)outputModel.getDataAsMap())) : outputModel.getResult()) : (output instanceof String ? (String)output : AccessController.doPrivileged(() -> StringUtils.gson.toJson(output)));
        return outputString;
    }

    public static int getMessageHistoryLimit(Map<String, String> params) {
        String messageHistoryLimitStr = params.get("message_history_limit");
        return messageHistoryLimitStr != null ? Integer.parseInt(messageHistoryLimitStr) : ConversationIndexMemory.LAST_N_INTERACTIONS;
    }

    public static String getToolName(MLToolSpec toolSpec) {
        return toolSpec.getName() != null ? toolSpec.getName() : toolSpec.getType();
    }

    public static List<MLToolSpec> getMlToolSpecs(MLAgent mlAgent, Map<String, String> params) {
        String selectedToolsStr = params.get(SELECTED_TOOLS);
        ArrayList<MLToolSpec> toolSpecs = mlAgent.getTools();
        if (!Strings.isEmpty((CharSequence)selectedToolsStr)) {
            List selectedTools = (List)StringUtils.gson.fromJson(selectedToolsStr, List.class);
            HashMap<String, MLToolSpec> toolNameSpecMap = new HashMap<String, MLToolSpec>();
            for (MLToolSpec toolSpec : toolSpecs) {
                toolNameSpecMap.put(AgentUtils.getToolName(toolSpec), toolSpec);
            }
            ArrayList<MLToolSpec> selectedToolSpecs = new ArrayList<MLToolSpec>();
            for (String tool : selectedTools) {
                if (!toolNameSpecMap.containsKey(tool)) continue;
                selectedToolSpecs.add((MLToolSpec)toolNameSpecMap.get(tool));
            }
            toolSpecs = selectedToolSpecs;
        }
        return toolSpecs;
    }

    public static void createTools(Map<String, Tool.Factory> toolFactories, Map<String, String> params, List<MLToolSpec> toolSpecs, Map<String, Tool> tools, Map<String, MLToolSpec> toolSpecMap, MLAgent mlAgent) {
        for (MLToolSpec toolSpec : toolSpecs) {
            Tool tool = AgentUtils.createTool(toolFactories, params, toolSpec, mlAgent.getTenantId());
            tools.put(tool.getName(), tool);
            toolSpecMap.put(tool.getName(), toolSpec);
        }
    }

    public static Tool createTool(Map<String, Tool.Factory> toolFactories, Map<String, String> params, MLToolSpec toolSpec, String tenantId) {
        if (!toolFactories.containsKey(toolSpec.getType())) {
            throw new IllegalArgumentException("Tool not found: " + toolSpec.getType());
        }
        HashMap<String, String> executeParams = new HashMap<String, String>();
        if (toolSpec.getParameters() != null) {
            executeParams.putAll(toolSpec.getParameters());
        }
        executeParams.put("tenant_id", tenantId);
        for (String key : params.keySet()) {
            String toolNamePrefix;
            if (!key.startsWith(toolNamePrefix = AgentUtils.getToolName(toolSpec) + ".")) continue;
            executeParams.put(key.replace(toolNamePrefix, ""), params.get(key));
        }
        Tool tool = toolFactories.get(toolSpec.getType()).create(executeParams);
        String toolName = AgentUtils.getToolName(toolSpec);
        tool.setName(toolName);
        if (toolSpec.getDescription() != null) {
            tool.setDescription(toolSpec.getDescription());
        }
        if (params.containsKey(toolName + ".description")) {
            tool.setDescription(params.get(toolName + ".description"));
        }
        return tool;
    }

    public static List<String> getToolNames(Map<String, Tool> tools) {
        ArrayList<String> inputTools = new ArrayList<String>();
        for (Map.Entry<String, Tool> entry : tools.entrySet()) {
            String toolName = entry.getValue().getName();
            inputTools.add(toolName);
        }
        return inputTools;
    }

    public static Map<String, String> constructToolParams(Map<String, Tool> tools, Map<String, MLToolSpec> toolSpecMap, String question, AtomicReference<String> lastActionInput, String action, String actionInput) {
        HashMap<String, String> toolParams = new HashMap<String, String>();
        Map toolSpecParams = toolSpecMap.get(action).getParameters();
        Map toolSpecConfigMap = toolSpecMap.get(action).getConfigMap();
        if (toolSpecParams != null) {
            toolParams.putAll(toolSpecParams);
        }
        if (toolSpecConfigMap != null) {
            toolParams.putAll(toolSpecConfigMap);
        }
        toolParams.put(LLM_GEN_INPUT, actionInput);
        if (StringUtils.isJson((String)actionInput)) {
            Map params = StringUtils.getParameterMap((Map)((Map)StringUtils.gson.fromJson(actionInput, Map.class)));
            toolParams.putAll(params);
        }
        if (tools.get(action).useOriginalInput()) {
            toolParams.put("input", question);
            lastActionInput.set(question);
        } else if (toolSpecConfigMap != null && toolSpecConfigMap.containsKey("input")) {
            String input = (String)toolSpecConfigMap.get("input");
            StringSubstitutor substitutor = new StringSubstitutor(toolParams, "${parameters.", "}");
            input = substitutor.replace(input);
            toolParams.put("input", input);
            if (StringUtils.isJson((String)input)) {
                Map params = StringUtils.getParameterMap((Map)((Map)StringUtils.gson.fromJson(input, Map.class)));
                toolParams.putAll(params);
            }
        } else {
            toolParams.put("input", actionInput);
        }
        return toolParams;
    }
}

