提交 c31a3d0c 编写于 作者: K Katherine Wu 提交者: TensorFlower Gardener

Copy tensorflow.python.saved_model.loader_impl.get_train_op to test file

#KERAS_PRIVATE_API_CLEANUP

PiperOrigin-RevId: 339917969
Change-Id: Iadd13a8d23a941528e0384e090dc07ddc9d63da6
上级 d2664949
......@@ -278,6 +278,13 @@ def load_model(sess, path, mode):
return inputs, outputs, meta_graph_def
def get_train_op(meta_graph_def):
graph = ops.get_default_graph()
signature_def = meta_graph_def.signature_def['__saved_model_train_op']
op_name = signature_def.outputs['__saved_model_train_op'].name
return graph.as_graph_element(op_name)
class TestModelSavedModelExport(test.TestCase, parameterized.TestCase):
def _save_model_dir(self, dirname='saved_model'):
......@@ -402,7 +409,7 @@ class TestModelSavedModelExport(test.TestCase, parameterized.TestCase):
self.assertIn('predictions/' + output_name, outputs)
# Train for a step
train_op = loader_impl.get_train_op(meta_graph_def)
train_op = get_train_op(meta_graph_def)
train_outputs, _ = sess.run(
[outputs, train_op], {inputs[input_name]: input_arr,
inputs[target_name]: target_arr})
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册