我正在尝试进行面部关键点回归。我成功创建了带有图像和标签编码的TFRecord文件(标签是面部kypoint)。
然后,我开始将数据(图像和关键点)加载到内存中(遵循dependency scope处的指南)。我想先批处理所有图像,然后按照该指南中的说明对图像进行解码。但是,这不起作用。如果我的理解是正确的,我只能在单个图像上使用tf.image.decode_image(),而不在批处理中使用。我的理解正确吗?如果是,我如何解码一批图像?
提前谢谢!
CC
代码如下:
Module A
这将引发以下ValueError:
ds = tf.data.TFRecordDataset(TFR_FILENAME)
ds = ds.repeat(EPOCHS)
ds = ds.shuffle(BUFFER_SIZE + BATCH_SIZE)
ds = ds.batch(BATCH_SIZE)
finally I tried to decode the image using tf.image.decode_image()
feature_description = {'height': tf.io.FixedLenFeature([],tf.int64),'width': tf.io.FixedLenFeature([],'depth': tf.io.FixedLenFeature([],'kpts': tf.io.FixedLenFeature([136],tf.float32),'image_raw': tf.io.FixedLenFeature([],tf.string),}
for record in ds.take(1):
record = tf.io.parse_example(record,feature_description)
decoded_image = tf.io.decode_image(record['image_raw'],dtype=tf.float32)