未验证 提交 84f899cb 编写于 作者: L liym27 提交者: GitHub

[cherry-pick] Get answer code from function instead of str. test=develop (#23904) (#23949)

上级 8ab4890a
......@@ -14,62 +14,79 @@
from __future__ import print_function
import astor
import gast
import inspect
import textwrap
import unittest
import numpy as np
import paddle.fluid as fluid
from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator
from paddle.fluid.dygraph.jit import dygraph_to_static_code
from ifelse_simple_func import dyfunc_with_if_else
def get_source_code(func):
raw_code = inspect.getsource(func)
code = textwrap.dedent(raw_code)
root = gast.parse(code)
source_code = astor.to_source(gast.gast_to_ast(root))
return source_code
class StaticCode1():
def dyfunc_with_if_else(x_v, label=None):
def true_fn_0(x_v):
x_v = x_v - 1
return x_v
def false_fn_0(x_v):
x_v = x_v + 1
return x_v
x_v = fluid.layers.cond(
fluid.layers.mean(x_v)[0] > 5, lambda: true_fn_0(x_v),
lambda: false_fn_0(x_v))
if label is not None:
loss = fluid.layers.cross_entropy(x_v, label)
return loss
return x_v
class StaticCode2():
def dyfunc_with_if_else(x_v, label=None):
def true_fn_1(x_v):
x_v = x_v - 1
return x_v
def false_fn_1(x_v):
x_v = x_v + 1
return x_v
x_v = fluid.layers.cond(
fluid.layers.mean(x_v)[0] > 5, lambda: true_fn_1(x_v),
lambda: false_fn_1(x_v))
if label is not None:
loss = fluid.layers.cross_entropy(x_v, label)
return loss
return x_v
class TestDygraphToStaticCode(unittest.TestCase):
def setUp(self):
# set to print all string diff when assertEqual fails
self.maxDiff = None
def test_decorator(self):
answer = "\
def dyfunc_with_if_else(x_v, label=None):\n\
\n\
def true_fn_0(x_v):\n\
x_v = x_v - 1\n\
return x_v\n\
\n\
def false_fn_0(x_v):\n\
x_v = x_v + 1\n\
return x_v\n\
x_v = fluid.layers.cond(fluid.layers.mean(x_v)[0] > 5, lambda :\n\
true_fn_0(x_v), lambda : false_fn_0(x_v))\n\
if label is not None:\n\
loss = fluid.layers.cross_entropy(x_v, label)\n\
return loss\n\
return x_v\n"
x_v = None
answer = get_source_code(StaticCode1.dyfunc_with_if_else)
code = dygraph_to_static_code(dyfunc_with_if_else)(x_v)
self.assertEqual(answer, code)
def test_program_translator(self):
answer = "\
def dyfunc_with_if_else(x_v, label=None):\n\
\n\
def true_fn_1(x_v):\n\
x_v = x_v - 1\n\
return x_v\n\
\n\
def false_fn_1(x_v):\n\
x_v = x_v + 1\n\
return x_v\n\
x_v = fluid.layers.cond(fluid.layers.mean(x_v)[0] > 5, lambda :\n\
true_fn_1(x_v), lambda : false_fn_1(x_v))\n\
if label is not None:\n\
loss = fluid.layers.cross_entropy(x_v, label)\n\
return loss\n\
return x_v\n"
answer = get_source_code(StaticCode2.dyfunc_with_if_else)
program_translator = ProgramTranslator()
code = program_translator.get_code(dyfunc_with_if_else)
self.assertEqual(answer, code)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册