提问者:小点点

如何使用ImageDataGenerator洗牌批次?


我正在使用ImageDataGenerator和flow_from_dataframe加载数据集。

使用flow\u from\u dataframeshuffle=True对数据集中的图像进行洗牌。

我想洗牌。如果我有12个图像并且批大小=3,那么我有4个批:

batch1 = [image1, image2, image3]
batch2 = [image4, image5, image6]
batch3 = [image7, image8, image9]
batch4 = [image10, image11, image12]

我希望在不洗牌每个批次中的图像的情况下洗牌批次,以便获得例如:

batch2 = [image4, image5, image6]
batch1 = [image1, image2, image3]
batch4 = [image10, image11, image12]
batch3 = [image7, image8, image9]

这可能与ImageDataGenerator和flow_from_dataframe?有预处理功能我可以使用吗?


共1个答案

匿名用户

考虑使用<代码> TF。数据数据集API。您可以在洗牌之前执行批处理操作。

import tensorflow as tf

file_names = [f'image_{i}' for i in range(1, 10)]

ds = tf.data.Dataset.from_tensor_slices(file_names).batch(3).shuffle(3)

for _ in range(3):
    for batch in ds:
        print(batch.numpy())
    print()
[b'image_4' b'image_5' b'image_6']
[b'image_7' b'image_8' b'image_9']
[b'image_1' b'image_2' b'image_3']

[b'image_1' b'image_2' b'image_3']
[b'image_4' b'image_5' b'image_6']
[b'image_7' b'image_8' b'image_9']

[b'image_1' b'image_2' b'image_3']
[b'image_4' b'image_5' b'image_6']
[b'image_7' b'image_8' b'image_9']

然后,可以使用映射操作从文件名加载图像:

def read_image(file_name):
  image = tf.io.read_file(file_name)
  image = tf.image.decode_image(image)
  image = tf.image.convert_image_dtype(image, tf.float32)
  image = tf.image.resize_with_crop_or_pad(image, target_height=224, target_width=224)
  label = tf.strings.split(file_path, os.sep)[0]
  label = tf.cast(tf.equal(label, class_categories), tf.int32)
  return image, label

ds = ds.map(read_image)