Java源码示例:org.apache.flink.api.common.functions.RichMapPartitionFunction

示例1
/**
 * Ensure that the user can pass a custom configuration object to the LocalEnvironment.
 */
@Test
public void testLocalEnvironmentWithConfig() throws Exception {
	Configuration conf = new Configuration();
	conf.setInteger(TaskManagerOptions.NUM_TASK_SLOTS, PARALLELISM);

	final ExecutionEnvironment env = ExecutionEnvironment.createLocalEnvironment(conf);
	env.setParallelism(ExecutionConfig.PARALLELISM_AUTO_MAX);
	env.getConfig().disableSysoutLogging();

	DataSet<Integer> result = env.createInput(new ParallelismDependentInputFormat())
			.rebalance()
			.mapPartition(new RichMapPartitionFunction<Integer, Integer>() {
				@Override
				public void mapPartition(Iterable<Integer> values, Collector<Integer> out) throws Exception {
					out.collect(getRuntimeContext().getIndexOfThisSubtask());
				}
			});
	List<Integer> resultCollection = result.collect();
	assertEquals(PARALLELISM, resultCollection.size());
}
 
示例2
/**
 * Ensure that the user can pass a custom configuration object to the LocalEnvironment.
 */
@Test
public void testLocalEnvironmentWithConfig() throws Exception {
	Configuration conf = new Configuration();
	conf.setInteger(TaskManagerOptions.NUM_TASK_SLOTS, PARALLELISM);

	final ExecutionEnvironment env = ExecutionEnvironment.createLocalEnvironment(conf);
	env.getConfig().disableSysoutLogging();

	DataSet<Integer> result = env.createInput(new ParallelismDependentInputFormat())
			.rebalance()
			.mapPartition(new RichMapPartitionFunction<Integer, Integer>() {
				@Override
				public void mapPartition(Iterable<Integer> values, Collector<Integer> out) throws Exception {
					out.collect(getRuntimeContext().getIndexOfThisSubtask());
				}
			});
	List<Integer> resultCollection = result.collect();
	assertEquals(PARALLELISM, resultCollection.size());
}
 
示例3
public <I, O> RichMapPartitionFunction<I, O> compile(MapPartitionsDescriptor<I, O> descriptor, FlinkExecutionContext fex){
    FunctionDescriptor.ExtendedSerializableFunction<Iterable<I>, Iterable<O>> function =
            (FunctionDescriptor.ExtendedSerializableFunction<Iterable<I>, Iterable<O>>)
                    descriptor.getJavaImplementation();
    return new RichMapPartitionFunction<I, O>() {
        @Override
        public void mapPartition(Iterable<I> iterable, Collector<O> collector) throws Exception {
            function.apply(iterable).forEach(
                    element -> {
                        collector.collect(element);
                    }
            );

        }
        @Override
        public void open(Configuration parameters) throws Exception {
            function.open(fex);
        }
    };
}
 
示例4
/**
 * Ensure that the user can pass a custom configuration object to the LocalEnvironment.
 */
@Test
public void testLocalEnvironmentWithConfig() throws Exception {
	Configuration conf = new Configuration();
	conf.setInteger(TaskManagerOptions.NUM_TASK_SLOTS, PARALLELISM);

	final ExecutionEnvironment env = ExecutionEnvironment.createLocalEnvironment(conf);

	DataSet<Integer> result = env.createInput(new ParallelismDependentInputFormat())
			.rebalance()
			.mapPartition(new RichMapPartitionFunction<Integer, Integer>() {
				@Override
				public void mapPartition(Iterable<Integer> values, Collector<Integer> out) throws Exception {
					out.collect(getRuntimeContext().getIndexOfThisSubtask());
				}
			});
	List<Integer> resultCollection = result.collect();
	assertEquals(PARALLELISM, resultCollection.size());
}
 
示例5
/**
 * Ensure that the program parallelism can be set even if the configuration is supplied.
 */
@Test
public void testUserSpecificParallelism() throws Exception {
	Configuration config = new Configuration();
	config.setString(AkkaOptions.STARTUP_TIMEOUT, VALID_STARTUP_TIMEOUT);

	final URI restAddress = MINI_CLUSTER_RESOURCE.getRestAddres();
	final String hostname = restAddress.getHost();
	final int port = restAddress.getPort();

	final ExecutionEnvironment env = ExecutionEnvironment.createRemoteEnvironment(
			hostname,
			port,
			config
	);
	env.setParallelism(USER_DOP);
	env.getConfig().disableSysoutLogging();

	DataSet<Integer> result = env.createInput(new ParallelismDependentInputFormat())
			.rebalance()
			.mapPartition(new RichMapPartitionFunction<Integer, Integer>() {
				@Override
				public void mapPartition(Iterable<Integer> values, Collector<Integer> out) throws Exception {
					out.collect(getRuntimeContext().getIndexOfThisSubtask());
				}
			});
	List<Integer> resultCollection = result.collect();
	assertEquals(USER_DOP, resultCollection.size());
}
 
示例6
/**
 * Method that goes over all the elements in each partition in order to retrieve
 * the total number of elements.
 *
 * @param input the DataSet received as input
 * @return a data set containing tuples of subtask index, number of elements mappings.
 */
public static <T> DataSet<Tuple2<Integer, Long>> countElementsPerPartition(DataSet<T> input) {
	return input.mapPartition(new RichMapPartitionFunction<T, Tuple2<Integer, Long>>() {
		@Override
		public void mapPartition(Iterable<T> values, Collector<Tuple2<Integer, Long>> out) throws Exception {
			long counter = 0;
			for (T value : values) {
				counter++;
			}
			out.collect(new Tuple2<>(getRuntimeContext().getIndexOfThisSubtask(), counter));
		}
	});
}
 
示例7
/**
 * Method that assigns a unique {@link Long} value to all elements in the input data set as described below.
 * <ul>
 *  <li> a map function is applied to the input data set
 *  <li> each map task holds a counter c which is increased for each record
 *  <li> c is shifted by n bits where n = log2(number of parallel tasks)
 * 	<li> to create a unique ID among all tasks, the task id is added to the counter
 * 	<li> for each record, the resulting counter is collected
 * </ul>
 *
 * @param input the input data set
 * @return a data set of tuple 2 consisting of ids and initial values.
 */
public static <T> DataSet<Tuple2<Long, T>> zipWithUniqueId (DataSet <T> input) {

	return input.mapPartition(new RichMapPartitionFunction<T, Tuple2<Long, T>>() {

		long maxBitSize = getBitSize(Long.MAX_VALUE);
		long shifter = 0;
		long start = 0;
		long taskId = 0;
		long label = 0;

		@Override
		public void open(Configuration parameters) throws Exception {
			super.open(parameters);
			shifter = getBitSize(getRuntimeContext().getNumberOfParallelSubtasks() - 1);
			taskId = getRuntimeContext().getIndexOfThisSubtask();
		}

		@Override
		public void mapPartition(Iterable<T> values, Collector<Tuple2<Long, T>> out) throws Exception {
			for (T value : values) {
				label = (start << shifter) + taskId;

				if (getBitSize(start) + shifter < maxBitSize) {
					out.collect(new Tuple2<>(label, value));
					start++;
				} else {
					throw new Exception("Exceeded Long value range while generating labels");
				}
			}
		}
	});
}
 
示例8
/**
 * Ensure that the program parallelism can be set even if the configuration is supplied.
 */
@Test
public void testUserSpecificParallelism() throws Exception {
	Configuration config = new Configuration();
	config.setString(AkkaOptions.STARTUP_TIMEOUT, VALID_STARTUP_TIMEOUT);

	final URI restAddress = MINI_CLUSTER_RESOURCE.getRestAddres();
	final String hostname = restAddress.getHost();
	final int port = restAddress.getPort();

	final ExecutionEnvironment env = ExecutionEnvironment.createRemoteEnvironment(
			hostname,
			port,
			config
	);
	env.setParallelism(USER_DOP);
	env.getConfig().disableSysoutLogging();

	DataSet<Integer> result = env.createInput(new ParallelismDependentInputFormat())
			.rebalance()
			.mapPartition(new RichMapPartitionFunction<Integer, Integer>() {
				@Override
				public void mapPartition(Iterable<Integer> values, Collector<Integer> out) throws Exception {
					out.collect(getRuntimeContext().getIndexOfThisSubtask());
				}
			});
	List<Integer> resultCollection = result.collect();
	assertEquals(USER_DOP, resultCollection.size());
}
 
示例9
/**
 * Method that goes over all the elements in each partition in order to retrieve
 * the total number of elements.
 *
 * @param input the DataSet received as input
 * @return a data set containing tuples of subtask index, number of elements mappings.
 */
public static <T> DataSet<Tuple2<Integer, Long>> countElementsPerPartition(DataSet<T> input) {
	return input.mapPartition(new RichMapPartitionFunction<T, Tuple2<Integer, Long>>() {
		@Override
		public void mapPartition(Iterable<T> values, Collector<Tuple2<Integer, Long>> out) throws Exception {
			long counter = 0;
			for (T value : values) {
				counter++;
			}
			out.collect(new Tuple2<>(getRuntimeContext().getIndexOfThisSubtask(), counter));
		}
	});
}
 
示例10
/**
 * Method that assigns a unique {@link Long} value to all elements in the input data set as described below.
 * <ul>
 *  <li> a map function is applied to the input data set
 *  <li> each map task holds a counter c which is increased for each record
 *  <li> c is shifted by n bits where n = log2(number of parallel tasks)
 * 	<li> to create a unique ID among all tasks, the task id is added to the counter
 * 	<li> for each record, the resulting counter is collected
 * </ul>
 *
 * @param input the input data set
 * @return a data set of tuple 2 consisting of ids and initial values.
 */
public static <T> DataSet<Tuple2<Long, T>> zipWithUniqueId (DataSet <T> input) {

	return input.mapPartition(new RichMapPartitionFunction<T, Tuple2<Long, T>>() {

		long maxBitSize = getBitSize(Long.MAX_VALUE);
		long shifter = 0;
		long start = 0;
		long taskId = 0;
		long label = 0;

		@Override
		public void open(Configuration parameters) throws Exception {
			super.open(parameters);
			shifter = getBitSize(getRuntimeContext().getNumberOfParallelSubtasks() - 1);
			taskId = getRuntimeContext().getIndexOfThisSubtask();
		}

		@Override
		public void mapPartition(Iterable<T> values, Collector<Tuple2<Long, T>> out) throws Exception {
			for (T value : values) {
				label = (start << shifter) + taskId;

				if (getBitSize(start) + shifter < maxBitSize) {
					out.collect(new Tuple2<>(label, value));
					start++;
				} else {
					throw new Exception("Exceeded Long value range while generating labels");
				}
			}
		}
	});
}
 
示例11
/**
 * Ensure that the program parallelism can be set even if the configuration is supplied.
 */
@Test
public void testUserSpecificParallelism() throws Exception {
	Configuration config = new Configuration();
	config.setString(AkkaOptions.STARTUP_TIMEOUT, VALID_STARTUP_TIMEOUT);

	final URI restAddress = MINI_CLUSTER_RESOURCE.getRestAddres();
	final String hostname = restAddress.getHost();
	final int port = restAddress.getPort();

	final ExecutionEnvironment env = ExecutionEnvironment.createRemoteEnvironment(
			hostname,
			port,
			config
	);
	env.setParallelism(USER_DOP);

	DataSet<Integer> result = env.createInput(new ParallelismDependentInputFormat())
			.rebalance()
			.mapPartition(new RichMapPartitionFunction<Integer, Integer>() {
				@Override
				public void mapPartition(Iterable<Integer> values, Collector<Integer> out) throws Exception {
					out.collect(getRuntimeContext().getIndexOfThisSubtask());
				}
			});
	List<Integer> resultCollection = result.collect();
	assertEquals(USER_DOP, resultCollection.size());
}
 
示例12
/**
 * Method that goes over all the elements in each partition in order to retrieve
 * the total number of elements.
 *
 * @param input the DataSet received as input
 * @return a data set containing tuples of subtask index, number of elements mappings.
 */
public static <T> DataSet<Tuple2<Integer, Long>> countElementsPerPartition(DataSet<T> input) {
	return input.mapPartition(new RichMapPartitionFunction<T, Tuple2<Integer, Long>>() {
		@Override
		public void mapPartition(Iterable<T> values, Collector<Tuple2<Integer, Long>> out) throws Exception {
			long counter = 0;
			for (T value : values) {
				counter++;
			}
			out.collect(new Tuple2<>(getRuntimeContext().getIndexOfThisSubtask(), counter));
		}
	});
}
 
示例13
/**
 * Method that assigns a unique {@link Long} value to all elements in the input data set as described below.
 * <ul>
 *  <li> a map function is applied to the input data set
 *  <li> each map task holds a counter c which is increased for each record
 *  <li> c is shifted by n bits where n = log2(number of parallel tasks)
 * 	<li> to create a unique ID among all tasks, the task id is added to the counter
 * 	<li> for each record, the resulting counter is collected
 * </ul>
 *
 * @param input the input data set
 * @return a data set of tuple 2 consisting of ids and initial values.
 */
public static <T> DataSet<Tuple2<Long, T>> zipWithUniqueId (DataSet <T> input) {

	return input.mapPartition(new RichMapPartitionFunction<T, Tuple2<Long, T>>() {

		long maxBitSize = getBitSize(Long.MAX_VALUE);
		long shifter = 0;
		long start = 0;
		long taskId = 0;
		long label = 0;

		@Override
		public void open(Configuration parameters) throws Exception {
			super.open(parameters);
			shifter = getBitSize(getRuntimeContext().getNumberOfParallelSubtasks() - 1);
			taskId = getRuntimeContext().getIndexOfThisSubtask();
		}

		@Override
		public void mapPartition(Iterable<T> values, Collector<Tuple2<Long, T>> out) throws Exception {
			for (T value : values) {
				label = (start << shifter) + taskId;

				if (getBitSize(start) + shifter < maxBitSize) {
					out.collect(new Tuple2<>(label, value));
					start++;
				} else {
					throw new Exception("Exceeded Long value range while generating labels");
				}
			}
		}
	});
}
 
示例14
/**
 * Method that assigns a unique {@link Long} value to all elements in the input data set. The generated values are
 * consecutive.
 *
 * @param input the input data set
 * @return a data set of tuple 2 consisting of consecutive ids and initial values.
 */
public static <T> DataSet<Tuple2<Long, T>> zipWithIndex(DataSet<T> input) {

	DataSet<Tuple2<Integer, Long>> elementCount = countElementsPerPartition(input);

	return input.mapPartition(new RichMapPartitionFunction<T, Tuple2<Long, T>>() {

		long start = 0;

		@Override
		public void open(Configuration parameters) throws Exception {
			super.open(parameters);

			List<Tuple2<Integer, Long>> offsets = getRuntimeContext().getBroadcastVariableWithInitializer(
					"counts",
					new BroadcastVariableInitializer<Tuple2<Integer, Long>, List<Tuple2<Integer, Long>>>() {
						@Override
						public List<Tuple2<Integer, Long>> initializeBroadcastVariable(Iterable<Tuple2<Integer, Long>> data) {
							// sort the list by task id to calculate the correct offset
							List<Tuple2<Integer, Long>> sortedData = new ArrayList<>();
							for (Tuple2<Integer, Long> datum : data) {
								sortedData.add(datum);
							}
							Collections.sort(sortedData, new Comparator<Tuple2<Integer, Long>>() {
								@Override
								public int compare(Tuple2<Integer, Long> o1, Tuple2<Integer, Long> o2) {
									return o1.f0.compareTo(o2.f0);
								}
							});
							return sortedData;
						}
					});

			// compute the offset for each partition
			for (int i = 0; i < getRuntimeContext().getIndexOfThisSubtask(); i++) {
				start += offsets.get(i).f1;
			}
		}

		@Override
		public void mapPartition(Iterable<T> values, Collector<Tuple2<Long, T>> out) throws Exception {
			for (T value: values) {
				out.collect(new Tuple2<>(start++, value));
			}
		}
	}).withBroadcastSet(elementCount, "counts");
}
 
示例15
@Test
public void testMapPartitionWithRuntimeContext() {
	try {
		final String taskName = "Test Task";
		final AtomicBoolean opened = new AtomicBoolean();
		final AtomicBoolean closed = new AtomicBoolean();
		
		final MapPartitionFunction<String, Integer> parser = new RichMapPartitionFunction<String, Integer>() {
			
			@Override
			public void open(Configuration parameters) throws Exception {
				opened.set(true);
				RuntimeContext ctx = getRuntimeContext();
				assertEquals(0, ctx.getIndexOfThisSubtask());
				assertEquals(1, ctx.getNumberOfParallelSubtasks());
				assertEquals(taskName, ctx.getTaskName());
			}
			
			@Override
			public void mapPartition(Iterable<String> values, Collector<Integer> out) {
				for (String s : values) {
					out.collect(Integer.parseInt(s));
				}
			}
			
			@Override
			public void close() throws Exception {
				closed.set(true);
			}
		};
		
		MapPartitionOperatorBase<String, Integer, MapPartitionFunction<String, Integer>> op = 
				new MapPartitionOperatorBase<String, Integer, MapPartitionFunction<String,Integer>>(
				parser, new UnaryOperatorInformation<String, Integer>(BasicTypeInfo.STRING_TYPE_INFO, BasicTypeInfo.INT_TYPE_INFO), taskName);
		
		List<String> input = new ArrayList<String>(asList("1", "2", "3", "4", "5", "6"));

		final TaskInfo taskInfo = new TaskInfo(taskName, 1, 0, 1, 0);

		ExecutionConfig executionConfig = new ExecutionConfig();
		executionConfig.disableObjectReuse();
		
		List<Integer> resultMutableSafe = op.executeOnCollections(input,
				new RuntimeUDFContext(taskInfo, null, executionConfig,
						new HashMap<String, Future<Path>>(),
						new HashMap<String, Accumulator<?, ?>>(),
						new UnregisteredMetricsGroup()),
				executionConfig);
		
		executionConfig.enableObjectReuse();
		List<Integer> resultRegular = op.executeOnCollections(input,
				new RuntimeUDFContext(taskInfo, null, executionConfig,
						new HashMap<String, Future<Path>>(),
						new HashMap<String, Accumulator<?, ?>>(),
						new UnregisteredMetricsGroup()),
				executionConfig);
		
		assertEquals(asList(1, 2, 3, 4, 5, 6), resultMutableSafe);
		assertEquals(asList(1, 2, 3, 4, 5, 6), resultRegular);
		
		assertTrue(opened.get());
		assertTrue(closed.get());
	}
	catch (Exception e) {
		e.printStackTrace();
		fail(e.getMessage());
	}
}
 
示例16
/**
 * Method that assigns a unique {@link Long} value to all elements in the input data set. The generated values are
 * consecutive.
 *
 * @param input the input data set
 * @return a data set of tuple 2 consisting of consecutive ids and initial values.
 */
public static <T> DataSet<Tuple2<Long, T>> zipWithIndex(DataSet<T> input) {

	DataSet<Tuple2<Integer, Long>> elementCount = countElementsPerPartition(input);

	return input.mapPartition(new RichMapPartitionFunction<T, Tuple2<Long, T>>() {

		long start = 0;

		@Override
		public void open(Configuration parameters) throws Exception {
			super.open(parameters);

			List<Tuple2<Integer, Long>> offsets = getRuntimeContext().getBroadcastVariableWithInitializer(
					"counts",
					new BroadcastVariableInitializer<Tuple2<Integer, Long>, List<Tuple2<Integer, Long>>>() {
						@Override
						public List<Tuple2<Integer, Long>> initializeBroadcastVariable(Iterable<Tuple2<Integer, Long>> data) {
							// sort the list by task id to calculate the correct offset
							List<Tuple2<Integer, Long>> sortedData = new ArrayList<>();
							for (Tuple2<Integer, Long> datum : data) {
								sortedData.add(datum);
							}
							Collections.sort(sortedData, new Comparator<Tuple2<Integer, Long>>() {
								@Override
								public int compare(Tuple2<Integer, Long> o1, Tuple2<Integer, Long> o2) {
									return o1.f0.compareTo(o2.f0);
								}
							});
							return sortedData;
						}
					});

			// compute the offset for each partition
			for (int i = 0; i < getRuntimeContext().getIndexOfThisSubtask(); i++) {
				start += offsets.get(i).f1;
			}
		}

		@Override
		public void mapPartition(Iterable<T> values, Collector<Tuple2<Long, T>> out) throws Exception {
			for (T value: values) {
				out.collect(new Tuple2<>(start++, value));
			}
		}
	}).withBroadcastSet(elementCount, "counts");
}
 
示例17
@Test
public void testMapPartitionWithRuntimeContext() {
	try {
		final String taskName = "Test Task";
		final AtomicBoolean opened = new AtomicBoolean();
		final AtomicBoolean closed = new AtomicBoolean();
		
		final MapPartitionFunction<String, Integer> parser = new RichMapPartitionFunction<String, Integer>() {
			
			@Override
			public void open(Configuration parameters) throws Exception {
				opened.set(true);
				RuntimeContext ctx = getRuntimeContext();
				assertEquals(0, ctx.getIndexOfThisSubtask());
				assertEquals(1, ctx.getNumberOfParallelSubtasks());
				assertEquals(taskName, ctx.getTaskName());
			}
			
			@Override
			public void mapPartition(Iterable<String> values, Collector<Integer> out) {
				for (String s : values) {
					out.collect(Integer.parseInt(s));
				}
			}
			
			@Override
			public void close() throws Exception {
				closed.set(true);
			}
		};
		
		MapPartitionOperatorBase<String, Integer, MapPartitionFunction<String, Integer>> op = 
				new MapPartitionOperatorBase<String, Integer, MapPartitionFunction<String,Integer>>(
				parser, new UnaryOperatorInformation<String, Integer>(BasicTypeInfo.STRING_TYPE_INFO, BasicTypeInfo.INT_TYPE_INFO), taskName);
		
		List<String> input = new ArrayList<String>(asList("1", "2", "3", "4", "5", "6"));

		final TaskInfo taskInfo = new TaskInfo(taskName, 1, 0, 1, 0);

		ExecutionConfig executionConfig = new ExecutionConfig();
		executionConfig.disableObjectReuse();
		
		List<Integer> resultMutableSafe = op.executeOnCollections(input,
				new RuntimeUDFContext(taskInfo, null, executionConfig,
						new HashMap<String, Future<Path>>(),
						new HashMap<String, Accumulator<?, ?>>(),
						new UnregisteredMetricsGroup()),
				executionConfig);
		
		executionConfig.enableObjectReuse();
		List<Integer> resultRegular = op.executeOnCollections(input,
				new RuntimeUDFContext(taskInfo, null, executionConfig,
						new HashMap<String, Future<Path>>(),
						new HashMap<String, Accumulator<?, ?>>(),
						new UnregisteredMetricsGroup()),
				executionConfig);
		
		assertEquals(asList(1, 2, 3, 4, 5, 6), resultMutableSafe);
		assertEquals(asList(1, 2, 3, 4, 5, 6), resultRegular);
		
		assertTrue(opened.get());
		assertTrue(closed.get());
	}
	catch (Exception e) {
		e.printStackTrace();
		fail(e.getMessage());
	}
}
 
示例18
private DataSet<Tuple2<Integer, Row>> split(BatchOperator<?> data, int k) {

		DataSet<Row> input = shuffle(data.getDataSet());

		DataSet<Tuple2<Integer, Long>> counts = DataSetUtils.countElementsPerPartition(input);

		return input
			.mapPartition(new RichMapPartitionFunction<Row, Tuple2<Integer, Row>>() {
				long taskStart = 0L;
				long totalNumInstance = 0L;

				@Override
				public void open(Configuration parameters) throws Exception {
					List<Tuple2<Integer, Long>> counts1 = getRuntimeContext().getBroadcastVariable("counts");

					int taskId = getRuntimeContext().getIndexOfThisSubtask();
					for (Tuple2<Integer, Long> cnt : counts1) {

						if (taskId < cnt.f0) {
							taskStart += cnt.f1;
						}

						totalNumInstance += cnt.f1;
					}
				}

				@Override
				public void mapPartition(Iterable<Row> values, Collector<Tuple2<Integer, Row>> out) throws Exception {
					DistributedInfo distributedInfo = new DefaultDistributedInfo();
					Tuple2<Integer, Long> split1 = new Tuple2<>(-1, -1L);
					long lcnt = taskStart;

					for (int i = 0; i <= k; ++i) {
						long sp = distributedInfo.startPos(i, k, totalNumInstance);
						long lrc = distributedInfo.localRowCnt(i, k, totalNumInstance);

						if (taskStart < sp) {
							split1.f0 = i - 1;
							split1.f1 = distributedInfo.startPos(i - 1, k, totalNumInstance)
								+ distributedInfo.localRowCnt(i - 1, k, totalNumInstance);

							break;
						}

						if (taskStart == sp) {
							split1.f0 = i;
							split1.f1 = sp + lrc;

							break;
						}
					}

					for (Row val : values) {

						if (lcnt >= split1.f1) {
							split1.f0 += 1;
							split1.f1 = distributedInfo.localRowCnt(split1.f0, k, totalNumInstance) + lcnt;
						}

						out.collect(Tuple2.of(split1.f0, val));
						lcnt++;
					}
				}
			}).withBroadcastSet(counts, "counts");
	}
 
示例19
/**
 * Matrix decomposition using ALS algorithm.
 *
 * @param inputs a dataset of user-item-rating tuples
 * @return user factors and item factors.
 */
@Override
public AlsTrainBatchOp linkFrom(BatchOperator<?>... inputs) {
    BatchOperator<?> in = checkAndGetFirst(inputs);

    final String userColName = getUserCol();
    final String itemColName = getItemCol();
    final String rateColName = getRateCol();

    final double lambda = getLambda();
    final int rank = getRank();
    final int numIter = getNumIter();
    final boolean nonNegative = getNonnegative();
    final boolean implicitPrefs = getImplicitPrefs();
    final double alpha = getAlpha();
    final int numMiniBatches = getNumBlocks();

    final int userColIdx = TableUtil.findColIndexWithAssertAndHint(in.getColNames(), userColName);
    final int itemColIdx = TableUtil.findColIndexWithAssertAndHint(in.getColNames(), itemColName);
    final int rateColIdx = TableUtil.findColIndexWithAssertAndHint(in.getColNames(), rateColName);

    // tuple3: userId, itemId, rating
    DataSet<Tuple3<Long, Long, Float>> alsInput = in.getDataSet()
        .map(new MapFunction<Row, Tuple3<Long, Long, Float>>() {
            @Override
            public Tuple3<Long, Long, Float> map(Row value) {
                return new Tuple3<>(((Number) value.getField(userColIdx)).longValue(),
                    ((Number) value.getField(itemColIdx)).longValue(),
                    ((Number) value.getField(rateColIdx)).floatValue());
            }
        });

    AlsTrain als = new AlsTrain(rank, numIter, lambda, implicitPrefs, alpha, numMiniBatches, nonNegative);
    DataSet<Tuple3<Byte, Long, float[]>> factors = als.fit(alsInput);

    DataSet<Row> output = factors.mapPartition(new RichMapPartitionFunction<Tuple3<Byte, Long, float[]>, Row>() {
        @Override
        public void mapPartition(Iterable<Tuple3<Byte, Long, float[]>> values, Collector<Row> out) {
            new AlsModelDataConverter(userColName, itemColName).save(values, out);
        }
    });

    this.setOutput(output, new AlsModelDataConverter(userColName, itemColName).getModelSchema());
    return this;
}
 
示例20
/**
 * Generate frequent sequence patterns using PrefixSpan algorithm.
 *
 * @return Frequent sequence patterns and their supports.
 */
public DataSet<Tuple2<int[], Integer>> run() {
    final int parallelism = BatchOperator.getExecutionEnvironmentFromDataSets(sequences).getParallelism();
    DataSet<Tuple2<Integer, int[]>> partitionedSequence = partitionSequence(sequences, itemCounts, parallelism);
    final int maxLength = maxPatternLength;

    return partitionedSequence
        .partitionCustom(new Partitioner<Integer>() {
            @Override
            public int partition(Integer key, int numPartitions) {
                return key % numPartitions;
            }
        }, 0)
        .mapPartition(new RichMapPartitionFunction<Tuple2<Integer, int[]>, Tuple2<int[], Integer>>() {
            @Override
            public void mapPartition(Iterable<Tuple2<Integer, int[]>> values,
                                     Collector<Tuple2<int[], Integer>> out) throws Exception {
                List<Long> bc1 = getRuntimeContext().getBroadcastVariable("minSupportCnt");
                List<Tuple2<Integer, Integer>> bc2 = getRuntimeContext().getBroadcastVariable("itemCounts");
                int taskId = getRuntimeContext().getIndexOfThisSubtask();

                long minSuppCnt = bc1.get(0);
                List<int[]> allSeq = new ArrayList<>();
                values.forEach(t -> allSeq.add(t.f1));

                List<Postfix> initialPostfixes = new ArrayList<>(allSeq.size());
                for (int i = 0; i < allSeq.size(); i++) {
                    initialPostfixes.add(new Postfix(i));
                }

                bc2.forEach(itemCount -> {
                    int item = itemCount.f0;
                    if (item % parallelism == taskId) {
                        generateFreqPattern(allSeq, initialPostfixes, item, minSuppCnt, maxLength, out);
                    }
                });
            }
        })
        .withBroadcastSet(this.minSupportCnt, "minSupportCnt")
        .withBroadcastSet(this.itemCounts, "itemCounts")
        .name("generate_freq_pattern");
}
 
示例21
/**
 * Method that assigns a unique {@link Long} value to all elements in the input data set. The generated values are
 * consecutive.
 *
 * @param input the input data set
 * @return a data set of tuple 2 consisting of consecutive ids and initial values.
 */
public static <T> DataSet<Tuple2<Long, T>> zipWithIndex(DataSet<T> input) {

	DataSet<Tuple2<Integer, Long>> elementCount = countElementsPerPartition(input);

	return input.mapPartition(new RichMapPartitionFunction<T, Tuple2<Long, T>>() {

		long start = 0;

		@Override
		public void open(Configuration parameters) throws Exception {
			super.open(parameters);

			List<Tuple2<Integer, Long>> offsets = getRuntimeContext().getBroadcastVariableWithInitializer(
					"counts",
					new BroadcastVariableInitializer<Tuple2<Integer, Long>, List<Tuple2<Integer, Long>>>() {
						@Override
						public List<Tuple2<Integer, Long>> initializeBroadcastVariable(Iterable<Tuple2<Integer, Long>> data) {
							// sort the list by task id to calculate the correct offset
							List<Tuple2<Integer, Long>> sortedData = new ArrayList<>();
							for (Tuple2<Integer, Long> datum : data) {
								sortedData.add(datum);
							}
							Collections.sort(sortedData, new Comparator<Tuple2<Integer, Long>>() {
								@Override
								public int compare(Tuple2<Integer, Long> o1, Tuple2<Integer, Long> o2) {
									return o1.f0.compareTo(o2.f0);
								}
							});
							return sortedData;
						}
					});

			// compute the offset for each partition
			for (int i = 0; i < getRuntimeContext().getIndexOfThisSubtask(); i++) {
				start += offsets.get(i).f1;
			}
		}

		@Override
		public void mapPartition(Iterable<T> values, Collector<Tuple2<Long, T>> out) throws Exception {
			for (T value: values) {
				out.collect(new Tuple2<>(start++, value));
			}
		}
	}).withBroadcastSet(elementCount, "counts");
}
 
示例22
@Test
public void testMapPartitionWithRuntimeContext() {
	try {
		final String taskName = "Test Task";
		final AtomicBoolean opened = new AtomicBoolean();
		final AtomicBoolean closed = new AtomicBoolean();
		
		final MapPartitionFunction<String, Integer> parser = new RichMapPartitionFunction<String, Integer>() {
			
			@Override
			public void open(Configuration parameters) throws Exception {
				opened.set(true);
				RuntimeContext ctx = getRuntimeContext();
				assertEquals(0, ctx.getIndexOfThisSubtask());
				assertEquals(1, ctx.getNumberOfParallelSubtasks());
				assertEquals(taskName, ctx.getTaskName());
			}
			
			@Override
			public void mapPartition(Iterable<String> values, Collector<Integer> out) {
				for (String s : values) {
					out.collect(Integer.parseInt(s));
				}
			}
			
			@Override
			public void close() throws Exception {
				closed.set(true);
			}
		};
		
		MapPartitionOperatorBase<String, Integer, MapPartitionFunction<String, Integer>> op = 
				new MapPartitionOperatorBase<String, Integer, MapPartitionFunction<String,Integer>>(
				parser, new UnaryOperatorInformation<String, Integer>(BasicTypeInfo.STRING_TYPE_INFO, BasicTypeInfo.INT_TYPE_INFO), taskName);
		
		List<String> input = new ArrayList<String>(asList("1", "2", "3", "4", "5", "6"));

		final TaskInfo taskInfo = new TaskInfo(taskName, 1, 0, 1, 0);

		ExecutionConfig executionConfig = new ExecutionConfig();
		executionConfig.disableObjectReuse();
		
		List<Integer> resultMutableSafe = op.executeOnCollections(input,
				new RuntimeUDFContext(taskInfo, null, executionConfig,
						new HashMap<String, Future<Path>>(),
						new HashMap<String, Accumulator<?, ?>>(),
						new UnregisteredMetricsGroup()),
				executionConfig);
		
		executionConfig.enableObjectReuse();
		List<Integer> resultRegular = op.executeOnCollections(input,
				new RuntimeUDFContext(taskInfo, null, executionConfig,
						new HashMap<String, Future<Path>>(),
						new HashMap<String, Accumulator<?, ?>>(),
						new UnregisteredMetricsGroup()),
				executionConfig);
		
		assertEquals(asList(1, 2, 3, 4, 5, 6), resultMutableSafe);
		assertEquals(asList(1, 2, 3, 4, 5, 6), resultRegular);
		
		assertTrue(opened.get());
		assertTrue(closed.get());
	}
	catch (Exception e) {
		e.printStackTrace();
		fail(e.getMessage());
	}
}