我正在使用ImageDataGenerator和flow_from_dataframe加载数据集。
使用flow\u from\u dataframe
和shuffle=True
对数据集中的图像进行洗牌。
我想洗牌。如果我有12个图像并且批大小=3,那么我有4个批:
batch1 = [image1, image2, image3]
batch2 = [image4, image5, image6]
batch3 = [image7, image8, image9]
batch4 = [image10, image11, image12]
我希望在不洗牌每个批次中的图像的情况下洗牌批次,以便获得例如:
batch2 = [image4, image5, image6]
batch1 = [image1, image2, image3]
batch4 = [image10, image11, image12]
batch3 = [image7, image8, image9]
这可能与ImageDataGenerator和flow_from_dataframe?有预处理功能我可以使用吗?
考虑使用<代码> TF。数据数据集API。您可以在洗牌之前执行批处理操作。
import tensorflow as tf
file_names = [f'image_{i}' for i in range(1, 10)]
ds = tf.data.Dataset.from_tensor_slices(file_names).batch(3).shuffle(3)
for _ in range(3):
for batch in ds:
print(batch.numpy())
print()
[b'image_4' b'image_5' b'image_6']
[b'image_7' b'image_8' b'image_9']
[b'image_1' b'image_2' b'image_3']
[b'image_1' b'image_2' b'image_3']
[b'image_4' b'image_5' b'image_6']
[b'image_7' b'image_8' b'image_9']
[b'image_1' b'image_2' b'image_3']
[b'image_4' b'image_5' b'image_6']
[b'image_7' b'image_8' b'image_9']
然后,可以使用映射操作从文件名加载图像:
def read_image(file_name):
image = tf.io.read_file(file_name)
image = tf.image.decode_image(image)
image = tf.image.convert_image_dtype(image, tf.float32)
image = tf.image.resize_with_crop_or_pad(image, target_height=224, target_width=224)
label = tf.strings.split(file_path, os.sep)[0]
label = tf.cast(tf.equal(label, class_categories), tf.int32)
return image, label
ds = ds.map(read_image)