我已经构建了自己的DataGenerator,以便在keras.fit_generator中使用它
在我的训练脚本中,我从2个路径列表中实例化了2个生成器。一种用作训练基因,另一种用作有效基因。 On_epoch_end是针对训练生成器的调用,而不是有效生成器的调用。我需要on_epoch_end回调才能重置我的音量索引,否则在第二个时代,我会收到错误:IndexError:列表索引超出范围(加载卷时)
training_generator = DataGenerator.DataGenerator('TrainingLoader',list_id,mask_id,n_cube=n_cube_train,batch_size=2,dim=(64,64,64),n_channels=1,n_classes=3,shuffle=True,augmentation=True,overlap=4,rotation=0,translation=0,scaling=1,channel_first=False,depth_first=False)
validation_generator = DataGenerator.DataGenerator('ValidLoader',valid_list_id,valid_mask_list_id,n_cube=n_cube_valid,shuffle=False,augmentation=False,depth_first=False)
model3.fit_generator(generator=training_generator,epochs=1000,validation_data=validation_generator,validation_freq=1,verbose=1,workers=0,callbacks=callback)
================================================ ============================= DataGenerator类(keras.utils.Sequence): “为Keras生成数据”
def __init__(self,name,n_cube,batch_size=5,n_classes=10,rotation=10,translation=10,scaling=0.9,depth_first=False):
"""
Initialization of the class
---
:param list_id:
:param labels:
:param batch_size: Number of data to load per batch
:param dim: Dimension of the data
:param n_channels: Number of information per pixel. 1-Grayscale 3-RGB
:param n_classes: Number of mask
:param shuffle: Boolean for shuffling the order of the loading data
"""
self.name = name
self.list_id = list_id
self.mask_id = mask_id
self.batch_size = batch_size
self.dim = dim
self.overlap = overlap
self.n_channels = n_channels
self.n_classes = n_classes
self.shuffle = shuffle
self.augmentation = augmentation
self.rotation = rotation
self.translation = translation
self.scaling = scaling
self.on_epoch_end()
self.offset = 0
self.volume_index = 0
self.cube_index = 0
self.volume_cube_index = []
self.n_cube = n_cube
self.channel_first = channel_first
self.depth_first = depth_first
def __len__(self):
"""
Function that calculate the number of batch needed per epoch
---
:return: The number of batch per epoch
"""
#print(self.name,int(np.floor(self.n_cube / self.batch_size)))
#return int(np.floor(self.n_cube / self.batch_size))
return (self.n_cube + self.batch_size - 1) // self.batch_size # round up
def __getitem__(self,index):
# Generate data
return self.__data_generation()
def on_epoch_end(self):
"""
activate at the beginning and at the end of every epoch.
Shuffle the ids id shuffle = True
---
:return: None
"""
self.offset = 0
self.volume_index = 0
self.cube_index = 0
def __data_generation(self):
"""
:return: a training batch cubes=(n
"""
for i in range(0,self.batch_size):
# Verify if load volume is already done
if self.cube_index == len(self.volume_cube_index):
self.load_volume()
if i == 0:
cubes = dt.get_cube(self.volume,self.volume_cube_index[self.cube_index],self.dim[0],self.dim[2])
masks = dt.get_cube_mask(self.mask,self.dim[2],self.n_classes)
elif i > 0:
temp_cube = dt.get_cube(self.volume,self.dim[2])
cubes = np.concatenate((temp_cube,cubes),axis=0)
temp_mask = dt.get_cube_mask(self.mask,self.n_classes)
masks = np.concatenate((temp_mask,masks),axis=0)
self.cube_index += 1
return cubes,masks
def load_volume(self):
self.volume,self.mask = dt.get_process_volume(data_dir=self.list_id[self.volume_index],mask_dir=self.mask_id[self.volume_index],kernel_widht=self.dim[0],kernel_depth=self.dim[2],overlap=self.overlap,rotation=self.rotation,translation=self.translation,scaling=self.scaling,augmentation=self.augmentation)
self.volume_cube_index = dt.get_cube_index(image=self.volume,resolution=self.dim[0],depth=self.dim[2],shuffle=self.shuffle)
# Reset the cube index,update volume index
self.cube_index = 0
self.volume_index += 1