未验证 提交 40162227 编写于 作者: A Aurelius84 提交者: GitHub

Support save model in dygraph.guard test=develop (#24761)

上级 355caee1
......@@ -550,6 +550,7 @@ class ProgramTranslator(object):
source_code = ast_to_source_code(root_wrapper.node)
return source_code
@switch_to_static_graph
def save_inference_model(self, dirname, feed=None, fetch=None):
"""
Saves current model as the inference model. It will prune the main_program
......
......@@ -200,7 +200,6 @@ class TestMNISTWithDeclarative(TestMNIST):
break
return loss_data
@switch_to_static_graph
def check_save_inference_model(self, inputs, prog_trans, to_static, gt_out):
if to_static:
infer_model_path = "./test_mnist_inference_model"
......@@ -208,6 +207,7 @@ class TestMNISTWithDeclarative(TestMNIST):
infer_out = self.load_and_run_inference(infer_model_path, inputs)
self.assertTrue(np.allclose(gt_out.numpy(), infer_out))
@switch_to_static_graph
def load_and_run_inference(self, model_path, inputs):
exe = fluid.Executor(self.place)
[inference_program, feed_target_names,
......
......@@ -30,6 +30,7 @@ np.random.seed(SEED)
place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda() else fluid.CPUPlace(
)
program_translator = ProgramTranslator()
class SimpleFcLayer(fluid.dygraph.Layer):
......@@ -63,6 +64,10 @@ class TestDyToStaticSaveInferenceModel(unittest.TestCase):
loss.backward()
adam.minimize(loss)
layer.clear_gradients()
# test for saving model in dygraph.guard
infer_model_dir = "./test_dy2stat_save_inference_model"
program_translator.save_inference_model(
infer_model_dir, feed=[0], fetch=[1])
# Check the correctness of the inference
dygraph_out, _ = layer(x)
self.check_save_inference_model(layer, [x_data], dygraph_out.numpy())
......@@ -77,7 +82,7 @@ class TestDyToStaticSaveInferenceModel(unittest.TestCase):
gt_out,
feed=None,
fetch=None):
program_translator = ProgramTranslator()
expected_persistable_vars = set([p.name for p in model.parameters()])
infer_model_dir = "./test_dy2stat_save_inference_model"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册