未验证 提交 0c024cb9 编写于 作者: C chentianyu03 提交者: GitHub

[Phi]Remove in_dtype, out_dtype in redcue grad (#40906)

* remove in_dtype, out_dtype in redcue grad

* set the dtype and layout in noneedbufferInputs func
上级 cadc4e6a
......@@ -573,6 +573,8 @@ void ClearNoNeedBufferInputs(OpBase* op) {
auto& old_tensor = var.Get<framework::LoDTensor>();
new_tensor->Resize(old_tensor.dims());
new_tensor->set_lod(old_tensor.lod());
new_tensor->set_type(old_tensor.dtype());
new_tensor->set_layout(old_tensor.layout());
each_var.reset(new_var);
}
}
......
......@@ -79,34 +79,25 @@ void ReduceSumGradKernel(const Context& dev_ctx,
const std::vector<int64_t>& dims,
bool keep_dim,
bool reduce_all,
DataType in_dtype,
DataType out_dtype,
DenseTensor* x_grad) {
if (dims.size() == 1) {
if (out_dtype != DataType::UNDEFINED) {
DenseTensorMeta x_grad_meta(out_dtype, x_grad->dims(), x_grad->layout());
if (out_grad.dtype() != x.dtype()) {
DenseTensorMeta x_grad_meta(
out_grad.dtype(), x_grad->dims(), x_grad->layout());
DenseTensor x_grad_tmp =
phi::Empty<Context>(dev_ctx, std::move(x_grad_meta));
ComputeFromInput<T, Context>(dev_ctx, x, out_grad, dims, &x_grad_tmp);
phi::CastKernel<T>(dev_ctx, x_grad_tmp, in_dtype, x_grad);
phi::CastKernel<T>(dev_ctx, x_grad_tmp, x.dtype(), x_grad);
} else {
ComputeFromInput<T, Context>(dev_ctx, x, out_grad, dims, x_grad);
}
}
ReduceGradKernel<Context, T, funcs::SumGradFunctor, true>(dev_ctx,
x,
paddle::none,
out_grad,
dims,
keep_dim,
reduce_all,
in_dtype,
out_dtype,
x_grad);
ReduceGradKernel<Context, T, funcs::SumGradFunctor, true>(
dev_ctx, x, paddle::none, out_grad, dims, keep_dim, reduce_all, x_grad);
}
template <typename T, typename Context>
......@@ -116,19 +107,9 @@ void ReduceMeanGradKernel(const Context& dev_ctx,
const std::vector<int64_t>& dims,
bool keep_dim,
bool reduce_all,
DataType in_dtype,
DataType out_dtype,
DenseTensor* x_grad) {
ReduceGradKernel<Context, T, funcs::MeanGradFunctor, true>(dev_ctx,
x,
paddle::none,
out_grad,
dims,
keep_dim,
reduce_all,
in_dtype,
out_dtype,
x_grad);
ReduceGradKernel<Context, T, funcs::MeanGradFunctor, true>(
dev_ctx, x, paddle::none, out_grad, dims, keep_dim, reduce_all, x_grad);
}
} // namespace phi
......
......@@ -27,7 +27,5 @@ void FrobeniusNormGradKernel(const Context& ctx,
const std::vector<int64_t>& axis,
bool keep_dim,
bool reduce_all,
DataType in_dtype,
DataType out_dtype,
DenseTensor* dx);
} // namespace phi
......@@ -52,14 +52,12 @@ void ReduceGradKernel(const Context& dev_ctx,
const std::vector<int64_t>& dims,
bool keep_dim,
bool reduce_all,
DataType in_dtype,
DataType out_dtype,
DenseTensor* x_grad) {
auto* in_x = &x;
auto* d_out = &out_grad;
auto* d_x = x_grad;
auto pt_out_dtype = in_dtype;
auto pt_out_dtype = x.dtype();
// get reduce_dim and reduce_num for reduce_mean_grad
int dim_size = in_x->dims().size();
......@@ -76,17 +74,11 @@ void ReduceGradKernel(const Context& dev_ctx,
DenseTensor new_d_out(d_out->dtype());
new_d_out.ShareDataWith(*d_out);
new_d_out.Resize(phi::make_ddim(update_dims));
if (in_dtype != DataType::UNDEFINED) {
dev_ctx.Alloc(d_x, in_dtype);
} else {
dev_ctx.Alloc(d_x, d_out->dtype());
}
dev_ctx.Alloc(d_x, x.dtype());
auto pt_d_out = new_d_out;
auto pt_d_x = *d_x;
if (in_dtype == DataType::UNDEFINED) {
pt_out_dtype = d_out->dtype();
}
using MPType = typename kps::details::MPTypeTrait<T>::Type;
phi::ReduceGrad<T, TransformOp<T, MPType>>(
......
......@@ -31,18 +31,9 @@ void ReduceSumGradKernel(const Context& dev_ctx,
const std::vector<int64_t>& dims,
bool keep_dim,
bool reduce_all,
DataType in_dtype,
DataType out_dtype,
DenseTensor* x_grad) {
ReduceGradKernel<T, Context, kps::IdentityFunctor>(dev_ctx,
x,
out_grad,
dims,
keep_dim,
reduce_all,
in_dtype,
out_dtype,
x_grad);
ReduceGradKernel<T, Context, kps::IdentityFunctor>(
dev_ctx, x, out_grad, dims, keep_dim, reduce_all, x_grad);
}
template <typename T, typename Context>
......@@ -52,18 +43,9 @@ void ReduceMeanGradKernel(const Context& dev_ctx,
const std::vector<int64_t>& dims,
bool keep_dim,
bool reduce_all,
DataType in_dtype,
DataType out_dtype,
DenseTensor* x_grad) {
ReduceGradKernel<T, Context, kps::DivideFunctor>(dev_ctx,
x,
out_grad,
dims,
keep_dim,
reduce_all,
in_dtype,
out_dtype,
x_grad);
ReduceGradKernel<T, Context, kps::DivideFunctor>(
dev_ctx, x, out_grad, dims, keep_dim, reduce_all, x_grad);
}
} // namespace phi
......
......@@ -29,11 +29,9 @@ void FrobeniusNormGradKernel(const Context& ctx,
const std::vector<int64_t>& axis,
bool keep_dim,
bool reduce_all,
DataType in_dtype,
DataType out_dtype,
DenseTensor* dx) {
ReduceGradKernel<Context, T, funcs::FrobeniusNormGradFunctor>(
ctx, x, out, dout, axis, keep_dim, reduce_all, in_dtype, out_dtype, dx);
ctx, x, out, dout, axis, keep_dim, reduce_all, dx);
}
} // namespace phi
......@@ -33,8 +33,6 @@ void ComputeFromInput(const Context& dev_ctx,
const std::vector<int64_t>& dims,
bool keep_dim,
bool reduce_all,
DataType in_dtype,
DataType out_dtype,
DenseTensor* x_grad) {
auto* input0 = &x;
auto* input1 = out.get_ptr();
......@@ -92,11 +90,10 @@ void ReduceGradKernel(const Context& dev_ctx,
const std::vector<int64_t>& dims,
bool keep_dim,
bool reduce_all,
DataType in_dtype,
DataType out_dtype,
DenseTensor* x_grad) {
if (in_dtype != DataType::UNDEFINED) {
DenseTensorMeta x_grad_meta(out_dtype, x_grad->dims(), x_grad->layout());
if (x.dtype() != out_grad.dtype()) {
DenseTensorMeta x_grad_meta(
out_grad.dtype(), x_grad->dims(), x_grad->layout());
DenseTensor x_grad_tmp =
phi::Empty<Context>(dev_ctx, std::move(x_grad_meta));
ComputeFromInput<Context, T, Functor, kNoNeedBufferX, kNoNeedBufferY>(
......@@ -108,11 +105,9 @@ void ReduceGradKernel(const Context& dev_ctx,
dims,
keep_dim,
reduce_all,
in_dtype,
out_dtype,
&x_grad_tmp);
phi::CastKernel<T>(dev_ctx, x_grad_tmp, in_dtype, x_grad);
phi::CastKernel<T>(dev_ctx, x_grad_tmp, x.dtype(), x_grad);
} else {
ComputeFromInput<Context, T, Functor, kNoNeedBufferX, kNoNeedBufferY>(
dev_ctx,
......@@ -123,8 +118,6 @@ void ReduceGradKernel(const Context& dev_ctx,
dims,
keep_dim,
reduce_all,
in_dtype,
out_dtype,
x_grad);
}
}
......
......@@ -29,19 +29,9 @@ void ReduceMaxGradKernel(const Context& dev_ctx,
const std::vector<int64_t>& dims,
bool keep_dim,
bool reduce_all,
DataType in_dtype,
DataType out_dtype,
DenseTensor* x_grad) {
ReduceGradKernel<Context, T, funcs::MaxOrMinGradFunctor>(dev_ctx,
x,
out,
out_grad,
dims,
keep_dim,
reduce_all,
in_dtype,
out_dtype,
x_grad);
ReduceGradKernel<Context, T, funcs::MaxOrMinGradFunctor>(
dev_ctx, x, out, out_grad, dims, keep_dim, reduce_all, x_grad);
}
} // namespace phi
......@@ -29,19 +29,9 @@ void ReduceMinGradKernel(const Context& dev_ctx,
const std::vector<int64_t>& dims,
bool keep_dim,
bool reduce_all,
DataType in_dtype,
DataType out_dtype,
DenseTensor* x_grad) {
ReduceGradKernel<Context, T, funcs::MaxOrMinGradFunctor>(dev_ctx,
x,
out,
out_grad,
dims,
keep_dim,
reduce_all,
in_dtype,
out_dtype,
x_grad);
ReduceGradKernel<Context, T, funcs::MaxOrMinGradFunctor>(
dev_ctx, x, out, out_grad, dims, keep_dim, reduce_all, x_grad);
}
} // namespace phi
......@@ -29,19 +29,9 @@ void ReduceProdGradKernel(const Context& dev_ctx,
const std::vector<int64_t>& dims,
bool keep_dim,
bool reduce_all,
DataType in_dtype,
DataType out_dtype,
DenseTensor* x_grad) {
ReduceGradKernel<Context, T, funcs::ProdGradFunctor>(dev_ctx,
x,
out,
out_grad,
dims,
keep_dim,
reduce_all,
in_dtype,
out_dtype,
x_grad);
ReduceGradKernel<Context, T, funcs::ProdGradFunctor>(
dev_ctx, x, out, out_grad, dims, keep_dim, reduce_all, x_grad);
}
} // namespace phi
......@@ -25,8 +25,6 @@ void ReduceSumGradKernel(const Context& dev_ctx,
const std::vector<int64_t>& dims,
bool keep_dim,
bool reduce_all,
DataType in_dtype,
DataType out_dtype,
DenseTensor* x_grad);
template <typename T, typename Context>
......@@ -36,8 +34,6 @@ void ReduceMeanGradKernel(const Context& dev_ctx,
const std::vector<int64_t>& dims,
bool keep_dim,
bool reduce_all,
DataType in_dtype,
DataType out_dtype,
DenseTensor* x_grad);
template <typename T, typename Context>
......@@ -48,8 +44,6 @@ void ReduceProdGradKernel(const Context& dev_ctx,
const std::vector<int64_t>& dims,
bool keep_dim,
bool reduce_all,
DataType in_dtype,
DataType out_dtype,
DenseTensor* x_grad);
template <typename T, typename Context>
......@@ -60,8 +54,6 @@ void ReduceMaxGradKernel(const Context& dev_ctx,
const std::vector<int64_t>& dims,
bool keep_dim,
bool reduce_all,
DataType in_dtype,
DataType out_dtype,
DenseTensor* x_grad);
template <typename T, typename Context>
......@@ -72,8 +64,6 @@ void ReduceMinGradKernel(const Context& dev_ctx,
const std::vector<int64_t>& dims,
bool keep_dim,
bool reduce_all,
DataType in_dtype,
DataType out_dtype,
DenseTensor* x_grad);
} // namespace phi
......@@ -24,10 +24,9 @@ KernelSignature FrobeniusNormOpArgumentMapping(
KernelSignature FrobeniusNormGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"frobenius_norm_grad",
return KernelSignature("frobenius_norm_grad",
{"X", "Out", GradVarName("Out")},
{"dim", "keep_dim", "reduce_all", "in_dtype", "out_dtype"},
{"dim", "keep_dim", "reduce_all"},
{GradVarName("X")});
}
......
......@@ -129,46 +129,41 @@ KernelSignature ReduceAllOpArgumentMapping(const ArgumentMappingContext& ctx) {
KernelSignature ReduceSumGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"sum_grad",
return KernelSignature("sum_grad",
{"X", GradVarName("Out")},
{"dim", "keep_dim", "reduce_all", "in_dtype", "out_dtype"},
{"dim", "keep_dim", "reduce_all"},
{GradVarName("X")});
}
KernelSignature ReduceMeanGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"mean_grad",
return KernelSignature("mean_grad",
{"X", GradVarName("Out")},
{"dim", "keep_dim", "reduce_all", "in_dtype", "out_dtype"},
{"dim", "keep_dim", "reduce_all"},
{GradVarName("X")});
}
KernelSignature ReduceMaxGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"max_grad",
return KernelSignature("max_grad",
{"X", "Out", GradVarName("Out")},
{"dim", "keep_dim", "reduce_all", "in_dtype", "out_dtype"},
{"dim", "keep_dim", "reduce_all"},
{GradVarName("X")});
}
KernelSignature ReduceMinGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"min_grad",
return KernelSignature("min_grad",
{"X", "Out", GradVarName("Out")},
{"dim", "keep_dim", "reduce_all", "in_dtype", "out_dtype"},
{"dim", "keep_dim", "reduce_all"},
{GradVarName("X")});
}
KernelSignature ReduceProdGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"prod_grad",
return KernelSignature("prod_grad",
{"X", "Out", GradVarName("Out")},
{"dim", "keep_dim", "reduce_all", "in_dtype", "out_dtype"},
{"dim", "keep_dim", "reduce_all"},
{GradVarName("X")});
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册