/*
 * Decompiled with CFR 0.152.
 */
package com.datumbox.framework.applications.nlp;

import com.datumbox.framework.common.Configuration;
import com.datumbox.framework.common.dataobjects.AssociativeArray;
import com.datumbox.framework.common.dataobjects.Dataframe;
import com.datumbox.framework.common.dataobjects.Record;
import com.datumbox.framework.common.interfaces.Extractable;
import com.datumbox.framework.common.interfaces.Trainable;
import com.datumbox.framework.common.persistentstorage.interfaces.DatabaseConnector;
import com.datumbox.framework.common.utilities.StringCleaner;
import com.datumbox.framework.core.machinelearning.common.abstracts.AbstractTrainer;
import com.datumbox.framework.core.machinelearning.common.abstracts.datatransformers.AbstractTransformer;
import com.datumbox.framework.core.machinelearning.common.abstracts.featureselectors.AbstractCategoricalFeatureSelector;
import com.datumbox.framework.core.machinelearning.common.abstracts.featureselectors.AbstractFeatureSelector;
import com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractModeler;
import com.datumbox.framework.core.machinelearning.common.abstracts.wrappers.AbstractWrapper;
import com.datumbox.framework.core.machinelearning.common.interfaces.ValidationMetrics;
import com.datumbox.framework.core.utilities.text.extractors.AbstractTextExtractor;
import java.net.URI;
import java.util.HashMap;
import java.util.Map;

public class TextClassifier
extends AbstractWrapper<ModelParameters, TrainingParameters> {
    public TextClassifier(String dbName, Configuration conf) {
        super(dbName, conf, ModelParameters.class, TrainingParameters.class);
    }

    public void fit(Map<Object, URI> datasets, TrainingParameters trainingParameters) {
        Dataframe trainingData = Dataframe.Builder.parseTextFiles(datasets, (Extractable)AbstractTextExtractor.newInstance(trainingParameters.getTextExtractorClass(), (AbstractTextExtractor.AbstractParameters)trainingParameters.getTextExtractorParameters()), (Configuration)this.kb().getConf());
        this.fit(trainingData, (AbstractTrainer.AbstractTrainingParameters)trainingParameters);
        trainingData.delete();
    }

    public void predict(Dataframe testDataset) {
        this.logger.info("predict()");
        this.kb().load();
        this.preprocessTestDataset(testDataset);
        this.modeler.predict(testDataset);
    }

    public Dataframe predict(URI datasetURI) {
        this.kb().load();
        HashMap<Object, URI> dataset = new HashMap<Object, URI>();
        dataset.put(null, datasetURI);
        TrainingParameters trainingParameters = (TrainingParameters)this.kb().getTrainingParameters();
        Dataframe testDataset = Dataframe.Builder.parseTextFiles(dataset, (Extractable)AbstractTextExtractor.newInstance(trainingParameters.getTextExtractorClass(), (AbstractTextExtractor.AbstractParameters)trainingParameters.getTextExtractorParameters()), (Configuration)this.kb().getConf());
        this.predict(testDataset);
        return testDataset;
    }

    public Record predict(String text) {
        this.kb().load();
        TrainingParameters trainingParameters = (TrainingParameters)this.kb().getTrainingParameters();
        Dataframe testDataset = new Dataframe(this.kb().getConf());
        testDataset.add(new Record(new AssociativeArray(AbstractTextExtractor.newInstance(trainingParameters.getTextExtractorClass(), (AbstractTextExtractor.AbstractParameters)trainingParameters.getTextExtractorParameters()).extract(StringCleaner.clear((String)text))), null));
        this.predict(testDataset);
        Record r = (Record)testDataset.iterator().next();
        testDataset.delete();
        return r;
    }

    public ValidationMetrics validate(Dataframe testDataset) {
        this.logger.info("validate()");
        this.kb().load();
        this.preprocessTestDataset(testDataset);
        AbstractModeler.AbstractValidationMetrics vm = this.modeler.validate(testDataset);
        return vm;
    }

    public ValidationMetrics validate(Map<Object, URI> datasets) {
        this.kb().load();
        TrainingParameters trainingParameters = (TrainingParameters)this.kb().getTrainingParameters();
        Dataframe testDataset = Dataframe.Builder.parseTextFiles(datasets, (Extractable)AbstractTextExtractor.newInstance(trainingParameters.getTextExtractorClass(), (AbstractTextExtractor.AbstractParameters)trainingParameters.getTextExtractorParameters()), (Configuration)this.kb().getConf());
        ValidationMetrics vm = this.validate(testDataset);
        testDataset.delete();
        return vm;
    }

    protected void _fit(Dataframe trainingDataset) {
        Class fsClass;
        boolean selectFeatures;
        boolean transformData;
        TrainingParameters trainingParameters = (TrainingParameters)this.kb().getTrainingParameters();
        Configuration conf = this.kb().getConf();
        Class dtClass = trainingParameters.getDataTransformerClass();
        boolean bl = transformData = dtClass != null;
        if (transformData) {
            this.dataTransformer = (AbstractTransformer)Trainable.newInstance((Class)dtClass, (String)this.dbName, (Configuration)conf);
            this.setParallelized((Trainable)this.dataTransformer);
            this.dataTransformer.fit_transform(trainingDataset, trainingParameters.getDataTransformerTrainingParameters());
        }
        boolean bl2 = selectFeatures = (fsClass = trainingParameters.getFeatureSelectorClass()) != null;
        if (selectFeatures) {
            this.featureSelector = (AbstractFeatureSelector)Trainable.newInstance((Class)fsClass, (String)this.dbName, (Configuration)conf);
            AbstractTrainer.AbstractTrainingParameters featureSelectorParameters = trainingParameters.getFeatureSelectorTrainingParameters();
            if (AbstractCategoricalFeatureSelector.AbstractTrainingParameters.class.isAssignableFrom(featureSelectorParameters.getClass())) {
                ((AbstractCategoricalFeatureSelector.AbstractTrainingParameters)featureSelectorParameters).setIgnoringNumericalFeatures(false);
            }
            this.setParallelized((Trainable)this.featureSelector);
            this.featureSelector.fit_transform(trainingDataset, trainingParameters.getFeatureSelectorTrainingParameters());
        }
        this.modeler = (AbstractModeler)Trainable.newInstance((Class)trainingParameters.getModelerClass(), (String)this.dbName, (Configuration)conf);
        this.setParallelized((Trainable)this.modeler);
        this.modeler.fit(trainingDataset, trainingParameters.getModelerTrainingParameters());
        if (transformData) {
            this.dataTransformer.denormalize(trainingDataset);
        }
    }

    private void preprocessTestDataset(Dataframe testDataset) {
        Class fsClass;
        boolean selectFeatures;
        boolean transformData;
        TrainingParameters trainingParameters = (TrainingParameters)this.kb().getTrainingParameters();
        Configuration conf = this.kb().getConf();
        Class dtClass = trainingParameters.getDataTransformerClass();
        boolean bl = transformData = dtClass != null;
        if (transformData) {
            if (this.dataTransformer == null) {
                this.dataTransformer = (AbstractTransformer)Trainable.newInstance((Class)dtClass, (String)this.dbName, (Configuration)conf);
            }
            this.setParallelized((Trainable)this.dataTransformer);
            this.dataTransformer.transform(testDataset);
        }
        boolean bl2 = selectFeatures = (fsClass = trainingParameters.getFeatureSelectorClass()) != null;
        if (selectFeatures) {
            if (this.featureSelector == null) {
                this.featureSelector = (AbstractFeatureSelector)Trainable.newInstance((Class)fsClass, (String)this.dbName, (Configuration)conf);
            }
            this.setParallelized((Trainable)this.featureSelector);
            this.featureSelector.transform(testDataset);
        }
        if (this.modeler == null) {
            this.modeler = (AbstractModeler)Trainable.newInstance((Class)trainingParameters.getModelerClass(), (String)this.dbName, (Configuration)conf);
        }
        this.setParallelized((Trainable)this.modeler);
    }

    public static class TrainingParameters
    extends AbstractWrapper.AbstractTrainingParameters<AbstractTransformer, AbstractFeatureSelector, AbstractModeler> {
        private static final long serialVersionUID = 1L;
        private Class<? extends AbstractTextExtractor> textExtractorClass;
        private AbstractTextExtractor.AbstractParameters textExtractorParameters;

        public Class<? extends AbstractTextExtractor> getTextExtractorClass() {
            return this.textExtractorClass;
        }

        public void setTextExtractorClass(Class<? extends AbstractTextExtractor> textExtractorClass) {
            this.textExtractorClass = textExtractorClass;
        }

        public AbstractTextExtractor.AbstractParameters getTextExtractorParameters() {
            return this.textExtractorParameters;
        }

        public void setTextExtractorParameters(AbstractTextExtractor.AbstractParameters textExtractorParameters) {
            this.textExtractorParameters = textExtractorParameters;
        }
    }

    public static class ModelParameters
    extends AbstractTrainer.AbstractModelParameters {
        private static final long serialVersionUID = 1L;

        protected ModelParameters(DatabaseConnector dbc) {
            super(dbc);
        }
    }
}

