From 82630f383433b2741324fd40820e163268fadbd5 Mon Sep 17 00:00:00 2001 From: Huihuang Zheng Date: Fri, 4 Jun 2021 15:19:19 +0800 Subject: [PATCH] [Dy2stat] Add Support for paddle.grad (#33110) This PR made these changes to support double grad: 1. Translate `paddle.grad` to `paddle.static.gradients` to support double grad for dy2stat. 2. Fix IfElseTransformer bug which may not change value if "Store before Load" variable is in "Store" statement is in IfElse conditional statement 3. Add `DOut` to support double grad variables in `run_program_op` 4. Add support for renaming for double grads for `jit.save/load` --- paddle/fluid/operators/run_program_op.cc | 8 ++ paddle/fluid/operators/run_program_op.h | 14 ++- .../dygraph_to_static/ast_transformer.py | 2 + .../dygraph_to_static/grad_transformer.py | 87 ++++++++++++++ .../dygraph_to_static/ifelse_transformer.py | 10 +- .../dygraph_to_static/partial_program.py | 31 ++++- python/paddle/fluid/dygraph/io.py | 99 ++++++++++++++-- .../unittests/dygraph_to_static/test_grad.py | 111 ++++++++++++++++++ .../tests/unittests/test_run_program_op.py | 5 + 9 files changed, 341 insertions(+), 26 deletions(-) create mode 100644 python/paddle/fluid/dygraph/dygraph_to_static/grad_transformer.py create mode 100644 python/paddle/fluid/tests/unittests/dygraph_to_static/test_grad.py diff --git a/paddle/fluid/operators/run_program_op.cc b/paddle/fluid/operators/run_program_op.cc index 2d59971644..69b2c5b738 100644 --- a/paddle/fluid/operators/run_program_op.cc +++ b/paddle/fluid/operators/run_program_op.cc @@ -83,6 +83,13 @@ class RunProgramOpMaker : public framework::OpProtoAndCheckerMaker { "contains at most one scope." "NOTE: Do not use Scope directly because Scope output is not " "currently supported."); + AddOutput("DOut", + "(vector)" + "The output tensors for GRAD Tensors in RunProgram forward " + "operator, the forward operator contains GRAD Tensors when it " + "computes double grad.") + .AsDuplicable() + .AsDispensable(); AddAttr("global_block", "(BlockDesc *)" "The global block of executed program desc."); @@ -154,6 +161,7 @@ class RunProgramGradOpMaker : public framework::SingleGradOpMaker { grad_op->SetInput("Params", this->Input("Params")); grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); grad_op->SetInput("OutScope", this->Output("OutScope")); + grad_op->SetInput("DOut", this->Output("DOut")); grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); grad_op->SetOutput(framework::GradVarName("Params"), this->InputGrad("Params")); diff --git a/paddle/fluid/operators/run_program_op.h b/paddle/fluid/operators/run_program_op.h index f78f5c5b94..c7aeb0e145 100644 --- a/paddle/fluid/operators/run_program_op.h +++ b/paddle/fluid/operators/run_program_op.h @@ -131,6 +131,9 @@ static void ShareVarsIntoScope(const std::vector &vars, const std::vector &var_names, framework::Scope *scope) { for (size_t i = 0; i < vars.size(); ++i) { + if (var_names[i] == "Fake_var") { + continue; + } auto *var = scope->Var(var_names[i]); CheckInputVarStatus(*vars[i], var_names[i]); VariableShare(*vars[i], var); @@ -141,9 +144,9 @@ static void ShareVarsFromScope(const std::vector &vars, const std::vector &var_names, framework::Scope *scope) { for (size_t i = 0; i < vars.size(); ++i) { - if (var_names[i] == framework::kEmptyVarName) { - VLOG(2) << "find variable name is " << framework::kEmptyVarName - << ", skip it!"; + if (var_names[i] == framework::kEmptyVarName || + var_names[i] == "Fake_var") { + VLOG(2) << "find variable name is " << var_names[i] << ", skip it!"; continue; } // NOTE: Here skip not found var is dangerous, if a bug is caused here, @@ -170,9 +173,11 @@ class RunProgramOpKernel : public framework::OpKernel { auto &input_vars = ctx.MultiInputVar("X"); auto ¶m_vars = ctx.MultiInputVar("Params"); auto output_vars = ctx.MultiOutputVar("Out"); + auto dout_vars = ctx.MultiOutputVar("DOut"); auto input_var_names = ctx.InputNames("X"); auto output_var_names = ctx.OutputNames("Out"); + auto dout_var_names = ctx.OutputNames("DOut"); // current program may not hold parameters std::vector param_names; @@ -195,7 +200,7 @@ class RunProgramOpKernel : public framework::OpKernel { // Step 2. prepare executor and init persistable variables framework::Executor exe(ctx.GetPlace()); auto exe_ctx = framework::GetExecutorInfoFromCache( - exe, ctx, {output_var_names}, /*is_grad=*/false); + exe, ctx, {output_var_names, dout_var_names}, /*is_grad=*/false); // NOTE(Aurelius84): While training some models, forward can be called many // times and then apply backpropagation all at once, such as Reinforcement @@ -219,6 +224,7 @@ class RunProgramOpKernel : public framework::OpKernel { // Step 4. Get Output details::ShareVarsFromScope(output_vars, output_var_names, &scope); + details::ShareVarsFromScope(dout_vars, dout_var_names, &scope); // Debug info: scope info when run end VLOG(3) << framework::GenScopeTreeDebugInfo(out_scope_vec->front()); diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py index fa168a62de..29eee429ef 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py @@ -25,6 +25,7 @@ from paddle.fluid.dygraph.dygraph_to_static.break_continue_transformer import Br from paddle.fluid.dygraph.dygraph_to_static.break_continue_transformer import BreakTransformOptimizer from paddle.fluid.dygraph.dygraph_to_static.call_transformer import CallTransformer from paddle.fluid.dygraph.dygraph_to_static.cast_transformer import CastTransformer +from paddle.fluid.dygraph.dygraph_to_static.grad_transformer import GradTransformer from paddle.fluid.dygraph.dygraph_to_static.ifelse_transformer import IfElseTransformer from paddle.fluid.dygraph.dygraph_to_static.list_transformer import ListTransformer from paddle.fluid.dygraph.dygraph_to_static.logical_transformer import LogicalTransformer @@ -86,6 +87,7 @@ class DygraphToStaticAst(gast.NodeTransformer): PrintTransformer, # print statement CallTransformer, # transform call recursively CastTransformer, # type casting statement + GradTransformer, # transform paddle.grad to paddle.gradients ] for index, transformer in enumerate(transformers): diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/grad_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/grad_transformer.py new file mode 100644 index 0000000000..f7a59063ae --- /dev/null +++ b/python/paddle/fluid/dygraph/dygraph_to_static/grad_transformer.py @@ -0,0 +1,87 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import gast +import warnings + +from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper +from paddle.fluid.dygraph.dygraph_to_static import utils + + +class GradTransformer(gast.NodeTransformer): + """ + A class transforms dygraph paddle.grad to static graph paddle.gradients. The + transformation is applied to support double grad mode. + """ + + def __init__(self, wrapper_root): + assert isinstance( + wrapper_root, AstNodeWrapper + ), "Input non-AstNodeWrapper node for the initialization of GradTransformer." + self.wrapper_root = wrapper_root + self.root = wrapper_root.node + + def transform(self): + self.visit(self.root) + + def visit_Call(self, node): + self.generic_visit(node) + if not is_grad_api_node(node): + return node + + dygraph_grad_parameters = [ + "outputs", "inputs", "grad_outputs", "retain_graph", "create_graph", + "only_inputs", "allow_unused", "no_grad_vars" + ] + to_static_grad_param = { + "outputs": "targets", + "inputs": "inputs", + "grad_outputs": "target_gradients", + "no_grad_vars": "no_grad_set" + } + static_keywords = [] + + for kw in node.keywords: + if kw.arg not in dygraph_grad_parameters or kw.arg not in to_static_grad_param: + warnings.warn("paddle.grad has unsupported parameter in jit: " + + kw.arg + ", jit will discard it") + continue + dygraph_grad_parameters.remove(kw.arg) + kw.arg = to_static_grad_param[kw.arg] + static_keywords.append(kw) + + for i in range(len(node.args)): + arg_name = dygraph_grad_parameters[i] + if arg_name not in to_static_grad_param: + warnings.warn("paddle.grad has unsupported parameter in jit: " + + kw.arg + ", jit will discard it") + continue + kw = gast.keyword( + arg=to_static_grad_param[arg_name], value=node.args[i]) + static_keywords.append(kw) + + node.func = gast.parse('paddle.static.gradients').body[0].value + node.keywords = static_keywords + node.args = [] + return node + + +def is_grad_api_node(node): + assert isinstance(node, gast.Call) + api_name = utils.ast_to_source_code(node.func).strip() + if utils.is_paddle_api(node): + return api_name.endswith("grad") + return False diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py index de788487fe..5bc1c3d96d 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py @@ -402,7 +402,7 @@ def parse_cond_return(parent_vars_dict, if_vars_dict, else_vars_dict, var for var in _vars_with_store(child_dict) if var in parent_dict ]) - def _vars_loaded_before_store(ids_dict): + def _vars_loaded(ids_dict): """ gast.Param is also a kind of `load` semantic. """ @@ -411,8 +411,6 @@ def parse_cond_return(parent_vars_dict, if_vars_dict, else_vars_dict, for ctx in ctxs: if isinstance(ctx, (gast.Load, gast.Param)): new_dict[k].append(ctx) - elif isinstance(ctx, gast.Store): - break return new_dict # modified vars @@ -439,8 +437,12 @@ def parse_cond_return(parent_vars_dict, if_vars_dict, else_vars_dict, new_vars_in_body_and_orelse = body_new_vars & orelse_new_vars # 3. new var is created only in one of If.body or If.orelse node, and it used as gast.Load firstly after gast.If node. + # TODO(zhhsplendid): the _vars_loaded can be optimized as _vars_loaded_before_store. Because if a variable is stored before load, + # the value would change by the store statement, we don't have to return to change the value. However, analysis is + # complex because if the IfElse is nested and outer IfElse store statement may not run at all. We will put this optimization + # as the future TODO used_vars_after_ifelse = set( - [var for var in _vars_loaded_before_store(after_ifelse_vars_dict)]) + [var for var in _vars_loaded(after_ifelse_vars_dict)]) new_vars_to_create = new_vars_in_one_of_body_or_orelse & used_vars_after_ifelse | new_vars_in_body_and_orelse # 4. generate return_ids of if/else node. diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/partial_program.py b/python/paddle/fluid/dygraph/dygraph_to_static/partial_program.py index feb8b0f9c9..6eea883226 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/partial_program.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/partial_program.py @@ -135,6 +135,7 @@ class PartialProgramLayer(layers.Layer): self._origin_main_program = self._verify_program(main_program) self._inner_scope = core.Scope() # Set default mode to train + self._double_grads = self._get_double_grads(self._origin_main_program) self.training = True @LazyInitialized @@ -192,24 +193,44 @@ class PartialProgramLayer(layers.Layer): """ required_params = [] for param in self._params: + found_param = False for block in program.blocks: - if param.name in block.vars: - required_params.append(param) + for op in block.ops: + if param.name in op.input_arg_names or param.name in op.output_arg_names: + required_params.append(param) + found_param = True + break + if found_param: break self._params = required_params + def _get_double_grads(self, program): + double_grads = [] + for block in program.blocks: + for name in block.vars: + if "@GRAD" in name: + var_desc = block.vars[name].desc + var_base = core.VarBase(var_desc.dtype(), + var_desc.shape(), + var_desc.name(), + var_desc.type(), False) + double_grads.append(var_base) + return double_grads + def forward(self, inputs): in_vars, out_vars, tmp_scope_vec = self._prepare(inputs) - framework._dygraph_tracer().trace_op( type='run_program', inputs={ 'X': valid_vars(in_vars), 'Params': valid_vars(self._params) }, - outputs={'Out': valid_vars(out_vars), - 'OutScope': tmp_scope_vec}, + outputs={ + 'Out': valid_vars(out_vars), + 'OutScope': tmp_scope_vec, + 'DOut': valid_vars(self._double_grads) + }, attrs={ 'global_block': self.program.desc.block(0), 'start_op_index': 0, diff --git a/python/paddle/fluid/dygraph/io.py b/python/paddle/fluid/dygraph/io.py index 33eb16f1b2..d5ad3a88e8 100644 --- a/python/paddle/fluid/dygraph/io.py +++ b/python/paddle/fluid/dygraph/io.py @@ -166,29 +166,46 @@ def _get_loaded_var_new_old(program_desc, all_new_old_dict_all): def _rename_var_program_desc(program_desc, include=None, exclude=None): """ - Change the name of the loaded variables.Use 'unique_name.generate' to avoid duplication - e.g. linear_0.tmp_3 ==> linear_0.tmp_1, x ==> x_0. - If 'include' is not `None`,variables that are not in include are not renamed. - If 'exclude' is not `None`,variables that are in exclude will are not renamed. + Change the name of the loaded variables.Use 'unique_name.generate' to avoid duplication. + It is used when loading multiple program during inference. + + e.g. linear_0.tmp_3 ==> linear_0.tmp_1, x ==> x_0. For double grad, x@GRAD ==> x_0@GRAD + If 'include' is not `None`,variables in include and the corresponding + double grad variables (if exist) are renamed. + If 'exclude' is not `None`,variables that are in exclude and the + corresponding double grad variables (if exist) are not renamed. Args: program_desc(ProgramDesc):the variables in it will be modified. include(List):list of names of variables. exclude(List):list of names of variables. + + Returns: + tuple of (dict_rename_var_new_old, dict_rename_var_old_new) + dict_rename_var_new_old is a dict mapping from new name to old name + dict_rename_var_old_new is a dict mapping from old name to new name """ dict_rename_var_old_new = dict() dict_rename_var_new_old = dict() old_names = [] + # Store all old names for b_idx in six.moves.range(program_desc.num_blocks()): cur_block = program_desc.block(b_idx) for var in cur_block.all_vars(): old_names.append(var.name()) + + # Create dict_rename_var_new_old and dict_rename_var_old_new for non double + # grad variables + has_double_grad = False for b_idx in six.moves.range(program_desc.num_blocks()): cur_block = program_desc.block(b_idx) for var_idx, var in enumerate(cur_block.all_vars()): name_old = var.name() + is_double_grad_var = "@GRAD" in name_old + has_double_grad = has_double_grad or is_double_grad_var should_rename = (include is None or name_old in include) and ( - exclude is None or name_old not in exclude) + exclude is None or + name_old not in exclude) and not is_double_grad_var if should_rename: temp_name = name_old.split('_') if len(temp_name) > 1 and temp_name[-1].isnumeric(): @@ -206,9 +223,29 @@ def _rename_var_program_desc(program_desc, include=None, exclude=None): if name_old != name_new: cur_block._rename_var( cpt.to_bytes(name_old), cpt.to_bytes(name_new)) - dict_rename_var_old_new[name_old] = name_new - dict_rename_var_new_old[name_new] = name_old - + if not is_double_grad_var: + dict_rename_var_old_new[name_old] = name_new + dict_rename_var_new_old[name_new] = name_old + + # Handle double grad names + if has_double_grad: + double_grad_rename_dict = {} + for name_old in dict_rename_var_old_new: + for b_idx in six.moves.range(program_desc.num_blocks()): + cur_block = program_desc.block(b_idx) + for var_idx, var in enumerate(cur_block.all_vars()): + var_name = var.name() + if "@GRAD" in var_name and name_old in var_name: + new_var_name = var_name.replace( + name_old, dict_rename_var_old_new[name_old]) + double_grad_rename_dict[var_name] = new_var_name + for var_name in double_grad_rename_dict: + dict_rename_var_old_new[var_name] = double_grad_rename_dict[ + var_name] + dict_rename_var_new_old[double_grad_rename_dict[ + var_name]] = var_name + + # Rename on program desc for b_idx in six.moves.range(program_desc.num_blocks()): cur_block = program_desc.block(b_idx) for op_idx in six.moves.range(cur_block.op_size()): @@ -220,6 +257,11 @@ def _rename_var_program_desc(program_desc, include=None, exclude=None): op._rename_input( input_arg_name, dict_rename_var_old_new[input_arg_name]) + if cur_block.has_var(cpt.to_bytes(input_arg_name)): + cur_block._rename_var( + cpt.to_bytes(input_arg_name), + cpt.to_bytes(dict_rename_var_old_new[ + input_arg_name])) for output_arg_name in op.output_arg_names(): if output_arg_name in dict_rename_var_old_new: if output_arg_name != dict_rename_var_old_new[ @@ -227,6 +269,11 @@ def _rename_var_program_desc(program_desc, include=None, exclude=None): op._rename_output( output_arg_name, dict_rename_var_old_new[output_arg_name]) + if cur_block.has_var(cpt.to_bytes(output_arg_name)): + cur_block._rename_var( + cpt.to_bytes(output_arg_name), + cpt.to_bytes(dict_rename_var_old_new[ + output_arg_name])) program_desc.flush() return dict_rename_var_new_old, dict_rename_var_old_new @@ -267,9 +314,10 @@ class _ProgramHolder(object): def __init__(self, program_desc): super(_ProgramHolder, self).__init__() - # input, output, persistable var info + # input, output, persistable, double_grads var info self._input_descs = [] self._output_descs = [] + self._double_grad_descs = [] self._persistable_names = [] # execution scope @@ -277,7 +325,6 @@ class _ProgramHolder(object): # append suffix var name dict self._suffix_varname_dict = None - # forward program self._infer_program_desc = self._preprocess(program_desc) # forward + backward program @@ -304,6 +351,10 @@ class _ProgramHolder(object): def persistable_names(self): return self._persistable_names + @property + def double_grad_descs(self): + return self._double_grad_descs + @property def scope(self): return self._inner_scope @@ -347,6 +398,12 @@ class _ProgramHolder(object): for op_idx in reversed(ops_to_remove): root_block._remove_op(op_idx, op_idx + 1) + for i in range(program_desc.num_blocks()): + block_desc = program_desc.block(i) + for var_desc in block_desc.all_vars(): + if "@GRAD" in var_desc.name(): + self._double_grad_descs.append(var_desc) + # 2. Input processing, reverse feed vars self._input_descs.reverse() @@ -412,7 +469,6 @@ class _ProgramHolder(object): # rewrite a series of methods for append_backward for program_desc. # Therefore, in order to reuse the method of backward.py, build the program here. program = _build_program_by_desc(program_desc_copy) - # 3. Add the outputs which is only used for training and not saved in # inference program. for block_idx in six.moves.range(program.num_blocks): @@ -738,6 +794,20 @@ def _run_dygraph(instance, input, program_holder): core.VarDesc.VarType.STEP_SCOPES, True) tmp_scope_vec.value().set_scope(program_holder.scope) + double_grad_vars = [] + for var_desc in program_holder.double_grad_descs: + var = core.VarBase(var_desc.dtype(), + var_desc.shape(), + var_desc.name(), var_desc.type(), False) + double_grad_vars.append(var) + if len(double_grad_vars) == 0: + double_grad_vars = [ + core.VarBase( + value=[1], + name='Fake_var', + place=framework._current_expected_place()) + ] + # 2. run program by op trace_program = program_holder.infer_program if instance._is_test else program_holder.train_program end_op_index = program_holder.infer_program.block(0).op_size() @@ -745,8 +815,11 @@ def _run_dygraph(instance, input, program_holder): type='run_program', inputs={'X': input_vars, 'Params': persistable_vars}, - outputs={'Out': output_vars, - 'OutScope': tmp_scope_vec}, + outputs={ + 'Out': output_vars, + 'OutScope': tmp_scope_vec, + 'DOut': double_grad_vars + }, attrs={ 'global_block': trace_program.block(0), 'start_op_index': 0, diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_grad.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_grad.py new file mode 100644 index 0000000000..ab87beb9e1 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_grad.py @@ -0,0 +1,111 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import paddle +import unittest + + +class GradLayer(paddle.nn.Layer): + def __init__(self): + super(GradLayer, self).__init__() + + @paddle.jit.to_static + def forward(self, x): + x.stop_gradient = False + y = x * x + dx = paddle.grad(outputs=[y], inputs=[x])[0] + return dx + + +class GradLinearLayer(paddle.nn.Layer): + def __init__(self): + super(GradLinearLayer, self).__init__() + self.linear = paddle.nn.Linear(5, 5, bias_attr=False) + + @paddle.jit.to_static + def forward(self, x): + x.stop_gradient = False + tmp = x + x + for i in range(10): + tmp = self.linear(tmp) + out = tmp + dx = paddle.grad( + [out], [x], None, create_graph=True, allow_unused=False)[0] + return dx + + +class TestGrad(unittest.TestCase): + def setUp(self): + self.func = GradLayer() + self.x = paddle.ones(shape=[10, 2, 5], dtype='float32') + self.x.stop_gradient = False + + def _run(self, func, to_static): + prog_trans = paddle.jit.ProgramTranslator() + prog_trans.enable(to_static) + ret = func(self.x).numpy() + prog_trans.enable(True) + return ret + + def test_forward(self): + dygraph_res = self._run(self.func, to_static=False) + static_res = self._run(self.func, to_static=True) + self.assertTrue(np.allclose(static_res, dygraph_res)) + + +class TestGradLinear(TestGrad): + def setUp(self): + self.func = GradLinearLayer() + self.x = paddle.ones(shape=[10, 2, 5], dtype='float32') + self.x.stop_gradient = False + + def test_save_infer_program(self): + path = "double_grad_infer_model" + input_spec = [ + paddle.static.InputSpec( + shape=[10, 2, 5], dtype='float32') + ] + paddle.jit.save(self.func, path, input_spec=input_spec) + load_func = paddle.jit.load(path) + + origin_res = self.func(self.x).numpy() + load_res = load_func(self.x).numpy() + self.assertTrue(np.allclose(origin_res, load_res)) + + def test_save_train_program(self): + grad_clip = paddle.nn.ClipGradByGlobalNorm(2.0) + optimizer = paddle.optimizer.SGD(learning_rate=0.01, + grad_clip=grad_clip, + parameters=self.func.parameters()) + for i in range(10): + out = self.func(self.x) + avg_loss = paddle.mean(paddle.abs(out - 1)) + avg_loss.backward() + optimizer.minimize(avg_loss) + self.func.clear_gradients() + + path = "double_grad_train_model" + paddle.jit.save(self.func, path) + load_func = paddle.jit.load(path) + + origin_res = self.func(self.x).numpy() + load_res = load_func(self.x).numpy() + self.assertTrue(np.allclose(origin_res, load_res)) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_run_program_op.py b/python/paddle/fluid/tests/unittests/test_run_program_op.py index f6332859f9..81490642fa 100644 --- a/python/paddle/fluid/tests/unittests/test_run_program_op.py +++ b/python/paddle/fluid/tests/unittests/test_run_program_op.py @@ -19,10 +19,13 @@ import unittest import numpy as np import six +import paddle import paddle.fluid as fluid from paddle import compat as cpt from paddle.fluid import core, framework, executor +paddle.enable_static() + @contextlib.contextmanager def program_scope_guard(): @@ -164,6 +167,8 @@ class RunProgramOpTest(unittest.TestCase): persistable=True) inner_scope = core.Scope() outputs['OutScope'].value().set_scope(inner_scope) + + outputs['DOut'] = [create_var_base(False, "Fake_var")] return outputs def calc_dygraph_output(self, place): -- GitLab