From d636d72e132b47745426f712a954a5f2a79894f8 Mon Sep 17 00:00:00 2001 From: wangruting Date: Wed, 18 Jan 2023 06:28:30 +0000 Subject: [PATCH] modify name --- .../generator/eager_gen.py | 16 ++++++--- .../manual/backward/composite_backward_api.h | 33 ++++++++++++------- .../prim/api/manual/utils/eager_utils.cc | 12 ++++++- .../prim/api/manual/utils/static_utils.cc | 9 ++++- paddle/fluid/prim/api/manual/utils/utils.h | 5 +++ 5 files changed, 57 insertions(+), 18 deletions(-) diff --git a/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py b/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py index 414270a544..000c127f51 100644 --- a/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py +++ b/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py @@ -1839,6 +1839,13 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): False if self.composite_func_info == {} else True ) + if is_composite_grad_api: + next_grad_node_creation_str = f""" + if (!paddle::prim::PrimCommonUtils::IsPrimEnabled()) {{ + {next_grad_node_creation_str} + }} + """ + if next_node_generator is not None: has_higher_order_node = True return ( @@ -1850,9 +1857,6 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): next_node_generator.backward_forward_inputs_map, ) # TODO(Ruting):Integrate invoke and composite as composite so the rest branch canbe covered - # TODO(Ruting): modify next_grad_node_creation_str when Flags_prim_enable deleted in the future - # if is_composite_grad_api: - # next_grad_node_creation_str = '' elif not is_invoke_forward_api and not is_composite_grad_api: next_grad_node_creation_str = f""" if(trace_backward) {{ PADDLE_THROW(phi::errors::Unavailable( @@ -2275,7 +2279,11 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): # Prepare for Node Creation if Necessary outputs_autograd_meta_str = "" compute_require_next_grad_str = "" - if len(next_grad_node_creation_str) > 0 or is_invoke_forward_api: + if ( + len(next_grad_node_creation_str) > 0 + or is_invoke_forward_api + or inplace_for_grad_outs_str != '' + ): compute_require_next_grad_str = f"{indent}bool trace_backward = egr::Controller::Instance().HasGrad() && create_graph;\n" # 3. Get Output AutoGradMeta diff --git a/paddle/fluid/prim/api/manual/backward/composite_backward_api.h b/paddle/fluid/prim/api/manual/backward/composite_backward_api.h index 19898e0c56..fc276842b8 100644 --- a/paddle/fluid/prim/api/manual/backward/composite_backward_api.h +++ b/paddle/fluid/prim/api/manual/backward/composite_backward_api.h @@ -31,7 +31,7 @@ void tanh_grad(const Tensor& out, const Tensor& grad_out, Tensor* grad_x) { auto tmp = pow(out, 2.0); tmp = scale(tmp, -1.0, 1.0, true); auto grad_x_tmp = multiply(grad_out, tmp); - grad_x->set_impl(grad_x_tmp.impl()); + set_output(grad_x_tmp, grad_x); } template void subtract_grad(const Tensor& x, @@ -48,7 +48,8 @@ void subtract_grad(const Tensor& x, auto dy_reduce_res = sum(scale_out_grad, phi::vectorize(reduce_dim), y.dtype(), false); auto dy_tmp = reshape(dy_reduce_res, phi::vectorize(y.dims())); - dy->set_impl(dy_tmp.impl()); + set_output(dy_tmp, dy); + // dy->set_impl(dy_tmp.impl()); } else { by_pass(scale_out_grad, dy); } @@ -60,7 +61,8 @@ void subtract_grad(const Tensor& x, auto dx_reduce_res = sum(out_grad, phi::vectorize(reduce_dim), x.dtype(), false); auto dx_tmp = reshape(dx_reduce_res, phi::vectorize(x.dims())); - dx->set_impl(dx_tmp.impl()); + set_output(dx_tmp, dx); + // dx->set_impl(dx_tmp.impl()); } else { by_pass(out_grad, dx); } @@ -81,7 +83,8 @@ void add_grad(const Tensor& x, auto dy_reduce_res = sum(out_grad, phi::vectorize(reduce_dim), y.dtype(), false); auto dy_tmp = reshape(dy_reduce_res, phi::vectorize(y.dims())); - dy->set_impl(dy_tmp.impl()); + set_output(dy_tmp, dy); + // dy->set_impl(dy_tmp.impl()); } else { by_pass(out_grad, dy); } @@ -93,7 +96,8 @@ void add_grad(const Tensor& x, auto dx_reduce_res = sum(out_grad, phi::vectorize(reduce_dim), x.dtype(), false); auto dx_tmp = reshape(dx_reduce_res, phi::vectorize(x.dims())); - dx->set_impl(dx_tmp.impl()); + set_output(dx_tmp, dx); + // dx->set_impl(dx_tmp.impl()); } else { by_pass(out_grad, dx); } @@ -134,8 +138,8 @@ void sum_grad(const Tensor& x, } else { x_grad_tmp = expand(out_grad, x_dim); } - - x_grad->set_impl(x_grad_tmp.impl()); + set_output(x_grad_tmp, x_grad); + // x_grad->set_impl(x_grad_tmp.impl()); } template @@ -158,9 +162,11 @@ void divide_grad(const Tensor& x, auto dy_reduce_res = sum(dy_res, phi::vectorize(reduce_dim), y.dtype(), false); auto dy_tmp = reshape(dy_reduce_res, phi::vectorize(y.dims())); - dy->set_impl(dy_tmp.impl()); + set_output(dy_tmp, dy); + // dy->set_impl(dy_tmp.impl()); } else { - dy->set_impl(dy_res.impl()); + set_output(dy_res, dy); + // dy->set_impl(dy_res.impl()); } } // indicate we will compute dy if (dx) { @@ -174,9 +180,11 @@ void divide_grad(const Tensor& x, auto dx_reduce_res = sum(dx_res, phi::vectorize(reduce_dim), x.dtype(), false); auto dx_tmp = reshape(dx_reduce_res, phi::vectorize(x.dims())); - dx->set_impl(dx_tmp.impl()); + set_output(dx_tmp, dx); + // dx->set_impl(dx_tmp.impl()); } else { - dx->set_impl(dx_res.impl()); + set_output(dx_res, dx); + // dx->set_impl(dx_res.impl()); } } // indicate we will compute dx } @@ -187,7 +195,8 @@ void sqrt_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) { auto div_x = full(phi::vectorize(out.dims()), 0.5); auto tmp = divide(div_x, out); auto x_grad_tmp = multiply(out_grad, tmp); - x_grad->set_impl(x_grad_tmp.impl()); + set_output(x_grad_tmp, x_grad); + // x_grad->set_impl(x_grad_tmp.impl()); } } } // namespace prim diff --git a/paddle/fluid/prim/api/manual/utils/eager_utils.cc b/paddle/fluid/prim/api/manual/utils/eager_utils.cc index 96d0b4ea1f..dbf9615058 100644 --- a/paddle/fluid/prim/api/manual/utils/eager_utils.cc +++ b/paddle/fluid/prim/api/manual/utils/eager_utils.cc @@ -38,9 +38,19 @@ Tensor empty_like(const paddle::experimental::Tensor& x, } return empty_like_ad_func(x, dtype, place); } + +template <> +void set_output(const paddle::experimental::Tensor& x_tmp, + paddle::experimental::Tensor* x) { + x->set_impl(x_tmp.impl()); + x->set_autograd_meta(x_tmp.mutable_autograd_meta()); +} + template <> void by_pass(const paddle::experimental::Tensor& x, Tensor* out) { - out->set_impl(x.impl()); + set_output(x, out); + // out->set_impl(x.impl()); } + } // namespace prim } // namespace paddle diff --git a/paddle/fluid/prim/api/manual/utils/static_utils.cc b/paddle/fluid/prim/api/manual/utils/static_utils.cc index c90cfdb34f..4def77831b 100644 --- a/paddle/fluid/prim/api/manual/utils/static_utils.cc +++ b/paddle/fluid/prim/api/manual/utils/static_utils.cc @@ -47,6 +47,12 @@ Tensor empty_like(const Tensor& x, paddle::experimental::IntArray(x.shape()), x.dtype(), paddle::Place()); } +template <> +void set_output(const paddle::experimental::Tensor& x_tmp, + paddle::experimental::Tensor* x) { + x->set_impl(x_tmp.impl()); +} + template <> void by_pass(const paddle::experimental::Tensor& x, paddle::experimental::Tensor* out) { @@ -62,7 +68,8 @@ void by_pass(const paddle::experimental::Tensor& x, op->CheckAttrs(); op->InferVarType(block); op->InferShape(*block); - out->set_impl(new_out.impl()); + set_output(new_out, out); + // out->set_impl(new_out.impl()); } } // namespace prim diff --git a/paddle/fluid/prim/api/manual/utils/utils.h b/paddle/fluid/prim/api/manual/utils/utils.h index 2a77d0cffd..69d879e37b 100644 --- a/paddle/fluid/prim/api/manual/utils/utils.h +++ b/paddle/fluid/prim/api/manual/utils/utils.h @@ -36,6 +36,11 @@ paddle::experimental::Tensor empty_like(const paddle::experimental::Tensor& x, template void by_pass(const paddle::experimental::Tensor& x, paddle::experimental::Tensor* out); + +template +void set_output(const paddle::experimental::Tensor& x_tmp, + paddle::experimental::Tensor* x); + // These method don't need to be specified static phi::DDim get_reduce_dims(const phi::DDim& x_dims, const phi::DDim& y_dims) { -- GitLab