/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.neuralsearch.sparse.codec;

import java.io.Closeable;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.backward_codecs.lucene101.Lucene101PostingsFormat;
import org.apache.lucene.codecs.BlockTermState;
import org.apache.lucene.codecs.DocValuesFormat;
import org.apache.lucene.codecs.DocValuesProducer;
import org.apache.lucene.codecs.NormsProducer;
import org.apache.lucene.codecs.PushPostingsWriterBase;
import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.NumericDocValues;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.store.DataOutput;
import org.apache.lucene.store.IOContext;
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.FixedBitSet;
import org.opensearch.common.util.io.IOUtils;
import org.opensearch.neuralsearch.sparse.algorithm.seismic.ClusteringTask;
import org.opensearch.neuralsearch.sparse.algorithm.seismic.RandomClusteringAlgorithm;
import org.opensearch.neuralsearch.sparse.algorithm.seismic.SeismicPostingClusterer;
import org.opensearch.neuralsearch.sparse.cache.CacheGatedForwardIndexReader;
import org.opensearch.neuralsearch.sparse.cache.CacheKey;
import org.opensearch.neuralsearch.sparse.cache.CacheableClusteredPostingWriter;
import org.opensearch.neuralsearch.sparse.cache.ClusteredPostingCache;
import org.opensearch.neuralsearch.sparse.cache.ForwardIndexCache;
import org.opensearch.neuralsearch.sparse.cache.ForwardIndexCacheItem;
import org.opensearch.neuralsearch.sparse.codec.CodecUtilWrapper;
import org.opensearch.neuralsearch.sparse.codec.SparseBinaryDocValuesPassThrough;
import org.opensearch.neuralsearch.sparse.common.IteratorWrapper;
import org.opensearch.neuralsearch.sparse.common.ValueEncoder;
import org.opensearch.neuralsearch.sparse.data.DocWeight;
import org.opensearch.neuralsearch.sparse.data.DocumentCluster;
import org.opensearch.neuralsearch.sparse.data.PostingClusters;
import org.opensearch.neuralsearch.sparse.data.SparseVector;
import org.opensearch.neuralsearch.sparse.quantization.ByteQuantizationUtil;
import org.opensearch.neuralsearch.sparse.quantization.ByteQuantizer;

public class ClusteredPostingTermsWriter
extends PushPostingsWriterBase {
    @Generated
    private static final Logger log = LogManager.getLogger(ClusteredPostingTermsWriter.class);
    private FixedBitSet docsSeen;
    private IndexOutput postingOut;
    private final List<DocWeight> docWeights = new ArrayList<DocWeight>();
    private BytesRef currentTerm;
    private SeismicPostingClusterer seismicPostingClusterer;
    private CacheKey key;
    private final String codecName;
    private final int version;
    private SegmentWriteState state;
    private DocValuesProducer docValuesProducer;
    private ByteQuantizer byteQuantizer = new ByteQuantizer(3.0f);
    private final CodecUtilWrapper codecUtilWrapper;

    public void setField(FieldInfo fieldInfo) {
        super.setField(fieldInfo);
        this.byteQuantizer = ByteQuantizationUtil.getByteQuantizerIngest(fieldInfo);
    }

    public BlockTermState write(BytesRef text, TermsEnum termsEnum, NormsProducer norms) throws IOException {
        this.currentTerm = text;
        return super.writeTerm(text, termsEnum, this.docsSeen, norms);
    }

    public BlockTermState write(BytesRef text, PostingClusters postingClusters) throws IOException {
        this.currentTerm = text;
        BlockTermState state = this.newTermState();
        this.writePostingClusters(postingClusters, state);
        return state;
    }

    public void setFieldAndMaxDoc(FieldInfo fieldInfo, int maxDoc, boolean isMerge) {
        this.setField(fieldInfo);
        this.key = new CacheKey(this.state.segmentInfo, fieldInfo);
        if (!isMerge) {
            this.setSeismicPostingClusterer(maxDoc);
        }
    }

    public BlockTermState newTermState() throws IOException {
        return new Lucene101PostingsFormat.IntBlockTermState();
    }

    public void startTerm(NumericDocValues norms) throws IOException {
        this.docWeights.clear();
    }

    private void setSeismicPostingClusterer(int maxDoc) {
        ForwardIndexCacheItem index = ForwardIndexCache.getInstance().getOrCreate(this.key, maxDoc);
        SparseBinaryDocValuesPassThrough luceneReader = null;
        DocValuesFormat fmt = this.state.segmentInfo.getCodec().docValuesFormat();
        SegmentReadState readState = new SegmentReadState(this.state.directory, this.state.segmentInfo, this.state.fieldInfos, IOContext.DEFAULT);
        try {
            this.docValuesProducer = fmt.fieldsProducer(readState);
            BinaryDocValues binaryDocValues = this.docValuesProducer.getBinary(this.fieldInfo);
            if (binaryDocValues != null) {
                luceneReader = new SparseBinaryDocValuesPassThrough(binaryDocValues, this.state.segmentInfo, this.fieldInfo);
            }
        }
        catch (Exception e) {
            log.error("Failed to retrieve lucene reader");
        }
        float clusterRatio = Float.parseFloat((String)this.fieldInfo.attributes().get("cluster_ratio"));
        int nPostings = Integer.parseInt((String)this.fieldInfo.attributes().get("n_postings")) == -1 ? Math.max((int)(5.0E-4f * (float)maxDoc), 160) : Integer.parseInt((String)this.fieldInfo.attributes().get("n_postings"));
        float summaryPruneRatio = Float.parseFloat((String)this.fieldInfo.attributes().get("summary_prune_ratio"));
        this.seismicPostingClusterer = new SeismicPostingClusterer(nPostings, new RandomClusteringAlgorithm(summaryPruneRatio, clusterRatio, new CacheGatedForwardIndexReader(index.getReader(), index.getWriter(), luceneReader)));
    }

    private void writePostingClusters(PostingClusters postingClusters, BlockTermState state) throws IOException {
        List<DocumentCluster> clusters = postingClusters.getClusters();
        state.blockFilePointer = this.postingOut.getFilePointer();
        this.postingOut.writeVLong((long)clusters.size());
        for (DocumentCluster cluster : clusters) {
            this.postingOut.writeVLong((long)cluster.size());
            Iterator<DocWeight> iterator = cluster.iterator();
            while (iterator.hasNext()) {
                DocWeight docWeight = iterator.next();
                this.postingOut.writeVInt(docWeight.getDocID());
                this.postingOut.writeByte(docWeight.getWeight());
            }
            this.postingOut.writeByte((byte)(cluster.isShouldNotSkip() ? 1 : 0));
            if (cluster.getSummary() == null) {
                this.postingOut.writeVLong(0L);
                continue;
            }
            IteratorWrapper<SparseVector.Item> iter = cluster.getSummary().iterator();
            this.postingOut.writeVLong((long)cluster.getSummary().getSize());
            while (iter.hasNext()) {
                SparseVector.Item item = iter.next();
                this.postingOut.writeVInt(item.getToken());
                this.postingOut.writeByte(item.getWeight());
            }
        }
    }

    public void finishTerm(BlockTermState state) throws IOException {
        CacheableClusteredPostingWriter writer = ClusteredPostingCache.getInstance().getOrCreate(this.key).getWriter();
        PostingClusters postingClusters = new ClusteringTask(this.currentTerm, this.docWeights, writer, this.seismicPostingClusterer).get();
        this.writePostingClusters(postingClusters, state);
        this.docWeights.clear();
        this.currentTerm = null;
    }

    public void startDoc(int docID, int freq) throws IOException {
        if (docID == -1) {
            throw new IllegalStateException("docId must be set before startDoc");
        }
        this.docWeights.add(new DocWeight(docID, this.byteQuantizer.quantize(ValueEncoder.decodeFeatureValue(freq))));
    }

    public void addPosition(int position, BytesRef payload, int startOffset, int endOffset) throws IOException {
        throw new UnsupportedOperationException();
    }

    public void finishDoc() throws IOException {
    }

    public void init(IndexOutput termsOut, SegmentWriteState state) throws IOException {
        this.postingOut = termsOut;
        this.state = state;
        this.docsSeen = new FixedBitSet(state.segmentInfo.maxDoc());
        this.codecUtilWrapper.writeIndexHeader((DataOutput)this.postingOut, this.codecName, this.version, state.segmentInfo.getId(), state.segmentSuffix);
    }

    public void encodeTerm(DataOutput out, FieldInfo fieldInfo, BlockTermState state, boolean absolute) throws IOException {
        throw new UnsupportedOperationException();
    }

    public void close() throws IOException {
        this.codecUtilWrapper.writeFooter(this.postingOut);
        if (this.docValuesProducer != null) {
            this.docValuesProducer.close();
            this.docValuesProducer = null;
        }
    }

    public void closeWithException() {
        IOUtils.closeWhileHandlingException((Closeable)this.postingOut);
        if (this.docValuesProducer != null) {
            IOUtils.closeWhileHandlingException((Closeable)this.docValuesProducer);
            this.docValuesProducer = null;
        }
    }

    public void close(long startFp) throws IOException {
        this.postingOut.writeLong(startFp);
        this.close();
    }

    @Generated
    public ClusteredPostingTermsWriter(String codecName, int version, CodecUtilWrapper codecUtilWrapper) {
        this.codecName = codecName;
        this.version = version;
        this.codecUtilWrapper = codecUtilWrapper;
    }
}

