提交 d636d72e 编写于 作者: W wangruting

modify name

上级 ea2e2495
......@@ -1839,6 +1839,13 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
False if self.composite_func_info == {} else True
)
if is_composite_grad_api:
next_grad_node_creation_str = f"""
if (!paddle::prim::PrimCommonUtils::IsPrimEnabled()) {{
{next_grad_node_creation_str}
}}
"""
if next_node_generator is not None:
has_higher_order_node = True
return (
......@@ -1850,9 +1857,6 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
next_node_generator.backward_forward_inputs_map,
)
# TODO(Ruting):Integrate invoke and composite as composite so the rest branch canbe covered
# TODO(Ruting): modify next_grad_node_creation_str when Flags_prim_enable deleted in the future
# if is_composite_grad_api:
# next_grad_node_creation_str = ''
elif not is_invoke_forward_api and not is_composite_grad_api:
next_grad_node_creation_str = f""" if(trace_backward) {{
PADDLE_THROW(phi::errors::Unavailable(
......@@ -2275,7 +2279,11 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
# Prepare for Node Creation if Necessary
outputs_autograd_meta_str = ""
compute_require_next_grad_str = ""
if len(next_grad_node_creation_str) > 0 or is_invoke_forward_api:
if (
len(next_grad_node_creation_str) > 0
or is_invoke_forward_api
or inplace_for_grad_outs_str != ''
):
compute_require_next_grad_str = f"{indent}bool trace_backward = egr::Controller::Instance().HasGrad() && create_graph;\n"
# 3. Get Output AutoGradMeta
......
......@@ -31,7 +31,7 @@ void tanh_grad(const Tensor& out, const Tensor& grad_out, Tensor* grad_x) {
auto tmp = pow<T>(out, 2.0);
tmp = scale<T>(tmp, -1.0, 1.0, true);
auto grad_x_tmp = multiply<T>(grad_out, tmp);
grad_x->set_impl(grad_x_tmp.impl());
set_output<T>(grad_x_tmp, grad_x);
}
template <typename T>
void subtract_grad(const Tensor& x,
......@@ -48,7 +48,8 @@ void subtract_grad(const Tensor& x,
auto dy_reduce_res =
sum<T>(scale_out_grad, phi::vectorize(reduce_dim), y.dtype(), false);
auto dy_tmp = reshape<T>(dy_reduce_res, phi::vectorize(y.dims()));
dy->set_impl(dy_tmp.impl());
set_output<T>(dy_tmp, dy);
// dy->set_impl(dy_tmp.impl());
} else {
by_pass<T>(scale_out_grad, dy);
}
......@@ -60,7 +61,8 @@ void subtract_grad(const Tensor& x,
auto dx_reduce_res =
sum<T>(out_grad, phi::vectorize(reduce_dim), x.dtype(), false);
auto dx_tmp = reshape<T>(dx_reduce_res, phi::vectorize(x.dims()));
dx->set_impl(dx_tmp.impl());
set_output<T>(dx_tmp, dx);
// dx->set_impl(dx_tmp.impl());
} else {
by_pass<T>(out_grad, dx);
}
......@@ -81,7 +83,8 @@ void add_grad(const Tensor& x,
auto dy_reduce_res =
sum<T>(out_grad, phi::vectorize(reduce_dim), y.dtype(), false);
auto dy_tmp = reshape<T>(dy_reduce_res, phi::vectorize(y.dims()));
dy->set_impl(dy_tmp.impl());
set_output<T>(dy_tmp, dy);
// dy->set_impl(dy_tmp.impl());
} else {
by_pass<T>(out_grad, dy);
}
......@@ -93,7 +96,8 @@ void add_grad(const Tensor& x,
auto dx_reduce_res =
sum<T>(out_grad, phi::vectorize(reduce_dim), x.dtype(), false);
auto dx_tmp = reshape<T>(dx_reduce_res, phi::vectorize(x.dims()));
dx->set_impl(dx_tmp.impl());
set_output<T>(dx_tmp, dx);
// dx->set_impl(dx_tmp.impl());
} else {
by_pass<T>(out_grad, dx);
}
......@@ -134,8 +138,8 @@ void sum_grad(const Tensor& x,
} else {
x_grad_tmp = expand<T>(out_grad, x_dim);
}
x_grad->set_impl(x_grad_tmp.impl());
set_output<T>(x_grad_tmp, x_grad);
// x_grad->set_impl(x_grad_tmp.impl());
}
template <typename T>
......@@ -158,9 +162,11 @@ void divide_grad(const Tensor& x,
auto dy_reduce_res =
sum<T>(dy_res, phi::vectorize(reduce_dim), y.dtype(), false);
auto dy_tmp = reshape<T>(dy_reduce_res, phi::vectorize(y.dims()));
dy->set_impl(dy_tmp.impl());
set_output<T>(dy_tmp, dy);
// dy->set_impl(dy_tmp.impl());
} else {
dy->set_impl(dy_res.impl());
set_output<T>(dy_res, dy);
// dy->set_impl(dy_res.impl());
}
} // indicate we will compute dy
if (dx) {
......@@ -174,9 +180,11 @@ void divide_grad(const Tensor& x,
auto dx_reduce_res =
sum<T>(dx_res, phi::vectorize(reduce_dim), x.dtype(), false);
auto dx_tmp = reshape<T>(dx_reduce_res, phi::vectorize(x.dims()));
dx->set_impl(dx_tmp.impl());
set_output<T>(dx_tmp, dx);
// dx->set_impl(dx_tmp.impl());
} else {
dx->set_impl(dx_res.impl());
set_output<T>(dx_res, dx);
// dx->set_impl(dx_res.impl());
}
} // indicate we will compute dx
}
......@@ -187,7 +195,8 @@ void sqrt_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) {
auto div_x = full<T>(phi::vectorize(out.dims()), 0.5);
auto tmp = divide<T>(div_x, out);
auto x_grad_tmp = multiply<T>(out_grad, tmp);
x_grad->set_impl(x_grad_tmp.impl());
set_output<T>(x_grad_tmp, x_grad);
// x_grad->set_impl(x_grad_tmp.impl());
}
}
} // namespace prim
......
......@@ -38,9 +38,19 @@ Tensor empty_like<Tensor>(const paddle::experimental::Tensor& x,
}
return empty_like_ad_func(x, dtype, place);
}
template <>
void set_output<Tensor>(const paddle::experimental::Tensor& x_tmp,
paddle::experimental::Tensor* x) {
x->set_impl(x_tmp.impl());
x->set_autograd_meta(x_tmp.mutable_autograd_meta());
}
template <>
void by_pass<Tensor>(const paddle::experimental::Tensor& x, Tensor* out) {
out->set_impl(x.impl());
set_output<Tensor>(x, out);
// out->set_impl(x.impl());
}
} // namespace prim
} // namespace paddle
......@@ -47,6 +47,12 @@ Tensor empty_like<DescTensor>(const Tensor& x,
paddle::experimental::IntArray(x.shape()), x.dtype(), paddle::Place());
}
template <>
void set_output<DescTensor>(const paddle::experimental::Tensor& x_tmp,
paddle::experimental::Tensor* x) {
x->set_impl(x_tmp.impl());
}
template <>
void by_pass<DescTensor>(const paddle::experimental::Tensor& x,
paddle::experimental::Tensor* out) {
......@@ -62,7 +68,8 @@ void by_pass<DescTensor>(const paddle::experimental::Tensor& x,
op->CheckAttrs();
op->InferVarType(block);
op->InferShape(*block);
out->set_impl(new_out.impl());
set_output<DescTensor>(new_out, out);
// out->set_impl(new_out.impl());
}
} // namespace prim
......
......@@ -36,6 +36,11 @@ paddle::experimental::Tensor empty_like(const paddle::experimental::Tensor& x,
template <typename T>
void by_pass(const paddle::experimental::Tensor& x,
paddle::experimental::Tensor* out);
template <typename T>
void set_output(const paddle::experimental::Tensor& x_tmp,
paddle::experimental::Tensor* x);
// These method don't need to be specified
static phi::DDim get_reduce_dims(const phi::DDim& x_dims,
const phi::DDim& y_dims) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册