未验证 提交 83baab9b 编写于 作者: L liym27 提交者: GitHub

[cherry-pick 2.0-beta][Dy2Stat] Transforme api 'to_tensor' to 'assign'. (#26873) (#27055)

Change-Id: Ic5b211f1bab42067715297fe58a78646e13e048d
上级 873da75a
...@@ -16,9 +16,7 @@ import astor ...@@ -16,9 +16,7 @@ import astor
import gast import gast
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper
from paddle.fluid.dygraph.dygraph_to_static.utils import is_dygraph_api, is_to_variable from paddle.fluid.dygraph.dygraph_to_static import utils
from paddle.fluid.dygraph.dygraph_to_static.utils import to_assign_node, to_static_ast, update_args_of_func
from paddle.fluid.dygraph.dygraph_to_static.utils import dygraph_class_to_static_api
class BasicApiTransformer(gast.NodeTransformer): class BasicApiTransformer(gast.NodeTransformer):
...@@ -56,7 +54,7 @@ class BasicApiTransformer(gast.NodeTransformer): ...@@ -56,7 +54,7 @@ class BasicApiTransformer(gast.NodeTransformer):
if isinstance(child_node, gast.Call): if isinstance(child_node, gast.Call):
# TODO(liym27): # TODO(liym27):
# Considers that a dygraph api which modifies the input or has a output. # Considers that a dygraph api which modifies the input or has a output.
if is_dygraph_api(child_node): if utils.is_dygraph_api(child_node):
return return
else: else:
self._visit_Call(child_node) self._visit_Call(child_node)
...@@ -73,7 +71,7 @@ class BasicApiTransformer(gast.NodeTransformer): ...@@ -73,7 +71,7 @@ class BasicApiTransformer(gast.NodeTransformer):
if self._is_dygraph_forward(func_name): if self._is_dygraph_forward(func_name):
class_node = self._get_class_node(func_name) class_node = self._get_class_node(func_name)
static_node = to_static_ast(node, class_node) static_node = utils.to_static_ast(node, class_node)
return static_node return static_node
else: else:
return node return node
...@@ -91,14 +89,51 @@ class BasicApiTransformer(gast.NodeTransformer): ...@@ -91,14 +89,51 @@ class BasicApiTransformer(gast.NodeTransformer):
if is_to_variable(node_value): if is_to_variable(node_value):
return False return False
if is_dygraph_api(node_value): if utils.is_dygraph_api(node_value):
dygraph_api = node_value.func.attr dygraph_api = node_value.func.attr
if not dygraph_class_to_static_api.get(dygraph_api): if not utils.dygraph_class_to_static_api.get(dygraph_api):
return False return False
update_args_of_func(node_value, node_value, "__init__") utils.update_args_of_func(node_value, node_value, "__init__")
target_str = astor.to_source(gast.gast_to_ast(node.targets[0])) target_str = astor.to_source(gast.gast_to_ast(node.targets[0]))
self.class_node_dict[target_str] = node_value self.class_node_dict[target_str] = node_value
return True return True
# TODO: node.value is not dygraph class # TODO: node.value is not dygraph class
return False return False
def is_to_variable(node):
assert isinstance(node, gast.Call)
api_name = utils.ast_to_source_code(node.func).strip()
if utils.is_dygraph_api(node):
return api_name.endswith("to_variable")
if utils.is_paddle_api(node):
return api_name.endswith("to_tensor")
return False
def to_assign_node(node):
# Transform dygraph api `fluid.dygraph.to_variable` alias `paddle.to_tensor` to static api `fluid.layers.assign`.
# NOTE:
# 1. Api `to_variable` supports data type {float16, float32, float64, int16, int32, int64, uint8, uint16},
# but api `assign` only supports {float32, float64, int32, int64, bool};
# 2. If the input of api `assign` is numpy.ndarray, its size cannot be greater than 1024 * 1024.
assert isinstance(node, gast.Call)
assign_api = gast.parse('fluid.layers.assign').body[0].value
node.func = assign_api
if node.args:
node.args = [node.args[0]]
node.keywords = []
else:
for idx, kw in enumerate(node.keywords):
if kw.arg == 'value' or kw.arg == 'data':
node.keywords[idx].arg = 'input'
node.keywords = [node.keywords[idx]]
node.args = []
break
return node
...@@ -136,9 +136,12 @@ def is_api_in_module(node, module_prefix): ...@@ -136,9 +136,12 @@ def is_api_in_module(node, module_prefix):
# import_str = "".join(import_statements) # import_str = "".join(import_statements)
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.dygraph as dygraph
import paddle.fluid.layers as layers import paddle.fluid.layers as layers
from paddle.fluid.dygraph import to_variable from paddle.fluid.dygraph import to_variable
import paddle.fluid.dygraph as dygraph from paddle import to_tensor
return eval("_is_api_in_module_helper({}, '{}')".format(func_str, return eval("_is_api_in_module_helper({}, '{}')".format(func_str,
module_prefix)) module_prefix))
except NameError: except NameError:
...@@ -146,15 +149,18 @@ def is_api_in_module(node, module_prefix): ...@@ -146,15 +149,18 @@ def is_api_in_module(node, module_prefix):
def is_dygraph_api(node): def is_dygraph_api(node):
# Note: A api in module dygraph_to_static is not a real dygraph api. # Note: A api in module dygraph_to_static is not a real dygraph api.
if is_api_in_module(node, "paddle.fluid.dygraph.dygraph_to_static"): if is_api_in_module(node, "paddle.fluid.dygraph.dygraph_to_static"):
return False return False
# TODO(liym27): A better way to determine whether it is a dygraph api.
# Consider the decorator @dygraph_only
return is_api_in_module(node, "paddle.fluid.dygraph") return is_api_in_module(node, "paddle.fluid.dygraph")
def is_paddle_api(node): def is_paddle_api(node):
return is_api_in_module(node, "paddle.fluid") return is_api_in_module(node, "paddle")
# Is numpy_api cannot reuse is_api_in_module because of numpy module problem # Is numpy_api cannot reuse is_api_in_module because of numpy module problem
...@@ -233,14 +239,6 @@ def _add_keywords_to(node, dygraph_api_name): ...@@ -233,14 +239,6 @@ def _add_keywords_to(node, dygraph_api_name):
return return
def is_to_variable(node):
assert isinstance(node, gast.Call)
if is_dygraph_api(node):
api_name = ast_to_source_code(node.func).strip()
return api_name.endswith("to_variable")
return False
def to_static_ast(node, class_node): def to_static_ast(node, class_node):
assert isinstance(node, gast.Call) assert isinstance(node, gast.Call)
assert isinstance(class_node, gast.Call) assert isinstance(class_node, gast.Call)
...@@ -268,29 +266,6 @@ def to_static_ast(node, class_node): ...@@ -268,29 +266,6 @@ def to_static_ast(node, class_node):
return node return node
def to_assign_node(node):
# Transform dygraph api `fluid.dygraph.to_variable` to static api `fluid.layers.assign`.
# NOTE:
# 1. Api `to_variable` supports data type {float16, float32, float64, int16, int32, int64, uint8, uint16},
# but api `assign` only supports {float32, float64, int32, int64, bool};
# 2. If the input of api `assign` is numpy.ndarray, its size cannot be greater than 1024 * 1024.
assert isinstance(node, gast.Call)
assign_api = gast.parse('fluid.layers.assign').body[0].value
node.func = assign_api
if node.args:
node.args = [node.args[0]]
node.keywords = []
else:
for idx, kw in enumerate(node.keywords):
if kw.arg == 'value':
node.keywords[idx].arg = 'input'
node.keywords = [node.keywords[idx]]
node.args = []
break
return node
def update_args_of_func(node, dygraph_node, method_name): def update_args_of_func(node, dygraph_node, method_name):
assert isinstance(node, gast.Call) assert isinstance(node, gast.Call)
if method_name not in ["__init__", "forward"]: if method_name not in ["__init__", "forward"]:
......
...@@ -19,9 +19,11 @@ import unittest ...@@ -19,9 +19,11 @@ import unittest
import inspect import inspect
import gast import gast
import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.dygraph as dygraph import paddle.fluid.dygraph as dygraph
from paddle import to_tensor
from paddle.fluid.dygraph import to_variable from paddle.fluid.dygraph import to_variable
from paddle.fluid.dygraph.jit import dygraph_to_static_func from paddle.fluid.dygraph.jit import dygraph_to_static_func
from paddle.fluid.dygraph.dygraph_to_static.utils import is_dygraph_api from paddle.fluid.dygraph.dygraph_to_static.utils import is_dygraph_api
...@@ -45,11 +47,19 @@ def dyfunc_to_variable_3(x): ...@@ -45,11 +47,19 @@ def dyfunc_to_variable_3(x):
return res return res
def dyfunc_to_tensor(x):
res1 = paddle.to_tensor(x, dtype=None, place=None, stop_gradient=True)
res2 = paddle.tensor.to_tensor(data=res1)
res3 = to_tensor(data=res2)
return res3
class TestDygraphBasicApi_ToVariable(unittest.TestCase): class TestDygraphBasicApi_ToVariable(unittest.TestCase):
def setUp(self): def setUp(self):
self.input = np.ones(5).astype("int32") self.input = np.ones(5).astype("int32")
self.test_funcs = [ self.test_funcs = [
dyfunc_to_variable, dyfunc_to_variable_2, dyfunc_to_variable_3 dyfunc_to_tensor, dyfunc_to_variable, dyfunc_to_variable_2,
dyfunc_to_variable_3
] ]
self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda( self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda(
) else fluid.CPUPlace() ) else fluid.CPUPlace()
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册