Tensorflow 2中的模型训练期间如何捕获任何异常

我正在使用Tensorflow训练Unet模型。如果我传递给模型进行训练的任何图像存在问题,则会引发异常。有时这可能需要一两个小时的训练时间。将来是否有可能捕获到任何此类异常,以便我的模型可以继续显示下一个图像并恢复训练?我尝试将try/catch块添加到下面显示的process_path函数中,但这没有效果...

def process_path(filePath):
    # catching exceptions here has no effect
    parts = tf.strings.split(filePath,'/')
    fileName = parts[-1]
    parts = tf.strings.split(fileName,'.')
    prefix = tf.convert_to_tensor(maskDir,dtype=tf.string)
    suffix = tf.convert_to_tensor("-mask.png",dtype=tf.string)
    maskFileName = tf.strings.join((parts[-2],suffix))
    maskPath = tf.strings.join((prefix,maskFileName),separator='/')

    # load the raw data from the file as a string
    img = tf.io.read_file(filePath)
    img = decode_img(img)
    mask = tf.io.read_file(maskPath)
    oneHot = decodeMask(mask)
    img.set_shape([256,256,3])
    oneHot.set_shape([256,10])
    return img,oneHot

trainSize = int(0.7 * DATASET_SIZE)
validSize = int(0.3 * DATASET_SIZE)
batchSize = 32

allDataSet = tf.data.Dataset.list_files(str(imageDir + "/*"))

trainDataSet = allDataSet.take(trainSize)
trainDataSet = trainDataSet.shuffle(1000).repeat()
trainDataSet = trainDataSet.map(process_path,num_parallel_calls=tf.data.experimental.AUTOTUNE)
trainDataSet = trainDataSet.batch(batchSize)
trainDataSet = trainDataSet.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)

validDataSet = allDataSet.skip(trainSize)
validDataSet = validDataSet.shuffle(1000).repeat()
validDataSet = validDataSet.map(process_path)
validDataSet = validDataSet.batch(batchSize)

imageHeight = 256
imageWidth = 256
channels = 3

inputImage = Input((imageHeight,imageWidth,channels),name='img') 
model = baseUnet.get_unet(inputImage,n_filters=16,dropout=0.05,batchnorm=True)
model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy'])

callbacks = [
    EarlyStopping(patience=5,verbose=1),ReduceLROnPlateau(factor=0.1,patience=5,min_lr=0.00001,ModelCheckpoint(outputModel,verbose=1,save_best_only=True,save_weights_only=False)
]

BATCH_SIZE = 32
BUFFER_SIZE = 1000
EPOCHS = 20

stepsPerEpoch = int(trainSize / BATCH_SIZE)
validationSteps = int(validSize / BATCH_SIZE)

model_history = model.fit(trainDataSet,epochs=EPOCHS,steps_per_epoch=stepsPerEpoch,validation_steps=validationSteps,validation_data=validDataSet,callbacks=callbacks)

以下link显示了类似的情况,并说明了“ Python函数仅执行一次即可构建函数图,并且try和else语句对此无效。”尽管该链接显示了如何遍历数据集并捕获错误...

dataset = ...
iterator = iter(dataset)

while True:
  try:
    elem = next(iterator)
    ...
  except InvalidArgumentError:
    ...
  except StopIteration:
    break

...但是,我正在寻找一种在训练过程中捕捉错误的方法。这可能吗?

huanghongwang123 回答:Tensorflow 2中的模型训练期间如何捕获任何异常

暂时没有好的解决方案,如果你有好的解决方案,请发邮件至:iooj@foxmail.com
本文链接:https://www.f2er.com/3167449.html

大家都在问