未验证 提交 40162227 编写于 作者: A Aurelius84 提交者: GitHub

Support save model in dygraph.guard test=develop (#24761)

上级 355caee1
...@@ -550,6 +550,7 @@ class ProgramTranslator(object): ...@@ -550,6 +550,7 @@ class ProgramTranslator(object):
source_code = ast_to_source_code(root_wrapper.node) source_code = ast_to_source_code(root_wrapper.node)
return source_code return source_code
@switch_to_static_graph
def save_inference_model(self, dirname, feed=None, fetch=None): def save_inference_model(self, dirname, feed=None, fetch=None):
""" """
Saves current model as the inference model. It will prune the main_program Saves current model as the inference model. It will prune the main_program
......
...@@ -200,7 +200,6 @@ class TestMNISTWithDeclarative(TestMNIST): ...@@ -200,7 +200,6 @@ class TestMNISTWithDeclarative(TestMNIST):
break break
return loss_data return loss_data
@switch_to_static_graph
def check_save_inference_model(self, inputs, prog_trans, to_static, gt_out): def check_save_inference_model(self, inputs, prog_trans, to_static, gt_out):
if to_static: if to_static:
infer_model_path = "./test_mnist_inference_model" infer_model_path = "./test_mnist_inference_model"
...@@ -208,6 +207,7 @@ class TestMNISTWithDeclarative(TestMNIST): ...@@ -208,6 +207,7 @@ class TestMNISTWithDeclarative(TestMNIST):
infer_out = self.load_and_run_inference(infer_model_path, inputs) infer_out = self.load_and_run_inference(infer_model_path, inputs)
self.assertTrue(np.allclose(gt_out.numpy(), infer_out)) self.assertTrue(np.allclose(gt_out.numpy(), infer_out))
@switch_to_static_graph
def load_and_run_inference(self, model_path, inputs): def load_and_run_inference(self, model_path, inputs):
exe = fluid.Executor(self.place) exe = fluid.Executor(self.place)
[inference_program, feed_target_names, [inference_program, feed_target_names,
......
...@@ -30,6 +30,7 @@ np.random.seed(SEED) ...@@ -30,6 +30,7 @@ np.random.seed(SEED)
place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda() else fluid.CPUPlace( place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda() else fluid.CPUPlace(
) )
program_translator = ProgramTranslator()
class SimpleFcLayer(fluid.dygraph.Layer): class SimpleFcLayer(fluid.dygraph.Layer):
...@@ -63,6 +64,10 @@ class TestDyToStaticSaveInferenceModel(unittest.TestCase): ...@@ -63,6 +64,10 @@ class TestDyToStaticSaveInferenceModel(unittest.TestCase):
loss.backward() loss.backward()
adam.minimize(loss) adam.minimize(loss)
layer.clear_gradients() layer.clear_gradients()
# test for saving model in dygraph.guard
infer_model_dir = "./test_dy2stat_save_inference_model"
program_translator.save_inference_model(
infer_model_dir, feed=[0], fetch=[1])
# Check the correctness of the inference # Check the correctness of the inference
dygraph_out, _ = layer(x) dygraph_out, _ = layer(x)
self.check_save_inference_model(layer, [x_data], dygraph_out.numpy()) self.check_save_inference_model(layer, [x_data], dygraph_out.numpy())
...@@ -77,7 +82,7 @@ class TestDyToStaticSaveInferenceModel(unittest.TestCase): ...@@ -77,7 +82,7 @@ class TestDyToStaticSaveInferenceModel(unittest.TestCase):
gt_out, gt_out,
feed=None, feed=None,
fetch=None): fetch=None):
program_translator = ProgramTranslator()
expected_persistable_vars = set([p.name for p in model.parameters()]) expected_persistable_vars = set([p.name for p in model.parameters()])
infer_model_dir = "./test_dy2stat_save_inference_model" infer_model_dir = "./test_dy2stat_save_inference_model"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册