未验证 提交 22b5241f 编写于 作者: X xiaoguoguo626807 提交者: GitHub

【prim】Modify dygraph code_gen , add set_output (#49918)

* modify name

* merge develop

* fix param

* fix exp gen bug

* fix sum_grad

* comment
上级 f71f77e9
......@@ -1839,6 +1839,13 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
False if self.composite_func_info == {} else True
)
if is_composite_grad_api and next_grad_node_creation_str != '':
next_grad_node_creation_str = f"""
if (!paddle::prim::PrimCommonUtils::IsPrimEnabled()) {{
{next_grad_node_creation_str}
}}
"""
if next_node_generator is not None:
has_higher_order_node = True
return (
......@@ -1850,9 +1857,6 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
next_node_generator.backward_forward_inputs_map,
)
# TODO(Ruting):Integrate invoke and composite as composite so the rest branch canbe covered
# TODO(Ruting): modify next_grad_node_creation_str when Flags_prim_enable deleted in the future
# if is_composite_grad_api:
# next_grad_node_creation_str = ''
elif not is_invoke_forward_api and not is_composite_grad_api:
next_grad_node_creation_str = f""" if(trace_backward) {{
PADDLE_THROW(phi::errors::Unavailable(
......@@ -1978,6 +1982,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
backward_attrs_list = self.backward_attrs_list
backward_inplace_map = self.backward_inplace_map
indent = GetIndent(1)
need_gen_trace_backard_for_inplace = False
# Construct grad_api function args
# Order: TensorWrappers, GradTensors, Attributes
......@@ -2207,6 +2212,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
}} else {{
{inplace_str}
}}"""
need_gen_trace_backard_for_inplace = True
else:
inplace_for_grad_outs_str += inplace_str
......@@ -2275,7 +2281,11 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
# Prepare for Node Creation if Necessary
outputs_autograd_meta_str = ""
compute_require_next_grad_str = ""
if len(next_grad_node_creation_str) > 0 or is_invoke_forward_api:
if (
len(next_grad_node_creation_str) > 0
or is_invoke_forward_api
or need_gen_trace_backard_for_inplace
):
compute_require_next_grad_str = f"{indent}bool trace_backward = egr::Controller::Instance().HasGrad() && create_graph;\n"
# 3. Get Output AutoGradMeta
......
......@@ -32,7 +32,7 @@ void tanh_grad(const Tensor& out, const Tensor& grad_out, Tensor* grad_x) {
auto tmp = pow<T>(out, 2.0);
tmp = scale<T>(tmp, -1.0, 1.0, true);
auto grad_x_tmp = multiply<T>(grad_out, tmp);
grad_x->set_impl(grad_x_tmp.impl());
set_output<T>(grad_x_tmp, grad_x);
}
template <typename T>
......@@ -53,7 +53,7 @@ void subtract_grad(const Tensor& x,
auto dy_reduce_res = sum<T>(
scale_out_grad, phi::vectorize(reduce_dim), y.dtype(), false);
auto dy_tmp = reshape<T>(dy_reduce_res, phi::vectorize(y.dims()));
dy->set_impl(dy_tmp.impl());
set_output<T>(dy_tmp, dy);
}
} else {
by_pass<T>(scale_out_grad, dy);
......@@ -69,7 +69,7 @@ void subtract_grad(const Tensor& x,
auto dx_reduce_res =
sum<T>(out_grad, phi::vectorize(reduce_dim), x.dtype(), false);
auto dx_tmp = reshape<T>(dx_reduce_res, phi::vectorize(x.dims()));
dx->set_impl(dx_tmp.impl());
set_output<T>(dx_tmp, dx);
}
} else {
by_pass<T>(out_grad, dx);
......@@ -94,7 +94,7 @@ void add_grad(const Tensor& x,
auto dy_reduce_res =
sum<T>(out_grad, phi::vectorize(reduce_dim), y.dtype(), false);
auto dy_tmp = reshape<T>(dy_reduce_res, phi::vectorize(y.dims()));
dy->set_impl(dy_tmp.impl());
set_output<T>(dy_tmp, dy);
}
} else {
......@@ -111,7 +111,7 @@ void add_grad(const Tensor& x,
auto dx_reduce_res =
sum<T>(out_grad, phi::vectorize(reduce_dim), x.dtype(), false);
auto dx_tmp = reshape<T>(dx_reduce_res, phi::vectorize(x.dims()));
dx->set_impl(dx_tmp.impl());
set_output<T>(dx_tmp, dx);
}
} else {
by_pass<T>(out_grad, dx);
......@@ -139,22 +139,26 @@ void sum_grad(const Tensor& x,
reduce_all = false;
}
auto x_grad_tmp = Tensor();
if (!keepdim) {
auto axis_ = std::vector<int64_t>();
if (reduce_all) {
for (int64_t i = 1; i < x_dim_size; i++) {
axis_.push_back(i);
if (x_dim_size == 1) {
x_grad_tmp = expand<T>(out_grad, IntArray(x_dim));
} else {
if (!keepdim) {
auto axis_ = std::vector<int64_t>();
if (reduce_all) {
for (int64_t i = 1; i < x_dim_size; i++) {
axis_.push_back(i);
}
} else {
axis_ = axis.GetData();
}
auto out_grad_ = unsqueeze<T>(out_grad, axis_);
x_grad_tmp = expand<T>(out_grad_, IntArray(x_dim));
} else {
axis_ = axis.GetData();
x_grad_tmp = expand<T>(out_grad, IntArray(x_dim));
}
auto out_grad_ = unsqueeze<T>(out_grad, axis_);
x_grad_tmp = expand<T>(out_grad_, IntArray(x_dim));
} else {
x_grad_tmp = expand<T>(out_grad, IntArray(x_dim));
}
x_grad->set_impl(x_grad_tmp.impl());
set_output<T>(x_grad_tmp, x_grad);
}
template <typename T>
......@@ -175,15 +179,15 @@ void divide_grad(const Tensor& x,
// Maybe need reduce here
phi::DDim reduce_dim = get_reduce_dims(y.dims(), x.dims());
if (!reduce_dim.size()) {
dy->set_impl(dy_res.impl());
set_output<T>(dy_res, dy);
} else {
auto dy_reduce_res =
sum<T>(dy_res, phi::vectorize(reduce_dim), y.dtype(), false);
auto dy_tmp = reshape<T>(dy_reduce_res, phi::vectorize(y.dims()));
dy->set_impl(dy_tmp.impl());
set_output<T>(dy_tmp, dy);
}
} else {
dy->set_impl(dy_res.impl());
set_output<T>(dy_res, dy);
}
} // indicate we will compute dy
if (dx) {
......@@ -195,16 +199,16 @@ void divide_grad(const Tensor& x,
// Maybe need reduce here
auto reduce_dim = get_reduce_dims(x.dims(), y.dims());
if (!reduce_dim.size()) {
dx->set_impl(dx_res.impl());
set_output<T>(dx_res, dx);
} else {
auto dx_reduce_res =
sum<T>(dx_res, phi::vectorize(reduce_dim), x.dtype(), false);
auto dx_tmp = reshape<T>(dx_reduce_res, phi::vectorize(x.dims()));
dx->set_impl(dx_tmp.impl());
set_output<T>(dx_tmp, dx);
}
} else {
dx->set_impl(dx_res.impl());
set_output<T>(dx_res, dx);
}
} // indicate we will compute dx
}
......@@ -215,7 +219,7 @@ void sqrt_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) {
auto div_x = full<T>(phi::vectorize(out.dims()), 0.5);
auto tmp = divide<T>(div_x, out);
auto x_grad_tmp = multiply<T>(out_grad, tmp);
x_grad->set_impl(x_grad_tmp.impl());
set_output<T>(x_grad_tmp, x_grad);
}
}
......@@ -231,7 +235,7 @@ void multiply_grad(const Tensor& x,
if (x.dims() != y.dims()) {
auto axes = get_reduce_dims(x.dims(), y.dims());
if (!axes.size()) {
x_grad->set_impl(x_grad_unreduce.impl());
set_output<T>(x_grad_unreduce, x_grad);
} else {
auto x_grad_reduced = sum<T>(x_grad_unreduce,
phi::vectorize(axes),
......@@ -240,10 +244,10 @@ void multiply_grad(const Tensor& x,
if (x_grad_reduced.dims().size() != x.dims().size()) {
x_grad_reduced = reshape<T>(x_grad_reduced, x.shape());
}
x_grad->set_impl(x_grad_reduced.impl());
set_output<T>(x_grad_reduced, x_grad);
}
} else {
x_grad->set_impl(x_grad_unreduce.impl());
set_output<T>(x_grad_unreduce, x_grad);
}
}
if (y_grad) {
......@@ -251,7 +255,7 @@ void multiply_grad(const Tensor& x,
if (y.dims() != x.dims()) {
auto axes = get_reduce_dims(y.dims(), x.dims());
if (!axes.size()) {
y_grad->set_impl(y_grad_unreduce.impl());
set_output<T>(y_grad_unreduce, y_grad);
} else {
auto y_grad_reduced = sum<T>(y_grad_unreduce,
phi::vectorize(axes),
......@@ -260,10 +264,10 @@ void multiply_grad(const Tensor& x,
if (y_grad_reduced.dims().size() != y.dims().size()) {
y_grad_reduced = reshape<T>(y_grad_reduced, y.shape());
}
y_grad->set_impl(y_grad_reduced.impl());
set_output<T>(y_grad_reduced, y_grad);
}
} else {
y_grad->set_impl(y_grad_unreduce.impl());
set_output<T>(y_grad_unreduce, y_grad);
}
}
}
......@@ -284,7 +288,7 @@ void expand_grad(const Tensor& x,
if (reduced.dims().size() != x.dims().size()) {
reduced = reshape<T>(reduced, x.shape());
}
x_grad->set_impl(reduced.impl());
set_output<T>(reduced, x_grad);
}
} else {
by_pass<T>(out_grad, x_grad);
......@@ -295,7 +299,7 @@ void expand_grad(const Tensor& x,
template <typename T>
void exp_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) {
if (x_grad) {
x_grad->set_impl(multiply<T>(out_grad, out).impl());
set_output<T>(multiply<T>(out_grad, out), x_grad);
}
}
......
......@@ -38,9 +38,18 @@ Tensor empty_like<Tensor>(const paddle::experimental::Tensor& x,
}
return empty_like_ad_func(x, dtype, place);
}
template <>
void set_output<Tensor>(const paddle::experimental::Tensor& x_tmp,
paddle::experimental::Tensor* x) {
x->set_impl(x_tmp.impl());
x->set_autograd_meta(x_tmp.mutable_autograd_meta());
}
template <>
void by_pass<Tensor>(const paddle::experimental::Tensor& x, Tensor* out) {
out->set_impl(x.impl());
set_output<Tensor>(x, out);
}
} // namespace prim
} // namespace paddle
......@@ -47,6 +47,12 @@ Tensor empty_like<DescTensor>(const Tensor& x,
paddle::experimental::IntArray(x.shape()), x.dtype(), paddle::Place());
}
template <>
void set_output<DescTensor>(const paddle::experimental::Tensor& x_tmp,
paddle::experimental::Tensor* x) {
x->set_impl(x_tmp.impl());
}
template <>
void by_pass<DescTensor>(const paddle::experimental::Tensor& x,
paddle::experimental::Tensor* out) {
......@@ -62,7 +68,7 @@ void by_pass<DescTensor>(const paddle::experimental::Tensor& x,
op->CheckAttrs();
op->InferVarType(block);
op->InferShape(*block);
out->set_impl(new_out.impl());
set_output<DescTensor>(new_out, out);
}
} // namespace prim
......
......@@ -38,6 +38,10 @@ template <typename T>
void by_pass(const paddle::experimental::Tensor& x,
paddle::experimental::Tensor* out);
template <typename T>
void set_output(const paddle::experimental::Tensor& x_tmp,
paddle::experimental::Tensor* x);
// These method don't need to be specified
static phi::DDim get_reduce_dims_from_out(const phi::DDim& dout_dims,
const phi::DDim& in_dims) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册