Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
b3520b14
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
b3520b14
编写于
4月 17, 2020
作者:
L
liym27
提交者:
GitHub
4月 17, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Get answer code from function instead of str. test=develop (#23904)
上级
37ef7c13
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
53 addition
and
36 deletion
+53
-36
python/paddle/fluid/tests/unittests/dygraph_to_static/test_program_translator.py
...ts/unittests/dygraph_to_static/test_program_translator.py
+53
-36
未找到文件。
python/paddle/fluid/tests/unittests/dygraph_to_static/test_program_translator.py
浏览文件 @
b3520b14
...
...
@@ -14,62 +14,79 @@
from
__future__
import
print_function
import
astor
import
gast
import
inspect
import
textwrap
import
unittest
import
numpy
as
np
import
paddle.fluid
as
fluid
from
paddle.fluid.dygraph.dygraph_to_static
import
ProgramTranslator
from
paddle.fluid.dygraph.jit
import
dygraph_to_static_code
from
ifelse_simple_func
import
dyfunc_with_if_else
def
get_source_code
(
func
):
raw_code
=
inspect
.
getsource
(
func
)
code
=
textwrap
.
dedent
(
raw_code
)
root
=
gast
.
parse
(
code
)
source_code
=
astor
.
to_source
(
gast
.
gast_to_ast
(
root
))
return
source_code
class
StaticCode1
():
def
dyfunc_with_if_else
(
x_v
,
label
=
None
):
def
true_fn_0
(
x_v
):
x_v
=
x_v
-
1
return
x_v
def
false_fn_0
(
x_v
):
x_v
=
x_v
+
1
return
x_v
x_v
=
fluid
.
layers
.
cond
(
fluid
.
layers
.
mean
(
x_v
)[
0
]
>
5
,
lambda
:
true_fn_0
(
x_v
),
lambda
:
false_fn_0
(
x_v
))
if
label
is
not
None
:
loss
=
fluid
.
layers
.
cross_entropy
(
x_v
,
label
)
return
loss
return
x_v
class
StaticCode2
():
def
dyfunc_with_if_else
(
x_v
,
label
=
None
):
def
true_fn_1
(
x_v
):
x_v
=
x_v
-
1
return
x_v
def
false_fn_1
(
x_v
):
x_v
=
x_v
+
1
return
x_v
x_v
=
fluid
.
layers
.
cond
(
fluid
.
layers
.
mean
(
x_v
)[
0
]
>
5
,
lambda
:
true_fn_1
(
x_v
),
lambda
:
false_fn_1
(
x_v
))
if
label
is
not
None
:
loss
=
fluid
.
layers
.
cross_entropy
(
x_v
,
label
)
return
loss
return
x_v
class
TestDygraphToStaticCode
(
unittest
.
TestCase
):
def
setUp
(
self
):
# set to print all string diff when assertEqual fails
self
.
maxDiff
=
None
def
test_decorator
(
self
):
answer
=
"
\
def dyfunc_with_if_else(x_v, label=None):
\n\
\n\
def true_fn_0(x_v):
\n\
x_v = x_v - 1
\n\
return x_v
\n\
\n\
def false_fn_0(x_v):
\n\
x_v = x_v + 1
\n\
return x_v
\n\
x_v = fluid.layers.cond(fluid.layers.mean(x_v)[0] > 5, lambda :
\n\
true_fn_0(x_v), lambda : false_fn_0(x_v))
\n\
if label is not None:
\n\
loss = fluid.layers.cross_entropy(x_v, label)
\n\
return loss
\n\
return x_v
\n
"
x_v
=
None
answer
=
get_source_code
(
StaticCode1
.
dyfunc_with_if_else
)
code
=
dygraph_to_static_code
(
dyfunc_with_if_else
)(
x_v
)
self
.
assertEqual
(
answer
,
code
)
def
test_program_translator
(
self
):
answer
=
"
\
def dyfunc_with_if_else(x_v, label=None):
\n\
\n\
def true_fn_1(x_v):
\n\
x_v = x_v - 1
\n\
return x_v
\n\
\n\
def false_fn_1(x_v):
\n\
x_v = x_v + 1
\n\
return x_v
\n\
x_v = fluid.layers.cond(fluid.layers.mean(x_v)[0] > 5, lambda :
\n\
true_fn_1(x_v), lambda : false_fn_1(x_v))
\n\
if label is not None:
\n\
loss = fluid.layers.cross_entropy(x_v, label)
\n\
return loss
\n\
return x_v
\n
"
answer
=
get_source_code
(
StaticCode2
.
dyfunc_with_if_else
)
program_translator
=
ProgramTranslator
()
code
=
program_translator
.
get_code
(
dyfunc_with_if_else
)
self
.
assertEqual
(
answer
,
code
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录