Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
2989c012
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
2989c012
编写于
7月 05, 2020
作者:
Z
Zhen Wang
提交者:
GitHub
7月 05, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[DygraphToStatic]Add cast transform for dygraph_to_static. (#25325)
* add cast transform and its UT for dygraph_to_static.
上级
bdad383c
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
246 addition
and
1 deletion
+246
-1
python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py
...paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py
+5
-1
python/paddle/fluid/dygraph/dygraph_to_static/cast_transformer.py
...addle/fluid/dygraph/dygraph_to_static/cast_transformer.py
+47
-0
python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py
...ddle/fluid/dygraph/dygraph_to_static/convert_operators.py
+21
-0
python/paddle/fluid/tests/unittests/dygraph_to_static/test_cast.py
...ddle/fluid/tests/unittests/dygraph_to_static/test_cast.py
+173
-0
未找到文件。
python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py
浏览文件 @
2989c012
...
@@ -21,9 +21,10 @@ from __future__ import print_function
...
@@ -21,9 +21,10 @@ from __future__ import print_function
import
gast
import
gast
from
paddle.fluid.dygraph.dygraph_to_static.assert_transformer
import
AssertTransformer
from
paddle.fluid.dygraph.dygraph_to_static.assert_transformer
import
AssertTransformer
from
paddle.fluid.dygraph.dygraph_to_static.call_transformer
import
CallTransformer
from
paddle.fluid.dygraph.dygraph_to_static.basic_api_transformer
import
BasicApiTransformer
from
paddle.fluid.dygraph.dygraph_to_static.basic_api_transformer
import
BasicApiTransformer
from
paddle.fluid.dygraph.dygraph_to_static.break_continue_transformer
import
BreakContinueTransformer
from
paddle.fluid.dygraph.dygraph_to_static.break_continue_transformer
import
BreakContinueTransformer
from
paddle.fluid.dygraph.dygraph_to_static.call_transformer
import
CallTransformer
from
paddle.fluid.dygraph.dygraph_to_static.cast_transformer
import
CastTransformer
from
paddle.fluid.dygraph.dygraph_to_static.ifelse_transformer
import
IfElseTransformer
from
paddle.fluid.dygraph.dygraph_to_static.ifelse_transformer
import
IfElseTransformer
from
paddle.fluid.dygraph.dygraph_to_static.list_transformer
import
ListTransformer
from
paddle.fluid.dygraph.dygraph_to_static.list_transformer
import
ListTransformer
from
paddle.fluid.dygraph.dygraph_to_static.logical_transformer
import
LogicalTransformer
from
paddle.fluid.dygraph.dygraph_to_static.logical_transformer
import
LogicalTransformer
...
@@ -93,6 +94,9 @@ class DygraphToStaticAst(gast.NodeTransformer):
...
@@ -93,6 +94,9 @@ class DygraphToStaticAst(gast.NodeTransformer):
# Transform call recursively
# Transform call recursively
CallTransformer
(
node_wrapper
).
transform
()
CallTransformer
(
node_wrapper
).
transform
()
# Transform python type casting statement
CastTransformer
(
node_wrapper
).
transform
()
def
visit_FunctionDef
(
self
,
node
):
def
visit_FunctionDef
(
self
,
node
):
if
self
.
decorate_func_name
is
None
:
if
self
.
decorate_func_name
is
None
:
self
.
decorate_func_name
=
node
.
name
self
.
decorate_func_name
=
node
.
name
...
...
python/paddle/fluid/dygraph/dygraph_to_static/cast_transformer.py
0 → 100644
浏览文件 @
2989c012
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
gast
from
paddle.fluid.dygraph.dygraph_to_static.static_analysis
import
AstNodeWrapper
from
paddle.fluid.dygraph.dygraph_to_static.utils
import
ast_to_source_code
class
CastTransformer
(
gast
.
NodeTransformer
):
"""
This class transforms type casting into Static Graph Ast.
"""
def
__init__
(
self
,
wrapper_root
):
assert
isinstance
(
wrapper_root
,
AstNodeWrapper
),
"Input non-AstNodeWrapper node for the initialization of CastTransformer."
self
.
_root
=
wrapper_root
.
node
self
.
_castable_type
=
{
'bool'
,
'int'
,
'float'
}
def
transform
(
self
):
self
.
visit
(
self
.
_root
)
def
visit_Call
(
self
,
node
):
self
.
generic_visit
(
node
)
func_str
=
ast_to_source_code
(
node
.
func
).
strip
()
if
func_str
in
self
.
_castable_type
and
len
(
node
.
args
)
>
0
:
args_str
=
ast_to_source_code
(
node
.
args
[
0
]).
strip
()
new_func_str
=
"fluid.dygraph.dygraph_to_static.convert_operators.convert_var_dtype({}, '{}')"
.
format
(
args_str
,
func_str
)
new_node
=
gast
.
parse
(
new_func_str
).
body
[
0
].
value
return
new_node
return
node
python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py
浏览文件 @
2989c012
...
@@ -238,3 +238,24 @@ def cast_bool_if_necessary(var):
...
@@ -238,3 +238,24 @@ def cast_bool_if_necessary(var):
if
convert_dtype
(
var
.
dtype
)
not
in
[
'bool'
]:
if
convert_dtype
(
var
.
dtype
)
not
in
[
'bool'
]:
var
=
cast
(
var
,
dtype
=
"bool"
)
var
=
cast
(
var
,
dtype
=
"bool"
)
return
var
return
var
def
convert_var_dtype
(
var
,
dtype
):
if
isinstance
(
var
,
Variable
):
src_dtype
=
convert_dtype
(
var
.
dtype
)
assert
src_dtype
in
[
'bool'
,
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
,
'uint8'
],
"The dtype of var {} is {}, which is not supported in the cast op."
.
format
(
var
.
name
,
src_dtype
)
assert
dtype
in
[
'bool'
,
'int'
,
'float'
],
"The casted target dtype is {}, which is not supported in type casting."
.
format
(
dtype
)
cast_map
=
{
'bool'
:
'bool'
,
'int'
:
'int32'
,
'float'
:
'float32'
,
}
return
cast
(
var
,
dtype
=
cast_map
[
dtype
])
else
:
return
eval
(
'{}(var)'
.
format
(
dtype
))
python/paddle/fluid/tests/unittests/dygraph_to_static/test_cast.py
0 → 100644
浏览文件 @
2989c012
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
import
numpy
as
np
import
paddle.fluid
as
fluid
from
paddle.fluid.dygraph
import
declarative
SEED
=
2020
np
.
random
.
seed
(
SEED
)
@
declarative
def
test_bool_cast
(
x
):
x
=
fluid
.
dygraph
.
to_variable
(
x
)
x
=
bool
(
x
)
return
x
@
declarative
def
test_int_cast
(
x
):
x
=
fluid
.
dygraph
.
to_variable
(
x
)
x
=
int
(
x
)
return
x
@
declarative
def
test_float_cast
(
x
):
x
=
fluid
.
dygraph
.
to_variable
(
x
)
x
=
float
(
x
)
return
x
@
declarative
def
test_not_var_cast
(
x
):
x
=
int
(
x
)
return
x
@
declarative
def
test_mix_cast
(
x
):
x
=
fluid
.
dygraph
.
to_variable
(
x
)
x
=
int
(
x
)
x
=
float
(
x
)
x
=
bool
(
x
)
x
=
float
(
x
)
return
x
class
TestCastBase
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
place
=
fluid
.
CUDAPlace
(
0
)
if
fluid
.
is_compiled_with_cuda
(
)
else
fluid
.
CPUPlace
()
self
.
prepare
()
self
.
set_func
()
def
prepare
(
self
):
self
.
input_shape
=
(
16
,
32
)
self
.
input_dtype
=
'float32'
self
.
input
=
np
.
random
.
binomial
(
4
,
0.3
,
size
=
np
.
product
(
self
.
input_shape
)).
reshape
(
self
.
input_shape
).
astype
(
self
.
input_dtype
)
self
.
cast_dtype
=
'bool'
def
set_func
(
self
):
self
.
func
=
test_bool_cast
def
do_test
(
self
):
with
fluid
.
dygraph
.
guard
():
res
=
self
.
func
(
self
.
input
)
return
res
def
test_cast_result
(
self
):
res
=
self
.
do_test
().
numpy
()
self
.
assertTrue
(
res
.
dtype
==
self
.
cast_dtype
,
msg
=
'The target dtype is {}, but the casted dtype is {}.'
.
format
(
self
.
cast_dtype
,
res
.
dtype
))
ref_val
=
self
.
input
.
astype
(
self
.
cast_dtype
)
self
.
assertTrue
(
np
.
allclose
(
res
,
ref_val
),
msg
=
'The casted value is {}.
\n
The correct value is {}.'
.
format
(
res
,
ref_val
))
class
TestIntCast
(
TestCastBase
):
def
prepare
(
self
):
self
.
input_shape
=
(
1
,
)
self
.
input_dtype
=
'float32'
self
.
input
=
np
.
random
.
normal
(
loc
=
6
,
scale
=
10
,
size
=
np
.
product
(
self
.
input_shape
)).
reshape
(
self
.
input_shape
).
astype
(
self
.
input_dtype
)
self
.
cast_dtype
=
'int32'
def
set_func
(
self
):
self
.
func
=
test_int_cast
class
TestFloatCast
(
TestCastBase
):
def
prepare
(
self
):
self
.
input_shape
=
(
8
,
16
)
self
.
input_dtype
=
'bool'
self
.
input
=
np
.
random
.
binomial
(
2
,
0.5
,
size
=
np
.
product
(
self
.
input_shape
)).
reshape
(
self
.
input_shape
).
astype
(
self
.
input_dtype
)
self
.
cast_dtype
=
'float32'
def
set_func
(
self
):
self
.
func
=
test_float_cast
class
TestMixCast
(
TestCastBase
):
def
prepare
(
self
):
self
.
input_shape
=
(
8
,
32
)
self
.
input_dtype
=
'float32'
self
.
input
=
np
.
random
.
normal
(
loc
=
6
,
scale
=
10
,
size
=
np
.
product
(
self
.
input_shape
)).
reshape
(
self
.
input_shape
).
astype
(
self
.
input_dtype
)
self
.
cast_int
=
'int'
self
.
cast_float
=
'float32'
self
.
cast_bool
=
'bool'
self
.
cast_dtype
=
'float32'
def
set_func
(
self
):
self
.
func
=
test_mix_cast
def
test_cast_result
(
self
):
res
=
self
.
do_test
().
numpy
()
self
.
assertTrue
(
res
.
dtype
==
self
.
cast_dtype
,
msg
=
'The target dtype is {}, but the casted dtype is {}.'
.
format
(
self
.
cast_dtype
,
res
.
dtype
))
ref_val
=
self
.
input
.
astype
(
self
.
cast_int
).
astype
(
self
.
cast_float
).
astype
(
self
.
cast_bool
).
astype
(
self
.
cast_dtype
)
self
.
assertTrue
(
np
.
allclose
(
res
,
ref_val
),
msg
=
'The casted value is {}.
\n
The correct value is {}.'
.
format
(
res
,
ref_val
))
class
TestNotVarCast
(
TestCastBase
):
def
prepare
(
self
):
self
.
input
=
3.14
self
.
cast_dtype
=
'int'
def
set_func
(
self
):
self
.
func
=
test_not_var_cast
def
test_cast_result
(
self
):
res
=
self
.
do_test
()
self
.
assertTrue
(
type
(
res
)
==
int
,
msg
=
'The casted dtype is not int.'
)
ref_val
=
int
(
self
.
input
)
self
.
assertTrue
(
res
==
ref_val
,
msg
=
'The casted value is {}.
\n
The correct value is {}.'
.
format
(
res
,
ref_val
))
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录