package com.hankcs.hanlp.dependency.perceptron.transition.trainer;

import com.hankcs.hanlp.corpus.io.IOUtil;
import com.hankcs.hanlp.dependency.perceptron.accessories.Edge;
import com.hankcs.hanlp.dependency.perceptron.accessories.Evaluator;
import com.hankcs.hanlp.dependency.perceptron.accessories.Options;
import com.hankcs.hanlp.dependency.perceptron.accessories.Pair;
import com.hankcs.hanlp.dependency.perceptron.learning.AveragedPerceptron;
import com.hankcs.hanlp.dependency.perceptron.structures.IndexMaps;
import com.hankcs.hanlp.dependency.perceptron.structures.ParserModel;
import com.hankcs.hanlp.dependency.perceptron.transition.configuration.BeamElement;
import com.hankcs.hanlp.dependency.perceptron.transition.configuration.Configuration;
import com.hankcs.hanlp.dependency.perceptron.transition.configuration.Instance;
import com.hankcs.hanlp.dependency.perceptron.transition.configuration.State;
import com.hankcs.hanlp.dependency.perceptron.transition.features.FeatureExtractor;
import com.hankcs.hanlp.dependency.perceptron.transition.parser.Action;
import com.hankcs.hanlp.dependency.perceptron.transition.parser.ArcEager;
import com.hankcs.hanlp.dependency.perceptron.transition.parser.KBeamArcEagerParser;
import com.hankcs.hanlp.dependency.perceptron.transition.parser.LabeledAction;
import com.hankcs.hanlp.dependency.perceptron.transition.parser.TransitionBasedParser;
import com.hankcs.hanlp.utility.MathUtility;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Random;
import java.util.TreeSet;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorCompletionService;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

/* loaded from: classes.dex */
public class ArcEagerBeamTrainer extends TransitionBasedParser {
    Options options;
    private Random randGen;
    private String updateMode;

    public ArcEagerBeamTrainer(String str, AveragedPerceptron averagedPerceptron, Options options, ArrayList<Integer> arrayList, int i, IndexMaps indexMaps) {
        super(averagedPerceptron, arrayList, i, indexMaps);
        this.updateMode = str;
        this.options = options;
        this.randGen = new Random();
    }

    private void addToBeam(TreeSet<BeamElement> treeSet, int i, float f, int i2, int i3, int i4) {
        treeSet.add(new BeamElement(f, i, i2, i3));
        if (treeSet.size() > i4) {
            treeSet.pollFirst();
        }
    }

    private void beamSortOneThread(ArrayList<Configuration> arrayList, TreeSet<BeamElement> treeSet) {
        for (int i = 0; i < arrayList.size(); i++) {
            Configuration configuration = arrayList.get(i);
            State state = configuration.state;
            float f = configuration.score;
            boolean canDo = ArcEager.canDo(Action.Shift, state);
            boolean canDo2 = ArcEager.canDo(Action.Reduce, state);
            boolean canDo3 = ArcEager.canDo(Action.RightArc, state);
            boolean canDo4 = ArcEager.canDo(Action.LeftArc, state);
            Object[] extractAllParseFeatures = FeatureExtractor.extractAllParseFeatures(configuration, this.featureLength);
            if (canDo) {
                addToBeam(treeSet, i, this.classifier.shiftScore(extractAllParseFeatures, false) + f, 0, -1, this.options.beamWidth);
            }
            if (canDo2) {
                addToBeam(treeSet, i, this.classifier.reduceScore(extractAllParseFeatures, false) + f, 1, -1, this.options.beamWidth);
            }
            if (canDo3) {
                float[] rightArcScores = this.classifier.rightArcScores(extractAllParseFeatures, false);
                Iterator<Integer> it = this.dependencyRelations.iterator();
                while (it.hasNext()) {
                    int intValue = it.next().intValue();
                    addToBeam(treeSet, i, rightArcScores[intValue] + f, 2, intValue, this.options.beamWidth);
                }
            }
            if (canDo4) {
                float[] leftArcScores = this.classifier.leftArcScores(extractAllParseFeatures, false);
                Iterator<Integer> it2 = this.dependencyRelations.iterator();
                while (it2.hasNext()) {
                    int intValue2 = it2.next().intValue();
                    addToBeam(treeSet, i, leftArcScores[intValue2] + f, 3, intValue2, this.options.beamWidth);
                }
            }
        }
    }

    private static boolean isTrueFeature(boolean z, Configuration configuration, int i) {
        if (!z || i < 3) {
            if (z && i == 0) {
                if (configuration.state.hasHead(configuration.state.bufferHead())) {
                    return true;
                }
            } else if (!z || i != 1 || configuration.state.hasHead(configuration.state.stackTop())) {
                return true;
            }
        } else if (configuration.state.hasHead(configuration.state.stackTop()) && configuration.state.hasHead(configuration.state.bufferHead())) {
            return true;
        }
        return false;
    }

    private Configuration staticOracle(Instance instance, Collection<Configuration> collection, Collection<Configuration> collection2) {
        HashMap<Integer, Edge> goldDependencies = instance.getGoldDependencies();
        HashMap<Integer, HashSet<Integer>> reversedDependencies = instance.getReversedDependencies();
        int i = -1;
        Configuration configuration = null;
        int i2 = -1;
        for (Configuration configuration2 : collection) {
            State state = configuration2.state;
            Object[] extractAllParseFeatures = FeatureExtractor.extractAllParseFeatures(configuration2, this.featureLength);
            if (!state.stackEmpty()) {
                i2 = state.stackTop();
            }
            if (!state.bufferEmpty()) {
                i = state.bufferHead();
            }
            if (configuration2.state.isTerminalState()) {
                collection2.add(configuration2);
            } else {
                configuration = configuration2.m14clone();
                if (i > 0 && goldDependencies.containsKey(Integer.valueOf(i)) && goldDependencies.get(Integer.valueOf(i)).headIndex == i2) {
                    int i3 = goldDependencies.get(Integer.valueOf(i)).relationId;
                    float f = this.classifier.rightArcScores(extractAllParseFeatures, false)[i3];
                    ArcEager.rightArc(configuration.state, i3);
                    configuration.addAction(i3 + 3);
                    configuration.addScore(f);
                } else if (i2 > 0 && goldDependencies.containsKey(Integer.valueOf(i2)) && goldDependencies.get(Integer.valueOf(i2)).headIndex == i) {
                    int i4 = goldDependencies.get(Integer.valueOf(i2)).relationId;
                    float f2 = this.classifier.leftArcScores(extractAllParseFeatures, false)[i4];
                    ArcEager.leftArc(configuration.state, i4);
                    configuration.addAction(this.dependencyRelations.size() + 3 + i4);
                    configuration.addScore(f2);
                } else if (i2 < 0 || !state.hasHead(i2)) {
                    if (state.bufferEmpty() && state.stackSize() == 1 && state.stackTop() == state.rootIndex) {
                        float reduceScore = this.classifier.reduceScore(extractAllParseFeatures, false);
                        ArcEager.reduce(configuration.state);
                        configuration.addAction(1);
                        configuration.addScore(reduceScore);
                    } else {
                        float shiftScore = this.classifier.shiftScore(extractAllParseFeatures, true);
                        ArcEager.shift(configuration.state);
                        configuration.addAction(0);
                        configuration.addScore(shiftScore);
                    }
                } else if (!reversedDependencies.containsKey(Integer.valueOf(i2))) {
                    float reduceScore2 = this.classifier.reduceScore(extractAllParseFeatures, false);
                    ArcEager.reduce(configuration.state);
                    configuration.addAction(1);
                    configuration.addScore(reduceScore2);
                } else if (reversedDependencies.get(Integer.valueOf(i2)).size() == state.valence(i2)) {
                    float reduceScore3 = this.classifier.reduceScore(extractAllParseFeatures, false);
                    ArcEager.reduce(configuration.state);
                    configuration.addAction(1);
                    configuration.addScore(reduceScore3);
                } else {
                    float shiftScore2 = this.classifier.shiftScore(extractAllParseFeatures, false);
                    ArcEager.shift(configuration.state);
                    configuration.addAction(0);
                    configuration.addScore(shiftScore2);
                }
                collection2.add(configuration);
            }
        }
        return configuration;
    }

    /* JADX WARN: Removed duplicated region for block: B:84:0x01c8  */
    /* JADX WARN: Removed duplicated region for block: B:87:0x01ce A[SYNTHETIC] */
    /*
        Code decompiled incorrectly, please refer to instructions dump.
        To view partially-correct add '--show-bad-code' argument
    */
    private void trainOnOneSample(com.hankcs.hanlp.dependency.perceptron.transition.configuration.Instance r23, int r24, int r25, int r26, java.util.concurrent.CompletionService<java.util.ArrayList<com.hankcs.hanlp.dependency.perceptron.transition.configuration.BeamElement>> r27) throws java.lang.InterruptedException, java.util.concurrent.ExecutionException {
        /*
            Method dump skipped, instructions count: 499
            To view this dump add '--comments-level debug' option
        */
        throw new UnsupportedOperationException("Method not decompiled: com.hankcs.hanlp.dependency.perceptron.transition.trainer.ArcEagerBeamTrainer.trainOnOneSample(com.hankcs.hanlp.dependency.perceptron.transition.configuration.Instance, int, int, int, java.util.concurrent.CompletionService):void");
    }

    /* JADX WARN: Multi-variable type inference failed */
    private void updateWeights(Configuration configuration, float f, boolean z, Configuration configuration2, Pair<Configuration, Configuration> pair, ArrayList<Configuration> arrayList) {
        Configuration configuration3;
        Configuration configuration4 = configuration2;
        int i = 0;
        if (this.updateMode.equals("max_violation")) {
            Pair<Configuration, Configuration> pair2 = arrayList.get(0).getScore(true) - configuration4.getScore(true) > f ? new Pair<>(arrayList.get(0), configuration4) : pair;
            configuration3 = pair2.first;
            configuration4 = pair2.second;
        } else {
            configuration3 = arrayList.get(0);
        }
        int i2 = this.featureLength;
        Object[] objArr = new Object[i2];
        Object[] objArr2 = new Object[this.featureLength];
        for (int i3 = 0; i3 < i2; i3++) {
            objArr2[i3] = new HashMap();
            objArr[i3] = new HashMap();
        }
        Configuration m14clone = configuration.m14clone();
        Configuration m14clone2 = configuration.m14clone();
        Iterator<Integer> it = configuration4.actionHistory.iterator();
        while (it.hasNext()) {
            int intValue = it.next().intValue();
            if (isTrueFeature(z, m14clone2, intValue)) {
                Object[] extractAllParseFeatures = FeatureExtractor.extractAllParseFeatures(m14clone2, this.featureLength);
                for (int i4 = i; i4 < extractAllParseFeatures.length; i4++) {
                    Pair pair3 = new Pair(Integer.valueOf(intValue), extractAllParseFeatures[i4]);
                    HashMap hashMap = (HashMap) objArr2[i4];
                    Float f2 = (Float) hashMap.get(pair3);
                    if (f2 == null) {
                        hashMap.put(pair3, Float.valueOf(1.0f));
                    } else {
                        hashMap.put(pair3, Float.valueOf(f2.floatValue() + 1.0f));
                    }
                }
            }
            if (intValue == 0) {
                ArcEager.shift(m14clone2.state);
            } else if (intValue == 1) {
                ArcEager.reduce(m14clone2.state);
            } else if (intValue >= this.dependencyRelations.size() + 3) {
                ArcEager.leftArc(m14clone2.state, intValue - (this.dependencyRelations.size() + 3));
            } else if (intValue >= 3) {
                ArcEager.rightArc(m14clone2.state, intValue - 3);
            }
            i = 0;
        }
        Iterator<Integer> it2 = configuration3.actionHistory.iterator();
        while (it2.hasNext()) {
            int intValue2 = it2.next().intValue();
            if (isTrueFeature(z, m14clone, intValue2)) {
                Object[] extractAllParseFeatures2 = FeatureExtractor.extractAllParseFeatures(m14clone, this.featureLength);
                if (intValue2 != 2) {
                    for (int i5 = 0; i5 < extractAllParseFeatures2.length; i5++) {
                        Pair pair4 = new Pair(Integer.valueOf(intValue2), extractAllParseFeatures2[i5]);
                        HashMap hashMap2 = (HashMap) objArr[i5];
                        if (((Float) hashMap2.get(pair4)) == null) {
                            hashMap2.put(pair4, Float.valueOf(1.0f));
                        } else {
                            hashMap2.put(pair4, Float.valueOf(((Float) hashMap2.get(pair4)).floatValue() + 1.0f));
                        }
                    }
                }
            }
            State state = m14clone.state;
            if (intValue2 == 0) {
                ArcEager.shift(state);
            } else if (intValue2 == 1) {
                ArcEager.reduce(state);
            } else if (intValue2 >= this.dependencyRelations.size() + 3) {
                ArcEager.leftArc(state, intValue2 - (this.dependencyRelations.size() + 3));
            } else if (intValue2 >= 3) {
                ArcEager.rightArc(state, intValue2 - 3);
            } else if (intValue2 == 2) {
                ArcEager.unShift(state);
            }
        }
        for (int i6 = 0; i6 < i2; i6++) {
            HashMap hashMap3 = (HashMap) objArr[i6];
            HashMap hashMap4 = (HashMap) objArr2[i6];
            for (Pair pair5 : hashMap3.keySet()) {
                LabeledAction labeledAction = new LabeledAction(((Integer) pair5.first).intValue(), this.dependencyRelations.size());
                Action action = labeledAction.action;
                int i7 = labeledAction.label;
                if (pair5.second != 0) {
                    Object obj = pair5.second;
                    if (!hashMap4.containsKey(pair5) || !((Float) hashMap4.get(pair5)).equals(hashMap3.get(pair5))) {
                        this.classifier.changeWeight(action, i6, obj, i7, -((Float) hashMap3.get(pair5)).floatValue());
                    }
                }
            }
            for (Pair pair6 : hashMap4.keySet()) {
                LabeledAction labeledAction2 = new LabeledAction(((Integer) pair6.first).intValue(), this.dependencyRelations.size());
                Action action2 = labeledAction2.action;
                int i8 = labeledAction2.label;
                if (pair6.second != 0) {
                    Object obj2 = pair6.second;
                    if (!hashMap3.containsKey(pair6) || !((Float) hashMap3.get(pair6)).equals(hashMap4.get(pair6))) {
                        this.classifier.changeWeight(action2, i6, obj2, i8, ((Float) hashMap4.get(pair6)).floatValue());
                    }
                }
            }
        }
    }

    private Configuration zeroCostDynamicOracle(Instance instance, Collection<Configuration> collection, Collection<Configuration> collection2) {
        float f = Float.NEGATIVE_INFINITY;
        Configuration configuration = null;
        for (Configuration configuration2 : collection) {
            if (configuration2.state.isTerminalState()) {
                collection2.add(configuration2);
            } else {
                State state = configuration2.state;
                Object[] extractAllParseFeatures = FeatureExtractor.extractAllParseFeatures(configuration2, this.featureLength);
                if (instance.actionCost(Action.Shift, -1, state) == 0) {
                    Configuration m14clone = configuration2.m14clone();
                    float shiftScore = this.classifier.shiftScore(extractAllParseFeatures, false);
                    ArcEager.shift(m14clone.state);
                    m14clone.addAction(0);
                    m14clone.addScore(shiftScore);
                    collection2.add(m14clone);
                    if (m14clone.getScore(true) > f) {
                        f = m14clone.getScore(true);
                        configuration = m14clone;
                    }
                }
                if (ArcEager.canDo(Action.RightArc, state)) {
                    float[] rightArcScores = this.classifier.rightArcScores(extractAllParseFeatures, false);
                    Iterator<Integer> it = this.dependencyRelations.iterator();
                    while (it.hasNext()) {
                        int intValue = it.next().intValue();
                        if (instance.actionCost(Action.RightArc, intValue, state) == 0) {
                            Configuration m14clone2 = configuration2.m14clone();
                            float f2 = rightArcScores[intValue];
                            ArcEager.rightArc(m14clone2.state, intValue);
                            m14clone2.addAction(intValue + 3);
                            m14clone2.addScore(f2);
                            collection2.add(m14clone2);
                            if (m14clone2.getScore(true) > f) {
                                f = m14clone2.getScore(true);
                                configuration = m14clone2;
                            }
                        }
                    }
                }
                if (ArcEager.canDo(Action.LeftArc, state)) {
                    float[] leftArcScores = this.classifier.leftArcScores(extractAllParseFeatures, false);
                    Iterator<Integer> it2 = this.dependencyRelations.iterator();
                    while (it2.hasNext()) {
                        int intValue2 = it2.next().intValue();
                        if (instance.actionCost(Action.LeftArc, intValue2, state) == 0) {
                            Configuration m14clone3 = configuration2.m14clone();
                            float f3 = leftArcScores[intValue2];
                            ArcEager.leftArc(m14clone3.state, intValue2);
                            m14clone3.addAction(this.dependencyRelations.size() + 3 + intValue2);
                            m14clone3.addScore(f3);
                            collection2.add(m14clone3);
                            if (m14clone3.getScore(true) > f) {
                                f = m14clone3.getScore(true);
                                configuration = m14clone3;
                            }
                        }
                    }
                }
                if (instance.actionCost(Action.Reduce, -1, state) == 0) {
                    Configuration m14clone4 = configuration2.m14clone();
                    float reduceScore = this.classifier.reduceScore(extractAllParseFeatures, false);
                    ArcEager.reduce(m14clone4.state);
                    m14clone4.addAction(1);
                    m14clone4.addScore(reduceScore);
                    collection2.add(m14clone4);
                    if (m14clone4.getScore(true) > f) {
                        f = m14clone4.getScore(true);
                        configuration = m14clone4;
                    }
                }
            }
        }
        return configuration;
    }

    public void train(ArrayList<Instance> arrayList, String str, int i, String str2, boolean z, HashSet<String> hashSet, int i2) throws IOException, ExecutionException, InterruptedException {
        int i3;
        ExecutorService executorService;
        String str3;
        String str4;
        int i4;
        ExecutorCompletionService executorCompletionService;
        int i5;
        int i6;
        String str5 = str;
        int i7 = i;
        String str6 = str2;
        ExecutorService newFixedThreadPool = Executors.newFixedThreadPool(this.options.numOfThreads);
        ExecutorCompletionService executorCompletionService2 = new ExecutorCompletionService(newFixedThreadPool);
        int i8 = 1;
        double d = -1.0d;
        int i9 = 1;
        while (i9 <= i7) {
            long currentTimeMillis = System.currentTimeMillis();
            int ceil = (int) Math.ceil(arrayList.size() / 10000.0f);
            Iterator<Instance> it = arrayList.iterator();
            int i10 = 0;
            while (it.hasNext()) {
                Instance next = it.next();
                int i11 = i10 + 1;
                if (i11 % ceil == 0 || i11 == arrayList.size()) {
                    Object[] objArr = new Object[i8];
                    executorCompletionService = executorCompletionService2;
                    i5 = i11;
                    i6 = ceil;
                    objArr[0] = Double.valueOf(MathUtility.percentage(i11, arrayList.size()));
                    System.out.printf("\r迭代 " + i9 + "/" + i7 + " %.2f%% ", objArr);
                } else {
                    i5 = i11;
                    i6 = ceil;
                    executorCompletionService = executorCompletionService2;
                }
                int i12 = i5;
                trainOnOneSample(next, i2, i9, i12, executorCompletionService);
                this.classifier.incrementIteration();
                i10 = i12;
                ceil = i6;
                executorCompletionService2 = executorCompletionService;
                i8 = 1;
            }
            ExecutorCompletionService executorCompletionService3 = executorCompletionService2;
            long currentTimeMillis2 = (System.currentTimeMillis() - currentTimeMillis) / 1000;
            System.out.print(" 耗时 " + currentTimeMillis2 + " 秒。");
            ParserModel parserModel = new ParserModel(this.classifier, this.maps, this.dependencyRelations, this.options);
            if (str5.equals("")) {
                i3 = i9;
                executorService = newFixedThreadPool;
                str3 = str6;
                str4 = str5;
                i4 = 1;
                parserModel.saveModel(str3);
                System.out.println();
            } else {
                KBeamArcEagerParser kBeamArcEagerParser = new KBeamArcEagerParser(new AveragedPerceptron(parserModel), this.dependencyRelations, this.featureLength, this.maps, this.options.numOfThreads, this.options);
                String str7 = str6 + ".__tmp__";
                i3 = i9;
                i4 = 1;
                executorService = newFixedThreadPool;
                str3 = str6;
                str4 = str5;
                kBeamArcEagerParser.parseConllFile(str, str7, this.options.rootFirst, this.options.beamWidth, true, z, this.options.numOfThreads, false, "");
                double[] evaluate = Evaluator.evaluate(str4, str7, hashSet);
                System.out.printf("UAS=%.2f LAS=%.2f", Double.valueOf(evaluate[0]), Double.valueOf(evaluate[1]));
                IOUtil.deleteFile(str7);
                kBeamArcEagerParser.shutDownLiveThreads();
                if (evaluate[0] > d) {
                    d = evaluate[0];
                    System.out.println(" 最高分！保存中...");
                    parserModel.saveModel(str3);
                } else {
                    System.out.println();
                }
            }
            i9 = i3 + 1;
            i7 = i;
            str6 = str3;
            str5 = str4;
            executorCompletionService2 = executorCompletionService3;
            i8 = i4;
            newFixedThreadPool = executorService;
        }
        ExecutorService executorService2 = newFixedThreadPool;
        for (boolean isTerminated = executorService2.isTerminated(); !isTerminated; isTerminated = executorService2.isTerminated()) {
            executorService2.shutdownNow();
        }
    }
}
