未验证 提交 f9d39b49 编写于 作者: L liym27 提交者: GitHub

[Dy2Stat] Transforme api 'to_tensor' to 'assign'. (#26873)

上级 932bbe95
......@@ -16,9 +16,7 @@ import astor
import gast
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper
from paddle.fluid.dygraph.dygraph_to_static.utils import is_dygraph_api, is_to_variable
from paddle.fluid.dygraph.dygraph_to_static.utils import to_assign_node, to_static_ast, update_args_of_func
from paddle.fluid.dygraph.dygraph_to_static.utils import dygraph_class_to_static_api
from paddle.fluid.dygraph.dygraph_to_static import utils
class BasicApiTransformer(gast.NodeTransformer):
......@@ -56,7 +54,7 @@ class BasicApiTransformer(gast.NodeTransformer):
if isinstance(child_node, gast.Call):
# TODO(liym27):
# Considers that a dygraph api which modifies the input or has a output.
if is_dygraph_api(child_node):
if utils.is_dygraph_api(child_node):
return
else:
self._visit_Call(child_node)
......@@ -73,7 +71,7 @@ class BasicApiTransformer(gast.NodeTransformer):
if self._is_dygraph_forward(func_name):
class_node = self._get_class_node(func_name)
static_node = to_static_ast(node, class_node)
static_node = utils.to_static_ast(node, class_node)
return static_node
else:
return node
......@@ -91,14 +89,51 @@ class BasicApiTransformer(gast.NodeTransformer):
if is_to_variable(node_value):
return False
if is_dygraph_api(node_value):
if utils.is_dygraph_api(node_value):
dygraph_api = node_value.func.attr
if not dygraph_class_to_static_api.get(dygraph_api):
if not utils.dygraph_class_to_static_api.get(dygraph_api):
return False
update_args_of_func(node_value, node_value, "__init__")
utils.update_args_of_func(node_value, node_value, "__init__")
target_str = astor.to_source(gast.gast_to_ast(node.targets[0]))
self.class_node_dict[target_str] = node_value
return True
# TODO: node.value is not dygraph class
return False
def is_to_variable(node):
assert isinstance(node, gast.Call)
api_name = utils.ast_to_source_code(node.func).strip()
if utils.is_dygraph_api(node):
return api_name.endswith("to_variable")
if utils.is_paddle_api(node):
return api_name.endswith("to_tensor")
return False
def to_assign_node(node):
# Transform dygraph api `fluid.dygraph.to_variable` alias `paddle.to_tensor` to static api `fluid.layers.assign`.
# NOTE:
# 1. Api `to_variable` supports data type {float16, float32, float64, int16, int32, int64, uint8, uint16},
# but api `assign` only supports {float32, float64, int32, int64, bool};
# 2. If the input of api `assign` is numpy.ndarray, its size cannot be greater than 1024 * 1024.
assert isinstance(node, gast.Call)
assign_api = gast.parse('fluid.layers.assign').body[0].value
node.func = assign_api
if node.args:
node.args = [node.args[0]]
node.keywords = []
else:
for idx, kw in enumerate(node.keywords):
if kw.arg == 'value' or kw.arg == 'data':
node.keywords[idx].arg = 'input'
node.keywords = [node.keywords[idx]]
node.args = []
break
return node
......@@ -136,9 +136,12 @@ def is_api_in_module(node, module_prefix):
# import_str = "".join(import_statements)
import paddle
import paddle.fluid as fluid
import paddle.fluid.dygraph as dygraph
import paddle.fluid.layers as layers
from paddle.fluid.dygraph import to_variable
import paddle.fluid.dygraph as dygraph
from paddle import to_tensor
return eval("_is_api_in_module_helper({}, '{}')".format(func_str,
module_prefix))
except NameError:
......@@ -146,15 +149,18 @@ def is_api_in_module(node, module_prefix):
def is_dygraph_api(node):
# Note: A api in module dygraph_to_static is not a real dygraph api.
if is_api_in_module(node, "paddle.fluid.dygraph.dygraph_to_static"):
return False
# TODO(liym27): A better way to determine whether it is a dygraph api.
# Consider the decorator @dygraph_only
return is_api_in_module(node, "paddle.fluid.dygraph")
def is_paddle_api(node):
return is_api_in_module(node, "paddle.fluid")
return is_api_in_module(node, "paddle")
# Is numpy_api cannot reuse is_api_in_module because of numpy module problem
......@@ -233,14 +239,6 @@ def _add_keywords_to(node, dygraph_api_name):
return
def is_to_variable(node):
assert isinstance(node, gast.Call)
if is_dygraph_api(node):
api_name = ast_to_source_code(node.func).strip()
return api_name.endswith("to_variable")
return False
def to_static_ast(node, class_node):
assert isinstance(node, gast.Call)
assert isinstance(class_node, gast.Call)
......@@ -268,29 +266,6 @@ def to_static_ast(node, class_node):
return node
def to_assign_node(node):
# Transform dygraph api `fluid.dygraph.to_variable` to static api `fluid.layers.assign`.
# NOTE:
# 1. Api `to_variable` supports data type {float16, float32, float64, int16, int32, int64, uint8, uint16},
# but api `assign` only supports {float32, float64, int32, int64, bool};
# 2. If the input of api `assign` is numpy.ndarray, its size cannot be greater than 1024 * 1024.
assert isinstance(node, gast.Call)
assign_api = gast.parse('fluid.layers.assign').body[0].value
node.func = assign_api
if node.args:
node.args = [node.args[0]]
node.keywords = []
else:
for idx, kw in enumerate(node.keywords):
if kw.arg == 'value':
node.keywords[idx].arg = 'input'
node.keywords = [node.keywords[idx]]
node.args = []
break
return node
def update_args_of_func(node, dygraph_node, method_name):
assert isinstance(node, gast.Call)
if method_name not in ["__init__", "forward"]:
......
......@@ -19,9 +19,11 @@ import unittest
import inspect
import gast
import paddle
import paddle.fluid as fluid
import paddle.fluid.dygraph as dygraph
from paddle import to_tensor
from paddle.fluid.dygraph import to_variable
from paddle.fluid.dygraph.jit import dygraph_to_static_func
from paddle.fluid.dygraph.dygraph_to_static.utils import is_dygraph_api
......@@ -45,11 +47,19 @@ def dyfunc_to_variable_3(x):
return res
def dyfunc_to_tensor(x):
res1 = paddle.to_tensor(x, dtype=None, place=None, stop_gradient=True)
res2 = paddle.tensor.to_tensor(data=res1)
res3 = to_tensor(data=res2)
return res3
class TestDygraphBasicApi_ToVariable(unittest.TestCase):
def setUp(self):
self.input = np.ones(5).astype("int32")
self.test_funcs = [
dyfunc_to_variable, dyfunc_to_variable_2, dyfunc_to_variable_3
dyfunc_to_tensor, dyfunc_to_variable, dyfunc_to_variable_2,
dyfunc_to_variable_3
]
self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda(
) else fluid.CPUPlace()
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册