/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.core.ml.inference.results;

import java.io.IOException;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationFeatureImportance;
import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PredictionFieldType;

public class ClassificationInferenceResults
extends SingleValueInferenceResults {
    public static final String NAME = "classification";
    public static final String PREDICTION_SCORE = "prediction_score";
    private final String topNumClassesField;
    protected final String resultsField;
    private final String classificationLabel;
    private final Double predictionProbability;
    private final Double predictionScore;
    private final List<TopClassEntry> topClasses;
    private final List<ClassificationFeatureImportance> featureImportance;
    private final PredictionFieldType predictionFieldType;

    public ClassificationInferenceResults(double value, String classificationLabel, List<TopClassEntry> topClasses, List<ClassificationFeatureImportance> featureImportance, InferenceConfig config, Double predictionProbability, Double predictionScore) {
        this(value, classificationLabel, topClasses, featureImportance, (ClassificationConfig)config, predictionProbability, predictionScore);
    }

    private ClassificationInferenceResults(double value, String classificationLabel, List<TopClassEntry> topClasses, List<ClassificationFeatureImportance> featureImportance, ClassificationConfig classificationConfig, Double predictionProbability, Double predictionScore) {
        this(value, classificationLabel, topClasses, featureImportance, classificationConfig.getTopClassesResultsField(), classificationConfig.getResultsField(), classificationConfig.getPredictionFieldType(), classificationConfig.getNumTopFeatureImportanceValues(), predictionProbability, predictionScore);
    }

    public ClassificationInferenceResults(double value, String classificationLabel, List<TopClassEntry> topClasses, List<ClassificationFeatureImportance> featureImportance, String topNumClassesField, String resultsField, PredictionFieldType predictionFieldType, int numTopFeatureImportanceValues, Double predictionProbability, Double predictionScore) {
        super(value);
        this.classificationLabel = classificationLabel;
        this.topClasses = topClasses == null ? Collections.emptyList() : Collections.unmodifiableList(topClasses);
        this.topNumClassesField = topNumClassesField;
        this.resultsField = resultsField;
        this.predictionFieldType = predictionFieldType;
        this.predictionProbability = predictionProbability;
        this.predictionScore = predictionScore;
        this.featureImportance = ClassificationInferenceResults.takeTopFeatureImportances(featureImportance, numTopFeatureImportanceValues);
    }

    static List<ClassificationFeatureImportance> takeTopFeatureImportances(List<ClassificationFeatureImportance> featureImportances, int numTopFeatures) {
        if (featureImportances == null || featureImportances.isEmpty()) {
            return Collections.emptyList();
        }
        return featureImportances.stream().sorted((l, r) -> Double.compare(r.getTotalImportance(), l.getTotalImportance())).limit(numTopFeatures).collect(Collectors.toUnmodifiableList());
    }

    public ClassificationInferenceResults(StreamInput in) throws IOException {
        super(in);
        this.featureImportance = in.readList(ClassificationFeatureImportance::new);
        this.classificationLabel = in.readOptionalString();
        this.topClasses = Collections.unmodifiableList(in.readList(TopClassEntry::new));
        this.topNumClassesField = in.readString();
        this.resultsField = in.readString();
        this.predictionFieldType = (PredictionFieldType)in.readEnum(PredictionFieldType.class);
        this.predictionProbability = in.readOptionalDouble();
        this.predictionScore = in.readOptionalDouble();
    }

    public String getClassificationLabel() {
        return this.classificationLabel;
    }

    public List<TopClassEntry> getTopClasses() {
        return this.topClasses;
    }

    public PredictionFieldType getPredictionFieldType() {
        return this.predictionFieldType;
    }

    public List<ClassificationFeatureImportance> getFeatureImportance() {
        return this.featureImportance;
    }

    @Override
    public void writeTo(StreamOutput out) throws IOException {
        super.writeTo(out);
        out.writeList(this.featureImportance);
        out.writeOptionalString(this.classificationLabel);
        out.writeCollection(this.topClasses);
        out.writeString(this.topNumClassesField);
        out.writeString(this.resultsField);
        out.writeEnum((Enum)this.predictionFieldType);
        out.writeOptionalDouble(this.predictionProbability);
        out.writeOptionalDouble(this.predictionScore);
    }

    public boolean equals(Object object) {
        if (object == this) {
            return true;
        }
        if (object == null || this.getClass() != object.getClass()) {
            return false;
        }
        ClassificationInferenceResults that = (ClassificationInferenceResults)object;
        return Objects.equals(this.value(), that.value()) && Objects.equals(this.classificationLabel, that.classificationLabel) && Objects.equals(this.resultsField, that.resultsField) && Objects.equals(this.topNumClassesField, that.topNumClassesField) && Objects.equals(this.topClasses, that.topClasses) && Objects.equals((Object)this.predictionFieldType, (Object)that.predictionFieldType) && Objects.equals(this.predictionProbability, that.predictionProbability) && Objects.equals(this.predictionScore, that.predictionScore) && Objects.equals(this.featureImportance, that.featureImportance);
    }

    public int hashCode() {
        return Objects.hash(new Object[]{this.value(), this.classificationLabel, this.topClasses, this.resultsField, this.topNumClassesField, this.predictionProbability, this.predictionScore, this.featureImportance, this.predictionFieldType});
    }

    @Override
    public String valueAsString() {
        return this.classificationLabel == null ? super.valueAsString() : this.classificationLabel;
    }

    @Override
    public Object predictedValue() {
        return this.predictionFieldType.transformPredictedValue(this.value(), this.valueAsString());
    }

    public Double getPredictionProbability() {
        return this.predictionProbability;
    }

    public Double getPredictionScore() {
        return this.predictionScore;
    }

    @Override
    public String getResultsField() {
        return this.resultsField;
    }

    @Override
    public Map<String, Object> asMap() {
        LinkedHashMap<String, Object> map = new LinkedHashMap<String, Object>();
        map.put(this.resultsField, this.predictionFieldType.transformPredictedValue(this.value(), this.valueAsString()));
        if (!this.topClasses.isEmpty()) {
            map.put(this.topNumClassesField, this.topClasses.stream().map(TopClassEntry::asValueMap).collect(Collectors.toList()));
        }
        if (this.predictionProbability != null) {
            map.put("prediction_probability", this.predictionProbability);
        }
        if (this.predictionScore != null) {
            map.put(PREDICTION_SCORE, this.predictionScore);
        }
        if (!this.featureImportance.isEmpty()) {
            map.put("feature_importance", this.featureImportance.stream().map(ClassificationFeatureImportance::toMap).collect(Collectors.toList()));
        }
        return map;
    }

    public String getWriteableName() {
        return NAME;
    }

    public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
        builder.field(this.resultsField, this.predictionFieldType.transformPredictedValue(this.value(), this.valueAsString()));
        if (this.topClasses.size() > 0) {
            builder.field(this.topNumClassesField, this.topClasses);
        }
        if (this.predictionProbability != null) {
            builder.field("prediction_probability", this.predictionProbability);
        }
        if (this.predictionScore != null) {
            builder.field(PREDICTION_SCORE, this.predictionScore);
        }
        if (!this.featureImportance.isEmpty()) {
            builder.field("feature_importance", this.featureImportance);
        }
        return builder;
    }
}

