未验证 提交 22b5241f 编写于 作者: X xiaoguoguo626807 提交者: GitHub

【prim】Modify dygraph code_gen , add set_output (#49918)

* modify name

* merge develop

* fix param

* fix exp gen bug

* fix sum_grad

* comment
上级 f71f77e9
...@@ -1839,6 +1839,13 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): ...@@ -1839,6 +1839,13 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
False if self.composite_func_info == {} else True False if self.composite_func_info == {} else True
) )
if is_composite_grad_api and next_grad_node_creation_str != '':
next_grad_node_creation_str = f"""
if (!paddle::prim::PrimCommonUtils::IsPrimEnabled()) {{
{next_grad_node_creation_str}
}}
"""
if next_node_generator is not None: if next_node_generator is not None:
has_higher_order_node = True has_higher_order_node = True
return ( return (
...@@ -1850,9 +1857,6 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): ...@@ -1850,9 +1857,6 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
next_node_generator.backward_forward_inputs_map, next_node_generator.backward_forward_inputs_map,
) )
# TODO(Ruting):Integrate invoke and composite as composite so the rest branch canbe covered # TODO(Ruting):Integrate invoke and composite as composite so the rest branch canbe covered
# TODO(Ruting): modify next_grad_node_creation_str when Flags_prim_enable deleted in the future
# if is_composite_grad_api:
# next_grad_node_creation_str = ''
elif not is_invoke_forward_api and not is_composite_grad_api: elif not is_invoke_forward_api and not is_composite_grad_api:
next_grad_node_creation_str = f""" if(trace_backward) {{ next_grad_node_creation_str = f""" if(trace_backward) {{
PADDLE_THROW(phi::errors::Unavailable( PADDLE_THROW(phi::errors::Unavailable(
...@@ -1978,6 +1982,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): ...@@ -1978,6 +1982,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
backward_attrs_list = self.backward_attrs_list backward_attrs_list = self.backward_attrs_list
backward_inplace_map = self.backward_inplace_map backward_inplace_map = self.backward_inplace_map
indent = GetIndent(1) indent = GetIndent(1)
need_gen_trace_backard_for_inplace = False
# Construct grad_api function args # Construct grad_api function args
# Order: TensorWrappers, GradTensors, Attributes # Order: TensorWrappers, GradTensors, Attributes
...@@ -2207,6 +2212,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): ...@@ -2207,6 +2212,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
}} else {{ }} else {{
{inplace_str} {inplace_str}
}}""" }}"""
need_gen_trace_backard_for_inplace = True
else: else:
inplace_for_grad_outs_str += inplace_str inplace_for_grad_outs_str += inplace_str
...@@ -2275,7 +2281,11 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): ...@@ -2275,7 +2281,11 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
# Prepare for Node Creation if Necessary # Prepare for Node Creation if Necessary
outputs_autograd_meta_str = "" outputs_autograd_meta_str = ""
compute_require_next_grad_str = "" compute_require_next_grad_str = ""
if len(next_grad_node_creation_str) > 0 or is_invoke_forward_api: if (
len(next_grad_node_creation_str) > 0
or is_invoke_forward_api
or need_gen_trace_backard_for_inplace
):
compute_require_next_grad_str = f"{indent}bool trace_backward = egr::Controller::Instance().HasGrad() && create_graph;\n" compute_require_next_grad_str = f"{indent}bool trace_backward = egr::Controller::Instance().HasGrad() && create_graph;\n"
# 3. Get Output AutoGradMeta # 3. Get Output AutoGradMeta
......
...@@ -32,7 +32,7 @@ void tanh_grad(const Tensor& out, const Tensor& grad_out, Tensor* grad_x) { ...@@ -32,7 +32,7 @@ void tanh_grad(const Tensor& out, const Tensor& grad_out, Tensor* grad_x) {
auto tmp = pow<T>(out, 2.0); auto tmp = pow<T>(out, 2.0);
tmp = scale<T>(tmp, -1.0, 1.0, true); tmp = scale<T>(tmp, -1.0, 1.0, true);
auto grad_x_tmp = multiply<T>(grad_out, tmp); auto grad_x_tmp = multiply<T>(grad_out, tmp);
grad_x->set_impl(grad_x_tmp.impl()); set_output<T>(grad_x_tmp, grad_x);
} }
template <typename T> template <typename T>
...@@ -53,7 +53,7 @@ void subtract_grad(const Tensor& x, ...@@ -53,7 +53,7 @@ void subtract_grad(const Tensor& x,
auto dy_reduce_res = sum<T>( auto dy_reduce_res = sum<T>(
scale_out_grad, phi::vectorize(reduce_dim), y.dtype(), false); scale_out_grad, phi::vectorize(reduce_dim), y.dtype(), false);
auto dy_tmp = reshape<T>(dy_reduce_res, phi::vectorize(y.dims())); auto dy_tmp = reshape<T>(dy_reduce_res, phi::vectorize(y.dims()));
dy->set_impl(dy_tmp.impl()); set_output<T>(dy_tmp, dy);
} }
} else { } else {
by_pass<T>(scale_out_grad, dy); by_pass<T>(scale_out_grad, dy);
...@@ -69,7 +69,7 @@ void subtract_grad(const Tensor& x, ...@@ -69,7 +69,7 @@ void subtract_grad(const Tensor& x,
auto dx_reduce_res = auto dx_reduce_res =
sum<T>(out_grad, phi::vectorize(reduce_dim), x.dtype(), false); sum<T>(out_grad, phi::vectorize(reduce_dim), x.dtype(), false);
auto dx_tmp = reshape<T>(dx_reduce_res, phi::vectorize(x.dims())); auto dx_tmp = reshape<T>(dx_reduce_res, phi::vectorize(x.dims()));
dx->set_impl(dx_tmp.impl()); set_output<T>(dx_tmp, dx);
} }
} else { } else {
by_pass<T>(out_grad, dx); by_pass<T>(out_grad, dx);
...@@ -94,7 +94,7 @@ void add_grad(const Tensor& x, ...@@ -94,7 +94,7 @@ void add_grad(const Tensor& x,
auto dy_reduce_res = auto dy_reduce_res =
sum<T>(out_grad, phi::vectorize(reduce_dim), y.dtype(), false); sum<T>(out_grad, phi::vectorize(reduce_dim), y.dtype(), false);
auto dy_tmp = reshape<T>(dy_reduce_res, phi::vectorize(y.dims())); auto dy_tmp = reshape<T>(dy_reduce_res, phi::vectorize(y.dims()));
dy->set_impl(dy_tmp.impl()); set_output<T>(dy_tmp, dy);
} }
} else { } else {
...@@ -111,7 +111,7 @@ void add_grad(const Tensor& x, ...@@ -111,7 +111,7 @@ void add_grad(const Tensor& x,
auto dx_reduce_res = auto dx_reduce_res =
sum<T>(out_grad, phi::vectorize(reduce_dim), x.dtype(), false); sum<T>(out_grad, phi::vectorize(reduce_dim), x.dtype(), false);
auto dx_tmp = reshape<T>(dx_reduce_res, phi::vectorize(x.dims())); auto dx_tmp = reshape<T>(dx_reduce_res, phi::vectorize(x.dims()));
dx->set_impl(dx_tmp.impl()); set_output<T>(dx_tmp, dx);
} }
} else { } else {
by_pass<T>(out_grad, dx); by_pass<T>(out_grad, dx);
...@@ -139,22 +139,26 @@ void sum_grad(const Tensor& x, ...@@ -139,22 +139,26 @@ void sum_grad(const Tensor& x,
reduce_all = false; reduce_all = false;
} }
auto x_grad_tmp = Tensor(); auto x_grad_tmp = Tensor();
if (!keepdim) { if (x_dim_size == 1) {
auto axis_ = std::vector<int64_t>(); x_grad_tmp = expand<T>(out_grad, IntArray(x_dim));
if (reduce_all) { } else {
for (int64_t i = 1; i < x_dim_size; i++) { if (!keepdim) {
axis_.push_back(i); auto axis_ = std::vector<int64_t>();
if (reduce_all) {
for (int64_t i = 1; i < x_dim_size; i++) {
axis_.push_back(i);
}
} else {
axis_ = axis.GetData();
} }
auto out_grad_ = unsqueeze<T>(out_grad, axis_);
x_grad_tmp = expand<T>(out_grad_, IntArray(x_dim));
} else { } else {
axis_ = axis.GetData(); x_grad_tmp = expand<T>(out_grad, IntArray(x_dim));
} }
auto out_grad_ = unsqueeze<T>(out_grad, axis_);
x_grad_tmp = expand<T>(out_grad_, IntArray(x_dim));
} else {
x_grad_tmp = expand<T>(out_grad, IntArray(x_dim));
} }
x_grad->set_impl(x_grad_tmp.impl()); set_output<T>(x_grad_tmp, x_grad);
} }
template <typename T> template <typename T>
...@@ -175,15 +179,15 @@ void divide_grad(const Tensor& x, ...@@ -175,15 +179,15 @@ void divide_grad(const Tensor& x,
// Maybe need reduce here // Maybe need reduce here
phi::DDim reduce_dim = get_reduce_dims(y.dims(), x.dims()); phi::DDim reduce_dim = get_reduce_dims(y.dims(), x.dims());
if (!reduce_dim.size()) { if (!reduce_dim.size()) {
dy->set_impl(dy_res.impl()); set_output<T>(dy_res, dy);
} else { } else {
auto dy_reduce_res = auto dy_reduce_res =
sum<T>(dy_res, phi::vectorize(reduce_dim), y.dtype(), false); sum<T>(dy_res, phi::vectorize(reduce_dim), y.dtype(), false);
auto dy_tmp = reshape<T>(dy_reduce_res, phi::vectorize(y.dims())); auto dy_tmp = reshape<T>(dy_reduce_res, phi::vectorize(y.dims()));
dy->set_impl(dy_tmp.impl()); set_output<T>(dy_tmp, dy);
} }
} else { } else {
dy->set_impl(dy_res.impl()); set_output<T>(dy_res, dy);
} }
} // indicate we will compute dy } // indicate we will compute dy
if (dx) { if (dx) {
...@@ -195,16 +199,16 @@ void divide_grad(const Tensor& x, ...@@ -195,16 +199,16 @@ void divide_grad(const Tensor& x,
// Maybe need reduce here // Maybe need reduce here
auto reduce_dim = get_reduce_dims(x.dims(), y.dims()); auto reduce_dim = get_reduce_dims(x.dims(), y.dims());
if (!reduce_dim.size()) { if (!reduce_dim.size()) {
dx->set_impl(dx_res.impl()); set_output<T>(dx_res, dx);
} else { } else {
auto dx_reduce_res = auto dx_reduce_res =
sum<T>(dx_res, phi::vectorize(reduce_dim), x.dtype(), false); sum<T>(dx_res, phi::vectorize(reduce_dim), x.dtype(), false);
auto dx_tmp = reshape<T>(dx_reduce_res, phi::vectorize(x.dims())); auto dx_tmp = reshape<T>(dx_reduce_res, phi::vectorize(x.dims()));
dx->set_impl(dx_tmp.impl()); set_output<T>(dx_tmp, dx);
} }
} else { } else {
dx->set_impl(dx_res.impl()); set_output<T>(dx_res, dx);
} }
} // indicate we will compute dx } // indicate we will compute dx
} }
...@@ -215,7 +219,7 @@ void sqrt_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) { ...@@ -215,7 +219,7 @@ void sqrt_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) {
auto div_x = full<T>(phi::vectorize(out.dims()), 0.5); auto div_x = full<T>(phi::vectorize(out.dims()), 0.5);
auto tmp = divide<T>(div_x, out); auto tmp = divide<T>(div_x, out);
auto x_grad_tmp = multiply<T>(out_grad, tmp); auto x_grad_tmp = multiply<T>(out_grad, tmp);
x_grad->set_impl(x_grad_tmp.impl()); set_output<T>(x_grad_tmp, x_grad);
} }
} }
...@@ -231,7 +235,7 @@ void multiply_grad(const Tensor& x, ...@@ -231,7 +235,7 @@ void multiply_grad(const Tensor& x,
if (x.dims() != y.dims()) { if (x.dims() != y.dims()) {
auto axes = get_reduce_dims(x.dims(), y.dims()); auto axes = get_reduce_dims(x.dims(), y.dims());
if (!axes.size()) { if (!axes.size()) {
x_grad->set_impl(x_grad_unreduce.impl()); set_output<T>(x_grad_unreduce, x_grad);
} else { } else {
auto x_grad_reduced = sum<T>(x_grad_unreduce, auto x_grad_reduced = sum<T>(x_grad_unreduce,
phi::vectorize(axes), phi::vectorize(axes),
...@@ -240,10 +244,10 @@ void multiply_grad(const Tensor& x, ...@@ -240,10 +244,10 @@ void multiply_grad(const Tensor& x,
if (x_grad_reduced.dims().size() != x.dims().size()) { if (x_grad_reduced.dims().size() != x.dims().size()) {
x_grad_reduced = reshape<T>(x_grad_reduced, x.shape()); x_grad_reduced = reshape<T>(x_grad_reduced, x.shape());
} }
x_grad->set_impl(x_grad_reduced.impl()); set_output<T>(x_grad_reduced, x_grad);
} }
} else { } else {
x_grad->set_impl(x_grad_unreduce.impl()); set_output<T>(x_grad_unreduce, x_grad);
} }
} }
if (y_grad) { if (y_grad) {
...@@ -251,7 +255,7 @@ void multiply_grad(const Tensor& x, ...@@ -251,7 +255,7 @@ void multiply_grad(const Tensor& x,
if (y.dims() != x.dims()) { if (y.dims() != x.dims()) {
auto axes = get_reduce_dims(y.dims(), x.dims()); auto axes = get_reduce_dims(y.dims(), x.dims());
if (!axes.size()) { if (!axes.size()) {
y_grad->set_impl(y_grad_unreduce.impl()); set_output<T>(y_grad_unreduce, y_grad);
} else { } else {
auto y_grad_reduced = sum<T>(y_grad_unreduce, auto y_grad_reduced = sum<T>(y_grad_unreduce,
phi::vectorize(axes), phi::vectorize(axes),
...@@ -260,10 +264,10 @@ void multiply_grad(const Tensor& x, ...@@ -260,10 +264,10 @@ void multiply_grad(const Tensor& x,
if (y_grad_reduced.dims().size() != y.dims().size()) { if (y_grad_reduced.dims().size() != y.dims().size()) {
y_grad_reduced = reshape<T>(y_grad_reduced, y.shape()); y_grad_reduced = reshape<T>(y_grad_reduced, y.shape());
} }
y_grad->set_impl(y_grad_reduced.impl()); set_output<T>(y_grad_reduced, y_grad);
} }
} else { } else {
y_grad->set_impl(y_grad_unreduce.impl()); set_output<T>(y_grad_unreduce, y_grad);
} }
} }
} }
...@@ -284,7 +288,7 @@ void expand_grad(const Tensor& x, ...@@ -284,7 +288,7 @@ void expand_grad(const Tensor& x,
if (reduced.dims().size() != x.dims().size()) { if (reduced.dims().size() != x.dims().size()) {
reduced = reshape<T>(reduced, x.shape()); reduced = reshape<T>(reduced, x.shape());
} }
x_grad->set_impl(reduced.impl()); set_output<T>(reduced, x_grad);
} }
} else { } else {
by_pass<T>(out_grad, x_grad); by_pass<T>(out_grad, x_grad);
...@@ -295,7 +299,7 @@ void expand_grad(const Tensor& x, ...@@ -295,7 +299,7 @@ void expand_grad(const Tensor& x,
template <typename T> template <typename T>
void exp_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) { void exp_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) {
if (x_grad) { if (x_grad) {
x_grad->set_impl(multiply<T>(out_grad, out).impl()); set_output<T>(multiply<T>(out_grad, out), x_grad);
} }
} }
......
...@@ -38,9 +38,18 @@ Tensor empty_like<Tensor>(const paddle::experimental::Tensor& x, ...@@ -38,9 +38,18 @@ Tensor empty_like<Tensor>(const paddle::experimental::Tensor& x,
} }
return empty_like_ad_func(x, dtype, place); return empty_like_ad_func(x, dtype, place);
} }
template <>
void set_output<Tensor>(const paddle::experimental::Tensor& x_tmp,
paddle::experimental::Tensor* x) {
x->set_impl(x_tmp.impl());
x->set_autograd_meta(x_tmp.mutable_autograd_meta());
}
template <> template <>
void by_pass<Tensor>(const paddle::experimental::Tensor& x, Tensor* out) { void by_pass<Tensor>(const paddle::experimental::Tensor& x, Tensor* out) {
out->set_impl(x.impl()); set_output<Tensor>(x, out);
} }
} // namespace prim } // namespace prim
} // namespace paddle } // namespace paddle
...@@ -47,6 +47,12 @@ Tensor empty_like<DescTensor>(const Tensor& x, ...@@ -47,6 +47,12 @@ Tensor empty_like<DescTensor>(const Tensor& x,
paddle::experimental::IntArray(x.shape()), x.dtype(), paddle::Place()); paddle::experimental::IntArray(x.shape()), x.dtype(), paddle::Place());
} }
template <>
void set_output<DescTensor>(const paddle::experimental::Tensor& x_tmp,
paddle::experimental::Tensor* x) {
x->set_impl(x_tmp.impl());
}
template <> template <>
void by_pass<DescTensor>(const paddle::experimental::Tensor& x, void by_pass<DescTensor>(const paddle::experimental::Tensor& x,
paddle::experimental::Tensor* out) { paddle::experimental::Tensor* out) {
...@@ -62,7 +68,7 @@ void by_pass<DescTensor>(const paddle::experimental::Tensor& x, ...@@ -62,7 +68,7 @@ void by_pass<DescTensor>(const paddle::experimental::Tensor& x,
op->CheckAttrs(); op->CheckAttrs();
op->InferVarType(block); op->InferVarType(block);
op->InferShape(*block); op->InferShape(*block);
out->set_impl(new_out.impl()); set_output<DescTensor>(new_out, out);
} }
} // namespace prim } // namespace prim
......
...@@ -38,6 +38,10 @@ template <typename T> ...@@ -38,6 +38,10 @@ template <typename T>
void by_pass(const paddle::experimental::Tensor& x, void by_pass(const paddle::experimental::Tensor& x,
paddle::experimental::Tensor* out); paddle::experimental::Tensor* out);
template <typename T>
void set_output(const paddle::experimental::Tensor& x_tmp,
paddle::experimental::Tensor* x);
// These method don't need to be specified // These method don't need to be specified
static phi::DDim get_reduce_dims_from_out(const phi::DDim& dout_dims, static phi::DDim get_reduce_dims_from_out(const phi::DDim& dout_dims,
const phi::DDim& in_dims) { const phi::DDim& in_dims) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册