diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/basic_api_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/basic_api_transformer.py index 82f39ffd080ec803beca4e60695204b707f48210..9334c15f7bcbc0ca3782be1d4f7fc6826a59bdbc 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/basic_api_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/basic_api_transformer.py @@ -16,9 +16,7 @@ import astor import gast from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper -from paddle.fluid.dygraph.dygraph_to_static.utils import is_dygraph_api, is_to_variable -from paddle.fluid.dygraph.dygraph_to_static.utils import to_assign_node, to_static_ast, update_args_of_func -from paddle.fluid.dygraph.dygraph_to_static.utils import dygraph_class_to_static_api +from paddle.fluid.dygraph.dygraph_to_static import utils class BasicApiTransformer(gast.NodeTransformer): @@ -56,7 +54,7 @@ class BasicApiTransformer(gast.NodeTransformer): if isinstance(child_node, gast.Call): # TODO(liym27): # Considers that a dygraph api which modifies the input or has a output. - if is_dygraph_api(child_node): + if utils.is_dygraph_api(child_node): return else: self._visit_Call(child_node) @@ -73,7 +71,7 @@ class BasicApiTransformer(gast.NodeTransformer): if self._is_dygraph_forward(func_name): class_node = self._get_class_node(func_name) - static_node = to_static_ast(node, class_node) + static_node = utils.to_static_ast(node, class_node) return static_node else: return node @@ -91,14 +89,51 @@ class BasicApiTransformer(gast.NodeTransformer): if is_to_variable(node_value): return False - if is_dygraph_api(node_value): + if utils.is_dygraph_api(node_value): dygraph_api = node_value.func.attr - if not dygraph_class_to_static_api.get(dygraph_api): + if not utils.dygraph_class_to_static_api.get(dygraph_api): return False - update_args_of_func(node_value, node_value, "__init__") + utils.update_args_of_func(node_value, node_value, "__init__") target_str = astor.to_source(gast.gast_to_ast(node.targets[0])) self.class_node_dict[target_str] = node_value return True # TODO: node.value is not dygraph class return False + + +def is_to_variable(node): + assert isinstance(node, gast.Call) + api_name = utils.ast_to_source_code(node.func).strip() + + if utils.is_dygraph_api(node): + return api_name.endswith("to_variable") + + if utils.is_paddle_api(node): + return api_name.endswith("to_tensor") + + return False + + +def to_assign_node(node): + # Transform dygraph api `fluid.dygraph.to_variable` alias `paddle.to_tensor` to static api `fluid.layers.assign`. + # NOTE: + # 1. Api `to_variable` supports data type {float16, float32, float64, int16, int32, int64, uint8, uint16}, + # but api `assign` only supports {float32, float64, int32, int64, bool}; + # 2. If the input of api `assign` is numpy.ndarray, its size cannot be greater than 1024 * 1024. + + assert isinstance(node, gast.Call) + assign_api = gast.parse('fluid.layers.assign').body[0].value + node.func = assign_api + + if node.args: + node.args = [node.args[0]] + node.keywords = [] + else: + for idx, kw in enumerate(node.keywords): + if kw.arg == 'value' or kw.arg == 'data': + node.keywords[idx].arg = 'input' + node.keywords = [node.keywords[idx]] + node.args = [] + break + return node diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py index f344ad2f7d7af00e6037b7552e258bf5c796a3b8..86593dc24aa8bda7906aab2001e8bd285f64288a 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py @@ -136,9 +136,12 @@ def is_api_in_module(node, module_prefix): # import_str = "".join(import_statements) import paddle import paddle.fluid as fluid + import paddle.fluid.dygraph as dygraph import paddle.fluid.layers as layers + from paddle.fluid.dygraph import to_variable - import paddle.fluid.dygraph as dygraph + from paddle import to_tensor + return eval("_is_api_in_module_helper({}, '{}')".format(func_str, module_prefix)) except NameError: @@ -146,15 +149,18 @@ def is_api_in_module(node, module_prefix): def is_dygraph_api(node): + # Note: A api in module dygraph_to_static is not a real dygraph api. if is_api_in_module(node, "paddle.fluid.dygraph.dygraph_to_static"): return False + # TODO(liym27): A better way to determine whether it is a dygraph api. + # Consider the decorator @dygraph_only return is_api_in_module(node, "paddle.fluid.dygraph") def is_paddle_api(node): - return is_api_in_module(node, "paddle.fluid") + return is_api_in_module(node, "paddle") # Is numpy_api cannot reuse is_api_in_module because of numpy module problem @@ -233,14 +239,6 @@ def _add_keywords_to(node, dygraph_api_name): return -def is_to_variable(node): - assert isinstance(node, gast.Call) - if is_dygraph_api(node): - api_name = ast_to_source_code(node.func).strip() - return api_name.endswith("to_variable") - return False - - def to_static_ast(node, class_node): assert isinstance(node, gast.Call) assert isinstance(class_node, gast.Call) @@ -268,29 +266,6 @@ def to_static_ast(node, class_node): return node -def to_assign_node(node): - # Transform dygraph api `fluid.dygraph.to_variable` to static api `fluid.layers.assign`. - # NOTE: - # 1. Api `to_variable` supports data type {float16, float32, float64, int16, int32, int64, uint8, uint16}, - # but api `assign` only supports {float32, float64, int32, int64, bool}; - # 2. If the input of api `assign` is numpy.ndarray, its size cannot be greater than 1024 * 1024. - assert isinstance(node, gast.Call) - assign_api = gast.parse('fluid.layers.assign').body[0].value - node.func = assign_api - - if node.args: - node.args = [node.args[0]] - node.keywords = [] - else: - for idx, kw in enumerate(node.keywords): - if kw.arg == 'value': - node.keywords[idx].arg = 'input' - node.keywords = [node.keywords[idx]] - node.args = [] - break - return node - - def update_args_of_func(node, dygraph_node, method_name): assert isinstance(node, gast.Call) if method_name not in ["__init__", "forward"]: diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_basic_api_transformation.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_basic_api_transformation.py index 3e6fe168b8eaf39286c518c8b4a2ad6d48b0e6bb..29b4f1b05f9c2911b849b323674b3a704a1da297 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_basic_api_transformation.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_basic_api_transformation.py @@ -19,9 +19,11 @@ import unittest import inspect import gast +import paddle import paddle.fluid as fluid import paddle.fluid.dygraph as dygraph +from paddle import to_tensor from paddle.fluid.dygraph import to_variable from paddle.fluid.dygraph.jit import dygraph_to_static_func from paddle.fluid.dygraph.dygraph_to_static.utils import is_dygraph_api @@ -45,11 +47,19 @@ def dyfunc_to_variable_3(x): return res +def dyfunc_to_tensor(x): + res1 = paddle.to_tensor(x, dtype=None, place=None, stop_gradient=True) + res2 = paddle.tensor.to_tensor(data=res1) + res3 = to_tensor(data=res2) + return res3 + + class TestDygraphBasicApi_ToVariable(unittest.TestCase): def setUp(self): self.input = np.ones(5).astype("int32") self.test_funcs = [ - dyfunc_to_variable, dyfunc_to_variable_2, dyfunc_to_variable_3 + dyfunc_to_tensor, dyfunc_to_variable, dyfunc_to_variable_2, + dyfunc_to_variable_3 ] self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda( ) else fluid.CPUPlace() diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_logging_utils.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_logging_utils.py index 214cd95d3bc620b3bcadb88e57c7e54a593eaaf4..e23e071219d17d28b7fca937f459b3df93f42f9a 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_logging_utils.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_logging_utils.py @@ -1,4 +1,4 @@ -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License.