/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.packageloader.action;

import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.io.UncheckedIOException;
import java.net.HttpURLConnection;
import java.net.URI;
import java.net.URISyntaxException;
import java.nio.file.Files;
import java.nio.file.OpenOption;
import java.security.AccessController;
import java.security.MessageDigest;
import java.security.Permission;
import java.security.PrivilegedAction;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.stream.Collectors;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.SpecialPermission;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.hash.MessageDigests;
import org.elasticsearch.common.io.Streams;
import org.elasticsearch.common.unit.ByteSizeUnit;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.SuppressForbidden;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;

final class ModelLoaderUtils {
    public static String METADATA_FILE_EXTENSION = ".metadata.json";
    public static String MODEL_FILE_EXTENSION = ".pt";
    private static ByteSizeValue VOCABULARY_SIZE_LIMIT = new ByteSizeValue(20L, ByteSizeUnit.MB);
    private static final String VOCABULARY = "vocabulary";
    private static final String MERGES = "merges";
    private static final String SCORES = "scores";

    static InputStream getInputStreamFromModelRepository(URI uri) {
        String scheme;
        switch (scheme = uri.getScheme().toLowerCase(Locale.ROOT)) {
            case "http": 
            case "https": {
                return ModelLoaderUtils.getHttpOrHttpsInputStream(uri, null);
            }
            case "file": {
                return ModelLoaderUtils.getFileInputStream(uri);
            }
        }
        throw new IllegalArgumentException("unsupported scheme");
    }

    static boolean uriIsFile(URI uri) {
        String scheme = uri.getScheme().toLowerCase(Locale.ROOT);
        return "file".equals(scheme);
    }

    static VocabularyParts loadVocabulary(URI uri) {
        if (uri.getPath().endsWith(".json")) {
            VocabularyParts vocabularyParts;
            block9: {
                InputStream vocabInputStream = ModelLoaderUtils.getInputStreamFromModelRepository(uri);
                try {
                    vocabularyParts = ModelLoaderUtils.parseVocabParts(vocabInputStream);
                    if (vocabInputStream == null) break block9;
                }
                catch (Throwable throwable) {
                    try {
                        if (vocabInputStream != null) {
                            try {
                                vocabInputStream.close();
                            }
                            catch (Throwable throwable2) {
                                throwable.addSuppressed(throwable2);
                            }
                        }
                        throw throwable;
                    }
                    catch (Exception e) {
                        throw new ElasticsearchException("Failed to load vocabulary file", (Throwable)e, new Object[0]);
                    }
                }
                vocabInputStream.close();
            }
            return vocabularyParts;
        }
        throw new IllegalArgumentException("unknown format vocabulary file format");
    }

    static VocabularyParts parseVocabParts(InputStream vocabInputStream) throws IOException {
        Map vocabParts;
        try (XContentParser sourceParser = XContentType.JSON.xContent().createParser(XContentParserConfiguration.EMPTY, Streams.limitStream((InputStream)vocabInputStream, (long)VOCABULARY_SIZE_LIMIT.getBytes()));){
            vocabParts = sourceParser.map(HashMap::new, XContentParser::list);
        }
        List<String> vocabulary = vocabParts.containsKey(VOCABULARY) ? ((List)vocabParts.get(VOCABULARY)).stream().map(Object::toString).collect(Collectors.toList()) : List.of();
        List<String> merges = vocabParts.containsKey(MERGES) ? ((List)vocabParts.get(MERGES)).stream().map(Object::toString).collect(Collectors.toList()) : List.of();
        List<Double> scores = vocabParts.containsKey(SCORES) ? ((List)vocabParts.get(SCORES)).stream().map(o -> (Double)o).collect(Collectors.toList()) : List.of();
        return new VocabularyParts(vocabulary, merges, scores);
    }

    static URI resolvePackageLocation(String repository, String artefact) throws URISyntaxException {
        URI baseUri = new URI((String)(repository.endsWith("/") ? repository : repository + "/")).normalize();
        URI resolvedUri = baseUri.resolve(artefact).normalize();
        if (Strings.isNullOrEmpty((String)baseUri.getScheme())) {
            throw new IllegalArgumentException("Repository must contain a scheme");
        }
        if (!baseUri.getScheme().equals(resolvedUri.getScheme())) {
            throw new IllegalArgumentException("Illegal schema change in package location");
        }
        if (!resolvedUri.getPath().startsWith(baseUri.getPath())) {
            throw new IllegalArgumentException("Illegal path in package location");
        }
        return baseUri.resolve(artefact);
    }

    private ModelLoaderUtils() {
    }

    @SuppressForbidden(reason="we need socket connection to download")
    private static InputStream getHttpOrHttpsInputStream(URI uri, @Nullable RequestRange range) {
        assert (uri.getUserInfo() == null) : "URI's with credentials are not supported";
        SecurityManager sm = System.getSecurityManager();
        if (sm != null) {
            sm.checkPermission((Permission)new SpecialPermission());
        }
        PrivilegedAction<InputStream> privilegedHttpReader = () -> {
            try {
                HttpURLConnection conn = (HttpURLConnection)uri.toURL().openConnection();
                if (range != null) {
                    conn.setRequestProperty("Range", range.bytesRange());
                }
                switch (conn.getResponseCode()) {
                    case 200: 
                    case 206: {
                        return conn.getInputStream();
                    }
                    case 301: 
                    case 302: 
                    case 303: {
                        throw new IllegalStateException("redirects aren't supported yet");
                    }
                    case 404: {
                        throw new ResourceNotFoundException("{} not found", new Object[]{uri});
                    }
                    case 416: {
                        throw new IllegalStateException("Invalid request range [" + range.bytesRange() + "]");
                    }
                }
                int responseCode = conn.getResponseCode();
                throw new ElasticsearchStatusException("error during downloading {}. Got response code {}", RestStatus.fromCode((int)responseCode), new Object[]{uri, responseCode});
            }
            catch (IOException e) {
                throw new UncheckedIOException(e);
            }
        };
        return AccessController.doPrivileged(privilegedHttpReader);
    }

    @SuppressForbidden(reason="we need load model data from a file")
    static InputStream getFileInputStream(URI uri) {
        SecurityManager sm = System.getSecurityManager();
        if (sm != null) {
            sm.checkPermission((Permission)new SpecialPermission());
        }
        PrivilegedAction<InputStream> privilegedFileReader = () -> {
            File file = new File(uri);
            if (!file.exists()) {
                throw new ResourceNotFoundException("{} not found", new Object[]{uri});
            }
            try {
                return Files.newInputStream(file.toPath(), new OpenOption[0]);
            }
            catch (IOException e) {
                throw new UncheckedIOException(e);
            }
        };
        return AccessController.doPrivileged(privilegedFileReader);
    }

    static List<RequestRange> split(long sizeInBytes, int numberOfStreams, long chunkSizeBytes) {
        int numberOfChunks = (int)((sizeInBytes + chunkSizeBytes - 1L) / chunkSizeBytes);
        if (numberOfStreams > numberOfChunks) {
            numberOfStreams = numberOfChunks;
        }
        ArrayList<RequestRange> ranges = new ArrayList<RequestRange>();
        int baseChunksPerStream = numberOfChunks / numberOfStreams;
        int remainder = numberOfChunks % numberOfStreams;
        long startOffset = 0L;
        int startChunkIndex = 0;
        for (int i = 0; i < numberOfStreams - 1; ++i) {
            int numChunksInStream = i < remainder ? baseChunksPerStream + 1 : baseChunksPerStream;
            long rangeEnd = startOffset + (long)numChunksInStream * chunkSizeBytes - 1L;
            ranges.add(new RequestRange(startOffset, rangeEnd, startChunkIndex, numChunksInStream));
            startOffset = rangeEnd + 1L;
            startChunkIndex += numChunksInStream;
        }
        if (baseChunksPerStream > 1) {
            int numChunksExcludingFinal = baseChunksPerStream - 1;
            long rangeEnd = startOffset + (long)numChunksExcludingFinal * chunkSizeBytes - 1L;
            ranges.add(new RequestRange(startOffset, rangeEnd, startChunkIndex, numChunksExcludingFinal));
            startOffset = rangeEnd + 1L;
            startChunkIndex += numChunksExcludingFinal;
        }
        long rangeEnd = Math.min(sizeInBytes, startOffset + (long)baseChunksPerStream * chunkSizeBytes) - 1L;
        ranges.add(new RequestRange(startOffset, rangeEnd, startChunkIndex, 1));
        return ranges;
    }

    record RequestRange(long rangeStart, long rangeEnd, int startPart, int numParts) {
        public String bytesRange() {
            return "bytes=" + this.rangeStart + "-" + this.rangeEnd;
        }
    }

    record VocabularyParts(List<String> vocab, List<String> merges, List<Double> scores) {
    }

    static class InputStreamChunker {
        private final InputStream inputStream;
        private final MessageDigest digestSha256 = MessageDigests.sha256();
        private final int chunkSize;
        private int totalBytesRead = 0;

        InputStreamChunker(InputStream inputStream, int chunkSize) {
            this.inputStream = inputStream;
            this.chunkSize = chunkSize;
        }

        public BytesArray next() throws IOException {
            int bytesRead;
            int read;
            byte[] buf = new byte[this.chunkSize];
            for (bytesRead = 0; bytesRead < this.chunkSize && (read = this.inputStream.read(buf, bytesRead, this.chunkSize - bytesRead)) != -1; bytesRead += read) {
            }
            this.digestSha256.update(buf, 0, bytesRead);
            this.totalBytesRead += bytesRead;
            return new BytesArray(buf, 0, bytesRead);
        }

        public String getSha256() {
            return MessageDigests.toHexString((byte[])this.digestSha256.digest());
        }

        public int getTotalBytesRead() {
            return this.totalBytesRead;
        }
    }

    static class HttpStreamChunker {
        private final InputStream inputStream;
        private final int chunkSize;
        private final AtomicLong totalBytesRead = new AtomicLong();
        private final AtomicInteger currentPart;
        private final int lastPartNumber;
        private final byte[] buf;

        HttpStreamChunker(URI uri, RequestRange range, int chunkSize) {
            InputStream inputStream;
            this.inputStream = inputStream = ModelLoaderUtils.getHttpOrHttpsInputStream(uri, range);
            this.chunkSize = chunkSize;
            this.lastPartNumber = range.startPart() + range.numParts();
            this.currentPart = new AtomicInteger(range.startPart());
            this.buf = new byte[chunkSize];
        }

        HttpStreamChunker(InputStream inputStream, RequestRange range, int chunkSize) {
            this.inputStream = inputStream;
            this.chunkSize = chunkSize;
            this.lastPartNumber = range.startPart() + range.numParts();
            this.currentPart = new AtomicInteger(range.startPart());
            this.buf = new byte[chunkSize];
        }

        public boolean hasNext() {
            return this.currentPart.get() < this.lastPartNumber;
        }

        public BytesAndPartIndex next() throws IOException {
            int bytesRead;
            int read;
            for (bytesRead = 0; bytesRead < this.chunkSize && (read = this.inputStream.read(this.buf, bytesRead, this.chunkSize - bytesRead)) != -1; bytesRead += read) {
            }
            if (bytesRead > 0) {
                this.totalBytesRead.addAndGet(bytesRead);
                return new BytesAndPartIndex(new BytesArray(this.buf, 0, bytesRead), this.currentPart.getAndIncrement());
            }
            return new BytesAndPartIndex(BytesArray.EMPTY, this.currentPart.get());
        }

        public long getTotalBytesRead() {
            return this.totalBytesRead.get();
        }

        public int getCurrentPart() {
            return this.currentPart.get();
        }

        record BytesAndPartIndex(BytesArray bytes, int partIndex) {
        }
    }
}

