我一直在尝试使用Tensorflow数据集,但无法弄清楚如何有效地创建RLE蒙版。 仅供参考,我正在使用Kaggle的空客船舶检测挑战赛中的dat:https://www.kaggle.com/c/airbus-ship-detection/data
我知道我的RLE解码功能可以从以下一种内核中工作(借用):
def rle_decode(mask_rle,shape=(768,768)):
'''
mask_rle: run-length as string formated (start length)
shape: (height,width) of array to return
Returns numpy array,1 - mask,0 - background
'''
if not isinstance(mask_rle,str):
img = np.zeros(shape[0]*shape[1],dtype=np.uint8)
return img.reshape(shape).T
s = mask_rle.split()
starts,lengths = [np.asarray(x,dtype=int) for x in (s[0:][::2],s[1:][::2])]
starts -= 1
ends = starts + lengths
img = np.zeros(shape[0]*shape[1],dtype=np.uint8)
for lo,hi in zip(starts,ends):
img[lo:hi] = 1
return img.reshape(shape).T
....但是它似乎不能很好地与管道配合使用:
list_ds = tf.data.Dataset.list_files(train_paths_abs)
ds = list_ds.map(parse_img)
使用以下解析功能,一切正常:
def parse_img(file_path,new_size=[128,128]):
img_content = tf.io.read_file(file_path)
img = tf.image.decode_jpeg(img_content)
img = tf.image.convert_image_dtype(img,tf.float32)
img = tf.image.resize(img,new_size)
return img
但是如果我戴上口罩,事情就会变得很糟糕
def parse_img(file_path,128]):
# Image
img_content = tf.io.read_file(file_path)
img = tf.image.decode_jpeg(img_content)
img = tf.image.convert_image_dtype(img,new_size)
# Mask
file_id = tf.strings.split(file_path,'/')[-1]
objects = [rle_decode(m) for m in df2[df.ImageId==file_id]]
mask = np.sum(objects,axis=0)
mask = np.expand_dims(mask,3) # Force mask to have 3 channels,necessary for resize step
mask = tf.image.convert_image_dtype(mask,tf.int8)
mask = tf.clip_by_value(mask,1)
mask = tf.image.resize(mask,new_size)
mask = tf.squeeze(mask) # squeeze back
mask = tf.image.convert_image_dtype(mask,tf.int8)
return img,mask
尽管我的parse_img
函数工作正常(我已经在一个样本上对其进行了检查,但是每次运行需要271 µs±67.9 µs); list_ds.map
步骤将永远挂起(> 5分钟),然后再挂起。
我不知道怎么了,这让我发疯!
有什么想法吗?