在Keras中使用Functional API检查模型输入时出现错误

我使用Keras根据以下指南制定了数据生成器方案:

https://stanford.edu/~shervine/blog/keras-how-to-generate-data-on-the-fly

我的数据生成器脚本似乎正在正确生成numpy数组。但是,当我使用Functional API创建模型时,出现以下错误:

ValueError:检查模型输入时出错:传递给模型的Numpy数组列表不是模型预期的大小。预计会看到2个数组,但获得了以下1个数组的列表:

我似乎不明白这个错误。我的模型接受两个不同的输入。它需要一个矩阵(大小为7000 x 208)作为卷积层的输入以及一个神经网络的矢量(7000)。然后将这两个分支合并,并提供给完全连接的层,然后再提供输出层。这是我设置网络的方式:

ksize = 2
l2_lambda =  0.0001

i1 = Input(shape=(7000,208))



c1 = Conv1D(128*2,kernel_size=ksize,activation='relu',kernel_regularizer=keras.regularizers.l2(l2_lambda))(i1)
c1 = Conv1D(128*2,kernel_regularizer=keras.regularizers.l2(l2_lambda))(c1)
c1 = AveragePooling1D(pool_size=ksize)(c1)
c1 = Dropout(0.2)(c1)
c1 = flatten()(c1)

i2 = Input(shape=(7000,))
c2 = Dense(64,kernel_regularizer=keras.regularizers.l2(l2_lambda))(i2)
c2 = Dropout(0.1)(c2)

c = concatenate([c1,c2])

x = Dense(256,kernel_initializer='normal',kernel_regularizer=keras.regularizers.l2(l2_lambda))(c)
x = Dropout(0.25)(x)
output = Dense(5,activation='softmax')(x)

model = Model([i1,i2],[output])

model.summary()

model.compile(loss=keras.losses.categorical_crossentropy,optimizer=keras.optimizers.Adam(),metrics=['accuracy'])

model.fit_generator(generator=training_generator,validation_data=validation_generator)

我的生成器脚本基本上会生成一定大小的一批,这样我就不必一次将所有内容都加载到内存中。数据生成脚本


class DataGenerator(keras.utils.Sequence):
    def __init__(self,list_IDs_snp,list_IDs_pos,labels,batch_size=32,n_channels=1,n_classes=5,shuffle=True):
        self.batch_size = batch_size
        self.list_IDs_snp = list_IDs_snp
        self.list_IDs_pos = list_IDs_pos
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.shuffle = shuffle
        self.labels = labels
        self.on_epoch_end()

    def __len__(self):
        'Denotes the number of batches per epoch'
        return int(np.floor(len(self.list_IDs_snp) / self.batch_size))

    def __getitem__(self,index):
        'Generate one batch of data'
        # Generate indexes of the batch
        indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]

        # Find list of IDs
        list_IDs_temp_snp = [self.list_IDs_snp[k] for k in indexes]
        list_IDs_temp_pos = [self.list_IDs_pos[k] for k in indexes]

        # Generate data
        snp,pos,y = self.__data_generation(list_IDs_temp_snp,list_IDs_temp_pos)

        return snp,y

    def on_epoch_end(self):
        self.indexes = np.arange(len(self.list_IDs_snp))
        if self.shuffle==True:
            np.random.shuffle(self.indexes)

    def __data_generation(self,list_IDs_temp_snp,list_IDs_temp_pos):
        snp = np.empty((self.batch_size,7000,208))
        pos = np.empty((self.batch_size,7000))
        y = np.empty((self.batch_size),dtype=int)

        for ID in range(len(list_IDs_temp_snp)):
            snp[ID] = np.load(list_IDs_temp_snp[ID])
            pos[ID] = np.load(list_IDs_temp_pos[ID])
            y[ID] = self.labels[list_IDs_temp_snp[ID]]
        return snp,y

此数据生成方案与我一开始共享的链接相同。

为了生成数据,我按如下方式调用脚本:

params = {'batch_size': 3,'n_classes': 5,'n_channels': 1,'shuffle': True}

training_generator = DataGenerator(partition_snp['train'],partition_pos['train'],**params)
validation_generator = DataGenerator(partition_snp['valid'],partition_pos['valid'],**params)

您认为问题可能出在我分别发送partition_snp和partition_pos吗? Partition_snp和partition_pos只是字典,其中包含每个示例的路径。每个字典都有两个键:“火车”和“有效”。

如果有人能解释为什么我遇到上面提到的错误,我将不胜感激。在执行代码时,我打印了矩阵和向量的类型,并显示了numpy array。因此,我不知道为什么会收到此错误。见识将不胜感激。

smazhe 回答:在Keras中使用Functional API检查模型输入时出现错误

问题出在您的__getitem__方法中,您将返回三个元素的元组,而它应该是一个输入列表和一个输出列表,例如一个元组,

def __getitem__(self,index):
    'Generate one batch of data'
    # Generate indexes of the batch
    indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]

    # Find list of IDs
    list_IDs_temp_snp = [self.list_IDs_snp[k] for k in indexes]
    list_IDs_temp_pos = [self.list_IDs_pos[k] for k in indexes]

    # Generate data
    snp,pos,y = self.__data_generation(list_IDs_temp_snp,list_IDs_temp_pos)

    return [snp,pos],y

由于只有一个输出,因此不需要列表。

本文链接:https://www.f2er.com/3167228.html

大家都在问