如何将所有tf.data.Dataset对象提取到功能和标签中并传递到ImageDataGenerator的flow()方法中?

我正在基于cifar10数据集进行小型项目。我已经从tfds.load(...)加载了数据并正在练习图像增强技术。

由于我正在使用tf.data.Dataset对象(即我的数据集),因此无法实现实时数据增强,因此我想将所有功能传递到tf.keras.preprocessing.image.ImageDataGenerator.flow(...)中以获得实时功能。增强。

但是此flow(...)方法接受与tf.data.Dataset对象无关的NumPy数组。

有人可以在这方面(或任何其他选择)指导我,我该如何进一步进行?

tf.image实时转换吗?如果没有,除了ImageDataGenerator.flow(...)之外,最好的方法是什么?

我的代码:

import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.keras.preprocessing.image import ImageDataGenerator

splitting = tfds.Split.ALL.subsplit(weighted=(70,20,10))
dataset_cifar10,dataset_info = tfds.load(name='cifar10',split=splitting,as_supervised=True,with_info=True)

train_dataset,valid_dataset,test_dataset = dataset_cifar10

BATCH_SIZE = 32

train_dataset = train_dataset.batch(batch_size=BATCH_SIZE)
train_dataset = train_dataset.prefetch(buffer_size=1)

image_generator = ImageDataGenerator(rotation_range=45,width_shift_range=0.15,height_shift_range=0.15,zoom_range=0.2,horizontal_flip=True,vertical_flip=True,rescale=1./255)

train_dataset_generator = image_generator.flow(...)

...
liulity5201314 回答:如何将所有tf.data.Dataset对象提取到功能和标签中并传递到ImageDataGenerator的flow()方法中?

在分割训练和测试数据集之后,您可以立即遍历数据集并追加一个列表,该列表可与ImageDataGenerator一起使用。完整的用例如下:

cifar10_data,cifar10_info = tfds.load("cifar10",with_info=True,as_supervised=True)
train_data,test_data = cifar10_data['train'],cifar10_data['test']
NUM_CLASSES = 10

train_x = []
train_y = []
for sample in train_data:
    train_x.append(sample[0].numpy())
    train_y.append(tf.keras.utils.to_categorical(sample[1].numpy(),num_classes=NUM_CLASSES))

train_x = np.asarray(train_x)
train_y = np.asarray(train_y)

# DataGenerator
datagen = tf.keras.preprocessing.image.ImageDataGenerator(
    featurewise_center=True,featurewise_std_normalization=True,horizontal_flip=True)

# Fitting train_x data
datagen.fit(train_x)

# Testing
EPOCHS = 1
BATCH_SIZE = 16
for e in range(EPOCHS):
    for batch_x,batch_y in datagen.flow(train_x,train_y,batch_size=BATCH_SIZE):
        print(batch_x,batch_y)
        # Manually needs to break loop
,
import tensorflow as tf
import tensorflow_datasets as tfds
tfds.disable_progress_bar()
from tensorflow.keras.preprocessing.image import ImageDataGenerator

splits = ['train[:70%]','train[70%:90%]','train[90%:]']
BATCH_SIZE = 64
dataset_cifar10,dataset_info = tfds.load(name='cifar10',split=splits,as_supervised=True,batch_size=BATCH_SIZE)

train_dataset,valid_dataset,test_dataset = dataset_cifar10

image_generator = tf.keras.preprocessing.image.ImageDataGenerator(
    rotation_range=45,width_shift_range=0.15,height_shift_range=0.15,zoom_range=0.2,horizontal_flip=True,vertical_flip=True,rescale=1./255)

# custom function to wrap image data generator with raw dataset
def tfds_imgen(ds,imgen,batch_size,num_batches):
    for images,labels in ds.batch(batch_size=batch_size).prefetch(buffer_size=1):
        flow = imgen.flow(images,labels,batch_size=batch_size)
        for _ in range(num_batches):
            yield next(flow)
# call the custom function to get the augmented data generator
train_dataset_generator = tfds_imgen(
    train_dataset.as_numpy_iterator(),image_generator,batch_size=32,num_batches=BATCH_SIZE // 32
)       
本文链接:https://www.f2er.com/3015189.html

大家都在问