未验证 提交 b290420f 编写于 作者: L liym27 提交者: GitHub

fix bug in the transformation from to_variable to assign. test=develop (#22885)

上级 ca9c8b41
...@@ -359,15 +359,20 @@ class BasicApiTransformer(gast.NodeTransformer): ...@@ -359,15 +359,20 @@ class BasicApiTransformer(gast.NodeTransformer):
def _update_feed_dict(self, node): def _update_feed_dict(self, node):
assert isinstance(node, gast.Call) assert isinstance(node, gast.Call)
var_name = None value_node = None
for kw in node.keywords: for kw in node.keywords:
if kw.arg == 'value': if kw.arg == 'value':
var_name = kw.value.id # eg: 'a' for "value=a " value_node = kw.value # eg: `a` for "value=a "
if not var_name: if not value_node:
var_name = node.args[0].id value_node = node.args[0]
feed_var_name = unique_name.generate(var_name) # eg: "a_0" if not isinstance(value_node, gast.Name):
self.feed_name_to_arg_id[feed_var_name] = var_name # eg: "a_0" : "a" return
else:
var_name = value_node.id
feed_var_name = unique_name.generate(var_name) # eg: "a_0"
self.feed_name_to_arg_id[
feed_var_name] = var_name # eg: "a_0" : "a"
def get_feed_name_to_arg_id(self): def get_feed_name_to_arg_id(self):
return self.feed_name_to_arg_id return self.feed_name_to_arg_id
......
...@@ -150,11 +150,27 @@ def to_static_ast(node, class_node): ...@@ -150,11 +150,27 @@ def to_static_ast(node, class_node):
return node return node
def to_assign_node(ori_node): def to_assign_node(node):
assert isinstance(ori_node, gast.Call) # Transform dygraph api `fluid.dygraph.to_variable` to static api `fluid.layers.assign`.
# NOTE:
# 1. Api `to_variable` supports data type {float16, float32, float64, int16, int32, int64, uint8, uint16},
# but api `assign` only supports {float32, float64, int32, int64, bool};
# 2. If the input of api `assign` is numpy.ndarray, its size cannot be greater than 1024 * 1024.
assert isinstance(node, gast.Call)
assign_api = gast.parse('fluid.layers.assign').body[0].value assign_api = gast.parse('fluid.layers.assign').body[0].value
ori_node.func = assign_api node.func = assign_api
return ori_node
if node.args:
node.args = [node.args[0]]
node.keywords = []
else:
for idx, kw in enumerate(node.keywords):
if kw.arg == 'value':
node.keywords[idx].arg = 'input'
node.keywords = [node.keywords[idx]]
node.args = []
break
return node
def update_args_of_func(node, dygraph_node, method_name): def update_args_of_func(node, dygraph_node, method_name):
......
...@@ -28,19 +28,25 @@ np.random.seed(SEED) ...@@ -28,19 +28,25 @@ np.random.seed(SEED)
def dyfunc_to_variable(x): def dyfunc_to_variable(x):
res = fluid.dygraph.to_variable(x) res = fluid.dygraph.to_variable(x, name=None, zero_copy=None)
return res
def dyfunc_to_variable_2(x):
res = fluid.dygraph.to_variable(value=np.zeros(shape=(1), dtype=np.int32))
return res return res
class TestDygraphBasicApi_ToVariable(unittest.TestCase): class TestDygraphBasicApi_ToVariable(unittest.TestCase):
def setUp(self): def setUp(self):
self.input = np.ones(5).astype("int32") self.input = np.ones(5).astype("int32")
self.dygraph_func = dyfunc_to_variable self.test_funcs = [dyfunc_to_variable, dyfunc_to_variable_2]
self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda(
) else fluid.CPUPlace()
def get_dygraph_output(self): def get_dygraph_output(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
res = self.dygraph_func(self.input).numpy() res = self.dygraph_func(self.input).numpy()
return res return res
def get_static_output(self): def get_static_output(self):
...@@ -49,18 +55,20 @@ class TestDygraphBasicApi_ToVariable(unittest.TestCase): ...@@ -49,18 +55,20 @@ class TestDygraphBasicApi_ToVariable(unittest.TestCase):
with fluid.program_guard(main_program): with fluid.program_guard(main_program):
static_out = dygraph_to_static_graph(self.dygraph_func)(self.input) static_out = dygraph_to_static_graph(self.dygraph_func)(self.input)
exe = fluid.Executor(fluid.CPUPlace()) exe = fluid.Executor(self.place)
static_res = exe.run(main_program, fetch_list=static_out) static_res = exe.run(main_program, fetch_list=static_out)
return static_res[0] return static_res[0]
def test_transformed_static_result(self): def test_transformed_static_result(self):
dygraph_res = self.get_dygraph_output() for func in self.test_funcs:
static_res = self.get_static_output() self.dygraph_func = func
self.assertTrue( dygraph_res = self.get_dygraph_output()
np.allclose(dygraph_res, static_res), static_res = self.get_static_output()
msg='dygraph is {}\n static_res is {}'.format(dygraph_res, self.assertTrue(
static_res)) np.allclose(dygraph_res, static_res),
msg='dygraph is {}\n static_res is {}'.format(dygraph_res,
static_res))
# 1. test Apis that inherit from layers.Layer # 1. test Apis that inherit from layers.Layer
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册