From b3520b14fcc6d351f85b3d0f29fd74160253d0db Mon Sep 17 00:00:00 2001 From: liym27 <33742067+liym27@users.noreply.github.com> Date: Fri, 17 Apr 2020 16:38:56 +0800 Subject: [PATCH] Get answer code from function instead of str. test=develop (#23904) --- .../test_program_translator.py | 89 +++++++++++-------- 1 file changed, 53 insertions(+), 36 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_program_translator.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_program_translator.py index f96f1ceafa6..44ffaf2cafa 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_program_translator.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_program_translator.py @@ -14,62 +14,79 @@ from __future__ import print_function +import astor +import gast +import inspect +import textwrap import unittest -import numpy as np import paddle.fluid as fluid - from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator from paddle.fluid.dygraph.jit import dygraph_to_static_code from ifelse_simple_func import dyfunc_with_if_else +def get_source_code(func): + raw_code = inspect.getsource(func) + code = textwrap.dedent(raw_code) + root = gast.parse(code) + source_code = astor.to_source(gast.gast_to_ast(root)) + return source_code + + +class StaticCode1(): + def dyfunc_with_if_else(x_v, label=None): + def true_fn_0(x_v): + x_v = x_v - 1 + return x_v + + def false_fn_0(x_v): + x_v = x_v + 1 + return x_v + + x_v = fluid.layers.cond( + fluid.layers.mean(x_v)[0] > 5, lambda: true_fn_0(x_v), + lambda: false_fn_0(x_v)) + if label is not None: + loss = fluid.layers.cross_entropy(x_v, label) + return loss + return x_v + + +class StaticCode2(): + def dyfunc_with_if_else(x_v, label=None): + def true_fn_1(x_v): + x_v = x_v - 1 + return x_v + + def false_fn_1(x_v): + x_v = x_v + 1 + return x_v + + x_v = fluid.layers.cond( + fluid.layers.mean(x_v)[0] > 5, lambda: true_fn_1(x_v), + lambda: false_fn_1(x_v)) + + if label is not None: + loss = fluid.layers.cross_entropy(x_v, label) + return loss + return x_v + + class TestDygraphToStaticCode(unittest.TestCase): def setUp(self): # set to print all string diff when assertEqual fails self.maxDiff = None def test_decorator(self): - answer = "\ -def dyfunc_with_if_else(x_v, label=None):\n\ -\n\ - def true_fn_0(x_v):\n\ - x_v = x_v - 1\n\ - return x_v\n\ -\n\ - def false_fn_0(x_v):\n\ - x_v = x_v + 1\n\ - return x_v\n\ - x_v = fluid.layers.cond(fluid.layers.mean(x_v)[0] > 5, lambda :\n\ - true_fn_0(x_v), lambda : false_fn_0(x_v))\n\ - if label is not None:\n\ - loss = fluid.layers.cross_entropy(x_v, label)\n\ - return loss\n\ - return x_v\n" - x_v = None + answer = get_source_code(StaticCode1.dyfunc_with_if_else) code = dygraph_to_static_code(dyfunc_with_if_else)(x_v) self.assertEqual(answer, code) def test_program_translator(self): - answer = "\ -def dyfunc_with_if_else(x_v, label=None):\n\ -\n\ - def true_fn_1(x_v):\n\ - x_v = x_v - 1\n\ - return x_v\n\ -\n\ - def false_fn_1(x_v):\n\ - x_v = x_v + 1\n\ - return x_v\n\ - x_v = fluid.layers.cond(fluid.layers.mean(x_v)[0] > 5, lambda :\n\ - true_fn_1(x_v), lambda : false_fn_1(x_v))\n\ - if label is not None:\n\ - loss = fluid.layers.cross_entropy(x_v, label)\n\ - return loss\n\ - return x_v\n" - + answer = get_source_code(StaticCode2.dyfunc_with_if_else) program_translator = ProgramTranslator() code = program_translator.get_code(dyfunc_with_if_else) self.assertEqual(answer, code) -- GitLab