/*
 * Decompiled with CFR 0.152.
 */
package com.hankcs.hanlp.mining.word2vec;

import com.hankcs.hanlp.mining.word2vec.Config;
import com.hankcs.hanlp.mining.word2vec.NeuralNetworkType;
import com.hankcs.hanlp.mining.word2vec.Preconditions;
import com.hankcs.hanlp.mining.word2vec.TrainingCallback;
import com.hankcs.hanlp.mining.word2vec.Utility;
import com.hankcs.hanlp.mining.word2vec.Word2VecTraining;
import com.hankcs.hanlp.mining.word2vec.WordVectorModel;
import com.hankcs.hanlp.utility.Predefine;
import com.hankcs.hanlp.utility.TextUtility;
import java.io.IOException;

public class Word2VecTrainer {
    private Integer layerSize = 200;
    private Integer windowSize = 5;
    private Integer numThreads = Runtime.getRuntime().availableProcessors();
    private int negativeSamples = 25;
    private boolean useHierarchicalSoftmax;
    private Integer minFrequency = 5;
    private Float initialLearningRate;
    private float downSampleRate = 1.0E-4f;
    private Integer iterations = 15;
    private NeuralNetworkType type = NeuralNetworkType.CBOW;
    private TrainingCallback callback;

    public void setCallback(TrainingCallback callback) {
        this.callback = callback;
    }

    public Word2VecTrainer setLayerSize(int layerSize) {
        Preconditions.checkArgument(layerSize > 0, "Value must be positive");
        this.layerSize = layerSize;
        return this;
    }

    public Word2VecTrainer setWindowSize(int windowSize) {
        Preconditions.checkArgument(windowSize > 0, "Value must be positive");
        this.windowSize = windowSize;
        return this;
    }

    public Word2VecTrainer useNumThreads(int numThreads) {
        Preconditions.checkArgument(numThreads > 0, "Value must be positive");
        this.numThreads = numThreads;
        return this;
    }

    public Word2VecTrainer type(NeuralNetworkType type) {
        this.type = Preconditions.checkNotNull(type);
        return this;
    }

    public Word2VecTrainer useHierarchicalSoftmax() {
        this.useHierarchicalSoftmax = true;
        return this;
    }

    public Word2VecTrainer useNegativeSamples(int negativeSamples) {
        Preconditions.checkArgument(negativeSamples >= 0, "Value must be non-negative");
        this.negativeSamples = negativeSamples;
        return this;
    }

    public Word2VecTrainer setMinVocabFrequency(int minFrequency) {
        Preconditions.checkArgument(minFrequency >= 0, "Value must be non-negative");
        this.minFrequency = minFrequency;
        return this;
    }

    public Word2VecTrainer setInitialLearningRate(float initialLearningRate) {
        Preconditions.checkArgument(initialLearningRate >= 0.0f, "Value must be non-negative");
        this.initialLearningRate = Float.valueOf(initialLearningRate);
        return this;
    }

    public Word2VecTrainer setDownSamplingRate(float downSampleRate) {
        Preconditions.checkArgument(downSampleRate >= 0.0f, "Value must be non-negative");
        this.downSampleRate = downSampleRate;
        return this;
    }

    public Word2VecTrainer setNumIterations(int iterations) {
        Preconditions.checkArgument(iterations > 0, "Value must be positive");
        this.iterations = iterations;
        return this;
    }

    public WordVectorModel train(String trainFileName, String modelFileName) {
        Config settings = new Config();
        settings.setInputFile(trainFileName);
        settings.setLayer1Size(this.layerSize);
        settings.setUseContinuousBagOfWords(this.type == NeuralNetworkType.CBOW);
        settings.setUseHierarchicalSoftmax(this.useHierarchicalSoftmax);
        settings.setNegative(this.negativeSamples);
        settings.setNumThreads(this.numThreads);
        settings.setAlpha(this.initialLearningRate == null ? this.type.getDefaultInitialLearningRate() : this.initialLearningRate.floatValue());
        settings.setSample(this.downSampleRate);
        settings.setWindow(this.windowSize);
        settings.setIter(this.iterations);
        settings.setMinCount(this.minFrequency);
        settings.setOutputFile(modelFileName);
        Word2VecTraining model = new Word2VecTraining(settings);
        long timeStart = System.currentTimeMillis();
        settings.setCallback(this.callback);
        try {
            model.trainModel();
            System.out.println();
            System.out.printf("\u8bad\u7ec3\u7ed3\u675f\uff0c\u4e00\u5171\u8017\u65f6\uff1a%s\n", Utility.humanTime(System.currentTimeMillis() - timeStart));
            return new WordVectorModel(modelFileName);
        }
        catch (IOException e) {
            Predefine.logger.warning("\u8bad\u7ec3\u8fc7\u7a0b\u4e2d\u53d1\u751fIO\u5f02\u5e38\n" + TextUtility.exceptionToString(e));
            return null;
        }
    }
}

