对于多输入模型,tf.data.Dataset.from_tensor_slices 的替代方法是什么?

我正在尝试制作一个接受两个输入的多输入 Keras 模型。一个输入是图像,第二个输入是一维文本。我将图像的路径存储在数据框中,然后将图像附加到这样的列表中:

from tqdm import tqdm

train_images = []
for image_path in tqdm(train_df['paths']):
  byte_file = tf.io.read_file(image_path)
  img = tf.image.decode_png(byte_file)
  train_images.append(img) 

一维文本输入存储在列表中。对验证集和测试集重复此过程。然后我制作一个数据集,如下所示:

train_protein = tf.expand_dims(padded_train_protein_encode,axis=2)
training_dataset = tf.data.Dataset.from_tensor_slices(((train_protein,train_images),train_Labels)) 

training_dataset = training_dataset.batch(20)

val_protein = tf.expand_dims(padded_val_protein_encode,axis=2)
validation_dataset = tf.data.Dataset.from_tensor_slices(((val_protein,val_images),validation_Labels))
validation_dataset = validation_dataset.batch(20)

test_protein = tf.expand_dims(padded_test_protein_encode,axis=2)
test_dataset = tf.data.Dataset.from_tensor_slices(((test_protein,test_images),test_Labels)) 
test_dataset = test_dataset.batch(20)

我在 Google colab 中运行此程序,即使使用高内存选项,程序也会因内存不足而崩溃。解决此问题的最佳方法是什么?

我已经研究了 tf.data.Dataset.from_generator 作为一个选项,但是当有两个输入时我无法弄清楚如何使它工作。有人可以帮忙吗?

woshisanbi 回答:对于多输入模型,tf.data.Dataset.from_tensor_slices 的替代方法是什么?

这是一种相当常见的疼痛。如果您的数据集太大而无法加载到内存中,那么没有比数据生成器更好的方法了。来自 PyTorch,有 Pythonic 类可以做到这一点,而不必使用 tf.data.Dataset.from_generator。子类化 tf.keras.utils.Sequence 可能是一个优雅的选择。无法访问您的数据集,我无法验证,但类似这样的操作应该可行。

__getitem__ 每批次都被调用。

class TfDataGenerator(tf.keras.utils.Sequence):
    def __init__(self,filepaths,proteins,labels):
        self.filepaths = np.array(filepaths)
        self.proteins = np.array(proteins)
        self.labels = labels

    def __len__(self):
        return len(self.filenames) // self.batch_size

    def __getitem__(self,index):
        indexes = self.indexes[index * self.batch_size:(index + 1) * self.batch_size]
        return __generate_x(indexes),labels[indexes]

    def __generate_x(self,indexes):
        x_1 = np.empty((self.batch_size,*self.dim,self.n_channels))
        x_2 = np.empty((self.batch_size,len(self.meta_features)))

        for index in enumerate(indexes):
            image = cv2.imread(self.filepaths[index])
            image = cv2.cvtColor(image,cv2.COLOR_RGB2BGR)
            x_1[num] = image.astype(np.float32)/255.
            x_2[num] = self.proteins[index]

        return [x_1,x_2]

    def on_epoch_end(self):
        self.indexes = np.arange(len(self.filenames))
        if self.shuffle:
            np.random.shuffle(self.indexes)

再次,一个非常粗略的例子,但希望它展示了可以做什么。 Tensorflow 文档 here

过去这让我很头疼,所以希望这个答案能有所帮助。

本文链接:https://www.f2er.com/28388.html

大家都在问