未验证 提交 67ffb86e 编写于 作者: C chentianyu03 提交者: GitHub

[Phi]Modify reduce arg order (#40706)

* modify out and out_grad order in reduce_grad_kernel

* delete unsed boolReduceKernel

* fix conflict
上级 9793fc5a
...@@ -265,67 +265,6 @@ class ReduceKernel : public framework::OpKernel<T> { ...@@ -265,67 +265,6 @@ class ReduceKernel : public framework::OpKernel<T> {
framework::TransToPhiDataType(cast_out_dtype), output); framework::TransToPhiDataType(cast_out_dtype), output);
} }
}; };
template <typename DeviceContext, typename OutT, typename Functor>
class BoolReduceKernel : public framework::OpKernel<OutT> {
public:
void Compute(const framework::ExecutionContext& context) const override {
bool reduce_all = context.Attr<bool>("reduce_all");
auto* input = context.Input<Tensor>("X");
auto* output = context.Output<Tensor>("Out");
output->mutable_data<OutT>(context.GetPlace());
auto dims = context.Attr<std::vector<int>>("dim");
bool keep_dim = context.Attr<bool>("keep_dim");
// The dims has full dim, set the reduce_all is True
const auto& input_dim_size = context.Input<Tensor>("X")->dims().size();
std::set<int> dims_set(dims.begin(), dims.end());
bool full_dim = true;
for (auto i = 0; i < input_dim_size; i++) {
if (dims_set.find(i) == dims_set.end()) {
full_dim = false;
break;
}
}
reduce_all = (reduce_all || full_dim);
if (reduce_all) {
// Flatten and reduce 1-D tensor
auto x = EigenVector<OutT>::Flatten(*input);
auto out = EigenScalar<OutT>::From(*output);
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
auto reduce_dim = Eigen::array<int, 1>({{0}});
Functor functor;
functor(place, &x, &out, reduce_dim);
} else {
int ndim = input->dims().size();
int rdim = dims.size();
// comments for accelerating compiling temporarily.
if (ndim > 6) {
HandleLargeDim<DeviceContext, OutT, Functor>(context, input, output,
dims, keep_dim);
} else {
HANDLE_DIM(6, 5);
HANDLE_DIM(6, 4);
HANDLE_DIM(6, 3);
HANDLE_DIM(6, 2);
HANDLE_DIM(6, 1);
HANDLE_DIM(5, 4);
HANDLE_DIM(5, 3);
HANDLE_DIM(5, 2);
HANDLE_DIM(5, 1);
HANDLE_DIM(4, 3);
HANDLE_DIM(4, 2);
HANDLE_DIM(4, 1);
HANDLE_DIM(3, 2);
HANDLE_DIM(3, 1);
HANDLE_DIM(2, 1);
HANDLE_DIM(1, 1);
}
}
}
};
template <typename DeviceContext, typename T, typename Functor> template <typename DeviceContext, typename T, typename Functor>
void LaunchReduceGradKernel(const framework::ExecutionContext& context, void LaunchReduceGradKernel(const framework::ExecutionContext& context,
......
...@@ -99,8 +99,8 @@ void ReduceSumGradKernel(const Context& dev_ctx, ...@@ -99,8 +99,8 @@ void ReduceSumGradKernel(const Context& dev_ctx,
ReduceGradKernel<Context, T, funcs::SumGradFunctor, true>(dev_ctx, ReduceGradKernel<Context, T, funcs::SumGradFunctor, true>(dev_ctx,
x, x,
out_grad,
paddle::none, paddle::none,
out_grad,
dims, dims,
keep_dim, keep_dim,
reduce_all, reduce_all,
...@@ -121,8 +121,8 @@ void ReduceMeanGradKernel(const Context& dev_ctx, ...@@ -121,8 +121,8 @@ void ReduceMeanGradKernel(const Context& dev_ctx,
DenseTensor* x_grad) { DenseTensor* x_grad) {
ReduceGradKernel<Context, T, funcs::MeanGradFunctor, true>(dev_ctx, ReduceGradKernel<Context, T, funcs::MeanGradFunctor, true>(dev_ctx,
x, x,
out_grad,
paddle::none, paddle::none,
out_grad,
dims, dims,
keep_dim, keep_dim,
reduce_all, reduce_all,
......
...@@ -33,7 +33,7 @@ void FrobeniusNormGradKernel(const Context& ctx, ...@@ -33,7 +33,7 @@ void FrobeniusNormGradKernel(const Context& ctx,
DataType out_dtype, DataType out_dtype,
DenseTensor* dx) { DenseTensor* dx) {
ReduceGradKernel<Context, T, funcs::FrobeniusNormGradFunctor>( ReduceGradKernel<Context, T, funcs::FrobeniusNormGradFunctor>(
ctx, x, dout, out, axis, keep_dim, reduce_all, in_dtype, out_dtype, dx); ctx, x, out, dout, axis, keep_dim, reduce_all, in_dtype, out_dtype, dx);
} }
} // namespace phi } // namespace phi
...@@ -87,8 +87,8 @@ template <typename Context, ...@@ -87,8 +87,8 @@ template <typename Context,
bool kNoNeedBufferY = false> bool kNoNeedBufferY = false>
void ReduceGradKernel(const Context& dev_ctx, void ReduceGradKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& out_grad,
const paddle::optional<DenseTensor>& out, const paddle::optional<DenseTensor>& out,
const DenseTensor& out_grad,
const std::vector<int64_t>& dims, const std::vector<int64_t>& dims,
bool keep_dim, bool keep_dim,
bool reduce_all, bool reduce_all,
......
...@@ -24,8 +24,8 @@ namespace phi { ...@@ -24,8 +24,8 @@ namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
void ReduceMaxGradKernel(const Context& dev_ctx, void ReduceMaxGradKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& out_grad,
const DenseTensor& out, const DenseTensor& out,
const DenseTensor& out_grad,
const std::vector<int64_t>& dims, const std::vector<int64_t>& dims,
bool keep_dim, bool keep_dim,
bool reduce_all, bool reduce_all,
...@@ -34,8 +34,8 @@ void ReduceMaxGradKernel(const Context& dev_ctx, ...@@ -34,8 +34,8 @@ void ReduceMaxGradKernel(const Context& dev_ctx,
DenseTensor* x_grad) { DenseTensor* x_grad) {
ReduceGradKernel<Context, T, funcs::MaxOrMinGradFunctor>(dev_ctx, ReduceGradKernel<Context, T, funcs::MaxOrMinGradFunctor>(dev_ctx,
x, x,
out_grad,
out, out,
out_grad,
dims, dims,
keep_dim, keep_dim,
reduce_all, reduce_all,
......
...@@ -24,8 +24,8 @@ namespace phi { ...@@ -24,8 +24,8 @@ namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
void ReduceMinGradKernel(const Context& dev_ctx, void ReduceMinGradKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& out_grad,
const DenseTensor& out, const DenseTensor& out,
const DenseTensor& out_grad,
const std::vector<int64_t>& dims, const std::vector<int64_t>& dims,
bool keep_dim, bool keep_dim,
bool reduce_all, bool reduce_all,
...@@ -34,8 +34,8 @@ void ReduceMinGradKernel(const Context& dev_ctx, ...@@ -34,8 +34,8 @@ void ReduceMinGradKernel(const Context& dev_ctx,
DenseTensor* x_grad) { DenseTensor* x_grad) {
ReduceGradKernel<Context, T, funcs::MaxOrMinGradFunctor>(dev_ctx, ReduceGradKernel<Context, T, funcs::MaxOrMinGradFunctor>(dev_ctx,
x, x,
out_grad,
out, out,
out_grad,
dims, dims,
keep_dim, keep_dim,
reduce_all, reduce_all,
......
...@@ -24,8 +24,8 @@ namespace phi { ...@@ -24,8 +24,8 @@ namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
void ReduceProdGradKernel(const Context& dev_ctx, void ReduceProdGradKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& out_grad,
const DenseTensor& out, const DenseTensor& out,
const DenseTensor& out_grad,
const std::vector<int64_t>& dims, const std::vector<int64_t>& dims,
bool keep_dim, bool keep_dim,
bool reduce_all, bool reduce_all,
...@@ -34,8 +34,8 @@ void ReduceProdGradKernel(const Context& dev_ctx, ...@@ -34,8 +34,8 @@ void ReduceProdGradKernel(const Context& dev_ctx,
DenseTensor* x_grad) { DenseTensor* x_grad) {
ReduceGradKernel<Context, T, funcs::ProdGradFunctor>(dev_ctx, ReduceGradKernel<Context, T, funcs::ProdGradFunctor>(dev_ctx,
x, x,
out_grad,
out, out,
out_grad,
dims, dims,
keep_dim, keep_dim,
reduce_all, reduce_all,
......
...@@ -43,8 +43,8 @@ void ReduceMeanGradKernel(const Context& dev_ctx, ...@@ -43,8 +43,8 @@ void ReduceMeanGradKernel(const Context& dev_ctx,
template <typename T, typename Context> template <typename T, typename Context>
void ReduceProdGradKernel(const Context& dev_ctx, void ReduceProdGradKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& out_grad,
const DenseTensor& out, const DenseTensor& out,
const DenseTensor& out_grad,
const std::vector<int64_t>& dims, const std::vector<int64_t>& dims,
bool keep_dim, bool keep_dim,
bool reduce_all, bool reduce_all,
...@@ -55,8 +55,8 @@ void ReduceProdGradKernel(const Context& dev_ctx, ...@@ -55,8 +55,8 @@ void ReduceProdGradKernel(const Context& dev_ctx,
template <typename T, typename Context> template <typename T, typename Context>
void ReduceMaxGradKernel(const Context& dev_ctx, void ReduceMaxGradKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& out_grad,
const DenseTensor& out, const DenseTensor& out,
const DenseTensor& out_grad,
const std::vector<int64_t>& dims, const std::vector<int64_t>& dims,
bool keep_dim, bool keep_dim,
bool reduce_all, bool reduce_all,
...@@ -67,8 +67,8 @@ void ReduceMaxGradKernel(const Context& dev_ctx, ...@@ -67,8 +67,8 @@ void ReduceMaxGradKernel(const Context& dev_ctx,
template <typename T, typename Context> template <typename T, typename Context>
void ReduceMinGradKernel(const Context& dev_ctx, void ReduceMinGradKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& out_grad,
const DenseTensor& out, const DenseTensor& out,
const DenseTensor& out_grad,
const std::vector<int64_t>& dims, const std::vector<int64_t>& dims,
bool keep_dim, bool keep_dim,
bool reduce_all, bool reduce_all,
......
...@@ -149,7 +149,7 @@ KernelSignature ReduceMaxGradOpArgumentMapping( ...@@ -149,7 +149,7 @@ KernelSignature ReduceMaxGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature( return KernelSignature(
"max_grad", "max_grad",
{"X", GradVarName("Out"), "Out"}, {"X", "Out", GradVarName("Out")},
{"dim", "keep_dim", "reduce_all", "in_dtype", "out_dtype"}, {"dim", "keep_dim", "reduce_all", "in_dtype", "out_dtype"},
{GradVarName("X")}); {GradVarName("X")});
} }
...@@ -158,7 +158,7 @@ KernelSignature ReduceMinGradOpArgumentMapping( ...@@ -158,7 +158,7 @@ KernelSignature ReduceMinGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature( return KernelSignature(
"min_grad", "min_grad",
{"X", GradVarName("Out"), "Out"}, {"X", "Out", GradVarName("Out")},
{"dim", "keep_dim", "reduce_all", "in_dtype", "out_dtype"}, {"dim", "keep_dim", "reduce_all", "in_dtype", "out_dtype"},
{GradVarName("X")}); {GradVarName("X")});
} }
...@@ -167,7 +167,7 @@ KernelSignature ReduceProdGradOpArgumentMapping( ...@@ -167,7 +167,7 @@ KernelSignature ReduceProdGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature( return KernelSignature(
"prod_grad", "prod_grad",
{"X", GradVarName("Out"), "Out"}, {"X", "Out", GradVarName("Out")},
{"dim", "keep_dim", "reduce_all", "in_dtype", "out_dtype"}, {"dim", "keep_dim", "reduce_all", "in_dtype", "out_dtype"},
{GradVarName("X")}); {GradVarName("X")});
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册