Java源码示例:ml.dmlc.xgboost4j.java.DMatrix
示例1
public static DMatrix wekaInstancesToDMatrix(Instances insts) throws XGBoostError {
int numRows = insts.numInstances();
int numCols = insts.numAttributes()-1;
float[] data = new float[numRows*numCols];
float[] labels = new float[numRows];
int ind = 0;
for (int i = 0; i < numRows; i++) {
for (int j = 0; j < numCols; j++)
data[ind++] = (float) insts.instance(i).value(j);
labels[i] = (float) insts.instance(i).classValue();
}
DMatrix dmat = new DMatrix(data, numRows, numCols);
dmat.setLabel(labels);
return dmat;
}
示例2
@Override
public DMatrix buildMatrix(@Nonnull float[] labels) throws XGBoostError {
final int numRows = rows.size();
if (labels.length != numRows) {
throw new XGBoostError(
String.format("labels.length does not match to nrows. labels.length=%d, nrows=%d",
labels.length, numRows));
}
final float[] data = new float[numRows * maxNumColumns];
Arrays.fill(data, Float.NaN);
for (int i = 0; i < numRows; i++) {
final float[] row = rows.get(i);
final int rowPtr = i * maxNumColumns;
for (int j = 0; j < row.length; j++) {
int ij = rowPtr + j;
data[ij] = row[j];
}
}
DMatrix matrix = new DMatrix(data, numRows, maxNumColumns, Float.NaN);
matrix.setLabel(labels);
return matrix;
}
示例3
@Nonnull
private static Booster train(@Nonnull final DMatrix dtrain, @Nonnegative final int round,
@Nonnull final Map<String, Object> params, @Nullable final Reporter reporter)
throws NoSuchMethodException, IllegalAccessException, InvocationTargetException,
InstantiationException, XGBoostError {
final Counters.Counter iterCounter = (reporter == null) ? null
: reporter.getCounter("hivemall.XGBoostTrainUDTF$Counter", "iteration");
final Booster booster = XGBoostUtils.createBooster(dtrain, params);
for (int iter = 0; iter < round; iter++) {
reportProgress(reporter);
setCounterValue(iterCounter, iter + 1);
booster.update(dtrain, iter);
}
return booster;
}
示例4
@Test
public void testCreateFromCSREx() throws XGBoostError {
// sparse matrix
// 1 0 2 3 0
// 4 0 2 3 5
// 3 1 2 5 0
DenseDMatrixBuilder builder = new DenseDMatrixBuilder(1024);
builder.nextRow(new float[] {1, 0, 2, 3, 0});
builder.nextRow(new float[] {4, 0, 2, 3, 5});
builder.nextRow(new float[] {3, 1, 2, 5, 0});
float[] label1 = new float[] {1, 0, 1};
DMatrix dmat1 = builder.buildMatrix(label1);
Assert.assertEquals(3, dmat1.rowNum());
float[] label2 = dmat1.getLabel();
Assert.assertArrayEquals(label1, label2, 0.f);
}
示例5
private static DMatrix createDenseDMatrix() throws XGBoostError {
/*
11 12 13 14 0 0
0 22 23 0 0 0
0 0 33 34 35 36
0 0 0 44 45 0
0 0 0 0 0 56
0 0 0 0 0 66
*/
DenseDMatrixBuilder builder = new DenseDMatrixBuilder(1024);
builder.nextRow(new String[] {"0:11", "1:12", "2:13", "3:14"});
builder.nextRow(new String[] {"1:22", "2:23"});
builder.nextRow(new String[] {"2:33", "3:34", "4:35", "5:36"});
builder.nextRow(new String[] {"3:44", "4:45"});
builder.nextRow(new String[] {"5:56"});
builder.nextRow(new String[] {"5:66"});
float[] labels = new float[6];
return builder.buildMatrix(labels);
}
示例6
private static DMatrix createSparseDMatrix() throws XGBoostError {
/*
11 12 13 14 0 0
0 22 23 0 0 0
0 0 33 34 35 36
0 0 0 44 45 0
0 0 0 0 0 56
0 0 0 0 0 66
*/
SparseDMatrixBuilder builder = new SparseDMatrixBuilder(1024);
builder.nextRow(new String[] {"0:11", "1:12", "2:13", "3:14"});
builder.nextRow(new String[] {"1:22", "2:23"});
builder.nextRow(new String[] {"2:33", "3:34", "4:35", "5:36"});
builder.nextRow(new String[] {"3:44", "4:45"});
builder.nextRow(new String[] {"5:56"});
builder.nextRow(new String[] {"5:66"});
float[] labels = new float[6];
return builder.buildMatrix(labels);
}
示例7
@Override
public DMatrix loadDatasetAsDMatrix() throws Exception {
final DMatrixBuilder builder = new SparseDMatrixBuilder(1024, false);
final FloatArrayList labels = new FloatArrayList(1024);
RowProcessor proc = new RowProcessor() {
@Override
public void handleRow(String[] splitted) throws Exception {
float label = Float.parseFloat(splitted[0]);
labels.add(label);
builder.nextRow(splitted, 1, splitted.length);
}
};
parse(proc);
return builder.buildMatrix(labels.toArray());
}
示例8
@Override
public DMatrix loadDatasetAsDMatrix() throws Exception {
final DMatrixBuilder builder = new DenseDMatrixBuilder(1024);
final FloatArrayList labels = new FloatArrayList(1024);
RowProcessor proc = new RowProcessor() {
@Override
public void handleRow(String[] splitted) throws Exception {
final float[] features = new float[34];
for (int i = 0; i <= 32; i++) {
features[i] = Float.parseFloat(splitted[i]);
}
features[33] = splitted[33].equals("?") ? 0.f : Float.parseFloat(splitted[33]);
int label = Integer.parseInt(splitted[34]) - 1;
labels.add(label);
builder.nextRow(features);
}
};
parse(proc);
return builder.buildMatrix(labels.toArray()).slice(sliceIndex);
}
示例9
public DependencyRelationship[] parse(List<SToken> stokens) throws XGBoostError{
CoNLLDependencyGraph cgraph = new CoNLLDependencyGraph(stokens);
// build ftrs
Float[][] pairFtrs = cgraph.buildAllFtrs();
float[] flattenPairFtrs = UtilFns.flatten2dFloatArray(pairFtrs);
int numRecords = pairFtrs.length;
int numFtrs = pairFtrs[0].length;
DMatrix dmatrix = new DMatrix(flattenPairFtrs,numRecords,numFtrs);
float[][] predictScores = this.edgeScoreModel.predict(dmatrix,false,SmoothNLP.XGBoost_DP_Edge_Model_Predict_Tree_Limit); // 调节treeLimit , 优化时间
float[] predictScoresFlatten = UtilFns.flatten2dFloatArray(predictScores);
float[][] edgeScores = new float[cgraph.size()][cgraph.size()];
for (int i =0; i<cgraph.size(); i++){
for (int j = 0; j<cgraph.size(); j++){
if (i!=j){ // 过滤一个token 自己依赖自己的情况
// todo: 待评估
edgeScores[i][j] = predictScoresFlatten[i*cgraph.size()+j];
}
}
}
cgraph.setEdgeScores(edgeScores);
return cgraph.parseDependencyRelationships(this.edgeTagModel);
}
示例10
public static void trainXgbModel(String trainFile, String devFile, String modelAddr, int nround, int negSampleRate, int earlyStop, int nthreads) throws IOException{
final DMatrix trainMatrix = readCoNLL2DMatrix(trainFile,negSampleRate);
final DMatrix devMatrix = readCoNLL2DMatrix(devFile,negSampleRate);
try{
Map<String, Object> params = new HashMap<String, Object>() {
{
put("nthread", nthreads);
put("max_depth", 16);
put("silent", 0);
put("objective", "binary:logistic");
put("colsample_bytree",0.95);
put("colsample_bylevel",0.95);
put("eta",0.2);
put("subsample",0.95);
put("lambda",0.2);
put("min_child_weight",5);
put("scale_pos_weight",negSampleRate);
// other parameters
// "objective" -> "multi:softmax", "num_class" -> "6"
put("eval_metric", "logloss");
put("tree_method","approx");
}
};
Map<String, DMatrix> watches = new HashMap<String, DMatrix>() {
{
put("train", trainMatrix);
put("dev",devMatrix);
}
};
Booster booster = XGBoost.train(trainMatrix, params, nround, watches, null, null,null,earlyStop);
OutputStream outstream = SmoothNLP.IOAdaptor.create(modelAddr);
booster.saveModel(outstream);
}catch(XGBoostError e){
System.out.println(e);
}
}
示例11
public static DMatrix toDMatrix(final MLSparseMatrix matrix)
throws XGBoostError {
final int nnz = (int) matrix.getNNZ();
final int nRows = matrix.getNRows();
final int nCols = matrix.getNCols();
long[] rowIndex = new long[nRows + 1];
int[] indexesFlat = new int[nnz];
float[] valuesFlat = new float[nnz];
int cur = 0;
for (int i = 0; i < nRows; i++) {
MLSparseVector row = matrix.getRow(i);
if (row == null) {
rowIndex[i] = cur;
continue;
}
int[] indexes = row.getIndexes();
int rowNNZ = indexes.length;
if (rowNNZ == 0) {
rowIndex[i] = cur;
continue;
}
float[] values = row.getValues();
rowIndex[i] = cur;
for (int j = 0; j < rowNNZ; j++, cur++) {
indexesFlat[cur] = indexes[j];
valuesFlat[cur] = values[j];
}
}
rowIndex[nRows] = cur;
return new DMatrix(rowIndex, indexesFlat, valuesFlat,
DMatrix.SparseType.CSR, nCols);
}
示例12
@Nonnull
public DMatrix buildMatrix(@Nonnull float[] labels) throws XGBoostError {
DMatrix matrix = new DMatrix(rowPointers.toArray(true), columnIndices.toArray(true),
values.toArray(true), DMatrix.SparseType.CSR, maxNumColumns);
matrix.setLabel(labels);
return matrix;
}
示例13
@Nonnull
public static DMatrix createDMatrix(@Nonnull final List<LabeledPointWithRowId> data)
throws XGBoostError {
final List<LabeledPoint> points = new ArrayList<>(data.size());
for (LabeledPointWithRowId d : data) {
points.add(d);
}
return new DMatrix(points.iterator(), "");
}
示例14
@Nonnull
public static Booster createBooster(@Nonnull DMatrix matrix,
@Nonnull Map<String, Object> params) throws NoSuchMethodException, XGBoostError,
IllegalAccessException, InvocationTargetException, InstantiationException {
Class<?>[] args = {Map.class, DMatrix[].class};
Constructor<Booster> ctor = Booster.class.getDeclaredConstructor(args);
ctor.setAccessible(true);
return ctor.newInstance(new Object[] {params, new DMatrix[] {matrix}});
}
示例15
public static void close(@Nullable final DMatrix matrix) {
if (matrix == null) {
return;
}
try {
matrix.dispose();
} catch (Throwable e) {
;
}
}
示例16
private void predictAndFlush(@Nonnull final Booster model,
@Nonnull final List<LabeledPointWithRowId> rowBatch) throws HiveException {
DMatrix testData = null;
final float[][] predicted;
try {
testData = XGBoostUtils.createDMatrix(rowBatch);
predicted = model.predict(testData);
} catch (XGBoostError e) {
throw new HiveException("Exception caused at prediction", e);
} finally {
XGBoostUtils.close(testData);
}
forwardPredicted(rowBatch, predicted);
rowBatch.clear();
}
示例17
public void learn(PredictiveModel model, LearningData learningData, LearningData validData) {
try {
DMatrix dtrain = new DMatrix(new XGBoostIterator(learningData), null);
Map<String, DMatrix> watches = new HashMap<>();
if (validData != null) {
watches.put("Validation", new DMatrix(new XGBoostIterator(validData), null));
}
Booster booster = XGBoost.train(dtrain, params, round, watches, null, null);
XGBoostModel boostModel = (XGBoostModel) model;
boostModel.setXGBooster(booster);
} catch (XGBoostError e) {
throw new BadRequestException(e);
}
}
示例18
public static void trainXgbModel(String trainFile, String devFile, String modelAddr, int nround, int earlyStop,int nthreads ) throws IOException{
final DMatrix trainMatrix = readCoNLL2DMatrix(trainFile);
final DMatrix devMatrix = readCoNLL2DMatrix(devFile);
try{
Map<String, Object> params = new HashMap<String, Object>() {
{
put("nthread", nthreads);
put("max_depth", 12);
put("silent", 0);
put("objective", "multi:softprob");
put("colsample_bytree",0.90);
put("colsample_bylevel",0.90);
put("eta",0.2);
put("subsample",0.95);
put("lambda",1.0);
// tree methods for regulation
put("min_child_weight",5);
put("max_leaves",128);
// other parameters
// "objective" -> "multi:softmax", "num_class" -> "6"
put("eval_metric", "merror");
put("tree_method","approx");
put("num_class",tag2float.size());
put("min_child_weight",5);
}
};
Map<String, DMatrix> watches = new HashMap<String, DMatrix>() {
{
put("train", trainMatrix);
put("dev",devMatrix);
}
};
Booster booster = XGBoost.train(trainMatrix, params, nround, watches, null, null,null,earlyStop);
OutputStream outstream = SmoothNLP.IOAdaptor.create(modelAddr);
booster.saveModel(outstream);
}catch(XGBoostError e){
System.out.println(e);
}
}
示例19
/**
* Does the 'actual' initialising and building of the model, as opposed to experimental code
* setup etc
* @throws Exception
*/
public void buildActualClassifer() throws Exception {
if(tuneParameters)
tuneHyperparameters();
String objective = "multi:softprob";
// String objective = numClasses == 2 ? "binary:logistic" : "multi:softprob";
trainDMat = wekaInstancesToDMatrix(trainInsts);
params = new HashMap<String, Object>();
//todo: this is a mega hack to enforce 1 thread only on cluster (else bad juju).
//fix some how at some point.
if (runSingleThreaded || System.getProperty("os.name").toLowerCase().contains("linux"))
params.put("nthread", 1);
// else == num processors by default
//fixed params
params.put("silent", 1);
params.put("objective", objective);
if(objective.contains("multi"))
params.put("num_class", numClasses); //required with multiclass problems
params.put("seed", seed);
params.put("subsample", rowSubsampling);
params.put("colsample_bytree", colSubsampling);
//tunable params (numiterations passed directly to XGBoost.train(...)
params.put("learning_rate", learningRate);
params.put("max_depth", maxTreeDepth);
params.put("min_child_weight", minChildWeight);
watches = new HashMap<String, DMatrix>();
// if (getDebugPrinting() || getDebug())
// watches.put("train", trainDMat);
// int earlyStopping = (int) Math.ceil(numIterations / 10.0);
//e.g numIts == 25 => stop after 3 increases in err
// numIts == 250 => stop after 25 increases in err
// booster = XGBoost.train(trainDMat, params, numIterations, watches, null, null, null, earlyStopping);
booster = XGBoost.train(trainDMat, params, numIterations, watches, null, null);
}
示例20
@Nonnull
public abstract DMatrix buildMatrix(@Nonnull float[] labels) throws XGBoostError;
示例21
@Override
public void close() throws HiveException {
final Reporter reporter = getReporter();
DMatrix dmatrix = null;
Booster booster = null;
try {
dmatrix = matrixBuilder.buildMatrix(labels.toArray(true));
this.matrixBuilder = null;
this.labels = null;
final int round = OptionUtils.getInt(params, "num_round");
final int earlyStoppingRounds = OptionUtils.getInt(params, "num_early_stopping_rounds");
if (earlyStoppingRounds > 0) {
double validationRatio = OptionUtils.getDouble(params, "validation_ratio");
long seed = OptionUtils.getLong(params, "seed");
int numRows = (int) dmatrix.rowNum();
int[] rows = MathUtils.permutation(numRows);
ArrayUtils.shuffle(rows, new Random(seed));
int numTest = (int) (numRows * validationRatio);
DMatrix dtrain = null, dtest = null;
try {
dtest = dmatrix.slice(Arrays.copyOf(rows, numTest));
dtrain = dmatrix.slice(Arrays.copyOfRange(rows, numTest, rows.length));
booster = train(dtrain, dtest, round, earlyStoppingRounds, params, reporter);
} finally {
XGBoostUtils.close(dtrain);
XGBoostUtils.close(dtest);
}
} else {
booster = train(dmatrix, round, params, reporter);
}
onFinishTraining(booster);
// Output the built model
String modelId = generateUniqueModelId();
Text predModel = XGBoostUtils.serializeBooster(booster);
logger.info("model_id:" + modelId.toString() + ", size:" + predModel.getLength());
forward(new Object[] {modelId, predModel});
} catch (Throwable e) {
throw new HiveException(e);
} finally {
XGBoostUtils.close(dmatrix);
XGBoostUtils.close(booster);
}
}
示例22
@Nonnull
private static Booster train(@Nonnull final DMatrix dtrain, @Nonnull final DMatrix dtest,
@Nonnegative final int round, @Nonnegative final int earlyStoppingRounds,
@Nonnull final Map<String, Object> params, @Nullable final Reporter reporter)
throws NoSuchMethodException, IllegalAccessException, InvocationTargetException,
InstantiationException, XGBoostError {
final Counters.Counter iterCounter = (reporter == null) ? null
: reporter.getCounter("hivemall.XGBoostTrainUDTF$Counter", "iteration");
final Booster booster = XGBoostUtils.createBooster(dtrain, params);
final boolean maximizeEvaluationMetrics =
OptionUtils.getBoolean(params, "maximize_evaluation_metrics");
float bestScore = maximizeEvaluationMetrics ? -Float.MAX_VALUE : Float.MAX_VALUE;
int bestIteration = 0;
final float[] metricsOut = new float[1];
for (int iter = 0; iter < round; iter++) {
reportProgress(reporter);
setCounterValue(iterCounter, iter + 1);
booster.update(dtrain, iter);
String evalInfo =
booster.evalSet(new DMatrix[] {dtest}, new String[] {"test"}, iter, metricsOut);
logger.info(evalInfo);
final float score = metricsOut[0];
if (maximizeEvaluationMetrics) {
// Update best score if the current score is better (no update when equal)
if (score > bestScore) {
bestScore = score;
bestIteration = iter;
}
} else {
if (score < bestScore) {
bestScore = score;
bestIteration = iter;
}
}
if (shouldEarlyStop(earlyStoppingRounds, iter, bestIteration)) {
logger.info(
String.format("early stopping after %d rounds away from the best iteration",
earlyStoppingRounds));
break;
}
}
return booster;
}
示例23
@Test
public void testDenseMatrix() throws XGBoostError {
DMatrix matrix = createDenseDMatrix();
Assert.assertEquals(6, matrix.rowNum());
matrix.dispose();
}
示例24
@Test
public void testSparseMatrix() throws XGBoostError {
DMatrix matrix = createSparseDMatrix();
Assert.assertEquals(6, matrix.rowNum());
matrix.dispose();
}
示例25
public abstract DMatrix loadDatasetAsDMatrix() throws Exception;