提交 d636d72e 编写于 作者: W wangruting

modify name

上级 ea2e2495
...@@ -1839,6 +1839,13 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): ...@@ -1839,6 +1839,13 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
False if self.composite_func_info == {} else True False if self.composite_func_info == {} else True
) )
if is_composite_grad_api:
next_grad_node_creation_str = f"""
if (!paddle::prim::PrimCommonUtils::IsPrimEnabled()) {{
{next_grad_node_creation_str}
}}
"""
if next_node_generator is not None: if next_node_generator is not None:
has_higher_order_node = True has_higher_order_node = True
return ( return (
...@@ -1850,9 +1857,6 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): ...@@ -1850,9 +1857,6 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
next_node_generator.backward_forward_inputs_map, next_node_generator.backward_forward_inputs_map,
) )
# TODO(Ruting):Integrate invoke and composite as composite so the rest branch canbe covered # TODO(Ruting):Integrate invoke and composite as composite so the rest branch canbe covered
# TODO(Ruting): modify next_grad_node_creation_str when Flags_prim_enable deleted in the future
# if is_composite_grad_api:
# next_grad_node_creation_str = ''
elif not is_invoke_forward_api and not is_composite_grad_api: elif not is_invoke_forward_api and not is_composite_grad_api:
next_grad_node_creation_str = f""" if(trace_backward) {{ next_grad_node_creation_str = f""" if(trace_backward) {{
PADDLE_THROW(phi::errors::Unavailable( PADDLE_THROW(phi::errors::Unavailable(
...@@ -2275,7 +2279,11 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): ...@@ -2275,7 +2279,11 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
# Prepare for Node Creation if Necessary # Prepare for Node Creation if Necessary
outputs_autograd_meta_str = "" outputs_autograd_meta_str = ""
compute_require_next_grad_str = "" compute_require_next_grad_str = ""
if len(next_grad_node_creation_str) > 0 or is_invoke_forward_api: if (
len(next_grad_node_creation_str) > 0
or is_invoke_forward_api
or inplace_for_grad_outs_str != ''
):
compute_require_next_grad_str = f"{indent}bool trace_backward = egr::Controller::Instance().HasGrad() && create_graph;\n" compute_require_next_grad_str = f"{indent}bool trace_backward = egr::Controller::Instance().HasGrad() && create_graph;\n"
# 3. Get Output AutoGradMeta # 3. Get Output AutoGradMeta
......
...@@ -31,7 +31,7 @@ void tanh_grad(const Tensor& out, const Tensor& grad_out, Tensor* grad_x) { ...@@ -31,7 +31,7 @@ void tanh_grad(const Tensor& out, const Tensor& grad_out, Tensor* grad_x) {
auto tmp = pow<T>(out, 2.0); auto tmp = pow<T>(out, 2.0);
tmp = scale<T>(tmp, -1.0, 1.0, true); tmp = scale<T>(tmp, -1.0, 1.0, true);
auto grad_x_tmp = multiply<T>(grad_out, tmp); auto grad_x_tmp = multiply<T>(grad_out, tmp);
grad_x->set_impl(grad_x_tmp.impl()); set_output<T>(grad_x_tmp, grad_x);
} }
template <typename T> template <typename T>
void subtract_grad(const Tensor& x, void subtract_grad(const Tensor& x,
...@@ -48,7 +48,8 @@ void subtract_grad(const Tensor& x, ...@@ -48,7 +48,8 @@ void subtract_grad(const Tensor& x,
auto dy_reduce_res = auto dy_reduce_res =
sum<T>(scale_out_grad, phi::vectorize(reduce_dim), y.dtype(), false); sum<T>(scale_out_grad, phi::vectorize(reduce_dim), y.dtype(), false);
auto dy_tmp = reshape<T>(dy_reduce_res, phi::vectorize(y.dims())); auto dy_tmp = reshape<T>(dy_reduce_res, phi::vectorize(y.dims()));
dy->set_impl(dy_tmp.impl()); set_output<T>(dy_tmp, dy);
// dy->set_impl(dy_tmp.impl());
} else { } else {
by_pass<T>(scale_out_grad, dy); by_pass<T>(scale_out_grad, dy);
} }
...@@ -60,7 +61,8 @@ void subtract_grad(const Tensor& x, ...@@ -60,7 +61,8 @@ void subtract_grad(const Tensor& x,
auto dx_reduce_res = auto dx_reduce_res =
sum<T>(out_grad, phi::vectorize(reduce_dim), x.dtype(), false); sum<T>(out_grad, phi::vectorize(reduce_dim), x.dtype(), false);
auto dx_tmp = reshape<T>(dx_reduce_res, phi::vectorize(x.dims())); auto dx_tmp = reshape<T>(dx_reduce_res, phi::vectorize(x.dims()));
dx->set_impl(dx_tmp.impl()); set_output<T>(dx_tmp, dx);
// dx->set_impl(dx_tmp.impl());
} else { } else {
by_pass<T>(out_grad, dx); by_pass<T>(out_grad, dx);
} }
...@@ -81,7 +83,8 @@ void add_grad(const Tensor& x, ...@@ -81,7 +83,8 @@ void add_grad(const Tensor& x,
auto dy_reduce_res = auto dy_reduce_res =
sum<T>(out_grad, phi::vectorize(reduce_dim), y.dtype(), false); sum<T>(out_grad, phi::vectorize(reduce_dim), y.dtype(), false);
auto dy_tmp = reshape<T>(dy_reduce_res, phi::vectorize(y.dims())); auto dy_tmp = reshape<T>(dy_reduce_res, phi::vectorize(y.dims()));
dy->set_impl(dy_tmp.impl()); set_output<T>(dy_tmp, dy);
// dy->set_impl(dy_tmp.impl());
} else { } else {
by_pass<T>(out_grad, dy); by_pass<T>(out_grad, dy);
} }
...@@ -93,7 +96,8 @@ void add_grad(const Tensor& x, ...@@ -93,7 +96,8 @@ void add_grad(const Tensor& x,
auto dx_reduce_res = auto dx_reduce_res =
sum<T>(out_grad, phi::vectorize(reduce_dim), x.dtype(), false); sum<T>(out_grad, phi::vectorize(reduce_dim), x.dtype(), false);
auto dx_tmp = reshape<T>(dx_reduce_res, phi::vectorize(x.dims())); auto dx_tmp = reshape<T>(dx_reduce_res, phi::vectorize(x.dims()));
dx->set_impl(dx_tmp.impl()); set_output<T>(dx_tmp, dx);
// dx->set_impl(dx_tmp.impl());
} else { } else {
by_pass<T>(out_grad, dx); by_pass<T>(out_grad, dx);
} }
...@@ -134,8 +138,8 @@ void sum_grad(const Tensor& x, ...@@ -134,8 +138,8 @@ void sum_grad(const Tensor& x,
} else { } else {
x_grad_tmp = expand<T>(out_grad, x_dim); x_grad_tmp = expand<T>(out_grad, x_dim);
} }
set_output<T>(x_grad_tmp, x_grad);
x_grad->set_impl(x_grad_tmp.impl()); // x_grad->set_impl(x_grad_tmp.impl());
} }
template <typename T> template <typename T>
...@@ -158,9 +162,11 @@ void divide_grad(const Tensor& x, ...@@ -158,9 +162,11 @@ void divide_grad(const Tensor& x,
auto dy_reduce_res = auto dy_reduce_res =
sum<T>(dy_res, phi::vectorize(reduce_dim), y.dtype(), false); sum<T>(dy_res, phi::vectorize(reduce_dim), y.dtype(), false);
auto dy_tmp = reshape<T>(dy_reduce_res, phi::vectorize(y.dims())); auto dy_tmp = reshape<T>(dy_reduce_res, phi::vectorize(y.dims()));
dy->set_impl(dy_tmp.impl()); set_output<T>(dy_tmp, dy);
// dy->set_impl(dy_tmp.impl());
} else { } else {
dy->set_impl(dy_res.impl()); set_output<T>(dy_res, dy);
// dy->set_impl(dy_res.impl());
} }
} // indicate we will compute dy } // indicate we will compute dy
if (dx) { if (dx) {
...@@ -174,9 +180,11 @@ void divide_grad(const Tensor& x, ...@@ -174,9 +180,11 @@ void divide_grad(const Tensor& x,
auto dx_reduce_res = auto dx_reduce_res =
sum<T>(dx_res, phi::vectorize(reduce_dim), x.dtype(), false); sum<T>(dx_res, phi::vectorize(reduce_dim), x.dtype(), false);
auto dx_tmp = reshape<T>(dx_reduce_res, phi::vectorize(x.dims())); auto dx_tmp = reshape<T>(dx_reduce_res, phi::vectorize(x.dims()));
dx->set_impl(dx_tmp.impl()); set_output<T>(dx_tmp, dx);
// dx->set_impl(dx_tmp.impl());
} else { } else {
dx->set_impl(dx_res.impl()); set_output<T>(dx_res, dx);
// dx->set_impl(dx_res.impl());
} }
} // indicate we will compute dx } // indicate we will compute dx
} }
...@@ -187,7 +195,8 @@ void sqrt_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) { ...@@ -187,7 +195,8 @@ void sqrt_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) {
auto div_x = full<T>(phi::vectorize(out.dims()), 0.5); auto div_x = full<T>(phi::vectorize(out.dims()), 0.5);
auto tmp = divide<T>(div_x, out); auto tmp = divide<T>(div_x, out);
auto x_grad_tmp = multiply<T>(out_grad, tmp); auto x_grad_tmp = multiply<T>(out_grad, tmp);
x_grad->set_impl(x_grad_tmp.impl()); set_output<T>(x_grad_tmp, x_grad);
// x_grad->set_impl(x_grad_tmp.impl());
} }
} }
} // namespace prim } // namespace prim
......
...@@ -38,9 +38,19 @@ Tensor empty_like<Tensor>(const paddle::experimental::Tensor& x, ...@@ -38,9 +38,19 @@ Tensor empty_like<Tensor>(const paddle::experimental::Tensor& x,
} }
return empty_like_ad_func(x, dtype, place); return empty_like_ad_func(x, dtype, place);
} }
template <>
void set_output<Tensor>(const paddle::experimental::Tensor& x_tmp,
paddle::experimental::Tensor* x) {
x->set_impl(x_tmp.impl());
x->set_autograd_meta(x_tmp.mutable_autograd_meta());
}
template <> template <>
void by_pass<Tensor>(const paddle::experimental::Tensor& x, Tensor* out) { void by_pass<Tensor>(const paddle::experimental::Tensor& x, Tensor* out) {
out->set_impl(x.impl()); set_output<Tensor>(x, out);
// out->set_impl(x.impl());
} }
} // namespace prim } // namespace prim
} // namespace paddle } // namespace paddle
...@@ -47,6 +47,12 @@ Tensor empty_like<DescTensor>(const Tensor& x, ...@@ -47,6 +47,12 @@ Tensor empty_like<DescTensor>(const Tensor& x,
paddle::experimental::IntArray(x.shape()), x.dtype(), paddle::Place()); paddle::experimental::IntArray(x.shape()), x.dtype(), paddle::Place());
} }
template <>
void set_output<DescTensor>(const paddle::experimental::Tensor& x_tmp,
paddle::experimental::Tensor* x) {
x->set_impl(x_tmp.impl());
}
template <> template <>
void by_pass<DescTensor>(const paddle::experimental::Tensor& x, void by_pass<DescTensor>(const paddle::experimental::Tensor& x,
paddle::experimental::Tensor* out) { paddle::experimental::Tensor* out) {
...@@ -62,7 +68,8 @@ void by_pass<DescTensor>(const paddle::experimental::Tensor& x, ...@@ -62,7 +68,8 @@ void by_pass<DescTensor>(const paddle::experimental::Tensor& x,
op->CheckAttrs(); op->CheckAttrs();
op->InferVarType(block); op->InferVarType(block);
op->InferShape(*block); op->InferShape(*block);
out->set_impl(new_out.impl()); set_output<DescTensor>(new_out, out);
// out->set_impl(new_out.impl());
} }
} // namespace prim } // namespace prim
......
...@@ -36,6 +36,11 @@ paddle::experimental::Tensor empty_like(const paddle::experimental::Tensor& x, ...@@ -36,6 +36,11 @@ paddle::experimental::Tensor empty_like(const paddle::experimental::Tensor& x,
template <typename T> template <typename T>
void by_pass(const paddle::experimental::Tensor& x, void by_pass(const paddle::experimental::Tensor& x,
paddle::experimental::Tensor* out); paddle::experimental::Tensor* out);
template <typename T>
void set_output(const paddle::experimental::Tensor& x_tmp,
paddle::experimental::Tensor* x);
// These method don't need to be specified // These method don't need to be specified
static phi::DDim get_reduce_dims(const phi::DDim& x_dims, static phi::DDim get_reduce_dims(const phi::DDim& x_dims,
const phi::DDim& y_dims) { const phi::DDim& y_dims) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册