无法使用tf.io.decode_image解码一批图像

我正在尝试进行面部关键点回归。我成功创建了带有图像和标签编码的TFRecord文件(标签是面部kypoint)。

然后,我开始将数据(图像和关键点)加载到内存中(遵循dependency scope处的指南)。我想先批处理所有图像,然后按照该指南中的说明对图像进行解码。但是,这不起作用。如果我的理解是正确的,我只能在单个图像上使用tf.image.decode_image(),而不在批处理中使用。我的理解正确吗?如果是,我如何解码一批图像?

提前谢谢!

CC

代码如下:

Module A

这将引发以下ValueError:

ds = tf.data.TFRecordDataset(TFR_FILENAME)

ds = ds.repeat(EPOCHS)

ds = ds.shuffle(BUFFER_SIZE + BATCH_SIZE)

ds = ds.batch(BATCH_SIZE)

finally I tried to decode the image using tf.image.decode_image()

feature_description = {'height': tf.io.FixedLenFeature([],tf.int64),'width': tf.io.FixedLenFeature([],'depth': tf.io.FixedLenFeature([],'kpts': tf.io.FixedLenFeature([136],tf.float32),'image_raw': tf.io.FixedLenFeature([],tf.string),}

for record in ds.take(1):
    record = tf.io.parse_example(record,feature_description)
    decoded_image = tf.io.decode_image(record['image_raw'],dtype=tf.float32)

cghchong 回答:无法使用tf.io.decode_image解码一批图像

实际上,decode_image仅适用于单个图像。在批处理之前,您仍然应该通过对数据集进行解码来获得合理的性能。

类似这样的东西(未经测试的代码,可能需要一些调整):

ds = tf.data.TFRecordDataset(TFR_FILENAME)

def parse_and_decode(record):
  record = tf.io.parse_example(record,feature_description)
  record['image'] = tf.io.decode_image(record['image_raw'],dtype=tf.float32)
  return record

ds = ds.map(parse_and_decode)

ds = ds.repeat(EPOCHS)

ds = ds.shuffle(BUFFER_SIZE + BATCH_SIZE)
...
本文链接:https://www.f2er.com/2967342.html

大家都在问