diff --git a/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py b/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py index 414270a54456eebc8af4c37d18cc9f9132e8f298..17ea95e3f4babd6b8fc766244eb057c078b2cd87 100644 --- a/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py +++ b/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py @@ -1839,6 +1839,13 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): False if self.composite_func_info == {} else True ) + if is_composite_grad_api and next_grad_node_creation_str != '': + next_grad_node_creation_str = f""" + if (!paddle::prim::PrimCommonUtils::IsPrimEnabled()) {{ + {next_grad_node_creation_str} + }} + """ + if next_node_generator is not None: has_higher_order_node = True return ( @@ -1850,9 +1857,6 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): next_node_generator.backward_forward_inputs_map, ) # TODO(Ruting):Integrate invoke and composite as composite so the rest branch canbe covered - # TODO(Ruting): modify next_grad_node_creation_str when Flags_prim_enable deleted in the future - # if is_composite_grad_api: - # next_grad_node_creation_str = '' elif not is_invoke_forward_api and not is_composite_grad_api: next_grad_node_creation_str = f""" if(trace_backward) {{ PADDLE_THROW(phi::errors::Unavailable( @@ -1978,6 +1982,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): backward_attrs_list = self.backward_attrs_list backward_inplace_map = self.backward_inplace_map indent = GetIndent(1) + need_gen_trace_backard_for_inplace = False # Construct grad_api function args # Order: TensorWrappers, GradTensors, Attributes @@ -2207,6 +2212,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): }} else {{ {inplace_str} }}""" + need_gen_trace_backard_for_inplace = True else: inplace_for_grad_outs_str += inplace_str @@ -2275,7 +2281,11 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): # Prepare for Node Creation if Necessary outputs_autograd_meta_str = "" compute_require_next_grad_str = "" - if len(next_grad_node_creation_str) > 0 or is_invoke_forward_api: + if ( + len(next_grad_node_creation_str) > 0 + or is_invoke_forward_api + or need_gen_trace_backard_for_inplace + ): compute_require_next_grad_str = f"{indent}bool trace_backward = egr::Controller::Instance().HasGrad() && create_graph;\n" # 3. Get Output AutoGradMeta diff --git a/paddle/fluid/prim/api/manual/backward/composite_backward_api.h b/paddle/fluid/prim/api/manual/backward/composite_backward_api.h index 6ddec5b4e292522f1a8ae4b326ba56a00b8e4c6f..9c12de9fe56607aeca3487c657a677ec0bf83da4 100644 --- a/paddle/fluid/prim/api/manual/backward/composite_backward_api.h +++ b/paddle/fluid/prim/api/manual/backward/composite_backward_api.h @@ -32,7 +32,7 @@ void tanh_grad(const Tensor& out, const Tensor& grad_out, Tensor* grad_x) { auto tmp = pow(out, 2.0); tmp = scale(tmp, -1.0, 1.0, true); auto grad_x_tmp = multiply(grad_out, tmp); - grad_x->set_impl(grad_x_tmp.impl()); + set_output(grad_x_tmp, grad_x); } template @@ -53,7 +53,7 @@ void subtract_grad(const Tensor& x, auto dy_reduce_res = sum( scale_out_grad, phi::vectorize(reduce_dim), y.dtype(), false); auto dy_tmp = reshape(dy_reduce_res, phi::vectorize(y.dims())); - dy->set_impl(dy_tmp.impl()); + set_output(dy_tmp, dy); } } else { by_pass(scale_out_grad, dy); @@ -69,7 +69,7 @@ void subtract_grad(const Tensor& x, auto dx_reduce_res = sum(out_grad, phi::vectorize(reduce_dim), x.dtype(), false); auto dx_tmp = reshape(dx_reduce_res, phi::vectorize(x.dims())); - dx->set_impl(dx_tmp.impl()); + set_output(dx_tmp, dx); } } else { by_pass(out_grad, dx); @@ -94,7 +94,7 @@ void add_grad(const Tensor& x, auto dy_reduce_res = sum(out_grad, phi::vectorize(reduce_dim), y.dtype(), false); auto dy_tmp = reshape(dy_reduce_res, phi::vectorize(y.dims())); - dy->set_impl(dy_tmp.impl()); + set_output(dy_tmp, dy); } } else { @@ -111,7 +111,7 @@ void add_grad(const Tensor& x, auto dx_reduce_res = sum(out_grad, phi::vectorize(reduce_dim), x.dtype(), false); auto dx_tmp = reshape(dx_reduce_res, phi::vectorize(x.dims())); - dx->set_impl(dx_tmp.impl()); + set_output(dx_tmp, dx); } } else { by_pass(out_grad, dx); @@ -139,22 +139,26 @@ void sum_grad(const Tensor& x, reduce_all = false; } auto x_grad_tmp = Tensor(); - if (!keepdim) { - auto axis_ = std::vector(); - if (reduce_all) { - for (int64_t i = 1; i < x_dim_size; i++) { - axis_.push_back(i); + if (x_dim_size == 1) { + x_grad_tmp = expand(out_grad, IntArray(x_dim)); + } else { + if (!keepdim) { + auto axis_ = std::vector(); + if (reduce_all) { + for (int64_t i = 1; i < x_dim_size; i++) { + axis_.push_back(i); + } + } else { + axis_ = axis.GetData(); } + auto out_grad_ = unsqueeze(out_grad, axis_); + x_grad_tmp = expand(out_grad_, IntArray(x_dim)); } else { - axis_ = axis.GetData(); + x_grad_tmp = expand(out_grad, IntArray(x_dim)); } - auto out_grad_ = unsqueeze(out_grad, axis_); - x_grad_tmp = expand(out_grad_, IntArray(x_dim)); - } else { - x_grad_tmp = expand(out_grad, IntArray(x_dim)); } - x_grad->set_impl(x_grad_tmp.impl()); + set_output(x_grad_tmp, x_grad); } template @@ -175,15 +179,15 @@ void divide_grad(const Tensor& x, // Maybe need reduce here phi::DDim reduce_dim = get_reduce_dims(y.dims(), x.dims()); if (!reduce_dim.size()) { - dy->set_impl(dy_res.impl()); + set_output(dy_res, dy); } else { auto dy_reduce_res = sum(dy_res, phi::vectorize(reduce_dim), y.dtype(), false); auto dy_tmp = reshape(dy_reduce_res, phi::vectorize(y.dims())); - dy->set_impl(dy_tmp.impl()); + set_output(dy_tmp, dy); } } else { - dy->set_impl(dy_res.impl()); + set_output(dy_res, dy); } } // indicate we will compute dy if (dx) { @@ -195,16 +199,16 @@ void divide_grad(const Tensor& x, // Maybe need reduce here auto reduce_dim = get_reduce_dims(x.dims(), y.dims()); if (!reduce_dim.size()) { - dx->set_impl(dx_res.impl()); + set_output(dx_res, dx); } else { auto dx_reduce_res = sum(dx_res, phi::vectorize(reduce_dim), x.dtype(), false); auto dx_tmp = reshape(dx_reduce_res, phi::vectorize(x.dims())); - dx->set_impl(dx_tmp.impl()); + set_output(dx_tmp, dx); } } else { - dx->set_impl(dx_res.impl()); + set_output(dx_res, dx); } } // indicate we will compute dx } @@ -215,7 +219,7 @@ void sqrt_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) { auto div_x = full(phi::vectorize(out.dims()), 0.5); auto tmp = divide(div_x, out); auto x_grad_tmp = multiply(out_grad, tmp); - x_grad->set_impl(x_grad_tmp.impl()); + set_output(x_grad_tmp, x_grad); } } @@ -231,7 +235,7 @@ void multiply_grad(const Tensor& x, if (x.dims() != y.dims()) { auto axes = get_reduce_dims(x.dims(), y.dims()); if (!axes.size()) { - x_grad->set_impl(x_grad_unreduce.impl()); + set_output(x_grad_unreduce, x_grad); } else { auto x_grad_reduced = sum(x_grad_unreduce, phi::vectorize(axes), @@ -240,10 +244,10 @@ void multiply_grad(const Tensor& x, if (x_grad_reduced.dims().size() != x.dims().size()) { x_grad_reduced = reshape(x_grad_reduced, x.shape()); } - x_grad->set_impl(x_grad_reduced.impl()); + set_output(x_grad_reduced, x_grad); } } else { - x_grad->set_impl(x_grad_unreduce.impl()); + set_output(x_grad_unreduce, x_grad); } } if (y_grad) { @@ -251,7 +255,7 @@ void multiply_grad(const Tensor& x, if (y.dims() != x.dims()) { auto axes = get_reduce_dims(y.dims(), x.dims()); if (!axes.size()) { - y_grad->set_impl(y_grad_unreduce.impl()); + set_output(y_grad_unreduce, y_grad); } else { auto y_grad_reduced = sum(y_grad_unreduce, phi::vectorize(axes), @@ -260,10 +264,10 @@ void multiply_grad(const Tensor& x, if (y_grad_reduced.dims().size() != y.dims().size()) { y_grad_reduced = reshape(y_grad_reduced, y.shape()); } - y_grad->set_impl(y_grad_reduced.impl()); + set_output(y_grad_reduced, y_grad); } } else { - y_grad->set_impl(y_grad_unreduce.impl()); + set_output(y_grad_unreduce, y_grad); } } } @@ -284,7 +288,7 @@ void expand_grad(const Tensor& x, if (reduced.dims().size() != x.dims().size()) { reduced = reshape(reduced, x.shape()); } - x_grad->set_impl(reduced.impl()); + set_output(reduced, x_grad); } } else { by_pass(out_grad, x_grad); @@ -295,7 +299,7 @@ void expand_grad(const Tensor& x, template void exp_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) { if (x_grad) { - x_grad->set_impl(multiply(out_grad, out).impl()); + set_output(multiply(out_grad, out), x_grad); } } diff --git a/paddle/fluid/prim/api/manual/utils/eager_utils.cc b/paddle/fluid/prim/api/manual/utils/eager_utils.cc index 96d0b4ea1f9be79325afd02579851c134830c87a..353945557f1d02386645a79c6b2d871fe90fb588 100644 --- a/paddle/fluid/prim/api/manual/utils/eager_utils.cc +++ b/paddle/fluid/prim/api/manual/utils/eager_utils.cc @@ -38,9 +38,18 @@ Tensor empty_like(const paddle::experimental::Tensor& x, } return empty_like_ad_func(x, dtype, place); } + +template <> +void set_output(const paddle::experimental::Tensor& x_tmp, + paddle::experimental::Tensor* x) { + x->set_impl(x_tmp.impl()); + x->set_autograd_meta(x_tmp.mutable_autograd_meta()); +} + template <> void by_pass(const paddle::experimental::Tensor& x, Tensor* out) { - out->set_impl(x.impl()); + set_output(x, out); } + } // namespace prim } // namespace paddle diff --git a/paddle/fluid/prim/api/manual/utils/static_utils.cc b/paddle/fluid/prim/api/manual/utils/static_utils.cc index c90cfdb34f014590bb22188e9b13c25ecaa690ba..74656cfe7d48d17fe0c3fc2122896ef10f8535b7 100644 --- a/paddle/fluid/prim/api/manual/utils/static_utils.cc +++ b/paddle/fluid/prim/api/manual/utils/static_utils.cc @@ -47,6 +47,12 @@ Tensor empty_like(const Tensor& x, paddle::experimental::IntArray(x.shape()), x.dtype(), paddle::Place()); } +template <> +void set_output(const paddle::experimental::Tensor& x_tmp, + paddle::experimental::Tensor* x) { + x->set_impl(x_tmp.impl()); +} + template <> void by_pass(const paddle::experimental::Tensor& x, paddle::experimental::Tensor* out) { @@ -62,7 +68,7 @@ void by_pass(const paddle::experimental::Tensor& x, op->CheckAttrs(); op->InferVarType(block); op->InferShape(*block); - out->set_impl(new_out.impl()); + set_output(new_out, out); } } // namespace prim diff --git a/paddle/fluid/prim/api/manual/utils/utils.h b/paddle/fluid/prim/api/manual/utils/utils.h index 22127d30d31690e1d145540dc3a19c476c54ae4f..20b02f2df9c79235d645f13a2a3cce8f8ff08d67 100644 --- a/paddle/fluid/prim/api/manual/utils/utils.h +++ b/paddle/fluid/prim/api/manual/utils/utils.h @@ -38,6 +38,10 @@ template void by_pass(const paddle::experimental::Tensor& x, paddle::experimental::Tensor* out); +template +void set_output(const paddle::experimental::Tensor& x_tmp, + paddle::experimental::Tensor* x); + // These method don't need to be specified static phi::DDim get_reduce_dims_from_out(const phi::DDim& dout_dims, const phi::DDim& in_dims) {