我正在基于cifar10
数据集进行小型项目。我已经从tfds.load(...)
加载了数据并正在练习图像增强技术。
由于我正在使用tf.data.Dataset
对象(即我的数据集),因此无法实现实时数据增强,因此我想将所有功能传递到tf.keras.preprocessing.image.ImageDataGenerator.flow(...)
中以获得实时功能。增强。
但是此flow(...)
方法接受与tf.data.Dataset
对象无关的NumPy数组。
有人可以在这方面(或任何其他选择)指导我,我该如何进一步进行?
tf.image
实时转换吗?如果没有,除了ImageDataGenerator.flow(...)
之外,最好的方法是什么?
我的代码:
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.keras.preprocessing.image import ImageDataGenerator
splitting = tfds.Split.ALL.subsplit(weighted=(70,20,10))
dataset_cifar10,dataset_info = tfds.load(name='cifar10',split=splitting,as_supervised=True,with_info=True)
train_dataset,valid_dataset,test_dataset = dataset_cifar10
BATCH_SIZE = 32
train_dataset = train_dataset.batch(batch_size=BATCH_SIZE)
train_dataset = train_dataset.prefetch(buffer_size=1)
image_generator = ImageDataGenerator(rotation_range=45,width_shift_range=0.15,height_shift_range=0.15,zoom_range=0.2,horizontal_flip=True,vertical_flip=True,rescale=1./255)
train_dataset_generator = image_generator.flow(...)
...