未验证 提交 573d2faa 编写于 作者: L liym27 提交者: GitHub

fix bug in function `is_to_variable`. test=develop (#23147)

上级 23baf865
......@@ -45,8 +45,15 @@ def is_api_in_module(node, module_prefix):
assert isinstance(node, gast.Call), "Input non-Call node for is_dygraph_api"
func_str = astor.to_source(gast.gast_to_ast(node.func))
try:
# TODO(liym27):
# Consider a better to import modules like:
# source_file = inspect.getfile(dyfunc)
# import_statements = ImportVisitor(source_file).transform()
# import_str = "".join(import_statements)
import paddle.fluid as fluid
import paddle
from paddle.fluid.dygraph import to_variable
import paddle.fluid.dygraph as dygraph
return eval("_is_api_in_module_helper({}, '{}')".format(func_str,
module_prefix))
except NameError:
......@@ -148,8 +155,8 @@ def _add_keywords_to(node, dygraph_api_name):
def is_to_variable(node):
assert isinstance(node, gast.Call)
if is_dygraph_api(node):
api_name = node.func.attr
return api_name == "to_variable"
api_name = ast_to_source_code(node.func).strip()
return api_name.endswith("to_variable")
return False
......
......@@ -15,11 +15,13 @@
from __future__ import print_function
import numpy as np
import paddle.fluid as fluid
import unittest
import inspect
import gast
import paddle.fluid as fluid
import paddle.fluid.dygraph as dygraph
from paddle.fluid.dygraph import to_variable
from paddle.fluid.dygraph.jit import dygraph_to_static_graph
from paddle.fluid.dygraph.dygraph_to_static.utils import is_dygraph_api
......@@ -33,14 +35,21 @@ def dyfunc_to_variable(x):
def dyfunc_to_variable_2(x):
res = fluid.dygraph.to_variable(value=np.zeros(shape=(1), dtype=np.int32))
res = dygraph.to_variable(value=np.zeros(shape=(1), dtype=np.int32))
return res
def dyfunc_to_variable_3(x):
res = to_variable(x, name=None, zero_copy=None)
return res
class TestDygraphBasicApi_ToVariable(unittest.TestCase):
def setUp(self):
self.input = np.ones(5).astype("int32")
self.test_funcs = [dyfunc_to_variable, dyfunc_to_variable_2]
self.test_funcs = [
dyfunc_to_variable, dyfunc_to_variable_2, dyfunc_to_variable_3
]
self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda(
) else fluid.CPUPlace()
......@@ -72,8 +81,6 @@ class TestDygraphBasicApi_ToVariable(unittest.TestCase):
# 1. test Apis that inherit from layers.Layer
def dyfunc_BilinearTensorProduct(layer1, layer2):
bilinearTensorProduct = fluid.dygraph.nn.BilinearTensorProduct(
input1_dim=5,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册