diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py index 6305dcee70cca80915ad0ee5c113d82aa40cae04..b175dde68ead6affd32a2d35ed38504c237ae7ca 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py @@ -359,15 +359,20 @@ class BasicApiTransformer(gast.NodeTransformer): def _update_feed_dict(self, node): assert isinstance(node, gast.Call) - var_name = None + value_node = None for kw in node.keywords: if kw.arg == 'value': - var_name = kw.value.id # eg: 'a' for "value=a " - if not var_name: - var_name = node.args[0].id + value_node = kw.value # eg: `a` for "value=a " + if not value_node: + value_node = node.args[0] - feed_var_name = unique_name.generate(var_name) # eg: "a_0" - self.feed_name_to_arg_id[feed_var_name] = var_name # eg: "a_0" : "a" + if not isinstance(value_node, gast.Name): + return + else: + var_name = value_node.id + feed_var_name = unique_name.generate(var_name) # eg: "a_0" + self.feed_name_to_arg_id[ + feed_var_name] = var_name # eg: "a_0" : "a" def get_feed_name_to_arg_id(self): return self.feed_name_to_arg_id diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py index cf54043dd37a8bff017fdc4e559a4892601f7994..0a72881c2c4ed06fe2c232a422c2139250669322 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py @@ -150,11 +150,27 @@ def to_static_ast(node, class_node): return node -def to_assign_node(ori_node): - assert isinstance(ori_node, gast.Call) +def to_assign_node(node): + # Transform dygraph api `fluid.dygraph.to_variable` to static api `fluid.layers.assign`. + # NOTE: + # 1. Api `to_variable` supports data type {float16, float32, float64, int16, int32, int64, uint8, uint16}, + # but api `assign` only supports {float32, float64, int32, int64, bool}; + # 2. If the input of api `assign` is numpy.ndarray, its size cannot be greater than 1024 * 1024. + assert isinstance(node, gast.Call) assign_api = gast.parse('fluid.layers.assign').body[0].value - ori_node.func = assign_api - return ori_node + node.func = assign_api + + if node.args: + node.args = [node.args[0]] + node.keywords = [] + else: + for idx, kw in enumerate(node.keywords): + if kw.arg == 'value': + node.keywords[idx].arg = 'input' + node.keywords = [node.keywords[idx]] + node.args = [] + break + return node def update_args_of_func(node, dygraph_node, method_name): diff --git a/python/paddle/fluid/tests/unittests/test_dygraph_to_static_basic_api_transformation.py b/python/paddle/fluid/tests/unittests/test_dygraph_to_static_basic_api_transformation.py index e60cca70b479443e20bfe35b1bcbeaf7cc3ed9ab..38a10ecb5dc0d0cbb3ed9b627fe15ae02be6843d 100644 --- a/python/paddle/fluid/tests/unittests/test_dygraph_to_static_basic_api_transformation.py +++ b/python/paddle/fluid/tests/unittests/test_dygraph_to_static_basic_api_transformation.py @@ -28,19 +28,25 @@ np.random.seed(SEED) def dyfunc_to_variable(x): - res = fluid.dygraph.to_variable(x) + res = fluid.dygraph.to_variable(x, name=None, zero_copy=None) + return res + + +def dyfunc_to_variable_2(x): + res = fluid.dygraph.to_variable(value=np.zeros(shape=(1), dtype=np.int32)) return res class TestDygraphBasicApi_ToVariable(unittest.TestCase): def setUp(self): self.input = np.ones(5).astype("int32") - self.dygraph_func = dyfunc_to_variable + self.test_funcs = [dyfunc_to_variable, dyfunc_to_variable_2] + self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda( + ) else fluid.CPUPlace() def get_dygraph_output(self): with fluid.dygraph.guard(): res = self.dygraph_func(self.input).numpy() - return res def get_static_output(self): @@ -49,18 +55,20 @@ class TestDygraphBasicApi_ToVariable(unittest.TestCase): with fluid.program_guard(main_program): static_out = dygraph_to_static_graph(self.dygraph_func)(self.input) - exe = fluid.Executor(fluid.CPUPlace()) + exe = fluid.Executor(self.place) static_res = exe.run(main_program, fetch_list=static_out) return static_res[0] def test_transformed_static_result(self): - dygraph_res = self.get_dygraph_output() - static_res = self.get_static_output() - self.assertTrue( - np.allclose(dygraph_res, static_res), - msg='dygraph is {}\n static_res is {}'.format(dygraph_res, - static_res)) + for func in self.test_funcs: + self.dygraph_func = func + dygraph_res = self.get_dygraph_output() + static_res = self.get_static_output() + self.assertTrue( + np.allclose(dygraph_res, static_res), + msg='dygraph is {}\n static_res is {}'.format(dygraph_res, + static_res)) # 1. test Apis that inherit from layers.Layer