From 573d2faacd262d99aafbe128c1b89331cc58d261 Mon Sep 17 00:00:00 2001 From: liym27 <33742067+liym27@users.noreply.github.com> Date: Tue, 24 Mar 2020 18:47:46 +0800 Subject: [PATCH] fix bug in function `is_to_variable`. test=develop (#23147) --- .../fluid/dygraph/dygraph_to_static/utils.py | 11 +++++++++-- .../test_basic_api_transformation.py | 17 ++++++++++++----- 2 files changed, 21 insertions(+), 7 deletions(-) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py index 66e45780e6..c0ae00e801 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py @@ -45,8 +45,15 @@ def is_api_in_module(node, module_prefix): assert isinstance(node, gast.Call), "Input non-Call node for is_dygraph_api" func_str = astor.to_source(gast.gast_to_ast(node.func)) try: + # TODO(liym27): + # Consider a better to import modules like: + # source_file = inspect.getfile(dyfunc) + # import_statements = ImportVisitor(source_file).transform() + # import_str = "".join(import_statements) import paddle.fluid as fluid import paddle + from paddle.fluid.dygraph import to_variable + import paddle.fluid.dygraph as dygraph return eval("_is_api_in_module_helper({}, '{}')".format(func_str, module_prefix)) except NameError: @@ -148,8 +155,8 @@ def _add_keywords_to(node, dygraph_api_name): def is_to_variable(node): assert isinstance(node, gast.Call) if is_dygraph_api(node): - api_name = node.func.attr - return api_name == "to_variable" + api_name = ast_to_source_code(node.func).strip() + return api_name.endswith("to_variable") return False diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_basic_api_transformation.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_basic_api_transformation.py index e490513468..f4f7a07d55 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_basic_api_transformation.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_basic_api_transformation.py @@ -15,11 +15,13 @@ from __future__ import print_function import numpy as np -import paddle.fluid as fluid import unittest import inspect import gast +import paddle.fluid as fluid +import paddle.fluid.dygraph as dygraph +from paddle.fluid.dygraph import to_variable from paddle.fluid.dygraph.jit import dygraph_to_static_graph from paddle.fluid.dygraph.dygraph_to_static.utils import is_dygraph_api @@ -33,14 +35,21 @@ def dyfunc_to_variable(x): def dyfunc_to_variable_2(x): - res = fluid.dygraph.to_variable(value=np.zeros(shape=(1), dtype=np.int32)) + res = dygraph.to_variable(value=np.zeros(shape=(1), dtype=np.int32)) + return res + + +def dyfunc_to_variable_3(x): + res = to_variable(x, name=None, zero_copy=None) return res class TestDygraphBasicApi_ToVariable(unittest.TestCase): def setUp(self): self.input = np.ones(5).astype("int32") - self.test_funcs = [dyfunc_to_variable, dyfunc_to_variable_2] + self.test_funcs = [ + dyfunc_to_variable, dyfunc_to_variable_2, dyfunc_to_variable_3 + ] self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda( ) else fluid.CPUPlace() @@ -72,8 +81,6 @@ class TestDygraphBasicApi_ToVariable(unittest.TestCase): # 1. test Apis that inherit from layers.Layer - - def dyfunc_BilinearTensorProduct(layer1, layer2): bilinearTensorProduct = fluid.dygraph.nn.BilinearTensorProduct( input1_dim=5, -- GitLab