/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.examples.ml;

import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.classification.FMClassificationModel;
import org.apache.spark.ml.classification.FMClassifier;
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;
import org.apache.spark.ml.feature.IndexToString;
import org.apache.spark.ml.feature.MinMaxScaler;
import org.apache.spark.ml.feature.MinMaxScalerModel;
import org.apache.spark.ml.feature.StringIndexer;
import org.apache.spark.ml.feature.StringIndexerModel;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.SparkSession;

public class JavaFMClassifierExample {
    public static void main(String[] args) {
        SparkSession spark = SparkSession.builder().appName("JavaFMClassifierExample").getOrCreate();
        Dataset data = spark.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt");
        StringIndexerModel labelIndexer = new StringIndexer().setInputCol("label").setOutputCol("indexedLabel").fit(data);
        MinMaxScalerModel featureScaler = new MinMaxScaler().setInputCol("features").setOutputCol("scaledFeatures").fit(data);
        Dataset[] splits = data.randomSplit(new double[]{0.7, 0.3});
        Dataset trainingData = splits[0];
        Dataset testData = splits[1];
        FMClassifier fm = ((FMClassifier)((FMClassifier)new FMClassifier().setLabelCol("indexedLabel")).setFeaturesCol("scaledFeatures")).setStepSize(0.001);
        IndexToString labelConverter = new IndexToString().setInputCol("prediction").setOutputCol("predictedLabel").setLabels(labelIndexer.labelsArray()[0]);
        Pipeline pipeline = new Pipeline().setStages(new PipelineStage[]{labelIndexer, featureScaler, fm, labelConverter});
        PipelineModel model = pipeline.fit(trainingData);
        Dataset predictions = model.transform(testData);
        predictions.select("predictedLabel", new String[]{"label", "features"}).show(5);
        MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator().setLabelCol("indexedLabel").setPredictionCol("prediction").setMetricName("accuracy");
        double accuracy = evaluator.evaluate(predictions);
        System.out.println("Test Accuracy = " + accuracy);
        FMClassificationModel fmModel = (FMClassificationModel)model.stages()[2];
        System.out.println("Factors: " + String.valueOf(fmModel.factors()));
        System.out.println("Linear: " + String.valueOf(fmModel.linear()));
        System.out.println("Intercept: " + fmModel.intercept());
        spark.stop();
    }
}

