package org.jpmml.evaluator.neural_network;

import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.BiMap;
import com.google.common.collect.ImmutableBiMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DerivedField;
import org.dmg.pmml.Entity;
import org.dmg.pmml.Expression;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.FieldRef;
import org.dmg.pmml.HasField;
import org.dmg.pmml.MathContext;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.NormContinuous;
import org.dmg.pmml.NormDiscrete;
import org.dmg.pmml.PMML;
import org.dmg.pmml.TypeDefinitionField;
import org.dmg.pmml.neural_network.NeuralInput;
import org.dmg.pmml.neural_network.NeuralInputs;
import org.dmg.pmml.neural_network.NeuralLayer;
import org.dmg.pmml.neural_network.NeuralNetwork;
import org.dmg.pmml.neural_network.NeuralOutput;
import org.dmg.pmml.neural_network.NeuralOutputs;
import org.dmg.pmml.neural_network.Neuron;
import org.jpmml.evaluator.CacheUtil;
import org.jpmml.evaluator.Classification;
import org.jpmml.evaluator.EntityUtil;
import org.jpmml.evaluator.EvaluationContext;
import org.jpmml.evaluator.ExpressionUtil;
import org.jpmml.evaluator.HasEntityRegistry;
import org.jpmml.evaluator.InvalidAttributeException;
import org.jpmml.evaluator.InvalidElementException;
import org.jpmml.evaluator.InvalidElementListException;
import org.jpmml.evaluator.MisplacedElementException;
import org.jpmml.evaluator.MissingAttributeException;
import org.jpmml.evaluator.MissingElementException;
import org.jpmml.evaluator.MissingFieldException;
import org.jpmml.evaluator.ModelEvaluationContext;
import org.jpmml.evaluator.ModelEvaluator;
import org.jpmml.evaluator.NormalizationUtil;
import org.jpmml.evaluator.OutputUtil;
import org.jpmml.evaluator.PMMLAttributes;
import org.jpmml.evaluator.PMMLElements;
import org.jpmml.evaluator.TargetField;
import org.jpmml.evaluator.TargetUtil;
import org.jpmml.evaluator.UnsupportedAttributeException;
import org.jpmml.evaluator.Value;
import org.jpmml.evaluator.ValueFactory;
import org.jpmml.evaluator.ValueMap;
import org.jpmml.evaluator.XPathUtil;

/* loaded from: classes7.dex */
public class NeuralNetworkEvaluator extends ModelEvaluator<NeuralNetwork> implements HasEntityRegistry<Entity> {
    private static final LoadingCache<NeuralNetwork, BiMap<String, Entity>> entityCache = CacheUtil.buildLoadingCache(new CacheLoader<NeuralNetwork, BiMap<String, Entity>>() { // from class: org.jpmml.evaluator.neural_network.NeuralNetworkEvaluator.1
        @Override // com.google.common.cache.CacheLoader
        public BiMap<String, Entity> load(NeuralNetwork neuralNetwork) {
            ImmutableBiMap.Builder builder = new ImmutableBiMap.Builder();
            AtomicInteger atomicInteger = new AtomicInteger(1);
            Iterator<NeuralInput> it2 = neuralNetwork.getNeuralInputs().iterator();
            while (it2.hasNext()) {
                builder = EntityUtil.put(it2.next(), atomicInteger, builder);
            }
            Iterator<NeuralLayer> it3 = neuralNetwork.getNeuralLayers().iterator();
            while (it3.hasNext()) {
                List<Neuron> neurons = it3.next().getNeurons();
                for (int i = 0; i < neurons.size(); i++) {
                    builder = EntityUtil.put(neurons.get(i), atomicInteger, builder);
                }
            }
            return builder.build();
        }
    });
    private transient BiMap<String, Entity> entityRegistry;
    private transient Map<FieldName, List<NeuralOutput>> neuralOutputMap;

    public NeuralNetworkEvaluator(PMML pmml) {
        this(pmml, (NeuralNetwork) selectModel(pmml, NeuralNetwork.class));
    }

    public NeuralNetworkEvaluator(PMML pmml, NeuralNetwork neuralNetwork) {
        super(pmml, neuralNetwork);
        NeuralInputs neuralInputs = neuralNetwork.getNeuralInputs();
        if (neuralInputs == null) {
            throw new MissingElementException(neuralNetwork, PMMLElements.NEURALNETWORK_NEURALINPUTS);
        }
        if (!neuralInputs.hasNeuralInputs()) {
            throw new MissingElementException(neuralInputs, PMMLElements.NEURALINPUTS_NEURALINPUTS);
        }
        if (!neuralNetwork.hasNeuralLayers()) {
            throw new MissingElementException(neuralNetwork, PMMLElements.NEURALNETWORK_NEURALLAYERS);
        }
        NeuralOutputs neuralOutputs = neuralNetwork.getNeuralOutputs();
        if (neuralOutputs == null) {
            throw new MissingElementException(neuralNetwork, PMMLElements.NEURALNETWORK_NEURALOUTPUTS);
        }
        if (!neuralOutputs.hasNeuralOutputs()) {
            throw new MissingElementException(neuralOutputs, PMMLElements.NEURALOUTPUTS_NEURALOUTPUTS);
        }
    }

    private <V extends Number> Map<FieldName, ? extends Classification<V>> evaluateClassification(ValueFactory<V> valueFactory, EvaluationContext evaluationContext) {
        NeuralNetwork model = getModel();
        List<TargetField> targetFields = getTargetFields();
        ValueMap<String, V> evaluateRaw = evaluateRaw(valueFactory, evaluationContext);
        if (evaluateRaw == null) {
            if (targetFields.size() == 1) {
                return TargetUtil.evaluateClassificationDefault(valueFactory, targetFields.get(0));
            }
            LinkedHashMap linkedHashMap = new LinkedHashMap();
            Iterator<TargetField> it2 = targetFields.iterator();
            while (it2.hasNext()) {
                linkedHashMap.putAll(TargetUtil.evaluateClassificationDefault(valueFactory, it2.next()));
            }
            return linkedHashMap;
        }
        Map<FieldName, List<NeuralOutput>> neuralOutputMap = getNeuralOutputMap();
        BiMap<String, Entity> entityRegistry = getEntityRegistry();
        LinkedHashMap linkedHashMap2 = null;
        for (TargetField targetField : targetFields) {
            List<NeuralOutput> list = neuralOutputMap.get(targetField.getName());
            if (list == null) {
                throw new InvalidElementException(model);
            }
            NeuronProbabilityDistribution neuronProbabilityDistribution = new NeuronProbabilityDistribution(new ValueMap(list.size() * 2), entityRegistry);
            for (NeuralOutput neuralOutput : list) {
                String outputNeuron = neuralOutput.getOutputNeuron();
                if (outputNeuron == null) {
                    throw new MissingAttributeException(neuralOutput, PMMLAttributes.NEURALOUTPUT_OUTPUTNEURON);
                }
                Entity entity = entityRegistry.get(outputNeuron);
                if (entity == null) {
                    throw new InvalidAttributeException(neuralOutput, PMMLAttributes.NEURALOUTPUT_OUTPUTNEURON, outputNeuron);
                }
                Value<V> value = (Value) evaluateRaw.get(outputNeuron);
                if (value == null) {
                    throw new InvalidAttributeException(neuralOutput, PMMLAttributes.NEURALOUTPUT_OUTPUTNEURON, outputNeuron);
                }
                Expression outputExpression = getOutputExpression(neuralOutput);
                if (!(outputExpression instanceof NormDiscrete)) {
                    throw new MisplacedElementException(outputExpression);
                }
                NormDiscrete normDiscrete = (NormDiscrete) outputExpression;
                String value2 = normDiscrete.getValue();
                if (value2 == null) {
                    throw new MissingAttributeException(normDiscrete, PMMLAttributes.NORMDISCRETE_VALUE);
                }
                neuronProbabilityDistribution.put(entity, value2, value);
            }
            if (targetFields.size() == 1) {
                return TargetUtil.evaluateClassification(targetField, neuronProbabilityDistribution);
            }
            if (linkedHashMap2 == null) {
                linkedHashMap2 = new LinkedHashMap();
            }
            linkedHashMap2.putAll(TargetUtil.evaluateClassification(targetField, neuronProbabilityDistribution));
        }
        return linkedHashMap2;
    }

    /* JADX WARN: Failed to find 'out' block for switch in B:27:0x009a. Please report as an issue. */
    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Removed duplicated region for block: B:35:0x00b8  */
    /*
        Code decompiled incorrectly, please refer to instructions dump.
        To view partially-correct add '--show-bad-code' argument
    */
    private <V extends java.lang.Number> org.jpmml.evaluator.ValueMap<java.lang.String, V> evaluateRaw(org.jpmml.evaluator.ValueFactory<V> r20, org.jpmml.evaluator.EvaluationContext r21) {
        /*
            Method dump skipped, instructions count: 444
            To view this dump add '--comments-level debug' option
        */
        throw new UnsupportedOperationException("Method not decompiled: org.jpmml.evaluator.neural_network.NeuralNetworkEvaluator.evaluateRaw(org.jpmml.evaluator.ValueFactory, org.jpmml.evaluator.EvaluationContext):org.jpmml.evaluator.ValueMap");
    }

    private <V extends Number> Map<FieldName, ?> evaluateRegression(ValueFactory<V> valueFactory, EvaluationContext evaluationContext) {
        NeuralNetwork model = getModel();
        List<TargetField> targetFields = getTargetFields();
        ValueMap<String, V> evaluateRaw = evaluateRaw(valueFactory, evaluationContext);
        if (evaluateRaw == null) {
            if (targetFields.size() == 1) {
                return TargetUtil.evaluateRegressionDefault(valueFactory, targetFields.get(0));
            }
            LinkedHashMap linkedHashMap = new LinkedHashMap();
            Iterator<TargetField> it2 = targetFields.iterator();
            while (it2.hasNext()) {
                linkedHashMap.putAll(TargetUtil.evaluateRegressionDefault(valueFactory, it2.next()));
            }
            return linkedHashMap;
        }
        Map<FieldName, List<NeuralOutput>> neuralOutputMap = getNeuralOutputMap();
        LinkedHashMap linkedHashMap2 = null;
        for (TargetField targetField : targetFields) {
            List<NeuralOutput> list = neuralOutputMap.get(targetField.getName());
            if (list == null) {
                throw new InvalidElementException(model);
            }
            if (list.size() != 1) {
                throw new InvalidElementListException(list);
            }
            NeuralOutput neuralOutput = list.get(0);
            String outputNeuron = neuralOutput.getOutputNeuron();
            if (outputNeuron == null) {
                throw new MissingAttributeException(neuralOutput, PMMLAttributes.NEURALOUTPUT_OUTPUTNEURON);
            }
            Value value = evaluateRaw.get(outputNeuron);
            if (value == null) {
                throw new InvalidAttributeException(neuralOutput, PMMLAttributes.NEURALOUTPUT_OUTPUTNEURON, outputNeuron);
            }
            Value<V> copy2 = value.copy2();
            Expression outputExpression = getOutputExpression(neuralOutput);
            if (!(outputExpression instanceof FieldRef)) {
                if (!(outputExpression instanceof NormContinuous)) {
                    throw new MisplacedElementException(outputExpression);
                }
                NormalizationUtil.denormalize((NormContinuous) outputExpression, copy2);
            }
            if (targetFields.size() == 1) {
                return TargetUtil.evaluateRegression(targetField, copy2);
            }
            if (linkedHashMap2 == null) {
                linkedHashMap2 = new LinkedHashMap();
            }
            linkedHashMap2.putAll(TargetUtil.evaluateRegression(targetField, copy2));
        }
        return linkedHashMap2;
    }

    private Map<FieldName, List<NeuralOutput>> getNeuralOutputMap() {
        if (this.neuralOutputMap == null) {
            this.neuralOutputMap = parseNeuralOutputs();
        }
        return this.neuralOutputMap;
    }

    private Expression getOutputExpression(NeuralOutput neuralOutput) {
        DerivedField derivedField = neuralOutput.getDerivedField();
        if (derivedField == null) {
            throw new MissingElementException(neuralOutput, PMMLElements.NEURALOUTPUT_DERIVEDFIELD);
        }
        Expression ensureExpression = ExpressionUtil.ensureExpression(derivedField);
        if (!(ensureExpression instanceof FieldRef)) {
            return ensureExpression;
        }
        FieldRef fieldRef = (FieldRef) ensureExpression;
        FieldName field = fieldRef.getField();
        if (field == null) {
            throw new MissingAttributeException(fieldRef, PMMLAttributes.FIELDREF_FIELD);
        }
        TypeDefinitionField resolveField = resolveField(field);
        if (resolveField == null) {
            throw new MissingFieldException(field, fieldRef);
        }
        if (resolveField instanceof DataField) {
            return ensureExpression;
        }
        if (resolveField instanceof DerivedField) {
            return ExpressionUtil.ensureExpression((DerivedField) resolveField);
        }
        throw new InvalidAttributeException(fieldRef, PMMLAttributes.FIELDREF_FIELD, field);
    }

    /* JADX WARN: Multi-variable type inference failed */
    private Map<FieldName, List<NeuralOutput>> parseNeuralOutputs() {
        NeuralNetwork model = getModel();
        ArrayListMultimap create = ArrayListMultimap.create();
        Iterator<NeuralOutput> it2 = model.getNeuralOutputs().iterator();
        while (it2.hasNext()) {
            NeuralOutput next = it2.next();
            Expression outputExpression = getOutputExpression(next);
            if (!(outputExpression instanceof HasField)) {
                throw new MisplacedElementException(outputExpression);
            }
            HasField hasField = (HasField) outputExpression;
            FieldName field = hasField.getField();
            if (field == null) {
                throw new MissingAttributeException(MissingAttributeException.formatMessage(XPathUtil.formatElement(hasField.getClass()) + "@field"), outputExpression);
            }
            create.put(field, next);
        }
        return create.asMap();
    }

    @Override // org.jpmml.evaluator.ModelEvaluator
    public Map<FieldName, ?> evaluate(ModelEvaluationContext modelEvaluationContext) {
        Map<FieldName, ?> evaluateRegression;
        NeuralNetwork ensureScorableModel = ensureScorableModel();
        MathContext mathContext = ensureScorableModel.getMathContext();
        switch (mathContext) {
            case FLOAT:
            case DOUBLE:
                ValueFactory<?> valueFactory = getValueFactory();
                MiningFunction miningFunction = ensureScorableModel.getMiningFunction();
                switch (miningFunction) {
                    case REGRESSION:
                        evaluateRegression = evaluateRegression(valueFactory, modelEvaluationContext);
                        break;
                    case CLASSIFICATION:
                        evaluateRegression = evaluateClassification(valueFactory, modelEvaluationContext);
                        break;
                    case ASSOCIATION_RULES:
                    case SEQUENCES:
                    case CLUSTERING:
                    case TIME_SERIES:
                    case MIXED:
                        throw new InvalidAttributeException(ensureScorableModel, miningFunction);
                    default:
                        throw new UnsupportedAttributeException(ensureScorableModel, miningFunction);
                }
                return OutputUtil.evaluate(evaluateRegression, modelEvaluationContext);
            default:
                throw new UnsupportedAttributeException(ensureScorableModel, mathContext);
        }
    }

    @Override // org.jpmml.evaluator.HasEntityRegistry
    public BiMap<String, Entity> getEntityRegistry() {
        if (this.entityRegistry == null) {
            this.entityRegistry = (BiMap) getValue(entityCache);
        }
        return this.entityRegistry;
    }

    @Override // org.jpmml.evaluator.Evaluator
    public String getSummary() {
        return "Neural network";
    }
}
