Java源码示例:org.apache.spark.ml.tuning.CrossValidatorModel
示例1
/**
*
* @param trainData
* @param trainingSettings
* @return
*/
public CrossValidatorModel crossValidate(DataFrame trainData, TrainingSettings trainingSettings) {
//First create the pipeline and the ParamGrid
createPipeline(trainData, trainingSettings);
// We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance.
// This will allow us to jointly choose parameters for all Pipeline stages.
// A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator.
CrossValidator cv = new CrossValidator()
.setEstimator(pipeline)
.setEvaluator(evaluator)
.setEstimatorParamMaps(paramGrid)
.setNumFolds(trainingSettings.getNumFolds());
if(classificationMethod.equals(TrainingSettings.ClassificationMethod.LOG_REG)) {
long numPositive = trainData.filter(col("label").equalTo("1.0")).count();
long datasetSize = trainData.count();
double balancingRatio = (double)(datasetSize - numPositive) / datasetSize;
trainData = trainData
.withColumn("classWeightCol",
when(col("label").equalTo("1.0"), 1* balancingRatio)
.otherwise((1 * (1.0 - balancingRatio))));
}
// Run cross-validation, and choose the best set of parameters.
bestModel = cv.fit(trainData);
System.out.println("IS LARGER BETTER ?"+bestModel.getEvaluator().isLargerBetter());
return bestModel;
}
示例2
public static void debugOutputModel(CrossValidatorModel model, TrainingSettings trainingSettings, String output) throws IOException {
FileSystem fs = FileSystem.get(new Configuration());
Path statsPath = new Path(output+"debug_"+trainingSettings.getClassificationMethod()+".txt");
fs.delete(statsPath, true);
FSDataOutputStream fsdos = fs.create(statsPath);
PipelineModel pipelineModel = (PipelineModel) model.bestModel();
switch (trainingSettings.getClassificationMethod()) {
case RANDOM_FOREST:
for(int i=0; i< pipelineModel.stages().length; i++) {
if (pipelineModel.stages()[i] instanceof RandomForestClassificationModel) {
RandomForestClassificationModel rfModel = (RandomForestClassificationModel) (pipelineModel.stages()[i]);
IOUtils.write(rfModel.toDebugString(), fsdos);
logger.info(rfModel.toDebugString());
}
}
break;
case LOG_REG:
for(int i=0; i< pipelineModel.stages().length; i++) {
if (pipelineModel.stages()[i] instanceof LogisticRegressionModel) {
LogisticRegressionModel lgModel = (LogisticRegressionModel) (pipelineModel.stages()[i]);
IOUtils.write(lgModel.toString(), fsdos);
logger.info(lgModel.toString());
}
}
break;
}
fsdos.flush();
IOUtils.closeQuietly(fsdos);
}
示例3
@Override
protected int run() throws Exception {
SparkConf sparkConf = new SparkConf()
.setAppName("EntitySalienceTrainingSparkRunner")
.set("spark.hadoop.validateOutputSpecs", "false")
.set("spark.yarn.executor.memoryOverhead", "3072")
.set("spark.rdd.compress", "true")
.set("spark.core.connection.ack.wait.timeout", "600")
.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
//.set("spark.kryo.registrationRequired", "true")
.registerKryoClasses(new Class[] {SCAS.class, LabeledPoint.class, SparseVector.class, int[].class, double[].class,
InternalRow[].class, GenericInternalRow.class, Object[].class, GenericArrayData.class,
VectorIndexer.class})
;//.setMaster("local[4]"); //Remove this if you run it on the server.
TrainingSettings trainingSettings = new TrainingSettings();
if(folds != null) {
trainingSettings.setNumFolds(folds);
}
if(method != null) {
trainingSettings.setClassificationMethod(TrainingSettings.ClassificationMethod.valueOf(method));
}
if(defaultConf != null) {
trainingSettings.setAidaDefaultConf(defaultConf);
}
if(scalingFactor != null) {
trainingSettings.setPositiveInstanceScalingFactor(scalingFactor);
}
JavaSparkContext sc = new JavaSparkContext(sparkConf);
int totalCores = Integer.parseInt(sc.getConf().get("spark.executor.instances"))
* Integer.parseInt(sc.getConf().get("spark.executor.cores"));
// int totalCores = 4;
//// trainingSettings.setFeatureExtractor(TrainingSettings.FeatureExtractor.ANNOTATE_AND_ENTITY_SALIENCE);
//// trainingSettings.setAidaDefaultConf("db");
// //trainingSettings.setClassificationMethod(TrainingSettings.ClassificationMethod.LOG_REG);
// trainingSettings.setPositiveInstanceScalingFactor(1);
//Add the cache files to each node only if annotation is required.
//The input documents could already be annotated, and in this case no caches are needed.
if(trainingSettings.getFeatureExtractor().equals(TrainingSettings.FeatureExtractor.ANNOTATE_AND_ENTITY_SALIENCE)) {
sc.addFile(trainingSettings.getBigramCountCache());
sc.addFile(trainingSettings.getKeywordCountCache());
sc.addFile(trainingSettings.getWordContractionsCache());
sc.addFile(trainingSettings.getWordExpansionsCache());
if (trainingSettings.getAidaDefaultConf().equals("db")) {
sc.addFile(trainingSettings.getDatabaseAida());
} else {
sc.addFile(trainingSettings.getCassandraConfig());
}
}
SQLContext sqlContext = new SQLContext(sc);
FileSystem fs = FileSystem.get(new Configuration());
int partitionNumber = 3 * totalCores;
if(partitions != null) {
partitionNumber = partitions;
}
//Read training documents serialized as SCAS
JavaRDD<SCAS> documents = sc.sequenceFile(input, Text.class, SCAS.class, partitionNumber).values();
//Instanciate a training spark runner
TrainingSparkRunner trainingSparkRunner = new TrainingSparkRunner();
//Train a model
CrossValidatorModel model = trainingSparkRunner.crossValidate(sc, sqlContext, documents, trainingSettings);
//Create the model path
String modelPath = output+"/"+sc.getConf().getAppId()+"/model_"+trainingSettings.getClassificationMethod();
//Delete the old model if there is one
fs.delete(new Path(modelPath), true);
//Save the new model model
List<Model> models = new ArrayList<>();
models.add(model.bestModel());
sc.parallelize(models, 1).saveAsObjectFile(modelPath);
//Save the model stats
SparkClassificationModel.saveStats(model, trainingSettings, output+"/"+sc.getConf().getAppId()+"/");
return 0;
}
示例4
public static void saveStats(CrossValidatorModel model, TrainingSettings trainingSettings, String output) throws IOException {
double[] avgMetrics = model.avgMetrics();
double bestMetric = 0;
int bestIndex=0;
for(int i=0; i<avgMetrics.length; i++) {
if(avgMetrics[i] > bestMetric) {
bestMetric = avgMetrics[i];
bestIndex = i;
}
}
FileSystem fs = FileSystem.get(new Configuration());
Path statsPath = new Path(output+"stats_"+trainingSettings.getClassificationMethod()+".txt");
fs.delete(statsPath, true);
FSDataOutputStream fsdos = fs.create(statsPath);
String avgLine="Average cross-validation metrics: "+ Arrays.toString(model.avgMetrics());
String bestMetricLine="\nBest cross-validation metric ["+trainingSettings.getMetricName()+"]: "+bestMetric;
String bestSetParamLine= "\nBest set of parameters: "+model.getEstimatorParamMaps()[bestIndex];
logger.info(avgLine);
logger.info(bestMetricLine);
logger.info(bestSetParamLine);
IOUtils.write(avgLine, fsdos);
IOUtils.write(bestMetricLine, fsdos);
IOUtils.write(bestSetParamLine, fsdos);
PipelineModel pipelineModel = (PipelineModel) model.bestModel();
for(Transformer t : pipelineModel.stages()) {
if(t instanceof ClassificationModel) {
IOUtils.write("\n"+((Model) t).parent().extractParamMap().toString(), fsdos);
logger.info(((Model) t).parent().extractParamMap().toString());
}
}
fsdos.flush();
IOUtils.closeQuietly(fsdos);
debugOutputModel(model,trainingSettings, output);
}
示例5
/**
* Train classification model for documents by doing cross validation and hyper parameter optimization at the same time.
* The produced model contains the best model and statistics about the runs, which are later saved from the caller method.
*
* @param jsc
* @param sqlContext
* @param documents
* @param trainingSettings
* @return
* @throws ResourceInitializationException
* @throws IOException
*/
public CrossValidatorModel crossValidate(JavaSparkContext jsc, SQLContext sqlContext, JavaRDD<SCAS> documents, TrainingSettings trainingSettings) throws ResourceInitializationException, IOException {
FeatureExtractorSpark fesr = FeatureExtractionFactory.createFeatureExtractorSparkRunner(trainingSettings);
//Extract features for each document as LabelPoints
DataFrame trainData = fesr.extract(jsc, documents, sqlContext);
//Save the data for future use, instead of recomputing it all the time
trainData.persist(StorageLevel.MEMORY_AND_DISK_SER_2());
//DataFrame trainData = sqlContext.createDataFrame(labeledPoints, LabeledPoint.class);
//Wrap the classification model base on the training settings
SparkClassificationModel model = new SparkClassificationModel(trainingSettings.getClassificationMethod());
//Train the be best model using CrossValidator
CrossValidatorModel cvModel = model.crossValidate(trainData, trainingSettings);
return cvModel;
}
示例6
public static void main(String[] args) {
SparkSession spark = SparkSession
.builder()
.appName("JavaModelSelectionViaCrossValidationExample")
.getOrCreate();
// $example on$
// Prepare training documents, which are labeled.
Dataset<Row> training = spark.createDataFrame(Arrays.asList(
new JavaLabeledDocument(0L, "a b c d e spark", 1.0),
new JavaLabeledDocument(1L, "b d", 0.0),
new JavaLabeledDocument(2L,"spark f g h", 1.0),
new JavaLabeledDocument(3L, "hadoop mapreduce", 0.0),
new JavaLabeledDocument(4L, "b spark who", 1.0),
new JavaLabeledDocument(5L, "g d a y", 0.0),
new JavaLabeledDocument(6L, "spark fly", 1.0),
new JavaLabeledDocument(7L, "was mapreduce", 0.0),
new JavaLabeledDocument(8L, "e spark program", 1.0),
new JavaLabeledDocument(9L, "a e c l", 0.0),
new JavaLabeledDocument(10L, "spark compile", 1.0),
new JavaLabeledDocument(11L, "hadoop software", 0.0)
), JavaLabeledDocument.class);
// Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr.
Tokenizer tokenizer = new Tokenizer()
.setInputCol("text")
.setOutputCol("words");
HashingTF hashingTF = new HashingTF()
.setNumFeatures(1000)
.setInputCol(tokenizer.getOutputCol())
.setOutputCol("features");
LogisticRegression lr = new LogisticRegression()
.setMaxIter(10)
.setRegParam(0.01);
Pipeline pipeline = new Pipeline()
.setStages(new PipelineStage[] {tokenizer, hashingTF, lr});
// We use a ParamGridBuilder to construct a grid of parameters to search over.
// With 3 values for hashingTF.numFeatures and 2 values for lr.regParam,
// this grid will have 3 x 2 = 6 parameter settings for CrossValidator to choose from.
ParamMap[] paramGrid = new ParamGridBuilder()
.addGrid(hashingTF.numFeatures(), new int[] {10, 100, 1000})
.addGrid(lr.regParam(), new double[] {0.1, 0.01})
.build();
// We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance.
// This will allow us to jointly choose parameters for all Pipeline stages.
// A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator.
// Note that the evaluator here is a BinaryClassificationEvaluator and its default metric
// is areaUnderROC.
CrossValidator cv = new CrossValidator()
.setEstimator(pipeline)
.setEvaluator(new BinaryClassificationEvaluator())
.setEstimatorParamMaps(paramGrid).setNumFolds(2); // Use 3+ in practice
// Run cross-validation, and choose the best set of parameters.
CrossValidatorModel cvModel = cv.fit(training);
// Prepare test documents, which are unlabeled.
Dataset<Row> test = spark.createDataFrame(Arrays.asList(
new JavaDocument(4L, "spark i j k"),
new JavaDocument(5L, "l m n"),
new JavaDocument(6L, "mapreduce spark"),
new JavaDocument(7L, "apache hadoop")
), JavaDocument.class);
// Make predictions on test documents. cvModel uses the best model found (lrModel).
Dataset<Row> predictions = cvModel.transform(test);
for (Row r : predictions.select("id", "text", "probability", "prediction").collectAsList()) {
System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> prob=" + r.get(2)
+ ", prediction=" + r.get(3));
}
// $example off$
spark.stop();
}