Java源码示例:org.apache.flink.api.common.functions.RichMapPartitionFunction
示例1
/**
* Ensure that the user can pass a custom configuration object to the LocalEnvironment.
*/
@Test
public void testLocalEnvironmentWithConfig() throws Exception {
Configuration conf = new Configuration();
conf.setInteger(TaskManagerOptions.NUM_TASK_SLOTS, PARALLELISM);
final ExecutionEnvironment env = ExecutionEnvironment.createLocalEnvironment(conf);
env.setParallelism(ExecutionConfig.PARALLELISM_AUTO_MAX);
env.getConfig().disableSysoutLogging();
DataSet<Integer> result = env.createInput(new ParallelismDependentInputFormat())
.rebalance()
.mapPartition(new RichMapPartitionFunction<Integer, Integer>() {
@Override
public void mapPartition(Iterable<Integer> values, Collector<Integer> out) throws Exception {
out.collect(getRuntimeContext().getIndexOfThisSubtask());
}
});
List<Integer> resultCollection = result.collect();
assertEquals(PARALLELISM, resultCollection.size());
}
示例2
/**
* Ensure that the user can pass a custom configuration object to the LocalEnvironment.
*/
@Test
public void testLocalEnvironmentWithConfig() throws Exception {
Configuration conf = new Configuration();
conf.setInteger(TaskManagerOptions.NUM_TASK_SLOTS, PARALLELISM);
final ExecutionEnvironment env = ExecutionEnvironment.createLocalEnvironment(conf);
env.getConfig().disableSysoutLogging();
DataSet<Integer> result = env.createInput(new ParallelismDependentInputFormat())
.rebalance()
.mapPartition(new RichMapPartitionFunction<Integer, Integer>() {
@Override
public void mapPartition(Iterable<Integer> values, Collector<Integer> out) throws Exception {
out.collect(getRuntimeContext().getIndexOfThisSubtask());
}
});
List<Integer> resultCollection = result.collect();
assertEquals(PARALLELISM, resultCollection.size());
}
示例3
public <I, O> RichMapPartitionFunction<I, O> compile(MapPartitionsDescriptor<I, O> descriptor, FlinkExecutionContext fex){
FunctionDescriptor.ExtendedSerializableFunction<Iterable<I>, Iterable<O>> function =
(FunctionDescriptor.ExtendedSerializableFunction<Iterable<I>, Iterable<O>>)
descriptor.getJavaImplementation();
return new RichMapPartitionFunction<I, O>() {
@Override
public void mapPartition(Iterable<I> iterable, Collector<O> collector) throws Exception {
function.apply(iterable).forEach(
element -> {
collector.collect(element);
}
);
}
@Override
public void open(Configuration parameters) throws Exception {
function.open(fex);
}
};
}
示例4
/**
* Ensure that the user can pass a custom configuration object to the LocalEnvironment.
*/
@Test
public void testLocalEnvironmentWithConfig() throws Exception {
Configuration conf = new Configuration();
conf.setInteger(TaskManagerOptions.NUM_TASK_SLOTS, PARALLELISM);
final ExecutionEnvironment env = ExecutionEnvironment.createLocalEnvironment(conf);
DataSet<Integer> result = env.createInput(new ParallelismDependentInputFormat())
.rebalance()
.mapPartition(new RichMapPartitionFunction<Integer, Integer>() {
@Override
public void mapPartition(Iterable<Integer> values, Collector<Integer> out) throws Exception {
out.collect(getRuntimeContext().getIndexOfThisSubtask());
}
});
List<Integer> resultCollection = result.collect();
assertEquals(PARALLELISM, resultCollection.size());
}
示例5
/**
* Ensure that the program parallelism can be set even if the configuration is supplied.
*/
@Test
public void testUserSpecificParallelism() throws Exception {
Configuration config = new Configuration();
config.setString(AkkaOptions.STARTUP_TIMEOUT, VALID_STARTUP_TIMEOUT);
final URI restAddress = MINI_CLUSTER_RESOURCE.getRestAddres();
final String hostname = restAddress.getHost();
final int port = restAddress.getPort();
final ExecutionEnvironment env = ExecutionEnvironment.createRemoteEnvironment(
hostname,
port,
config
);
env.setParallelism(USER_DOP);
env.getConfig().disableSysoutLogging();
DataSet<Integer> result = env.createInput(new ParallelismDependentInputFormat())
.rebalance()
.mapPartition(new RichMapPartitionFunction<Integer, Integer>() {
@Override
public void mapPartition(Iterable<Integer> values, Collector<Integer> out) throws Exception {
out.collect(getRuntimeContext().getIndexOfThisSubtask());
}
});
List<Integer> resultCollection = result.collect();
assertEquals(USER_DOP, resultCollection.size());
}
示例6
/**
* Method that goes over all the elements in each partition in order to retrieve
* the total number of elements.
*
* @param input the DataSet received as input
* @return a data set containing tuples of subtask index, number of elements mappings.
*/
public static <T> DataSet<Tuple2<Integer, Long>> countElementsPerPartition(DataSet<T> input) {
return input.mapPartition(new RichMapPartitionFunction<T, Tuple2<Integer, Long>>() {
@Override
public void mapPartition(Iterable<T> values, Collector<Tuple2<Integer, Long>> out) throws Exception {
long counter = 0;
for (T value : values) {
counter++;
}
out.collect(new Tuple2<>(getRuntimeContext().getIndexOfThisSubtask(), counter));
}
});
}
示例7
/**
* Method that assigns a unique {@link Long} value to all elements in the input data set as described below.
* <ul>
* <li> a map function is applied to the input data set
* <li> each map task holds a counter c which is increased for each record
* <li> c is shifted by n bits where n = log2(number of parallel tasks)
* <li> to create a unique ID among all tasks, the task id is added to the counter
* <li> for each record, the resulting counter is collected
* </ul>
*
* @param input the input data set
* @return a data set of tuple 2 consisting of ids and initial values.
*/
public static <T> DataSet<Tuple2<Long, T>> zipWithUniqueId (DataSet <T> input) {
return input.mapPartition(new RichMapPartitionFunction<T, Tuple2<Long, T>>() {
long maxBitSize = getBitSize(Long.MAX_VALUE);
long shifter = 0;
long start = 0;
long taskId = 0;
long label = 0;
@Override
public void open(Configuration parameters) throws Exception {
super.open(parameters);
shifter = getBitSize(getRuntimeContext().getNumberOfParallelSubtasks() - 1);
taskId = getRuntimeContext().getIndexOfThisSubtask();
}
@Override
public void mapPartition(Iterable<T> values, Collector<Tuple2<Long, T>> out) throws Exception {
for (T value : values) {
label = (start << shifter) + taskId;
if (getBitSize(start) + shifter < maxBitSize) {
out.collect(new Tuple2<>(label, value));
start++;
} else {
throw new Exception("Exceeded Long value range while generating labels");
}
}
}
});
}
示例8
/**
* Ensure that the program parallelism can be set even if the configuration is supplied.
*/
@Test
public void testUserSpecificParallelism() throws Exception {
Configuration config = new Configuration();
config.setString(AkkaOptions.STARTUP_TIMEOUT, VALID_STARTUP_TIMEOUT);
final URI restAddress = MINI_CLUSTER_RESOURCE.getRestAddres();
final String hostname = restAddress.getHost();
final int port = restAddress.getPort();
final ExecutionEnvironment env = ExecutionEnvironment.createRemoteEnvironment(
hostname,
port,
config
);
env.setParallelism(USER_DOP);
env.getConfig().disableSysoutLogging();
DataSet<Integer> result = env.createInput(new ParallelismDependentInputFormat())
.rebalance()
.mapPartition(new RichMapPartitionFunction<Integer, Integer>() {
@Override
public void mapPartition(Iterable<Integer> values, Collector<Integer> out) throws Exception {
out.collect(getRuntimeContext().getIndexOfThisSubtask());
}
});
List<Integer> resultCollection = result.collect();
assertEquals(USER_DOP, resultCollection.size());
}
示例9
/**
* Method that goes over all the elements in each partition in order to retrieve
* the total number of elements.
*
* @param input the DataSet received as input
* @return a data set containing tuples of subtask index, number of elements mappings.
*/
public static <T> DataSet<Tuple2<Integer, Long>> countElementsPerPartition(DataSet<T> input) {
return input.mapPartition(new RichMapPartitionFunction<T, Tuple2<Integer, Long>>() {
@Override
public void mapPartition(Iterable<T> values, Collector<Tuple2<Integer, Long>> out) throws Exception {
long counter = 0;
for (T value : values) {
counter++;
}
out.collect(new Tuple2<>(getRuntimeContext().getIndexOfThisSubtask(), counter));
}
});
}
示例10
/**
* Method that assigns a unique {@link Long} value to all elements in the input data set as described below.
* <ul>
* <li> a map function is applied to the input data set
* <li> each map task holds a counter c which is increased for each record
* <li> c is shifted by n bits where n = log2(number of parallel tasks)
* <li> to create a unique ID among all tasks, the task id is added to the counter
* <li> for each record, the resulting counter is collected
* </ul>
*
* @param input the input data set
* @return a data set of tuple 2 consisting of ids and initial values.
*/
public static <T> DataSet<Tuple2<Long, T>> zipWithUniqueId (DataSet <T> input) {
return input.mapPartition(new RichMapPartitionFunction<T, Tuple2<Long, T>>() {
long maxBitSize = getBitSize(Long.MAX_VALUE);
long shifter = 0;
long start = 0;
long taskId = 0;
long label = 0;
@Override
public void open(Configuration parameters) throws Exception {
super.open(parameters);
shifter = getBitSize(getRuntimeContext().getNumberOfParallelSubtasks() - 1);
taskId = getRuntimeContext().getIndexOfThisSubtask();
}
@Override
public void mapPartition(Iterable<T> values, Collector<Tuple2<Long, T>> out) throws Exception {
for (T value : values) {
label = (start << shifter) + taskId;
if (getBitSize(start) + shifter < maxBitSize) {
out.collect(new Tuple2<>(label, value));
start++;
} else {
throw new Exception("Exceeded Long value range while generating labels");
}
}
}
});
}
示例11
/**
* Ensure that the program parallelism can be set even if the configuration is supplied.
*/
@Test
public void testUserSpecificParallelism() throws Exception {
Configuration config = new Configuration();
config.setString(AkkaOptions.STARTUP_TIMEOUT, VALID_STARTUP_TIMEOUT);
final URI restAddress = MINI_CLUSTER_RESOURCE.getRestAddres();
final String hostname = restAddress.getHost();
final int port = restAddress.getPort();
final ExecutionEnvironment env = ExecutionEnvironment.createRemoteEnvironment(
hostname,
port,
config
);
env.setParallelism(USER_DOP);
DataSet<Integer> result = env.createInput(new ParallelismDependentInputFormat())
.rebalance()
.mapPartition(new RichMapPartitionFunction<Integer, Integer>() {
@Override
public void mapPartition(Iterable<Integer> values, Collector<Integer> out) throws Exception {
out.collect(getRuntimeContext().getIndexOfThisSubtask());
}
});
List<Integer> resultCollection = result.collect();
assertEquals(USER_DOP, resultCollection.size());
}
示例12
/**
* Method that goes over all the elements in each partition in order to retrieve
* the total number of elements.
*
* @param input the DataSet received as input
* @return a data set containing tuples of subtask index, number of elements mappings.
*/
public static <T> DataSet<Tuple2<Integer, Long>> countElementsPerPartition(DataSet<T> input) {
return input.mapPartition(new RichMapPartitionFunction<T, Tuple2<Integer, Long>>() {
@Override
public void mapPartition(Iterable<T> values, Collector<Tuple2<Integer, Long>> out) throws Exception {
long counter = 0;
for (T value : values) {
counter++;
}
out.collect(new Tuple2<>(getRuntimeContext().getIndexOfThisSubtask(), counter));
}
});
}
示例13
/**
* Method that assigns a unique {@link Long} value to all elements in the input data set as described below.
* <ul>
* <li> a map function is applied to the input data set
* <li> each map task holds a counter c which is increased for each record
* <li> c is shifted by n bits where n = log2(number of parallel tasks)
* <li> to create a unique ID among all tasks, the task id is added to the counter
* <li> for each record, the resulting counter is collected
* </ul>
*
* @param input the input data set
* @return a data set of tuple 2 consisting of ids and initial values.
*/
public static <T> DataSet<Tuple2<Long, T>> zipWithUniqueId (DataSet <T> input) {
return input.mapPartition(new RichMapPartitionFunction<T, Tuple2<Long, T>>() {
long maxBitSize = getBitSize(Long.MAX_VALUE);
long shifter = 0;
long start = 0;
long taskId = 0;
long label = 0;
@Override
public void open(Configuration parameters) throws Exception {
super.open(parameters);
shifter = getBitSize(getRuntimeContext().getNumberOfParallelSubtasks() - 1);
taskId = getRuntimeContext().getIndexOfThisSubtask();
}
@Override
public void mapPartition(Iterable<T> values, Collector<Tuple2<Long, T>> out) throws Exception {
for (T value : values) {
label = (start << shifter) + taskId;
if (getBitSize(start) + shifter < maxBitSize) {
out.collect(new Tuple2<>(label, value));
start++;
} else {
throw new Exception("Exceeded Long value range while generating labels");
}
}
}
});
}
示例14
/**
* Method that assigns a unique {@link Long} value to all elements in the input data set. The generated values are
* consecutive.
*
* @param input the input data set
* @return a data set of tuple 2 consisting of consecutive ids and initial values.
*/
public static <T> DataSet<Tuple2<Long, T>> zipWithIndex(DataSet<T> input) {
DataSet<Tuple2<Integer, Long>> elementCount = countElementsPerPartition(input);
return input.mapPartition(new RichMapPartitionFunction<T, Tuple2<Long, T>>() {
long start = 0;
@Override
public void open(Configuration parameters) throws Exception {
super.open(parameters);
List<Tuple2<Integer, Long>> offsets = getRuntimeContext().getBroadcastVariableWithInitializer(
"counts",
new BroadcastVariableInitializer<Tuple2<Integer, Long>, List<Tuple2<Integer, Long>>>() {
@Override
public List<Tuple2<Integer, Long>> initializeBroadcastVariable(Iterable<Tuple2<Integer, Long>> data) {
// sort the list by task id to calculate the correct offset
List<Tuple2<Integer, Long>> sortedData = new ArrayList<>();
for (Tuple2<Integer, Long> datum : data) {
sortedData.add(datum);
}
Collections.sort(sortedData, new Comparator<Tuple2<Integer, Long>>() {
@Override
public int compare(Tuple2<Integer, Long> o1, Tuple2<Integer, Long> o2) {
return o1.f0.compareTo(o2.f0);
}
});
return sortedData;
}
});
// compute the offset for each partition
for (int i = 0; i < getRuntimeContext().getIndexOfThisSubtask(); i++) {
start += offsets.get(i).f1;
}
}
@Override
public void mapPartition(Iterable<T> values, Collector<Tuple2<Long, T>> out) throws Exception {
for (T value: values) {
out.collect(new Tuple2<>(start++, value));
}
}
}).withBroadcastSet(elementCount, "counts");
}
示例15
@Test
public void testMapPartitionWithRuntimeContext() {
try {
final String taskName = "Test Task";
final AtomicBoolean opened = new AtomicBoolean();
final AtomicBoolean closed = new AtomicBoolean();
final MapPartitionFunction<String, Integer> parser = new RichMapPartitionFunction<String, Integer>() {
@Override
public void open(Configuration parameters) throws Exception {
opened.set(true);
RuntimeContext ctx = getRuntimeContext();
assertEquals(0, ctx.getIndexOfThisSubtask());
assertEquals(1, ctx.getNumberOfParallelSubtasks());
assertEquals(taskName, ctx.getTaskName());
}
@Override
public void mapPartition(Iterable<String> values, Collector<Integer> out) {
for (String s : values) {
out.collect(Integer.parseInt(s));
}
}
@Override
public void close() throws Exception {
closed.set(true);
}
};
MapPartitionOperatorBase<String, Integer, MapPartitionFunction<String, Integer>> op =
new MapPartitionOperatorBase<String, Integer, MapPartitionFunction<String,Integer>>(
parser, new UnaryOperatorInformation<String, Integer>(BasicTypeInfo.STRING_TYPE_INFO, BasicTypeInfo.INT_TYPE_INFO), taskName);
List<String> input = new ArrayList<String>(asList("1", "2", "3", "4", "5", "6"));
final TaskInfo taskInfo = new TaskInfo(taskName, 1, 0, 1, 0);
ExecutionConfig executionConfig = new ExecutionConfig();
executionConfig.disableObjectReuse();
List<Integer> resultMutableSafe = op.executeOnCollections(input,
new RuntimeUDFContext(taskInfo, null, executionConfig,
new HashMap<String, Future<Path>>(),
new HashMap<String, Accumulator<?, ?>>(),
new UnregisteredMetricsGroup()),
executionConfig);
executionConfig.enableObjectReuse();
List<Integer> resultRegular = op.executeOnCollections(input,
new RuntimeUDFContext(taskInfo, null, executionConfig,
new HashMap<String, Future<Path>>(),
new HashMap<String, Accumulator<?, ?>>(),
new UnregisteredMetricsGroup()),
executionConfig);
assertEquals(asList(1, 2, 3, 4, 5, 6), resultMutableSafe);
assertEquals(asList(1, 2, 3, 4, 5, 6), resultRegular);
assertTrue(opened.get());
assertTrue(closed.get());
}
catch (Exception e) {
e.printStackTrace();
fail(e.getMessage());
}
}
示例16
/**
* Method that assigns a unique {@link Long} value to all elements in the input data set. The generated values are
* consecutive.
*
* @param input the input data set
* @return a data set of tuple 2 consisting of consecutive ids and initial values.
*/
public static <T> DataSet<Tuple2<Long, T>> zipWithIndex(DataSet<T> input) {
DataSet<Tuple2<Integer, Long>> elementCount = countElementsPerPartition(input);
return input.mapPartition(new RichMapPartitionFunction<T, Tuple2<Long, T>>() {
long start = 0;
@Override
public void open(Configuration parameters) throws Exception {
super.open(parameters);
List<Tuple2<Integer, Long>> offsets = getRuntimeContext().getBroadcastVariableWithInitializer(
"counts",
new BroadcastVariableInitializer<Tuple2<Integer, Long>, List<Tuple2<Integer, Long>>>() {
@Override
public List<Tuple2<Integer, Long>> initializeBroadcastVariable(Iterable<Tuple2<Integer, Long>> data) {
// sort the list by task id to calculate the correct offset
List<Tuple2<Integer, Long>> sortedData = new ArrayList<>();
for (Tuple2<Integer, Long> datum : data) {
sortedData.add(datum);
}
Collections.sort(sortedData, new Comparator<Tuple2<Integer, Long>>() {
@Override
public int compare(Tuple2<Integer, Long> o1, Tuple2<Integer, Long> o2) {
return o1.f0.compareTo(o2.f0);
}
});
return sortedData;
}
});
// compute the offset for each partition
for (int i = 0; i < getRuntimeContext().getIndexOfThisSubtask(); i++) {
start += offsets.get(i).f1;
}
}
@Override
public void mapPartition(Iterable<T> values, Collector<Tuple2<Long, T>> out) throws Exception {
for (T value: values) {
out.collect(new Tuple2<>(start++, value));
}
}
}).withBroadcastSet(elementCount, "counts");
}
示例17
@Test
public void testMapPartitionWithRuntimeContext() {
try {
final String taskName = "Test Task";
final AtomicBoolean opened = new AtomicBoolean();
final AtomicBoolean closed = new AtomicBoolean();
final MapPartitionFunction<String, Integer> parser = new RichMapPartitionFunction<String, Integer>() {
@Override
public void open(Configuration parameters) throws Exception {
opened.set(true);
RuntimeContext ctx = getRuntimeContext();
assertEquals(0, ctx.getIndexOfThisSubtask());
assertEquals(1, ctx.getNumberOfParallelSubtasks());
assertEquals(taskName, ctx.getTaskName());
}
@Override
public void mapPartition(Iterable<String> values, Collector<Integer> out) {
for (String s : values) {
out.collect(Integer.parseInt(s));
}
}
@Override
public void close() throws Exception {
closed.set(true);
}
};
MapPartitionOperatorBase<String, Integer, MapPartitionFunction<String, Integer>> op =
new MapPartitionOperatorBase<String, Integer, MapPartitionFunction<String,Integer>>(
parser, new UnaryOperatorInformation<String, Integer>(BasicTypeInfo.STRING_TYPE_INFO, BasicTypeInfo.INT_TYPE_INFO), taskName);
List<String> input = new ArrayList<String>(asList("1", "2", "3", "4", "5", "6"));
final TaskInfo taskInfo = new TaskInfo(taskName, 1, 0, 1, 0);
ExecutionConfig executionConfig = new ExecutionConfig();
executionConfig.disableObjectReuse();
List<Integer> resultMutableSafe = op.executeOnCollections(input,
new RuntimeUDFContext(taskInfo, null, executionConfig,
new HashMap<String, Future<Path>>(),
new HashMap<String, Accumulator<?, ?>>(),
new UnregisteredMetricsGroup()),
executionConfig);
executionConfig.enableObjectReuse();
List<Integer> resultRegular = op.executeOnCollections(input,
new RuntimeUDFContext(taskInfo, null, executionConfig,
new HashMap<String, Future<Path>>(),
new HashMap<String, Accumulator<?, ?>>(),
new UnregisteredMetricsGroup()),
executionConfig);
assertEquals(asList(1, 2, 3, 4, 5, 6), resultMutableSafe);
assertEquals(asList(1, 2, 3, 4, 5, 6), resultRegular);
assertTrue(opened.get());
assertTrue(closed.get());
}
catch (Exception e) {
e.printStackTrace();
fail(e.getMessage());
}
}
示例18
private DataSet<Tuple2<Integer, Row>> split(BatchOperator<?> data, int k) {
DataSet<Row> input = shuffle(data.getDataSet());
DataSet<Tuple2<Integer, Long>> counts = DataSetUtils.countElementsPerPartition(input);
return input
.mapPartition(new RichMapPartitionFunction<Row, Tuple2<Integer, Row>>() {
long taskStart = 0L;
long totalNumInstance = 0L;
@Override
public void open(Configuration parameters) throws Exception {
List<Tuple2<Integer, Long>> counts1 = getRuntimeContext().getBroadcastVariable("counts");
int taskId = getRuntimeContext().getIndexOfThisSubtask();
for (Tuple2<Integer, Long> cnt : counts1) {
if (taskId < cnt.f0) {
taskStart += cnt.f1;
}
totalNumInstance += cnt.f1;
}
}
@Override
public void mapPartition(Iterable<Row> values, Collector<Tuple2<Integer, Row>> out) throws Exception {
DistributedInfo distributedInfo = new DefaultDistributedInfo();
Tuple2<Integer, Long> split1 = new Tuple2<>(-1, -1L);
long lcnt = taskStart;
for (int i = 0; i <= k; ++i) {
long sp = distributedInfo.startPos(i, k, totalNumInstance);
long lrc = distributedInfo.localRowCnt(i, k, totalNumInstance);
if (taskStart < sp) {
split1.f0 = i - 1;
split1.f1 = distributedInfo.startPos(i - 1, k, totalNumInstance)
+ distributedInfo.localRowCnt(i - 1, k, totalNumInstance);
break;
}
if (taskStart == sp) {
split1.f0 = i;
split1.f1 = sp + lrc;
break;
}
}
for (Row val : values) {
if (lcnt >= split1.f1) {
split1.f0 += 1;
split1.f1 = distributedInfo.localRowCnt(split1.f0, k, totalNumInstance) + lcnt;
}
out.collect(Tuple2.of(split1.f0, val));
lcnt++;
}
}
}).withBroadcastSet(counts, "counts");
}
示例19
/**
* Matrix decomposition using ALS algorithm.
*
* @param inputs a dataset of user-item-rating tuples
* @return user factors and item factors.
*/
@Override
public AlsTrainBatchOp linkFrom(BatchOperator<?>... inputs) {
BatchOperator<?> in = checkAndGetFirst(inputs);
final String userColName = getUserCol();
final String itemColName = getItemCol();
final String rateColName = getRateCol();
final double lambda = getLambda();
final int rank = getRank();
final int numIter = getNumIter();
final boolean nonNegative = getNonnegative();
final boolean implicitPrefs = getImplicitPrefs();
final double alpha = getAlpha();
final int numMiniBatches = getNumBlocks();
final int userColIdx = TableUtil.findColIndexWithAssertAndHint(in.getColNames(), userColName);
final int itemColIdx = TableUtil.findColIndexWithAssertAndHint(in.getColNames(), itemColName);
final int rateColIdx = TableUtil.findColIndexWithAssertAndHint(in.getColNames(), rateColName);
// tuple3: userId, itemId, rating
DataSet<Tuple3<Long, Long, Float>> alsInput = in.getDataSet()
.map(new MapFunction<Row, Tuple3<Long, Long, Float>>() {
@Override
public Tuple3<Long, Long, Float> map(Row value) {
return new Tuple3<>(((Number) value.getField(userColIdx)).longValue(),
((Number) value.getField(itemColIdx)).longValue(),
((Number) value.getField(rateColIdx)).floatValue());
}
});
AlsTrain als = new AlsTrain(rank, numIter, lambda, implicitPrefs, alpha, numMiniBatches, nonNegative);
DataSet<Tuple3<Byte, Long, float[]>> factors = als.fit(alsInput);
DataSet<Row> output = factors.mapPartition(new RichMapPartitionFunction<Tuple3<Byte, Long, float[]>, Row>() {
@Override
public void mapPartition(Iterable<Tuple3<Byte, Long, float[]>> values, Collector<Row> out) {
new AlsModelDataConverter(userColName, itemColName).save(values, out);
}
});
this.setOutput(output, new AlsModelDataConverter(userColName, itemColName).getModelSchema());
return this;
}
示例20
/**
* Generate frequent sequence patterns using PrefixSpan algorithm.
*
* @return Frequent sequence patterns and their supports.
*/
public DataSet<Tuple2<int[], Integer>> run() {
final int parallelism = BatchOperator.getExecutionEnvironmentFromDataSets(sequences).getParallelism();
DataSet<Tuple2<Integer, int[]>> partitionedSequence = partitionSequence(sequences, itemCounts, parallelism);
final int maxLength = maxPatternLength;
return partitionedSequence
.partitionCustom(new Partitioner<Integer>() {
@Override
public int partition(Integer key, int numPartitions) {
return key % numPartitions;
}
}, 0)
.mapPartition(new RichMapPartitionFunction<Tuple2<Integer, int[]>, Tuple2<int[], Integer>>() {
@Override
public void mapPartition(Iterable<Tuple2<Integer, int[]>> values,
Collector<Tuple2<int[], Integer>> out) throws Exception {
List<Long> bc1 = getRuntimeContext().getBroadcastVariable("minSupportCnt");
List<Tuple2<Integer, Integer>> bc2 = getRuntimeContext().getBroadcastVariable("itemCounts");
int taskId = getRuntimeContext().getIndexOfThisSubtask();
long minSuppCnt = bc1.get(0);
List<int[]> allSeq = new ArrayList<>();
values.forEach(t -> allSeq.add(t.f1));
List<Postfix> initialPostfixes = new ArrayList<>(allSeq.size());
for (int i = 0; i < allSeq.size(); i++) {
initialPostfixes.add(new Postfix(i));
}
bc2.forEach(itemCount -> {
int item = itemCount.f0;
if (item % parallelism == taskId) {
generateFreqPattern(allSeq, initialPostfixes, item, minSuppCnt, maxLength, out);
}
});
}
})
.withBroadcastSet(this.minSupportCnt, "minSupportCnt")
.withBroadcastSet(this.itemCounts, "itemCounts")
.name("generate_freq_pattern");
}
示例21
/**
* Method that assigns a unique {@link Long} value to all elements in the input data set. The generated values are
* consecutive.
*
* @param input the input data set
* @return a data set of tuple 2 consisting of consecutive ids and initial values.
*/
public static <T> DataSet<Tuple2<Long, T>> zipWithIndex(DataSet<T> input) {
DataSet<Tuple2<Integer, Long>> elementCount = countElementsPerPartition(input);
return input.mapPartition(new RichMapPartitionFunction<T, Tuple2<Long, T>>() {
long start = 0;
@Override
public void open(Configuration parameters) throws Exception {
super.open(parameters);
List<Tuple2<Integer, Long>> offsets = getRuntimeContext().getBroadcastVariableWithInitializer(
"counts",
new BroadcastVariableInitializer<Tuple2<Integer, Long>, List<Tuple2<Integer, Long>>>() {
@Override
public List<Tuple2<Integer, Long>> initializeBroadcastVariable(Iterable<Tuple2<Integer, Long>> data) {
// sort the list by task id to calculate the correct offset
List<Tuple2<Integer, Long>> sortedData = new ArrayList<>();
for (Tuple2<Integer, Long> datum : data) {
sortedData.add(datum);
}
Collections.sort(sortedData, new Comparator<Tuple2<Integer, Long>>() {
@Override
public int compare(Tuple2<Integer, Long> o1, Tuple2<Integer, Long> o2) {
return o1.f0.compareTo(o2.f0);
}
});
return sortedData;
}
});
// compute the offset for each partition
for (int i = 0; i < getRuntimeContext().getIndexOfThisSubtask(); i++) {
start += offsets.get(i).f1;
}
}
@Override
public void mapPartition(Iterable<T> values, Collector<Tuple2<Long, T>> out) throws Exception {
for (T value: values) {
out.collect(new Tuple2<>(start++, value));
}
}
}).withBroadcastSet(elementCount, "counts");
}
示例22
@Test
public void testMapPartitionWithRuntimeContext() {
try {
final String taskName = "Test Task";
final AtomicBoolean opened = new AtomicBoolean();
final AtomicBoolean closed = new AtomicBoolean();
final MapPartitionFunction<String, Integer> parser = new RichMapPartitionFunction<String, Integer>() {
@Override
public void open(Configuration parameters) throws Exception {
opened.set(true);
RuntimeContext ctx = getRuntimeContext();
assertEquals(0, ctx.getIndexOfThisSubtask());
assertEquals(1, ctx.getNumberOfParallelSubtasks());
assertEquals(taskName, ctx.getTaskName());
}
@Override
public void mapPartition(Iterable<String> values, Collector<Integer> out) {
for (String s : values) {
out.collect(Integer.parseInt(s));
}
}
@Override
public void close() throws Exception {
closed.set(true);
}
};
MapPartitionOperatorBase<String, Integer, MapPartitionFunction<String, Integer>> op =
new MapPartitionOperatorBase<String, Integer, MapPartitionFunction<String,Integer>>(
parser, new UnaryOperatorInformation<String, Integer>(BasicTypeInfo.STRING_TYPE_INFO, BasicTypeInfo.INT_TYPE_INFO), taskName);
List<String> input = new ArrayList<String>(asList("1", "2", "3", "4", "5", "6"));
final TaskInfo taskInfo = new TaskInfo(taskName, 1, 0, 1, 0);
ExecutionConfig executionConfig = new ExecutionConfig();
executionConfig.disableObjectReuse();
List<Integer> resultMutableSafe = op.executeOnCollections(input,
new RuntimeUDFContext(taskInfo, null, executionConfig,
new HashMap<String, Future<Path>>(),
new HashMap<String, Accumulator<?, ?>>(),
new UnregisteredMetricsGroup()),
executionConfig);
executionConfig.enableObjectReuse();
List<Integer> resultRegular = op.executeOnCollections(input,
new RuntimeUDFContext(taskInfo, null, executionConfig,
new HashMap<String, Future<Path>>(),
new HashMap<String, Accumulator<?, ?>>(),
new UnregisteredMetricsGroup()),
executionConfig);
assertEquals(asList(1, 2, 3, 4, 5, 6), resultMutableSafe);
assertEquals(asList(1, 2, 3, 4, 5, 6), resultRegular);
assertTrue(opened.get());
assertTrue(closed.get());
}
catch (Exception e) {
e.printStackTrace();
fail(e.getMessage());
}
}