我对Tensorflow估算器非常陌生。我正在尝试在训练后保存 pix2pix 模型。使用Estimators以后如何在本地计算机上导出和使用模型基本上是我的问题(就像我们对tf.keras模型所做的那样,即加载权重并进行预测)。我已经尝试了几乎所有看到的解决方案。
对于 serving_input_fn 和 tf.estimator.ModeKeys.PREDICT 案例,我大多感到困惑我写的> model_fn :
^[[3J^[[H^[[2J
serving_input_fn stackoverflow source。
我正在使用 TFRECORD 进行训练,但希望使用 .png 图片进行预测。我的 model_fn 是:
def serving_input_receiver_fn(flaGS):
def decode_and_resize(image_str_tensor):
image = tf.image.decode_jpeg(image_str_tensor,channels=flaGS.NB_CHANNELS)
image = tf.expand_dims(image,0)
image = tf.image.resize_bilinear(image,[flaGS.IMAGE_DIM,flaGS.IMAGE_DIM],align_corners=False)
image = tf.squeeze(image,squeeze_dims=[0])
image = tf.cast(image,dtype=tf.uint8)
return image
input_ph = tf.placeholder(tf.string,shape=[None],name='image_binary')
images_tensor = tf.map_fn(decode_and_resize,input_ph,back_prop=False,dtype=tf.uint8)
images_tensor = tf.image.convert_image_dtype(images_tensor,dtype=tf.float32)
return tf.estimator.export.ServingInputReceiver({'images': images_tensor},{'bytes': input_ph}
)
serving_input_fn = partial(serving_input_receiver_fn,flaGS=flaGS)
估算器训练得很好,检查点和其他文件被顺利保存在 GCS-Bucket 中。只有当我尝试导出模型时,它才会给出错误:
def model_fn(features,labels,mode,params):
if (mode != tf.estimator.ModeKeys.PREDICT):
loss = loss_fn(features,labels)
learning_rate = tf.train.exponential_decay(flaGS.LEARNING_RATE,tf.train.get_global_step(),decay_steps=100000,decay_rate=0.96)
optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)
optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)
train_op=optimizer.minimize(loss,tf.train.get_global_step())
return tf.contrib.tpu.TPUEstimatorSpec(mode=mode,loss=loss,train_op=train_op)
else:
predictions=generator_fn(features)
return tf.contrib.tpu.TPUEstimatorSpec(mode=mode,predictions={"predictions": predictions})