未验证 提交 0c024cb9 编写于 作者: C chentianyu03 提交者: GitHub

[Phi]Remove in_dtype, out_dtype in redcue grad (#40906)

* remove in_dtype, out_dtype in redcue grad

* set the dtype and layout in noneedbufferInputs func
上级 cadc4e6a
...@@ -573,6 +573,8 @@ void ClearNoNeedBufferInputs(OpBase* op) { ...@@ -573,6 +573,8 @@ void ClearNoNeedBufferInputs(OpBase* op) {
auto& old_tensor = var.Get<framework::LoDTensor>(); auto& old_tensor = var.Get<framework::LoDTensor>();
new_tensor->Resize(old_tensor.dims()); new_tensor->Resize(old_tensor.dims());
new_tensor->set_lod(old_tensor.lod()); new_tensor->set_lod(old_tensor.lod());
new_tensor->set_type(old_tensor.dtype());
new_tensor->set_layout(old_tensor.layout());
each_var.reset(new_var); each_var.reset(new_var);
} }
} }
......
...@@ -79,34 +79,25 @@ void ReduceSumGradKernel(const Context& dev_ctx, ...@@ -79,34 +79,25 @@ void ReduceSumGradKernel(const Context& dev_ctx,
const std::vector<int64_t>& dims, const std::vector<int64_t>& dims,
bool keep_dim, bool keep_dim,
bool reduce_all, bool reduce_all,
DataType in_dtype,
DataType out_dtype,
DenseTensor* x_grad) { DenseTensor* x_grad) {
if (dims.size() == 1) { if (dims.size() == 1) {
if (out_dtype != DataType::UNDEFINED) { if (out_grad.dtype() != x.dtype()) {
DenseTensorMeta x_grad_meta(out_dtype, x_grad->dims(), x_grad->layout()); DenseTensorMeta x_grad_meta(
out_grad.dtype(), x_grad->dims(), x_grad->layout());
DenseTensor x_grad_tmp = DenseTensor x_grad_tmp =
phi::Empty<Context>(dev_ctx, std::move(x_grad_meta)); phi::Empty<Context>(dev_ctx, std::move(x_grad_meta));
ComputeFromInput<T, Context>(dev_ctx, x, out_grad, dims, &x_grad_tmp); ComputeFromInput<T, Context>(dev_ctx, x, out_grad, dims, &x_grad_tmp);
phi::CastKernel<T>(dev_ctx, x_grad_tmp, in_dtype, x_grad); phi::CastKernel<T>(dev_ctx, x_grad_tmp, x.dtype(), x_grad);
} else { } else {
ComputeFromInput<T, Context>(dev_ctx, x, out_grad, dims, x_grad); ComputeFromInput<T, Context>(dev_ctx, x, out_grad, dims, x_grad);
} }
} }
ReduceGradKernel<Context, T, funcs::SumGradFunctor, true>(dev_ctx, ReduceGradKernel<Context, T, funcs::SumGradFunctor, true>(
x, dev_ctx, x, paddle::none, out_grad, dims, keep_dim, reduce_all, x_grad);
paddle::none,
out_grad,
dims,
keep_dim,
reduce_all,
in_dtype,
out_dtype,
x_grad);
} }
template <typename T, typename Context> template <typename T, typename Context>
...@@ -116,19 +107,9 @@ void ReduceMeanGradKernel(const Context& dev_ctx, ...@@ -116,19 +107,9 @@ void ReduceMeanGradKernel(const Context& dev_ctx,
const std::vector<int64_t>& dims, const std::vector<int64_t>& dims,
bool keep_dim, bool keep_dim,
bool reduce_all, bool reduce_all,
DataType in_dtype,
DataType out_dtype,
DenseTensor* x_grad) { DenseTensor* x_grad) {
ReduceGradKernel<Context, T, funcs::MeanGradFunctor, true>(dev_ctx, ReduceGradKernel<Context, T, funcs::MeanGradFunctor, true>(
x, dev_ctx, x, paddle::none, out_grad, dims, keep_dim, reduce_all, x_grad);
paddle::none,
out_grad,
dims,
keep_dim,
reduce_all,
in_dtype,
out_dtype,
x_grad);
} }
} // namespace phi } // namespace phi
......
...@@ -27,7 +27,5 @@ void FrobeniusNormGradKernel(const Context& ctx, ...@@ -27,7 +27,5 @@ void FrobeniusNormGradKernel(const Context& ctx,
const std::vector<int64_t>& axis, const std::vector<int64_t>& axis,
bool keep_dim, bool keep_dim,
bool reduce_all, bool reduce_all,
DataType in_dtype,
DataType out_dtype,
DenseTensor* dx); DenseTensor* dx);
} // namespace phi } // namespace phi
...@@ -52,14 +52,12 @@ void ReduceGradKernel(const Context& dev_ctx, ...@@ -52,14 +52,12 @@ void ReduceGradKernel(const Context& dev_ctx,
const std::vector<int64_t>& dims, const std::vector<int64_t>& dims,
bool keep_dim, bool keep_dim,
bool reduce_all, bool reduce_all,
DataType in_dtype,
DataType out_dtype,
DenseTensor* x_grad) { DenseTensor* x_grad) {
auto* in_x = &x; auto* in_x = &x;
auto* d_out = &out_grad; auto* d_out = &out_grad;
auto* d_x = x_grad; auto* d_x = x_grad;
auto pt_out_dtype = in_dtype; auto pt_out_dtype = x.dtype();
// get reduce_dim and reduce_num for reduce_mean_grad // get reduce_dim and reduce_num for reduce_mean_grad
int dim_size = in_x->dims().size(); int dim_size = in_x->dims().size();
...@@ -76,17 +74,11 @@ void ReduceGradKernel(const Context& dev_ctx, ...@@ -76,17 +74,11 @@ void ReduceGradKernel(const Context& dev_ctx,
DenseTensor new_d_out(d_out->dtype()); DenseTensor new_d_out(d_out->dtype());
new_d_out.ShareDataWith(*d_out); new_d_out.ShareDataWith(*d_out);
new_d_out.Resize(phi::make_ddim(update_dims)); new_d_out.Resize(phi::make_ddim(update_dims));
if (in_dtype != DataType::UNDEFINED) {
dev_ctx.Alloc(d_x, in_dtype); dev_ctx.Alloc(d_x, x.dtype());
} else {
dev_ctx.Alloc(d_x, d_out->dtype());
}
auto pt_d_out = new_d_out; auto pt_d_out = new_d_out;
auto pt_d_x = *d_x; auto pt_d_x = *d_x;
if (in_dtype == DataType::UNDEFINED) {
pt_out_dtype = d_out->dtype();
}
using MPType = typename kps::details::MPTypeTrait<T>::Type; using MPType = typename kps::details::MPTypeTrait<T>::Type;
phi::ReduceGrad<T, TransformOp<T, MPType>>( phi::ReduceGrad<T, TransformOp<T, MPType>>(
......
...@@ -31,18 +31,9 @@ void ReduceSumGradKernel(const Context& dev_ctx, ...@@ -31,18 +31,9 @@ void ReduceSumGradKernel(const Context& dev_ctx,
const std::vector<int64_t>& dims, const std::vector<int64_t>& dims,
bool keep_dim, bool keep_dim,
bool reduce_all, bool reduce_all,
DataType in_dtype,
DataType out_dtype,
DenseTensor* x_grad) { DenseTensor* x_grad) {
ReduceGradKernel<T, Context, kps::IdentityFunctor>(dev_ctx, ReduceGradKernel<T, Context, kps::IdentityFunctor>(
x, dev_ctx, x, out_grad, dims, keep_dim, reduce_all, x_grad);
out_grad,
dims,
keep_dim,
reduce_all,
in_dtype,
out_dtype,
x_grad);
} }
template <typename T, typename Context> template <typename T, typename Context>
...@@ -52,18 +43,9 @@ void ReduceMeanGradKernel(const Context& dev_ctx, ...@@ -52,18 +43,9 @@ void ReduceMeanGradKernel(const Context& dev_ctx,
const std::vector<int64_t>& dims, const std::vector<int64_t>& dims,
bool keep_dim, bool keep_dim,
bool reduce_all, bool reduce_all,
DataType in_dtype,
DataType out_dtype,
DenseTensor* x_grad) { DenseTensor* x_grad) {
ReduceGradKernel<T, Context, kps::DivideFunctor>(dev_ctx, ReduceGradKernel<T, Context, kps::DivideFunctor>(
x, dev_ctx, x, out_grad, dims, keep_dim, reduce_all, x_grad);
out_grad,
dims,
keep_dim,
reduce_all,
in_dtype,
out_dtype,
x_grad);
} }
} // namespace phi } // namespace phi
......
...@@ -29,11 +29,9 @@ void FrobeniusNormGradKernel(const Context& ctx, ...@@ -29,11 +29,9 @@ void FrobeniusNormGradKernel(const Context& ctx,
const std::vector<int64_t>& axis, const std::vector<int64_t>& axis,
bool keep_dim, bool keep_dim,
bool reduce_all, bool reduce_all,
DataType in_dtype,
DataType out_dtype,
DenseTensor* dx) { DenseTensor* dx) {
ReduceGradKernel<Context, T, funcs::FrobeniusNormGradFunctor>( ReduceGradKernel<Context, T, funcs::FrobeniusNormGradFunctor>(
ctx, x, out, dout, axis, keep_dim, reduce_all, in_dtype, out_dtype, dx); ctx, x, out, dout, axis, keep_dim, reduce_all, dx);
} }
} // namespace phi } // namespace phi
...@@ -33,8 +33,6 @@ void ComputeFromInput(const Context& dev_ctx, ...@@ -33,8 +33,6 @@ void ComputeFromInput(const Context& dev_ctx,
const std::vector<int64_t>& dims, const std::vector<int64_t>& dims,
bool keep_dim, bool keep_dim,
bool reduce_all, bool reduce_all,
DataType in_dtype,
DataType out_dtype,
DenseTensor* x_grad) { DenseTensor* x_grad) {
auto* input0 = &x; auto* input0 = &x;
auto* input1 = out.get_ptr(); auto* input1 = out.get_ptr();
...@@ -92,11 +90,10 @@ void ReduceGradKernel(const Context& dev_ctx, ...@@ -92,11 +90,10 @@ void ReduceGradKernel(const Context& dev_ctx,
const std::vector<int64_t>& dims, const std::vector<int64_t>& dims,
bool keep_dim, bool keep_dim,
bool reduce_all, bool reduce_all,
DataType in_dtype,
DataType out_dtype,
DenseTensor* x_grad) { DenseTensor* x_grad) {
if (in_dtype != DataType::UNDEFINED) { if (x.dtype() != out_grad.dtype()) {
DenseTensorMeta x_grad_meta(out_dtype, x_grad->dims(), x_grad->layout()); DenseTensorMeta x_grad_meta(
out_grad.dtype(), x_grad->dims(), x_grad->layout());
DenseTensor x_grad_tmp = DenseTensor x_grad_tmp =
phi::Empty<Context>(dev_ctx, std::move(x_grad_meta)); phi::Empty<Context>(dev_ctx, std::move(x_grad_meta));
ComputeFromInput<Context, T, Functor, kNoNeedBufferX, kNoNeedBufferY>( ComputeFromInput<Context, T, Functor, kNoNeedBufferX, kNoNeedBufferY>(
...@@ -108,11 +105,9 @@ void ReduceGradKernel(const Context& dev_ctx, ...@@ -108,11 +105,9 @@ void ReduceGradKernel(const Context& dev_ctx,
dims, dims,
keep_dim, keep_dim,
reduce_all, reduce_all,
in_dtype,
out_dtype,
&x_grad_tmp); &x_grad_tmp);
phi::CastKernel<T>(dev_ctx, x_grad_tmp, in_dtype, x_grad); phi::CastKernel<T>(dev_ctx, x_grad_tmp, x.dtype(), x_grad);
} else { } else {
ComputeFromInput<Context, T, Functor, kNoNeedBufferX, kNoNeedBufferY>( ComputeFromInput<Context, T, Functor, kNoNeedBufferX, kNoNeedBufferY>(
dev_ctx, dev_ctx,
...@@ -123,8 +118,6 @@ void ReduceGradKernel(const Context& dev_ctx, ...@@ -123,8 +118,6 @@ void ReduceGradKernel(const Context& dev_ctx,
dims, dims,
keep_dim, keep_dim,
reduce_all, reduce_all,
in_dtype,
out_dtype,
x_grad); x_grad);
} }
} }
......
...@@ -29,19 +29,9 @@ void ReduceMaxGradKernel(const Context& dev_ctx, ...@@ -29,19 +29,9 @@ void ReduceMaxGradKernel(const Context& dev_ctx,
const std::vector<int64_t>& dims, const std::vector<int64_t>& dims,
bool keep_dim, bool keep_dim,
bool reduce_all, bool reduce_all,
DataType in_dtype,
DataType out_dtype,
DenseTensor* x_grad) { DenseTensor* x_grad) {
ReduceGradKernel<Context, T, funcs::MaxOrMinGradFunctor>(dev_ctx, ReduceGradKernel<Context, T, funcs::MaxOrMinGradFunctor>(
x, dev_ctx, x, out, out_grad, dims, keep_dim, reduce_all, x_grad);
out,
out_grad,
dims,
keep_dim,
reduce_all,
in_dtype,
out_dtype,
x_grad);
} }
} // namespace phi } // namespace phi
...@@ -29,19 +29,9 @@ void ReduceMinGradKernel(const Context& dev_ctx, ...@@ -29,19 +29,9 @@ void ReduceMinGradKernel(const Context& dev_ctx,
const std::vector<int64_t>& dims, const std::vector<int64_t>& dims,
bool keep_dim, bool keep_dim,
bool reduce_all, bool reduce_all,
DataType in_dtype,
DataType out_dtype,
DenseTensor* x_grad) { DenseTensor* x_grad) {
ReduceGradKernel<Context, T, funcs::MaxOrMinGradFunctor>(dev_ctx, ReduceGradKernel<Context, T, funcs::MaxOrMinGradFunctor>(
x, dev_ctx, x, out, out_grad, dims, keep_dim, reduce_all, x_grad);
out,
out_grad,
dims,
keep_dim,
reduce_all,
in_dtype,
out_dtype,
x_grad);
} }
} // namespace phi } // namespace phi
...@@ -29,19 +29,9 @@ void ReduceProdGradKernel(const Context& dev_ctx, ...@@ -29,19 +29,9 @@ void ReduceProdGradKernel(const Context& dev_ctx,
const std::vector<int64_t>& dims, const std::vector<int64_t>& dims,
bool keep_dim, bool keep_dim,
bool reduce_all, bool reduce_all,
DataType in_dtype,
DataType out_dtype,
DenseTensor* x_grad) { DenseTensor* x_grad) {
ReduceGradKernel<Context, T, funcs::ProdGradFunctor>(dev_ctx, ReduceGradKernel<Context, T, funcs::ProdGradFunctor>(
x, dev_ctx, x, out, out_grad, dims, keep_dim, reduce_all, x_grad);
out,
out_grad,
dims,
keep_dim,
reduce_all,
in_dtype,
out_dtype,
x_grad);
} }
} // namespace phi } // namespace phi
...@@ -25,8 +25,6 @@ void ReduceSumGradKernel(const Context& dev_ctx, ...@@ -25,8 +25,6 @@ void ReduceSumGradKernel(const Context& dev_ctx,
const std::vector<int64_t>& dims, const std::vector<int64_t>& dims,
bool keep_dim, bool keep_dim,
bool reduce_all, bool reduce_all,
DataType in_dtype,
DataType out_dtype,
DenseTensor* x_grad); DenseTensor* x_grad);
template <typename T, typename Context> template <typename T, typename Context>
...@@ -36,8 +34,6 @@ void ReduceMeanGradKernel(const Context& dev_ctx, ...@@ -36,8 +34,6 @@ void ReduceMeanGradKernel(const Context& dev_ctx,
const std::vector<int64_t>& dims, const std::vector<int64_t>& dims,
bool keep_dim, bool keep_dim,
bool reduce_all, bool reduce_all,
DataType in_dtype,
DataType out_dtype,
DenseTensor* x_grad); DenseTensor* x_grad);
template <typename T, typename Context> template <typename T, typename Context>
...@@ -48,8 +44,6 @@ void ReduceProdGradKernel(const Context& dev_ctx, ...@@ -48,8 +44,6 @@ void ReduceProdGradKernel(const Context& dev_ctx,
const std::vector<int64_t>& dims, const std::vector<int64_t>& dims,
bool keep_dim, bool keep_dim,
bool reduce_all, bool reduce_all,
DataType in_dtype,
DataType out_dtype,
DenseTensor* x_grad); DenseTensor* x_grad);
template <typename T, typename Context> template <typename T, typename Context>
...@@ -60,8 +54,6 @@ void ReduceMaxGradKernel(const Context& dev_ctx, ...@@ -60,8 +54,6 @@ void ReduceMaxGradKernel(const Context& dev_ctx,
const std::vector<int64_t>& dims, const std::vector<int64_t>& dims,
bool keep_dim, bool keep_dim,
bool reduce_all, bool reduce_all,
DataType in_dtype,
DataType out_dtype,
DenseTensor* x_grad); DenseTensor* x_grad);
template <typename T, typename Context> template <typename T, typename Context>
...@@ -72,8 +64,6 @@ void ReduceMinGradKernel(const Context& dev_ctx, ...@@ -72,8 +64,6 @@ void ReduceMinGradKernel(const Context& dev_ctx,
const std::vector<int64_t>& dims, const std::vector<int64_t>& dims,
bool keep_dim, bool keep_dim,
bool reduce_all, bool reduce_all,
DataType in_dtype,
DataType out_dtype,
DenseTensor* x_grad); DenseTensor* x_grad);
} // namespace phi } // namespace phi
...@@ -24,11 +24,10 @@ KernelSignature FrobeniusNormOpArgumentMapping( ...@@ -24,11 +24,10 @@ KernelSignature FrobeniusNormOpArgumentMapping(
KernelSignature FrobeniusNormGradOpArgumentMapping( KernelSignature FrobeniusNormGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature( return KernelSignature("frobenius_norm_grad",
"frobenius_norm_grad", {"X", "Out", GradVarName("Out")},
{"X", "Out", GradVarName("Out")}, {"dim", "keep_dim", "reduce_all"},
{"dim", "keep_dim", "reduce_all", "in_dtype", "out_dtype"}, {GradVarName("X")});
{GradVarName("X")});
} }
} // namespace phi } // namespace phi
......
...@@ -129,47 +129,42 @@ KernelSignature ReduceAllOpArgumentMapping(const ArgumentMappingContext& ctx) { ...@@ -129,47 +129,42 @@ KernelSignature ReduceAllOpArgumentMapping(const ArgumentMappingContext& ctx) {
KernelSignature ReduceSumGradOpArgumentMapping( KernelSignature ReduceSumGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature( return KernelSignature("sum_grad",
"sum_grad", {"X", GradVarName("Out")},
{"X", GradVarName("Out")}, {"dim", "keep_dim", "reduce_all"},
{"dim", "keep_dim", "reduce_all", "in_dtype", "out_dtype"}, {GradVarName("X")});
{GradVarName("X")});
} }
KernelSignature ReduceMeanGradOpArgumentMapping( KernelSignature ReduceMeanGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature( return KernelSignature("mean_grad",
"mean_grad", {"X", GradVarName("Out")},
{"X", GradVarName("Out")}, {"dim", "keep_dim", "reduce_all"},
{"dim", "keep_dim", "reduce_all", "in_dtype", "out_dtype"}, {GradVarName("X")});
{GradVarName("X")});
} }
KernelSignature ReduceMaxGradOpArgumentMapping( KernelSignature ReduceMaxGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature( return KernelSignature("max_grad",
"max_grad", {"X", "Out", GradVarName("Out")},
{"X", "Out", GradVarName("Out")}, {"dim", "keep_dim", "reduce_all"},
{"dim", "keep_dim", "reduce_all", "in_dtype", "out_dtype"}, {GradVarName("X")});
{GradVarName("X")});
} }
KernelSignature ReduceMinGradOpArgumentMapping( KernelSignature ReduceMinGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature( return KernelSignature("min_grad",
"min_grad", {"X", "Out", GradVarName("Out")},
{"X", "Out", GradVarName("Out")}, {"dim", "keep_dim", "reduce_all"},
{"dim", "keep_dim", "reduce_all", "in_dtype", "out_dtype"}, {GradVarName("X")});
{GradVarName("X")});
} }
KernelSignature ReduceProdGradOpArgumentMapping( KernelSignature ReduceProdGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature( return KernelSignature("prod_grad",
"prod_grad", {"X", "Out", GradVarName("Out")},
{"X", "Out", GradVarName("Out")}, {"dim", "keep_dim", "reduce_all"},
{"dim", "keep_dim", "reduce_all", "in_dtype", "out_dtype"}, {GradVarName("X")});
{GradVarName("X")});
} }
} // namespace phi } // namespace phi
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册