/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.inference.nlp;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Optional;
import org.elasticsearch.xpack.core.ml.inference.results.FillMaskResults;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.FillMaskConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
import org.elasticsearch.xpack.ml.inference.deployment.PyTorchResult;
import org.elasticsearch.xpack.ml.inference.nlp.NlpHelpers;
import org.elasticsearch.xpack.ml.inference.nlp.NlpTask;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult;

public class FillMaskProcessor
implements NlpTask.Processor {
    private final NlpTask.RequestBuilder requestBuilder;

    FillMaskProcessor(NlpTokenizer tokenizer, FillMaskConfig config) {
        this.requestBuilder = tokenizer.requestBuilder();
    }

    @Override
    public void validateInputs(List<String> inputs) {
        if (inputs.isEmpty()) {
            throw new IllegalArgumentException("input request is empty");
        }
        for (String input : inputs) {
            int maskIndex = input.indexOf("[MASK]");
            if (maskIndex < 0) {
                throw new IllegalArgumentException("no [MASK] token could be found");
            }
            if ((maskIndex = input.indexOf("[MASK]", maskIndex + "[MASK]".length())) <= 0) continue;
            throw new IllegalArgumentException("only one [MASK] token should exist in the input");
        }
    }

    @Override
    public NlpTask.RequestBuilder getRequestBuilder(NlpConfig config) {
        return this.requestBuilder;
    }

    @Override
    public NlpTask.ResultProcessor getResultProcessor(NlpConfig config) {
        if (config instanceof FillMaskConfig) {
            FillMaskConfig fillMaskConfig = (FillMaskConfig)config;
            return (tokenization, result) -> FillMaskProcessor.processResult(tokenization, result, fillMaskConfig.getNumTopClasses(), fillMaskConfig.getResultsField());
        }
        return (tokenization, result) -> FillMaskProcessor.processResult(tokenization, result, 5, "predicted_value");
    }

    static InferenceResults processResult(TokenizationResult tokenization, PyTorchResult pyTorchResult, int numResults, String resultsField) {
        if (tokenization.getTokenizations().isEmpty() || tokenization.getTokenizations().get(0).getTokens().length == 0) {
            return new WarningInferenceResults("No valid tokens for inference");
        }
        int maskTokenIndex = Arrays.asList(tokenization.getTokenizations().get(0).getTokens()).indexOf("[MASK]");
        double[] normalizedScores = NlpHelpers.convertToProbabilitiesBySoftMax(pyTorchResult.getInferenceResult()[0][maskTokenIndex]);
        NlpHelpers.ScoreAndIndex[] scoreAndIndices = NlpHelpers.topK(numResults == -1 ? Integer.MAX_VALUE : Math.max(numResults, 1), normalizedScores);
        ArrayList<TopClassEntry> results = new ArrayList<TopClassEntry>(scoreAndIndices.length);
        if (numResults != 0) {
            for (NlpHelpers.ScoreAndIndex scoreAndIndex : scoreAndIndices) {
                String predictedToken = tokenization.getFromVocab(scoreAndIndex.index);
                results.add(new TopClassEntry((Object)predictedToken, scoreAndIndex.score, scoreAndIndex.score));
            }
        }
        return new FillMaskResults(tokenization.getFromVocab(scoreAndIndices[0].index), tokenization.getTokenizations().get(0).getInput().replace("[MASK]", tokenization.getFromVocab(scoreAndIndices[0].index)), results, Optional.ofNullable(resultsField).orElse("predicted_value"), Double.valueOf(scoreAndIndices[0].score), tokenization.anyTruncated());
    }
}

