未验证 提交 67ffb86e 编写于 作者: C chentianyu03 提交者: GitHub

[Phi]Modify reduce arg order (#40706)

* modify out and out_grad order in reduce_grad_kernel

* delete unsed boolReduceKernel

* fix conflict
上级 9793fc5a
......@@ -265,67 +265,6 @@ class ReduceKernel : public framework::OpKernel<T> {
framework::TransToPhiDataType(cast_out_dtype), output);
}
};
template <typename DeviceContext, typename OutT, typename Functor>
class BoolReduceKernel : public framework::OpKernel<OutT> {
public:
void Compute(const framework::ExecutionContext& context) const override {
bool reduce_all = context.Attr<bool>("reduce_all");
auto* input = context.Input<Tensor>("X");
auto* output = context.Output<Tensor>("Out");
output->mutable_data<OutT>(context.GetPlace());
auto dims = context.Attr<std::vector<int>>("dim");
bool keep_dim = context.Attr<bool>("keep_dim");
// The dims has full dim, set the reduce_all is True
const auto& input_dim_size = context.Input<Tensor>("X")->dims().size();
std::set<int> dims_set(dims.begin(), dims.end());
bool full_dim = true;
for (auto i = 0; i < input_dim_size; i++) {
if (dims_set.find(i) == dims_set.end()) {
full_dim = false;
break;
}
}
reduce_all = (reduce_all || full_dim);
if (reduce_all) {
// Flatten and reduce 1-D tensor
auto x = EigenVector<OutT>::Flatten(*input);
auto out = EigenScalar<OutT>::From(*output);
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
auto reduce_dim = Eigen::array<int, 1>({{0}});
Functor functor;
functor(place, &x, &out, reduce_dim);
} else {
int ndim = input->dims().size();
int rdim = dims.size();
// comments for accelerating compiling temporarily.
if (ndim > 6) {
HandleLargeDim<DeviceContext, OutT, Functor>(context, input, output,
dims, keep_dim);
} else {
HANDLE_DIM(6, 5);
HANDLE_DIM(6, 4);
HANDLE_DIM(6, 3);
HANDLE_DIM(6, 2);
HANDLE_DIM(6, 1);
HANDLE_DIM(5, 4);
HANDLE_DIM(5, 3);
HANDLE_DIM(5, 2);
HANDLE_DIM(5, 1);
HANDLE_DIM(4, 3);
HANDLE_DIM(4, 2);
HANDLE_DIM(4, 1);
HANDLE_DIM(3, 2);
HANDLE_DIM(3, 1);
HANDLE_DIM(2, 1);
HANDLE_DIM(1, 1);
}
}
}
};
template <typename DeviceContext, typename T, typename Functor>
void LaunchReduceGradKernel(const framework::ExecutionContext& context,
......
......@@ -99,8 +99,8 @@ void ReduceSumGradKernel(const Context& dev_ctx,
ReduceGradKernel<Context, T, funcs::SumGradFunctor, true>(dev_ctx,
x,
out_grad,
paddle::none,
out_grad,
dims,
keep_dim,
reduce_all,
......@@ -121,8 +121,8 @@ void ReduceMeanGradKernel(const Context& dev_ctx,
DenseTensor* x_grad) {
ReduceGradKernel<Context, T, funcs::MeanGradFunctor, true>(dev_ctx,
x,
out_grad,
paddle::none,
out_grad,
dims,
keep_dim,
reduce_all,
......
......@@ -33,7 +33,7 @@ void FrobeniusNormGradKernel(const Context& ctx,
DataType out_dtype,
DenseTensor* dx) {
ReduceGradKernel<Context, T, funcs::FrobeniusNormGradFunctor>(
ctx, x, dout, out, axis, keep_dim, reduce_all, in_dtype, out_dtype, dx);
ctx, x, out, dout, axis, keep_dim, reduce_all, in_dtype, out_dtype, dx);
}
} // namespace phi
......@@ -87,8 +87,8 @@ template <typename Context,
bool kNoNeedBufferY = false>
void ReduceGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
const paddle::optional<DenseTensor>& out,
const DenseTensor& out_grad,
const std::vector<int64_t>& dims,
bool keep_dim,
bool reduce_all,
......
......@@ -24,8 +24,8 @@ namespace phi {
template <typename T, typename Context>
void ReduceMaxGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
const DenseTensor& out,
const DenseTensor& out_grad,
const std::vector<int64_t>& dims,
bool keep_dim,
bool reduce_all,
......@@ -34,8 +34,8 @@ void ReduceMaxGradKernel(const Context& dev_ctx,
DenseTensor* x_grad) {
ReduceGradKernel<Context, T, funcs::MaxOrMinGradFunctor>(dev_ctx,
x,
out_grad,
out,
out_grad,
dims,
keep_dim,
reduce_all,
......
......@@ -24,8 +24,8 @@ namespace phi {
template <typename T, typename Context>
void ReduceMinGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
const DenseTensor& out,
const DenseTensor& out_grad,
const std::vector<int64_t>& dims,
bool keep_dim,
bool reduce_all,
......@@ -34,8 +34,8 @@ void ReduceMinGradKernel(const Context& dev_ctx,
DenseTensor* x_grad) {
ReduceGradKernel<Context, T, funcs::MaxOrMinGradFunctor>(dev_ctx,
x,
out_grad,
out,
out_grad,
dims,
keep_dim,
reduce_all,
......
......@@ -24,8 +24,8 @@ namespace phi {
template <typename T, typename Context>
void ReduceProdGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
const DenseTensor& out,
const DenseTensor& out_grad,
const std::vector<int64_t>& dims,
bool keep_dim,
bool reduce_all,
......@@ -34,8 +34,8 @@ void ReduceProdGradKernel(const Context& dev_ctx,
DenseTensor* x_grad) {
ReduceGradKernel<Context, T, funcs::ProdGradFunctor>(dev_ctx,
x,
out_grad,
out,
out_grad,
dims,
keep_dim,
reduce_all,
......
......@@ -43,8 +43,8 @@ void ReduceMeanGradKernel(const Context& dev_ctx,
template <typename T, typename Context>
void ReduceProdGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
const DenseTensor& out,
const DenseTensor& out_grad,
const std::vector<int64_t>& dims,
bool keep_dim,
bool reduce_all,
......@@ -55,8 +55,8 @@ void ReduceProdGradKernel(const Context& dev_ctx,
template <typename T, typename Context>
void ReduceMaxGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
const DenseTensor& out,
const DenseTensor& out_grad,
const std::vector<int64_t>& dims,
bool keep_dim,
bool reduce_all,
......@@ -67,8 +67,8 @@ void ReduceMaxGradKernel(const Context& dev_ctx,
template <typename T, typename Context>
void ReduceMinGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
const DenseTensor& out,
const DenseTensor& out_grad,
const std::vector<int64_t>& dims,
bool keep_dim,
bool reduce_all,
......
......@@ -149,7 +149,7 @@ KernelSignature ReduceMaxGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"max_grad",
{"X", GradVarName("Out"), "Out"},
{"X", "Out", GradVarName("Out")},
{"dim", "keep_dim", "reduce_all", "in_dtype", "out_dtype"},
{GradVarName("X")});
}
......@@ -158,7 +158,7 @@ KernelSignature ReduceMinGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"min_grad",
{"X", GradVarName("Out"), "Out"},
{"X", "Out", GradVarName("Out")},
{"dim", "keep_dim", "reduce_all", "in_dtype", "out_dtype"},
{GradVarName("X")});
}
......@@ -167,7 +167,7 @@ KernelSignature ReduceProdGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"prod_grad",
{"X", GradVarName("Out"), "Out"},
{"X", "Out", GradVarName("Out")},
{"dim", "keep_dim", "reduce_all", "in_dtype", "out_dtype"},
{GradVarName("X")});
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册