/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.knn.search.processor.mmr;

import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import lombok.Generated;
import org.opensearch.action.IndicesRequest;
import org.opensearch.action.OriginalIndices;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.cluster.metadata.IndexMetadata;
import org.opensearch.core.action.ActionListener;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.util.KNNClusterUtil;
import org.opensearch.knn.search.extension.MMRSearchExtBuilder;
import org.opensearch.knn.search.processor.mmr.MMRQueryTransformer;
import org.opensearch.knn.search.processor.mmr.MMRRerankContext;
import org.opensearch.knn.search.processor.mmr.MMRTransformContext;
import org.opensearch.knn.search.processor.mmr.MMRUtil;
import org.opensearch.knn.search.processor.mmr.MMRVectorFieldInfo;
import org.opensearch.search.fetch.StoredFieldsContext;
import org.opensearch.search.fetch.subphase.FetchSourceContext;
import org.opensearch.search.pipeline.PipelineProcessingContext;
import org.opensearch.search.pipeline.Processor;
import org.opensearch.search.pipeline.ProcessorGenerationContext;
import org.opensearch.search.pipeline.SearchRequestProcessor;
import org.opensearch.search.pipeline.SystemGeneratedProcessor;
import org.opensearch.transport.client.Client;

public class MMROverSampleProcessor
implements SearchRequestProcessor,
SystemGeneratedProcessor {
    public static final String TYPE = "mmr_over_sample";
    public static final String DESCRIPTION = "This is a system generated processor that will modify the query size andk of the knn query to oversample for Maximal Marginal Relevance rerank.";
    private static final int DEFAULT_QUERY_SIZE_INDICATOR = -1;
    private static final int DEFAULT_QUERY_SIZE = 10;
    private static final int DEFAULT_OVERSAMPLE_SCALE = 3;
    private final String tag;
    private final boolean ignoreFailure;
    private final Client client;
    private final Map<String, MMRQueryTransformer<? extends QueryBuilder>> mmrQueryTransformers;

    public MMROverSampleProcessor(String tag, boolean ignoreFailure, Client client, Map<String, MMRQueryTransformer<? extends QueryBuilder>> mmrQueryTransformers) {
        this.tag = tag;
        this.ignoreFailure = ignoreFailure;
        this.client = client;
        this.mmrQueryTransformers = mmrQueryTransformers;
    }

    public SearchRequest processRequest(SearchRequest searchRequest) {
        throw new UnsupportedOperationException(String.format(Locale.ROOT, "Should not try to use %s to process a search request synchronously.", TYPE));
    }

    public SearchRequest processRequest(SearchRequest request, PipelineProcessingContext requestContext) {
        throw new UnsupportedOperationException(String.format(Locale.ROOT, "Should not try to use %s to process a search request synchronously.", TYPE));
    }

    public void processRequestAsync(SearchRequest request, PipelineProcessingContext requestContext, ActionListener<SearchRequest> requestListener) {
        try {
            if (request == null || request.source() == null || request.source().ext() == null) {
                throw new IllegalStateException(String.format(Locale.ROOT, "Search request passed to %s search request processor must have mmr search extension.", TYPE));
            }
            MMRSearchExtBuilder mmrSearchExtBuilder = this.extractMMRExtension(request);
            String[] allTargetIndices = request.indices();
            String remoteSeparator = String.valueOf(':');
            List<String> remoteIndices = this.splitIndices(allTargetIndices, remoteSeparator, true);
            List<String> localIndices = this.splitIndices(allTargetIndices, remoteSeparator, false);
            MMRRerankContext mmrRerankContext = new MMRRerankContext();
            mmrRerankContext.setDiversity(mmrSearchExtBuilder.getDiversity());
            this.validateForRemoteIndices(mmrSearchExtBuilder, remoteIndices);
            int candidates = this.computeCandidatesAndSetRequestSize(mmrRerankContext, request, mmrSearchExtBuilder);
            this.preserveAndEnableFullSource(request, mmrRerankContext);
            OriginalIndices localIndicesSearchRequest = new OriginalIndices((String[])localIndices.toArray(String[]::new), request.indicesOptions());
            List<IndexMetadata> localIndexMetadataList = this.getLocalIndexMetadata(localIndicesSearchRequest);
            String userProvidedVectorFieldPath = mmrSearchExtBuilder.getVectorFieldPath();
            VectorDataType userProvidedVectorDataType = mmrSearchExtBuilder.getVectorFieldDataType();
            SpaceType userProvidedSpaceType = mmrSearchExtBuilder.getSpaceType();
            MMRTransformContext mmrTransformContext = new MMRTransformContext(candidates, mmrRerankContext, localIndexMetadataList, remoteIndices, userProvidedSpaceType, userProvidedVectorFieldPath, userProvidedVectorDataType, this.client, false);
            if (userProvidedVectorFieldPath != null) {
                this.processWithUserProvidedVectorFieldPath(request, requestContext, requestListener, mmrTransformContext);
                return;
            }
            this.transformQueryForMMR(request, requestListener, mmrTransformContext, requestContext);
        }
        catch (Exception e) {
            requestListener.onFailure(e);
        }
    }

    private void processWithUserProvidedVectorFieldPath(SearchRequest request, PipelineProcessingContext requestContext, ActionListener<SearchRequest> requestListener, MMRTransformContext mmrTransformContext) {
        try {
            String userProvidedVectorFieldPath = mmrTransformContext.getUserProvidedVectorFieldPath();
            SpaceType userProvidedSpaceType = mmrTransformContext.getUserProvidedSpaceType();
            VectorDataType userProvidedVectorDataType = mmrTransformContext.getUserProvidedVectorDataType();
            List<IndexMetadata> localIndexMetadataList = mmrTransformContext.getLocalIndexMetadataList();
            MMRRerankContext mmrRerankContext = mmrTransformContext.getMmrRerankContext();
            mmrRerankContext.setVectorFieldPath(userProvidedVectorFieldPath);
            MMRUtil.resolveKnnVectorFieldInfo(userProvidedVectorFieldPath, userProvidedSpaceType, userProvidedVectorDataType, localIndexMetadataList, this.client, (ActionListener<MMRVectorFieldInfo>)ActionListener.wrap(vectorFieldInfo -> {
                mmrRerankContext.setVectorDataType(vectorFieldInfo.getVectorDataType());
                mmrRerankContext.setSpaceType(vectorFieldInfo.getSpaceType());
                mmrTransformContext.setVectorFieldInfoResolved(true);
                this.transformQueryForMMR(request, requestListener, mmrTransformContext, requestContext);
            }, arg_0 -> requestListener.onFailure(arg_0)));
        }
        catch (Exception e) {
            requestListener.onFailure(e);
        }
    }

    private MMRSearchExtBuilder extractMMRExtension(SearchRequest request) {
        return request.source().ext().stream().filter(MMRSearchExtBuilder.class::isInstance).map(MMRSearchExtBuilder.class::cast).findFirst().orElseThrow(() -> new IllegalStateException(String.format(Locale.ROOT, "SearchRequest passed to %s processor must have an MMRSearchExtBuilder", TYPE)));
    }

    private List<String> splitIndices(String[] indices, String separator, boolean remote) {
        return Arrays.stream(indices).filter(index -> index.contains(separator) == remote).toList();
    }

    private void validateForRemoteIndices(MMRSearchExtBuilder mmrSearchExtBuilder, List<String> remoteIndices) {
        if (remoteIndices.isEmpty()) {
            return;
        }
        String indicesString = String.join((CharSequence)",", remoteIndices);
        SpaceType spaceType = mmrSearchExtBuilder.getSpaceType();
        if (spaceType == null) {
            throw new IllegalArgumentException(String.format(Locale.ROOT, "%s is required in the MMR query extension when querying remote indices [%s].", MMRSearchExtBuilder.VECTOR_FIELD_SPACE_TYPE_FIELD.getPreferredName(), indicesString));
        }
        VectorDataType vectorDataType = mmrSearchExtBuilder.getVectorFieldDataType();
        if (vectorDataType == null) {
            throw new IllegalArgumentException(String.format(Locale.ROOT, "%s is required in the MMR query extension when querying remote indices [%s].", MMRSearchExtBuilder.VECTOR_FIELD_DATA_TYPE_FIELD.getPreferredName(), indicesString));
        }
    }

    private List<IndexMetadata> getLocalIndexMetadata(OriginalIndices localIndicesSearchRequest) {
        return KNNClusterUtil.instance().getIndexMetadataList((IndicesRequest)localIndicesSearchRequest);
    }

    private int computeCandidatesAndSetRequestSize(MMRRerankContext mmrRerankContext, SearchRequest request, MMRSearchExtBuilder mmrSearchExtBuilder) {
        int originalQuerySize = request.source().size();
        if (originalQuerySize == -1) {
            originalQuerySize = 10;
        }
        mmrRerankContext.setOriginalQuerySize(originalQuerySize);
        Integer candidates = mmrSearchExtBuilder.getCandidates();
        if (candidates == null) {
            candidates = 3 * originalQuerySize;
        }
        request.source().size(candidates.intValue());
        return candidates;
    }

    private void preserveAndEnableFullSource(SearchRequest request, MMRRerankContext mmrContext) {
        FetchSourceContext currentSourceContext = request.source().fetchSource();
        StoredFieldsContext storedFieldsContext = request.source().storedFields();
        if (storedFieldsContext != null) {
            if (this.isStoredFieldsDisabled(storedFieldsContext)) {
                this.handleDisabledStoredFields(request, mmrContext, currentSourceContext);
                return;
            }
            if (this.isSourceNotExplicitlySet(currentSourceContext)) {
                this.enableSourceTemporarily(request, mmrContext);
                return;
            }
        }
        if (this.isAlreadyFetchingFullSource(currentSourceContext)) {
            return;
        }
        this.preserveAndEnableFullSourceFetch(request, mmrContext, currentSourceContext);
    }

    private boolean isStoredFieldsDisabled(StoredFieldsContext context) {
        return !context.fetchFields();
    }

    private boolean isSourceNotExplicitlySet(FetchSourceContext sourceContext) {
        return sourceContext == null;
    }

    private boolean isAlreadyFetchingFullSource(FetchSourceContext sourceContext) {
        if (sourceContext == null) {
            return true;
        }
        boolean fetchingAll = sourceContext.fetchSource();
        boolean noIncludes = sourceContext.includes().length == 0;
        boolean noExcludes = sourceContext.excludes().length == 0;
        return fetchingAll && noIncludes && noExcludes;
    }

    private void handleDisabledStoredFields(SearchRequest request, MMRRerankContext mmrContext, FetchSourceContext currentSourceContext) {
        if (currentSourceContext != null) {
            throw new IllegalArgumentException("[stored_fields] cannot be disabled if [_source] is requested");
        }
        mmrContext.setOriginalFetchSourceContext(new FetchSourceContext(false));
        request.source().storedFields(StoredFieldsContext.fromList(Collections.emptyList()));
        request.source().fetchSource(new FetchSourceContext(true));
    }

    private void enableSourceTemporarily(SearchRequest request, MMRRerankContext mmrContext) {
        mmrContext.setOriginalFetchSourceContext(new FetchSourceContext(false));
        request.source().fetchSource(new FetchSourceContext(true));
    }

    private void preserveAndEnableFullSourceFetch(SearchRequest request, MMRRerankContext mmrContext, FetchSourceContext currentSourceContext) {
        mmrContext.setOriginalFetchSourceContext(currentSourceContext);
        request.source().fetchSource(new FetchSourceContext(true));
    }

    private void transformQueryForMMR(final SearchRequest request, final ActionListener<SearchRequest> requestListener, final MMRTransformContext mmrTransformationContext, final PipelineProcessingContext requestContext) {
        QueryBuilder queryBuilder = request.source().query();
        if (queryBuilder == null) {
            throw new IllegalArgumentException("Query builder must not be null to do Maximal Marginal Relevance rerank.");
        }
        MMRQueryTransformer<? extends QueryBuilder> transformer = this.mmrQueryTransformers.get(queryBuilder.getWriteableName());
        if (transformer == null) {
            throw new IllegalArgumentException(String.format(Locale.ROOT, "Maximal Marginal Relevance rerank doesn't support the query type [%s]", queryBuilder.getWriteableName()));
        }
        transformer.transform((QueryBuilder)queryBuilder, new ActionListener<Void>(this){

            public void onResponse(Void unused) {
                requestContext.setAttribute("mmr.rerank_context", (Object)mmrTransformationContext.getMmrRerankContext());
                requestListener.onResponse((Object)request);
            }

            public void onFailure(Exception e) {
                requestListener.onFailure(e);
            }
        }, mmrTransformationContext);
    }

    public SystemGeneratedProcessor.ExecutionStage getExecutionStage() {
        return SystemGeneratedProcessor.ExecutionStage.POST_USER_DEFINED;
    }

    public String getType() {
        return TYPE;
    }

    public String getTag() {
        return this.tag;
    }

    public String getDescription() {
        return DESCRIPTION;
    }

    public boolean isIgnoreFailure() {
        return this.ignoreFailure;
    }

    public static class MMROverSampleProcessorFactory
    implements SystemGeneratedProcessor.SystemGeneratedFactory<SearchRequestProcessor> {
        public static final String TYPE = "mmr_over_sample_factory";
        private final Client client;
        private final Map<String, MMRQueryTransformer<? extends QueryBuilder>> mmrQueryTransformers;

        public boolean shouldGenerate(ProcessorGenerationContext processorGenerationContext) {
            return MMRUtil.shouldGenerateMMRProcessor(processorGenerationContext);
        }

        public SearchRequestProcessor create(Map<String, Processor.Factory<SearchRequestProcessor>> processorFactories, String tag, String description, boolean ignoreFailure, Map<String, Object> config, Processor.PipelineContext pipelineContext) throws Exception {
            return new MMROverSampleProcessor(tag, ignoreFailure, this.client, this.mmrQueryTransformers);
        }

        @Generated
        public MMROverSampleProcessorFactory(Client client, Map<String, MMRQueryTransformer<? extends QueryBuilder>> mmrQueryTransformers) {
            this.client = client;
            this.mmrQueryTransformers = mmrQueryTransformers;
        }
    }
}

