Java源码示例:ml.dmlc.xgboost4j.java.DMatrix

示例1
public static DMatrix wekaInstancesToDMatrix(Instances insts) throws XGBoostError {
    int numRows = insts.numInstances();
    int numCols = insts.numAttributes()-1;

    float[] data = new float[numRows*numCols];
    float[] labels = new float[numRows];

    int ind = 0;
    for (int i = 0; i < numRows; i++) {
        for (int j = 0; j < numCols; j++)
            data[ind++] = (float) insts.instance(i).value(j);
        labels[i] = (float) insts.instance(i).classValue();
    }

    DMatrix dmat = new DMatrix(data, numRows, numCols);
    dmat.setLabel(labels);
    return dmat;
}
 
示例2
@Override
public DMatrix buildMatrix(@Nonnull float[] labels) throws XGBoostError {
    final int numRows = rows.size();
    if (labels.length != numRows) {
        throw new XGBoostError(
            String.format("labels.length does not match to nrows. labels.length=%d, nrows=%d",
                labels.length, numRows));
    }

    final float[] data = new float[numRows * maxNumColumns];
    Arrays.fill(data, Float.NaN);
    for (int i = 0; i < numRows; i++) {
        final float[] row = rows.get(i);
        final int rowPtr = i * maxNumColumns;
        for (int j = 0; j < row.length; j++) {
            int ij = rowPtr + j;
            data[ij] = row[j];
        }
    }

    DMatrix matrix = new DMatrix(data, numRows, maxNumColumns, Float.NaN);
    matrix.setLabel(labels);
    return matrix;
}
 
示例3
@Nonnull
private static Booster train(@Nonnull final DMatrix dtrain, @Nonnegative final int round,
        @Nonnull final Map<String, Object> params, @Nullable final Reporter reporter)
        throws NoSuchMethodException, IllegalAccessException, InvocationTargetException,
        InstantiationException, XGBoostError {
    final Counters.Counter iterCounter = (reporter == null) ? null
            : reporter.getCounter("hivemall.XGBoostTrainUDTF$Counter", "iteration");

    final Booster booster = XGBoostUtils.createBooster(dtrain, params);
    for (int iter = 0; iter < round; iter++) {
        reportProgress(reporter);
        setCounterValue(iterCounter, iter + 1);

        booster.update(dtrain, iter);
    }
    return booster;
}
 
示例4
@Test
public void testCreateFromCSREx() throws XGBoostError {
    // sparse matrix
    // 1 0 2 3 0
    // 4 0 2 3 5
    // 3 1 2 5 0
    DenseDMatrixBuilder builder = new DenseDMatrixBuilder(1024);
    builder.nextRow(new float[] {1, 0, 2, 3, 0});
    builder.nextRow(new float[] {4, 0, 2, 3, 5});
    builder.nextRow(new float[] {3, 1, 2, 5, 0});
    float[] label1 = new float[] {1, 0, 1};
    DMatrix dmat1 = builder.buildMatrix(label1);

    Assert.assertEquals(3, dmat1.rowNum());
    float[] label2 = dmat1.getLabel();
    Assert.assertArrayEquals(label1, label2, 0.f);
}
 
示例5
private static DMatrix createDenseDMatrix() throws XGBoostError {
    /*
    11  12  13  14  0   0
    0   22  23  0   0   0
    0   0   33  34  35  36
    0   0   0   44  45  0
    0   0   0   0   0   56
    0   0   0   0   0   66
    */
    DenseDMatrixBuilder builder = new DenseDMatrixBuilder(1024);
    builder.nextRow(new String[] {"0:11", "1:12", "2:13", "3:14"});
    builder.nextRow(new String[] {"1:22", "2:23"});
    builder.nextRow(new String[] {"2:33", "3:34", "4:35", "5:36"});
    builder.nextRow(new String[] {"3:44", "4:45"});
    builder.nextRow(new String[] {"5:56"});
    builder.nextRow(new String[] {"5:66"});

    float[] labels = new float[6];
    return builder.buildMatrix(labels);
}
 
示例6
private static DMatrix createSparseDMatrix() throws XGBoostError {
    /*
    11  12  13  14  0   0
    0   22  23  0   0   0
    0   0   33  34  35  36
    0   0   0   44  45  0
    0   0   0   0   0   56
    0   0   0   0   0   66
    */
    SparseDMatrixBuilder builder = new SparseDMatrixBuilder(1024);
    builder.nextRow(new String[] {"0:11", "1:12", "2:13", "3:14"});
    builder.nextRow(new String[] {"1:22", "2:23"});
    builder.nextRow(new String[] {"2:33", "3:34", "4:35", "5:36"});
    builder.nextRow(new String[] {"3:44", "4:45"});
    builder.nextRow(new String[] {"5:56"});
    builder.nextRow(new String[] {"5:66"});

    float[] labels = new float[6];
    return builder.buildMatrix(labels);
}
 
示例7
@Override
public DMatrix loadDatasetAsDMatrix() throws Exception {
    final DMatrixBuilder builder = new SparseDMatrixBuilder(1024, false);
    final FloatArrayList labels = new FloatArrayList(1024);

    RowProcessor proc = new RowProcessor() {
        @Override
        public void handleRow(String[] splitted) throws Exception {
            float label = Float.parseFloat(splitted[0]);
            labels.add(label);
            builder.nextRow(splitted, 1, splitted.length);
        }
    };
    parse(proc);

    return builder.buildMatrix(labels.toArray());
}
 
示例8
@Override
public DMatrix loadDatasetAsDMatrix() throws Exception {
    final DMatrixBuilder builder = new DenseDMatrixBuilder(1024);
    final FloatArrayList labels = new FloatArrayList(1024);

    RowProcessor proc = new RowProcessor() {
        @Override
        public void handleRow(String[] splitted) throws Exception {
            final float[] features = new float[34];
            for (int i = 0; i <= 32; i++) {
                features[i] = Float.parseFloat(splitted[i]);
            }
            features[33] = splitted[33].equals("?") ? 0.f : Float.parseFloat(splitted[33]);
            int label = Integer.parseInt(splitted[34]) - 1;

            labels.add(label);
            builder.nextRow(features);
        }
    };
    parse(proc);

    return builder.buildMatrix(labels.toArray()).slice(sliceIndex);
}
 
示例9
public DependencyRelationship[] parse(List<SToken> stokens) throws XGBoostError{
    CoNLLDependencyGraph cgraph = new CoNLLDependencyGraph(stokens);
    // build ftrs
    Float[][] pairFtrs = cgraph.buildAllFtrs();
    float[] flattenPairFtrs = UtilFns.flatten2dFloatArray(pairFtrs);


    int numRecords = pairFtrs.length;
    int numFtrs = pairFtrs[0].length;
    DMatrix dmatrix = new DMatrix(flattenPairFtrs,numRecords,numFtrs);

    float[][] predictScores = this.edgeScoreModel.predict(dmatrix,false,SmoothNLP.XGBoost_DP_Edge_Model_Predict_Tree_Limit);  // 调节treeLimit , 优化时间

    float[] predictScoresFlatten = UtilFns.flatten2dFloatArray(predictScores);
    float[][] edgeScores = new float[cgraph.size()][cgraph.size()];
    for (int i =0; i<cgraph.size(); i++){
        for (int j = 0; j<cgraph.size(); j++){
            if (i!=j){  // 过滤一个token 自己依赖自己的情况
                // todo: 待评估
                edgeScores[i][j] = predictScoresFlatten[i*cgraph.size()+j];
            }
        }
    }
    cgraph.setEdgeScores(edgeScores);

    return cgraph.parseDependencyRelationships(this.edgeTagModel);
}
 
示例10
public static void trainXgbModel(String trainFile, String devFile, String modelAddr, int nround, int negSampleRate, int earlyStop, int nthreads) throws IOException{
    final DMatrix trainMatrix = readCoNLL2DMatrix(trainFile,negSampleRate);
    final DMatrix devMatrix = readCoNLL2DMatrix(devFile,negSampleRate);
    try{
        Map<String, Object> params = new HashMap<String, Object>() {
            {
                put("nthread", nthreads);
                put("max_depth", 16);
                put("silent", 0);
                put("objective", "binary:logistic");
                put("colsample_bytree",0.95);
                put("colsample_bylevel",0.95);
                put("eta",0.2);
                put("subsample",0.95);
                put("lambda",0.2);

                put("min_child_weight",5);
                put("scale_pos_weight",negSampleRate);

                // other parameters
                // "objective" -> "multi:softmax", "num_class" -> "6"

                put("eval_metric", "logloss");
                put("tree_method","approx");
            }
        };
        Map<String, DMatrix> watches = new HashMap<String, DMatrix>() {
            {
                put("train", trainMatrix);
                put("dev",devMatrix);
            }
        };
        Booster booster = XGBoost.train(trainMatrix, params, nround, watches, null, null,null,earlyStop);
        OutputStream outstream = SmoothNLP.IOAdaptor.create(modelAddr);
        booster.saveModel(outstream);
    }catch(XGBoostError e){
        System.out.println(e);
    }
}
 
示例11
public static DMatrix toDMatrix(final MLSparseMatrix matrix)
		throws XGBoostError {

	final int nnz = (int) matrix.getNNZ();
	final int nRows = matrix.getNRows();
	final int nCols = matrix.getNCols();

	long[] rowIndex = new long[nRows + 1];
	int[] indexesFlat = new int[nnz];
	float[] valuesFlat = new float[nnz];

	int cur = 0;
	for (int i = 0; i < nRows; i++) {
		MLSparseVector row = matrix.getRow(i);
		if (row == null) {
			rowIndex[i] = cur;
			continue;
		}
		int[] indexes = row.getIndexes();
		int rowNNZ = indexes.length;
		if (rowNNZ == 0) {
			rowIndex[i] = cur;
			continue;
		}
		float[] values = row.getValues();
		rowIndex[i] = cur;

		for (int j = 0; j < rowNNZ; j++, cur++) {
			indexesFlat[cur] = indexes[j];
			valuesFlat[cur] = values[j];
		}
	}
	rowIndex[nRows] = cur;
	return new DMatrix(rowIndex, indexesFlat, valuesFlat,
			DMatrix.SparseType.CSR, nCols);
}
 
示例12
@Nonnull
public DMatrix buildMatrix(@Nonnull float[] labels) throws XGBoostError {
    DMatrix matrix = new DMatrix(rowPointers.toArray(true), columnIndices.toArray(true),
        values.toArray(true), DMatrix.SparseType.CSR, maxNumColumns);
    matrix.setLabel(labels);
    return matrix;
}
 
示例13
@Nonnull
public static DMatrix createDMatrix(@Nonnull final List<LabeledPointWithRowId> data)
        throws XGBoostError {
    final List<LabeledPoint> points = new ArrayList<>(data.size());
    for (LabeledPointWithRowId d : data) {
        points.add(d);
    }
    return new DMatrix(points.iterator(), "");
}
 
示例14
@Nonnull
public static Booster createBooster(@Nonnull DMatrix matrix,
        @Nonnull Map<String, Object> params) throws NoSuchMethodException, XGBoostError,
        IllegalAccessException, InvocationTargetException, InstantiationException {
    Class<?>[] args = {Map.class, DMatrix[].class};
    Constructor<Booster> ctor = Booster.class.getDeclaredConstructor(args);
    ctor.setAccessible(true);
    return ctor.newInstance(new Object[] {params, new DMatrix[] {matrix}});
}
 
示例15
public static void close(@Nullable final DMatrix matrix) {
    if (matrix == null) {
        return;
    }
    try {
        matrix.dispose();
    } catch (Throwable e) {
        ;
    }
}
 
示例16
private void predictAndFlush(@Nonnull final Booster model,
        @Nonnull final List<LabeledPointWithRowId> rowBatch) throws HiveException {
    DMatrix testData = null;
    final float[][] predicted;
    try {
        testData = XGBoostUtils.createDMatrix(rowBatch);
        predicted = model.predict(testData);
    } catch (XGBoostError e) {
        throw new HiveException("Exception caused at prediction", e);
    } finally {
        XGBoostUtils.close(testData);
    }
    forwardPredicted(rowBatch, predicted);
    rowBatch.clear();
}
 
示例17
public void learn(PredictiveModel model, LearningData learningData, LearningData validData) {
    try {
        DMatrix dtrain = new DMatrix(new XGBoostIterator(learningData), null);
        Map<String, DMatrix> watches = new HashMap<>();
        if (validData != null) {
            watches.put("Validation", new DMatrix(new XGBoostIterator(validData), null));
        }
        Booster booster = XGBoost.train(dtrain, params, round, watches, null, null);
        XGBoostModel boostModel = (XGBoostModel) model;
        boostModel.setXGBooster(booster);
    } catch (XGBoostError e) {
        throw new BadRequestException(e);
    }
}
 
示例18
public static void trainXgbModel(String trainFile, String devFile, String modelAddr, int nround, int earlyStop,int nthreads ) throws IOException{
    final DMatrix trainMatrix = readCoNLL2DMatrix(trainFile);
    final DMatrix devMatrix = readCoNLL2DMatrix(devFile);
    try{
        Map<String, Object> params = new HashMap<String, Object>() {
            {
                put("nthread", nthreads);
                put("max_depth", 12);
                put("silent", 0);
                put("objective", "multi:softprob");
                put("colsample_bytree",0.90);
                put("colsample_bylevel",0.90);
                put("eta",0.2);
                put("subsample",0.95);
                put("lambda",1.0);

                // tree methods for regulation
                put("min_child_weight",5);
                put("max_leaves",128);

                // other parameters
                // "objective" -> "multi:softmax", "num_class" -> "6"

                put("eval_metric", "merror");
                put("tree_method","approx");
                put("num_class",tag2float.size());

                put("min_child_weight",5);
            }
        };
        Map<String, DMatrix> watches = new HashMap<String, DMatrix>() {
            {
                put("train", trainMatrix);
                put("dev",devMatrix);
            }
        };
        Booster booster = XGBoost.train(trainMatrix, params, nround, watches, null, null,null,earlyStop);
        OutputStream outstream = SmoothNLP.IOAdaptor.create(modelAddr);
        booster.saveModel(outstream);



    }catch(XGBoostError e){
        System.out.println(e);
    }
}
 
示例19
/**
     * Does the 'actual' initialising and building of the model, as opposed to experimental code
     * setup etc
     * @throws Exception 
     */    
    public void buildActualClassifer() throws Exception {
        if(tuneParameters)
            tuneHyperparameters();

        String objective = "multi:softprob"; 
//        String objective = numClasses == 2 ? "binary:logistic" : "multi:softprob";

        trainDMat = wekaInstancesToDMatrix(trainInsts);
        params = new HashMap<String, Object>();
        //todo: this is a mega hack to enforce 1 thread only on cluster (else bad juju).
        //fix some how at some point. 
        if (runSingleThreaded || System.getProperty("os.name").toLowerCase().contains("linux"))
            params.put("nthread", 1);
        // else == num processors by default

        //fixed params
        params.put("silent", 1);
        params.put("objective", objective);
        if(objective.contains("multi"))
            params.put("num_class", numClasses); //required with multiclass problems
        params.put("seed", seed);
        params.put("subsample", rowSubsampling);
        params.put("colsample_bytree", colSubsampling);

        //tunable params (numiterations passed directly to XGBoost.train(...)
        params.put("learning_rate", learningRate);
        params.put("max_depth", maxTreeDepth);
        params.put("min_child_weight", minChildWeight);

        watches = new HashMap<String, DMatrix>();
//        if (getDebugPrinting() || getDebug())
//        watches.put("train", trainDMat);

//        int earlyStopping = (int) Math.ceil(numIterations / 10.0); 
        //e.g numIts == 25    =>   stop after 3 increases in err 
        //    numIts == 250   =>   stop after 25 increases in err

//        booster = XGBoost.train(trainDMat, params, numIterations, watches, null, null, null, earlyStopping);
        booster = XGBoost.train(trainDMat, params, numIterations, watches, null, null);

    }
 
示例20
@Nonnull
public abstract DMatrix buildMatrix(@Nonnull float[] labels) throws XGBoostError;
 
示例21
@Override
public void close() throws HiveException {
    final Reporter reporter = getReporter();

    DMatrix dmatrix = null;
    Booster booster = null;
    try {
        dmatrix = matrixBuilder.buildMatrix(labels.toArray(true));
        this.matrixBuilder = null;
        this.labels = null;

        final int round = OptionUtils.getInt(params, "num_round");
        final int earlyStoppingRounds = OptionUtils.getInt(params, "num_early_stopping_rounds");
        if (earlyStoppingRounds > 0) {
            double validationRatio = OptionUtils.getDouble(params, "validation_ratio");
            long seed = OptionUtils.getLong(params, "seed");

            int numRows = (int) dmatrix.rowNum();
            int[] rows = MathUtils.permutation(numRows);
            ArrayUtils.shuffle(rows, new Random(seed));

            int numTest = (int) (numRows * validationRatio);
            DMatrix dtrain = null, dtest = null;
            try {
                dtest = dmatrix.slice(Arrays.copyOf(rows, numTest));
                dtrain = dmatrix.slice(Arrays.copyOfRange(rows, numTest, rows.length));
                booster = train(dtrain, dtest, round, earlyStoppingRounds, params, reporter);
            } finally {
                XGBoostUtils.close(dtrain);
                XGBoostUtils.close(dtest);
            }
        } else {
            booster = train(dmatrix, round, params, reporter);
        }
        onFinishTraining(booster);

        // Output the built model
        String modelId = generateUniqueModelId();
        Text predModel = XGBoostUtils.serializeBooster(booster);

        logger.info("model_id:" + modelId.toString() + ", size:" + predModel.getLength());
        forward(new Object[] {modelId, predModel});
    } catch (Throwable e) {
        throw new HiveException(e);
    } finally {
        XGBoostUtils.close(dmatrix);
        XGBoostUtils.close(booster);
    }
}
 
示例22
@Nonnull
private static Booster train(@Nonnull final DMatrix dtrain, @Nonnull final DMatrix dtest,
        @Nonnegative final int round, @Nonnegative final int earlyStoppingRounds,
        @Nonnull final Map<String, Object> params, @Nullable final Reporter reporter)
        throws NoSuchMethodException, IllegalAccessException, InvocationTargetException,
        InstantiationException, XGBoostError {
    final Counters.Counter iterCounter = (reporter == null) ? null
            : reporter.getCounter("hivemall.XGBoostTrainUDTF$Counter", "iteration");

    final Booster booster = XGBoostUtils.createBooster(dtrain, params);

    final boolean maximizeEvaluationMetrics =
            OptionUtils.getBoolean(params, "maximize_evaluation_metrics");
    float bestScore = maximizeEvaluationMetrics ? -Float.MAX_VALUE : Float.MAX_VALUE;
    int bestIteration = 0;

    final float[] metricsOut = new float[1];
    for (int iter = 0; iter < round; iter++) {
        reportProgress(reporter);
        setCounterValue(iterCounter, iter + 1);

        booster.update(dtrain, iter);

        String evalInfo =
                booster.evalSet(new DMatrix[] {dtest}, new String[] {"test"}, iter, metricsOut);
        logger.info(evalInfo);

        final float score = metricsOut[0];
        if (maximizeEvaluationMetrics) {
            // Update best score if the current score is better (no update when equal)
            if (score > bestScore) {
                bestScore = score;
                bestIteration = iter;
            }
        } else {
            if (score < bestScore) {
                bestScore = score;
                bestIteration = iter;
            }
        }

        if (shouldEarlyStop(earlyStoppingRounds, iter, bestIteration)) {
            logger.info(
                String.format("early stopping after %d rounds away from the best iteration",
                    earlyStoppingRounds));
            break;
        }
    }

    return booster;
}
 
示例23
@Test
public void testDenseMatrix() throws XGBoostError {
    DMatrix matrix = createDenseDMatrix();
    Assert.assertEquals(6, matrix.rowNum());
    matrix.dispose();
}
 
示例24
@Test
public void testSparseMatrix() throws XGBoostError {
    DMatrix matrix = createSparseDMatrix();
    Assert.assertEquals(6, matrix.rowNum());
    matrix.dispose();
}
 
示例25
public abstract DMatrix loadDatasetAsDMatrix() throws Exception;