未验证 提交 573d2faa 编写于 作者: L liym27 提交者: GitHub

fix bug in function `is_to_variable`. test=develop (#23147)

上级 23baf865
...@@ -45,8 +45,15 @@ def is_api_in_module(node, module_prefix): ...@@ -45,8 +45,15 @@ def is_api_in_module(node, module_prefix):
assert isinstance(node, gast.Call), "Input non-Call node for is_dygraph_api" assert isinstance(node, gast.Call), "Input non-Call node for is_dygraph_api"
func_str = astor.to_source(gast.gast_to_ast(node.func)) func_str = astor.to_source(gast.gast_to_ast(node.func))
try: try:
# TODO(liym27):
# Consider a better to import modules like:
# source_file = inspect.getfile(dyfunc)
# import_statements = ImportVisitor(source_file).transform()
# import_str = "".join(import_statements)
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle import paddle
from paddle.fluid.dygraph import to_variable
import paddle.fluid.dygraph as dygraph
return eval("_is_api_in_module_helper({}, '{}')".format(func_str, return eval("_is_api_in_module_helper({}, '{}')".format(func_str,
module_prefix)) module_prefix))
except NameError: except NameError:
...@@ -148,8 +155,8 @@ def _add_keywords_to(node, dygraph_api_name): ...@@ -148,8 +155,8 @@ def _add_keywords_to(node, dygraph_api_name):
def is_to_variable(node): def is_to_variable(node):
assert isinstance(node, gast.Call) assert isinstance(node, gast.Call)
if is_dygraph_api(node): if is_dygraph_api(node):
api_name = node.func.attr api_name = ast_to_source_code(node.func).strip()
return api_name == "to_variable" return api_name.endswith("to_variable")
return False return False
......
...@@ -15,11 +15,13 @@ ...@@ -15,11 +15,13 @@
from __future__ import print_function from __future__ import print_function
import numpy as np import numpy as np
import paddle.fluid as fluid
import unittest import unittest
import inspect import inspect
import gast import gast
import paddle.fluid as fluid
import paddle.fluid.dygraph as dygraph
from paddle.fluid.dygraph import to_variable
from paddle.fluid.dygraph.jit import dygraph_to_static_graph from paddle.fluid.dygraph.jit import dygraph_to_static_graph
from paddle.fluid.dygraph.dygraph_to_static.utils import is_dygraph_api from paddle.fluid.dygraph.dygraph_to_static.utils import is_dygraph_api
...@@ -33,14 +35,21 @@ def dyfunc_to_variable(x): ...@@ -33,14 +35,21 @@ def dyfunc_to_variable(x):
def dyfunc_to_variable_2(x): def dyfunc_to_variable_2(x):
res = fluid.dygraph.to_variable(value=np.zeros(shape=(1), dtype=np.int32)) res = dygraph.to_variable(value=np.zeros(shape=(1), dtype=np.int32))
return res
def dyfunc_to_variable_3(x):
res = to_variable(x, name=None, zero_copy=None)
return res return res
class TestDygraphBasicApi_ToVariable(unittest.TestCase): class TestDygraphBasicApi_ToVariable(unittest.TestCase):
def setUp(self): def setUp(self):
self.input = np.ones(5).astype("int32") self.input = np.ones(5).astype("int32")
self.test_funcs = [dyfunc_to_variable, dyfunc_to_variable_2] self.test_funcs = [
dyfunc_to_variable, dyfunc_to_variable_2, dyfunc_to_variable_3
]
self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda( self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda(
) else fluid.CPUPlace() ) else fluid.CPUPlace()
...@@ -72,8 +81,6 @@ class TestDygraphBasicApi_ToVariable(unittest.TestCase): ...@@ -72,8 +81,6 @@ class TestDygraphBasicApi_ToVariable(unittest.TestCase):
# 1. test Apis that inherit from layers.Layer # 1. test Apis that inherit from layers.Layer
def dyfunc_BilinearTensorProduct(layer1, layer2): def dyfunc_BilinearTensorProduct(layer1, layer2):
bilinearTensorProduct = fluid.dygraph.nn.BilinearTensorProduct( bilinearTensorProduct = fluid.dygraph.nn.BilinearTensorProduct(
input1_dim=5, input1_dim=5,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册