Java源码示例:org.apache.spark.ml.tuning.CrossValidatorModel

示例1
/**
 *
 * @param trainData
 * @param trainingSettings
 * @return
 */
public CrossValidatorModel crossValidate(DataFrame trainData, TrainingSettings trainingSettings) {

    //First create the pipeline and the ParamGrid
    createPipeline(trainData, trainingSettings);

    // We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance.
    // This will allow us to jointly choose parameters for all Pipeline stages.
    // A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator.
    CrossValidator cv = new CrossValidator()
            .setEstimator(pipeline)
            .setEvaluator(evaluator)
            .setEstimatorParamMaps(paramGrid)
            .setNumFolds(trainingSettings.getNumFolds());

    if(classificationMethod.equals(TrainingSettings.ClassificationMethod.LOG_REG)) {
        long numPositive = trainData.filter(col("label").equalTo("1.0")).count();
        long datasetSize = trainData.count();
        double balancingRatio = (double)(datasetSize - numPositive) / datasetSize;

        trainData = trainData
                .withColumn("classWeightCol",
                        when(col("label").equalTo("1.0"), 1* balancingRatio)
                                .otherwise((1 * (1.0 - balancingRatio))));
    }
    // Run cross-validation, and choose the best set of parameters.
    bestModel = cv.fit(trainData);
    System.out.println("IS LARGER BETTER ?"+bestModel.getEvaluator().isLargerBetter());
    return bestModel;
}
 
示例2
public static void debugOutputModel(CrossValidatorModel model, TrainingSettings trainingSettings, String output) throws IOException {
    FileSystem fs = FileSystem.get(new Configuration());
    Path statsPath = new Path(output+"debug_"+trainingSettings.getClassificationMethod()+".txt");
    fs.delete(statsPath, true);

    FSDataOutputStream fsdos = fs.create(statsPath);
    PipelineModel pipelineModel = (PipelineModel) model.bestModel();
    switch (trainingSettings.getClassificationMethod()) {
        case RANDOM_FOREST:
            for(int i=0; i< pipelineModel.stages().length; i++) {
                if (pipelineModel.stages()[i] instanceof RandomForestClassificationModel) {
                    RandomForestClassificationModel rfModel = (RandomForestClassificationModel) (pipelineModel.stages()[i]);
                    IOUtils.write(rfModel.toDebugString(), fsdos);
                    logger.info(rfModel.toDebugString());
                }
            }
            break;
        case LOG_REG:
            for(int i=0; i< pipelineModel.stages().length; i++) {
                if (pipelineModel.stages()[i] instanceof LogisticRegressionModel) {
                    LogisticRegressionModel lgModel = (LogisticRegressionModel) (pipelineModel.stages()[i]);
                    IOUtils.write(lgModel.toString(), fsdos);
                    logger.info(lgModel.toString());
                }
            }
            break;
    }
    fsdos.flush();
    IOUtils.closeQuietly(fsdos);
}
 
示例3
@Override
    protected int run() throws Exception {

        SparkConf sparkConf = new SparkConf()
                .setAppName("EntitySalienceTrainingSparkRunner")
                .set("spark.hadoop.validateOutputSpecs", "false")
                .set("spark.yarn.executor.memoryOverhead", "3072")
                .set("spark.rdd.compress", "true")
                .set("spark.core.connection.ack.wait.timeout", "600")
                .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
                //.set("spark.kryo.registrationRequired", "true")
                .registerKryoClasses(new Class[] {SCAS.class, LabeledPoint.class, SparseVector.class, int[].class, double[].class,
                        InternalRow[].class, GenericInternalRow.class, Object[].class, GenericArrayData.class,
                        VectorIndexer.class})
                ;//.setMaster("local[4]"); //Remove this if you run it on the server.

        TrainingSettings trainingSettings = new TrainingSettings();

        if(folds != null) {
            trainingSettings.setNumFolds(folds);
        }
        if(method != null) {
            trainingSettings.setClassificationMethod(TrainingSettings.ClassificationMethod.valueOf(method));
        }
        if(defaultConf != null) {
            trainingSettings.setAidaDefaultConf(defaultConf);
        }

        if(scalingFactor != null) {
            trainingSettings.setPositiveInstanceScalingFactor(scalingFactor);
        }

        JavaSparkContext sc = new JavaSparkContext(sparkConf);
        int totalCores = Integer.parseInt(sc.getConf().get("spark.executor.instances"))
                * Integer.parseInt(sc.getConf().get("spark.executor.cores"));

//        int totalCores = 4;
////        trainingSettings.setFeatureExtractor(TrainingSettings.FeatureExtractor.ANNOTATE_AND_ENTITY_SALIENCE);
////        trainingSettings.setAidaDefaultConf("db");
//        //trainingSettings.setClassificationMethod(TrainingSettings.ClassificationMethod.LOG_REG);
//        trainingSettings.setPositiveInstanceScalingFactor(1);

        //Add the cache files to each node only if annotation is required.
        //The input documents could already be annotated, and in this case no caches are needed.
        if(trainingSettings.getFeatureExtractor().equals(TrainingSettings.FeatureExtractor.ANNOTATE_AND_ENTITY_SALIENCE)) {
            sc.addFile(trainingSettings.getBigramCountCache());
            sc.addFile(trainingSettings.getKeywordCountCache());
            sc.addFile(trainingSettings.getWordContractionsCache());
            sc.addFile(trainingSettings.getWordExpansionsCache());
            if (trainingSettings.getAidaDefaultConf().equals("db")) {
                sc.addFile(trainingSettings.getDatabaseAida());
            } else {
                sc.addFile(trainingSettings.getCassandraConfig());
            }
        }

        SQLContext sqlContext = new SQLContext(sc);


        FileSystem fs = FileSystem.get(new Configuration());

        int partitionNumber = 3 * totalCores;
        if(partitions != null) {
            partitionNumber = partitions;
        }

        //Read training documents serialized as SCAS
        JavaRDD<SCAS> documents = sc.sequenceFile(input, Text.class, SCAS.class, partitionNumber).values();

        //Instanciate a training spark runner
        TrainingSparkRunner trainingSparkRunner = new TrainingSparkRunner();

        //Train a model
        CrossValidatorModel model = trainingSparkRunner.crossValidate(sc, sqlContext, documents, trainingSettings);


        //Create the model path
        String modelPath = output+"/"+sc.getConf().getAppId()+"/model_"+trainingSettings.getClassificationMethod();

        //Delete the old model if there is one
        fs.delete(new Path(modelPath), true);

        //Save the new model model
        List<Model> models = new ArrayList<>();
        models.add(model.bestModel());
        sc.parallelize(models, 1).saveAsObjectFile(modelPath);

        //Save the model stats
        SparkClassificationModel.saveStats(model, trainingSettings, output+"/"+sc.getConf().getAppId()+"/");


        return 0;
    }
 
示例4
public static void saveStats(CrossValidatorModel model, TrainingSettings trainingSettings, String output) throws IOException {
    double[] avgMetrics = model.avgMetrics();
    double bestMetric = 0;
    int bestIndex=0;

    for(int i=0; i<avgMetrics.length; i++) {
        if(avgMetrics[i] > bestMetric) {
            bestMetric = avgMetrics[i];
            bestIndex = i;
        }
    }


    FileSystem fs = FileSystem.get(new Configuration());
    Path statsPath = new Path(output+"stats_"+trainingSettings.getClassificationMethod()+".txt");
    fs.delete(statsPath, true);

    FSDataOutputStream fsdos = fs.create(statsPath);

    String avgLine="Average cross-validation metrics: "+ Arrays.toString(model.avgMetrics());
    String bestMetricLine="\nBest cross-validation metric ["+trainingSettings.getMetricName()+"]: "+bestMetric;
    String bestSetParamLine= "\nBest set of parameters: "+model.getEstimatorParamMaps()[bestIndex];

    logger.info(avgLine);
    logger.info(bestMetricLine);
    logger.info(bestSetParamLine);


    IOUtils.write(avgLine, fsdos);
    IOUtils.write(bestMetricLine, fsdos);
    IOUtils.write(bestSetParamLine, fsdos);

    PipelineModel pipelineModel = (PipelineModel) model.bestModel();
    for(Transformer t : pipelineModel.stages()) {
        if(t instanceof ClassificationModel) {
            IOUtils.write("\n"+((Model) t).parent().extractParamMap().toString(), fsdos);
            logger.info(((Model) t).parent().extractParamMap().toString());
        }
    }

    fsdos.flush();
    IOUtils.closeQuietly(fsdos);

    debugOutputModel(model,trainingSettings, output);
}
 
示例5
/**
 * Train classification model for documents by doing cross validation and hyper parameter optimization at the same time.
 * The produced model contains the best model and statistics about the runs, which are later saved from the caller method.
 *
 * @param jsc
 * @param sqlContext
 * @param documents
 * @param trainingSettings
 * @return
 * @throws ResourceInitializationException
 * @throws IOException
 */
public CrossValidatorModel crossValidate(JavaSparkContext jsc, SQLContext sqlContext, JavaRDD<SCAS> documents, TrainingSettings trainingSettings) throws ResourceInitializationException, IOException {

    FeatureExtractorSpark fesr = FeatureExtractionFactory.createFeatureExtractorSparkRunner(trainingSettings);

    //Extract features for each document as LabelPoints
    DataFrame trainData = fesr.extract(jsc, documents, sqlContext);

    //Save the data for future use, instead of recomputing it all the time
    trainData.persist(StorageLevel.MEMORY_AND_DISK_SER_2());

    //DataFrame trainData = sqlContext.createDataFrame(labeledPoints, LabeledPoint.class);

    //Wrap the classification model base on the training settings
    SparkClassificationModel model = new SparkClassificationModel(trainingSettings.getClassificationMethod());

    //Train the be best model using CrossValidator
    CrossValidatorModel cvModel = model.crossValidate(trainData, trainingSettings);

    return cvModel;
}
 
示例6
public static void main(String[] args) {
  SparkSession spark = SparkSession
    .builder()
    .appName("JavaModelSelectionViaCrossValidationExample")
    .getOrCreate();

  // $example on$
  // Prepare training documents, which are labeled.
  Dataset<Row> training = spark.createDataFrame(Arrays.asList(
    new JavaLabeledDocument(0L, "a b c d e spark", 1.0),
    new JavaLabeledDocument(1L, "b d", 0.0),
    new JavaLabeledDocument(2L,"spark f g h", 1.0),
    new JavaLabeledDocument(3L, "hadoop mapreduce", 0.0),
    new JavaLabeledDocument(4L, "b spark who", 1.0),
    new JavaLabeledDocument(5L, "g d a y", 0.0),
    new JavaLabeledDocument(6L, "spark fly", 1.0),
    new JavaLabeledDocument(7L, "was mapreduce", 0.0),
    new JavaLabeledDocument(8L, "e spark program", 1.0),
    new JavaLabeledDocument(9L, "a e c l", 0.0),
    new JavaLabeledDocument(10L, "spark compile", 1.0),
    new JavaLabeledDocument(11L, "hadoop software", 0.0)
  ), JavaLabeledDocument.class);

  // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr.
  Tokenizer tokenizer = new Tokenizer()
    .setInputCol("text")
    .setOutputCol("words");
  HashingTF hashingTF = new HashingTF()
    .setNumFeatures(1000)
    .setInputCol(tokenizer.getOutputCol())
    .setOutputCol("features");
  LogisticRegression lr = new LogisticRegression()
    .setMaxIter(10)
    .setRegParam(0.01);
  Pipeline pipeline = new Pipeline()
    .setStages(new PipelineStage[] {tokenizer, hashingTF, lr});

  // We use a ParamGridBuilder to construct a grid of parameters to search over.
  // With 3 values for hashingTF.numFeatures and 2 values for lr.regParam,
  // this grid will have 3 x 2 = 6 parameter settings for CrossValidator to choose from.
  ParamMap[] paramGrid = new ParamGridBuilder()
    .addGrid(hashingTF.numFeatures(), new int[] {10, 100, 1000})
    .addGrid(lr.regParam(), new double[] {0.1, 0.01})
    .build();

  // We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance.
  // This will allow us to jointly choose parameters for all Pipeline stages.
  // A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator.
  // Note that the evaluator here is a BinaryClassificationEvaluator and its default metric
  // is areaUnderROC.
  CrossValidator cv = new CrossValidator()
    .setEstimator(pipeline)
    .setEvaluator(new BinaryClassificationEvaluator())
    .setEstimatorParamMaps(paramGrid).setNumFolds(2);  // Use 3+ in practice

  // Run cross-validation, and choose the best set of parameters.
  CrossValidatorModel cvModel = cv.fit(training);

  // Prepare test documents, which are unlabeled.
  Dataset<Row> test = spark.createDataFrame(Arrays.asList(
    new JavaDocument(4L, "spark i j k"),
    new JavaDocument(5L, "l m n"),
    new JavaDocument(6L, "mapreduce spark"),
    new JavaDocument(7L, "apache hadoop")
  ), JavaDocument.class);

  // Make predictions on test documents. cvModel uses the best model found (lrModel).
  Dataset<Row> predictions = cvModel.transform(test);
  for (Row r : predictions.select("id", "text", "probability", "prediction").collectAsList()) {
    System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> prob=" + r.get(2)
      + ", prediction=" + r.get(3));
  }
  // $example off$

  spark.stop();
}