/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.inference.nlp.tokenizers;

import java.io.IOException;
import java.util.List;
import java.util.stream.IntStream;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization;
import org.elasticsearch.xpack.ml.inference.nlp.NlpTask;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizationResult;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult;

public class MPNetTokenizationResult
extends TokenizationResult {
    static final String REQUEST_ID = "request_id";
    static final String TOKENS = "tokens";
    static final String ARG1 = "arg_1";

    public MPNetTokenizationResult(List<String> vocab, List<TokenizationResult.Tokens> tokenizations, int padTokenId) {
        super(vocab, tokenizations, padTokenId);
    }

    @Override
    public NlpTask.Request buildRequest(String requestId, Tokenization.Truncate t) throws IOException {
        XContentBuilder builder = XContentFactory.jsonBuilder();
        builder.startObject();
        builder.field(REQUEST_ID, requestId);
        this.writePaddedTokens(TOKENS, builder);
        this.writeAttentionMask(ARG1, builder);
        builder.endObject();
        BytesReference jsonRequest = BytesReference.bytes((XContentBuilder)builder);
        return new NlpTask.Request(this, jsonRequest);
    }

    static class MPNetTokensBuilder
    extends BertTokenizationResult.BertTokensBuilder {
        MPNetTokensBuilder(boolean withSpecialTokens, int clsTokenId, int sepTokenId) {
            super(withSpecialTokens, clsTokenId, sepTokenId);
        }

        @Override
        public TokenizationResult.TokensBuilder addSequencePair(List<Integer> tokenId1s, List<Integer> tokenMap1, List<Integer> tokenId2s, List<Integer> tokenMap2) {
            if (this.withSpecialTokens) {
                this.tokenIds.add(IntStream.of(this.clsTokenId));
                this.tokenMap.add(IntStream.of(-1));
            }
            this.tokenIds.add(tokenId1s.stream().mapToInt(Integer::valueOf));
            this.tokenMap.add(tokenMap1.stream().mapToInt(Integer::valueOf));
            int previouslyFinalMap = tokenMap1.get(tokenMap1.size() - 1);
            if (this.withSpecialTokens) {
                this.tokenIds.add(IntStream.of(this.sepTokenId, this.sepTokenId));
                this.tokenMap.add(IntStream.of(-1, -1));
            }
            this.seqPairOffset = this.withSpecialTokens ? tokenId1s.size() + 3 : tokenId1s.size();
            this.tokenIds.add(tokenId2s.stream().mapToInt(Integer::valueOf));
            this.tokenMap.add(tokenMap2.stream().mapToInt(i -> i + previouslyFinalMap));
            if (this.withSpecialTokens) {
                this.tokenIds.add(IntStream.of(this.sepTokenId));
                this.tokenMap.add(IntStream.of(-1));
            }
            return this;
        }
    }
}

