我一直在尝试使用Tensorflow数据集,但无法弄清楚如何有效地创建RLE蒙版。仅供参考,我正在使用Kaggle的空客船舶检测挑战赛中的dat:https://www.kaggle.com/c/airbus-ship-detection/data
我知道我的RLE解码功能可以从以下一种内核中工作(借用):
def rle_decode(mask_rle, shape=(768, 768)): ''' mask_rle: run-length as string formated (start length) shape: (height,width) of array to return Returns numpy array, 1 - mask, 0 - background ''' if not isinstance(mask_rle, str): img = np.zeros(shape[0]*shape[1], dtype=np.uint8) return img.reshape(shape).T s = mask_rle.split() starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])] starts -= 1 ends = starts + lengths img = np.zeros(shape[0]*shape[1], dtype=np.uint8) for lo, hi in zip(starts, ends): img[lo:hi] = 1 return img.reshape(shape).T
....但它似乎在管道中不能很好地发挥作用:
list_ds = tf.data.Dataset.list_files(train_paths_abs) ds = list_ds.map(parse_img)
使用以下解析函数,一切正常:
def parse_img(file_path,new_size=[128,128]): img_cOntent= tf.io.read_file(file_path) img = tf.image.decode_jpeg(img_content) img = tf.image.convert_image_dtype(img, tf.float32) img = tf.image.resize(img,new_size) return img
但是如果我戴上口罩,事情就会变得很糟糕:
def rle_decode(mask_rle, shape=(768, 768)): ''' mask_rle: run-length as string formated (start length) shape: (height,width) of array to return Returns numpy array, 1 - mask, 0 - background ''' if not isinstance(mask_rle, str): img = np.zeros(shape[0]*shape[1], dtype=np.uint8) return img.reshape(shape).T s = mask_rle.split() starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])] starts -= 1 ends = starts + lengths img = np.zeros(shape[0]*shape[1], dtype=np.uint8) for lo, hi in zip(starts, ends): img[lo:hi] = 1 return img.reshape(shape).T
尽管我的parse_img
功能工作正常(我已经在一个样本上对其进行了检查,但每次运行需要271 µs±67.9 µs)。list_ds.map
挂起之前,此步骤需要花费永久时间(> 5分钟)。我不知道怎么了,这让我发疯!
任何的想法?