diff --git a/paddle/fluid/operators/reduce_ops/reduce_op.h b/paddle/fluid/operators/reduce_ops/reduce_op.h index ea9b272878c517311b5b0480e3740e6223a97815..440fc1f7e3776dd9ab00c855e22536918949534c 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_op.h +++ b/paddle/fluid/operators/reduce_ops/reduce_op.h @@ -546,6 +546,25 @@ class ReduceOp : public framework::OperatorWithKernel { } return framework::OpKernelType(input_data_type, ctx.GetPlace()); } + + framework::KernelSignature GetExpectedPtenKernelArgs( + const framework::ExecutionContext& ctx) const override { + if (Type() == "reduce_sum") { + if (ctx.InputVar("X")->IsType()) { + return framework::KernelSignature( + "sum", {"X"}, {"dim", "keep_dim", "reduce_all", "out_dtype"}, + {"Out"}); + } + } + if (Type() == "reduce_mean") { + if (ctx.InputVar("X")->IsType()) { + return framework::KernelSignature( + "mean", {"X"}, {"dim", "keep_dim", "reduce_all"}, {"Out"}); + } + } + // TODO(chentianyu03): support other cases after selected rows added + return framework::KernelSignature("reduce.unregistered", {}, {}, {}); + } }; class ReduceOpUseInputPlace : public ReduceOp { diff --git a/paddle/pten/api/include/kernel_signature.h b/paddle/pten/api/include/kernel_signature.h index 7d92019d29e5383e5619d748ddc3c64ce9fa1e66..7b60ff12cf2ec1f729870d18afa6e81564166ec4 100644 --- a/paddle/pten/api/include/kernel_signature.h +++ b/paddle/pten/api/include/kernel_signature.h @@ -71,8 +71,6 @@ using mean_kernel = void (*)(const DeviceContext&, const std::vector&, bool, bool, - DataType, - DataType, DenseTensor*); using multiply_kernel = void (*)(const DeviceContext&, @@ -99,7 +97,6 @@ using sum_kernel = void (*)(const DeviceContext&, bool, bool, DataType, - DataType, DenseTensor*); using subtract_kernel = void (*)(const DeviceContext&, diff --git a/paddle/pten/include/math.h b/paddle/pten/include/math.h index 3872f663fed7800be161dfceaebf319532b4ab54..0dfa17234e3bfb24f1b7dfd56c6d0b44890ac0be 100644 --- a/paddle/pten/include/math.h +++ b/paddle/pten/include/math.h @@ -45,9 +45,7 @@ DenseTensor Mean(const ContextT& dev_ctx, dev_ctx.GetPlace()), std::move(out_meta)); bool reduce_all = false; - DataType out_dtype = pten::DataType::UNDEFINED; - Mean( - dev_ctx, x, axis, keep_dim, reduce_all, x.dtype(), out_dtype, &dense_out); + Mean(dev_ctx, x, axis, keep_dim, reduce_all, &dense_out); return dense_out; } @@ -57,7 +55,7 @@ DenseTensor Sum(const ContextT& dev_ctx, const std::vector& axis, DataType dtype, bool keep_dim) { - auto out_meta = ReduceInferMeta(x.meta(), axis, keep_dim); + auto out_meta = ReduceInferMeta(x.meta(), axis, keep_dim, dtype); pten::DenseTensor dense_out( pten::make_intrusive( dev_ctx.GetPlace()), @@ -67,12 +65,7 @@ DenseTensor Sum(const ContextT& dev_ctx, // so use default value(false) is OK. bool reduce_all = false; - if (x.dtype() == pten::DataType::BOOL || x.dtype() == pten::DataType::INT32 || - x.dtype() == pten::DataType::INT64) { - dtype = pten::DataType::INT64; - } - - Sum(dev_ctx, x, axis, keep_dim, reduce_all, x.dtype(), dtype, &dense_out); + Sum(dev_ctx, x, axis, keep_dim, reduce_all, out_meta.dtype, &dense_out); return dense_out; } diff --git a/paddle/pten/infermeta/unary.cc b/paddle/pten/infermeta/unary.cc index 4092e2842b9752cba95ab7b82b74733171a1db06..49d4a24e3a2c465cb4cdee2a2c729133b2f9ea4f 100644 --- a/paddle/pten/infermeta/unary.cc +++ b/paddle/pten/infermeta/unary.cc @@ -234,7 +234,8 @@ DenseTensorMeta ReshapeInferMeta(const DenseTensorMeta& x_meta, DenseTensorMeta ReduceInferMeta(const DenseTensorMeta& x_meta, const std::vector& axis, - bool keep_dim) { + bool keep_dim, + DataType dtype) { bool reduce_all = true; std::set dims_set(axis.begin(), axis.end()); for (int64_t i = 0; i < x_meta.dims.size(); ++i) { @@ -268,10 +269,16 @@ DenseTensorMeta ReduceInferMeta(const DenseTensorMeta& x_meta, } DDim out_dim = paddle::framework::make_ddim(out_dim_vector); - DataType out_dtype = x_meta.dtype; - if (x_meta.dtype == DataType::BOOL || x_meta.dtype == DataType::INT32 || - x_meta.dtype == DataType::INT64) { - out_dtype = DataType::INT64; + DataType out_dtype; + if (dtype != DataType::UNDEFINED) { + out_dtype = dtype; + } else { + if (x_meta.dtype == DataType::BOOL || x_meta.dtype == DataType::INT32 || + x_meta.dtype == DataType::INT64) { + out_dtype = DataType::INT64; + } else { + out_dtype = x_meta.dtype; + } } DenseTensorMeta return_meta(out_dtype, out_dim, x_meta.layout); diff --git a/paddle/pten/infermeta/unary.h b/paddle/pten/infermeta/unary.h index 408a77234f4b62d08933c9ac7ad33ca9e95814e4..3f28b2b48530ffe3d44098b204617fe4690939e9 100644 --- a/paddle/pten/infermeta/unary.h +++ b/paddle/pten/infermeta/unary.h @@ -56,5 +56,6 @@ DenseTensorMeta ReshapeInferMeta(const DenseTensorMeta& x_meta, DenseTensorMeta ReduceInferMeta(const DenseTensorMeta& x_meta, const std::vector& axis, - bool keep_dim); + bool keep_dim, + DataType dtype = DataType::UNDEFINED); } // namespace pten diff --git a/paddle/pten/kernels/cpu/math.cc b/paddle/pten/kernels/cpu/math.cc index 659a4d0e09686c7d05d372ecd9e8aaf5e5c13b03..861ecf2829feb5e8757610c887c634f1e70064ea 100644 --- a/paddle/pten/kernels/cpu/math.cc +++ b/paddle/pten/kernels/cpu/math.cc @@ -39,9 +39,8 @@ void Mean(const CPUContext& dev_ctx, const std::vector& dims, bool keep_dim, bool reduce_all, - DataType in_dtype, - DataType out_dtype, DenseTensor* out) { + auto out_dtype = x.dtype(); pten::general::Reduce( dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out); } @@ -76,7 +75,6 @@ void Sum(const CPUContext& dev_ctx, const std::vector& dims, bool keep_dim, bool reduce_all, - DataType in_dtype, DataType out_dtype, DenseTensor* out) { pten::general::Reduce( diff --git a/paddle/pten/kernels/cpu/math.h b/paddle/pten/kernels/cpu/math.h index 67a2feb4eef837b6aa89819e4e9d3b5c669d928c..61e361d37ab3d8947049f5a38fcff9873c2453d8 100644 --- a/paddle/pten/kernels/cpu/math.h +++ b/paddle/pten/kernels/cpu/math.h @@ -30,8 +30,6 @@ void Mean(const CPUContext& dev_ctx, const std::vector& dims, bool keep_dim, bool reduce_all, - DataType in_dtype, - DataType out_dtype, DenseTensor* out); template @@ -67,7 +65,6 @@ void Sum(const CPUContext& dev_ctx, const std::vector& dims, bool keep_dim, bool reduce_all, - DataType in_dtype, DataType out_dtype, DenseTensor* out); diff --git a/paddle/pten/kernels/cuda/math.cu b/paddle/pten/kernels/cuda/math.cu index 27d1ba1e043fe88f8edd0afd6e81f99ef41cffed..3dacc01e8b923b1334c3b848e1c2a5409574fd9d 100644 --- a/paddle/pten/kernels/cuda/math.cu +++ b/paddle/pten/kernels/cuda/math.cu @@ -68,9 +68,8 @@ void Mean(const CUDAContext& dev_ctx, const std::vector& dims, bool keep_dim, bool reduce_all, - DataType in_dtype, - DataType out_dtype, DenseTensor* out) { + auto out_dtype = x.dtype(); pten::Reduce( dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out); } @@ -90,7 +89,6 @@ void Sum(const CUDAContext& dev_ctx, const std::vector& dims, bool keep_dim, bool reduce_all, - DataType in_dtype, DataType out_dtype, DenseTensor* out) { pten::Reduce( diff --git a/paddle/pten/kernels/cuda/math.h b/paddle/pten/kernels/cuda/math.h index b2bca5e41185e09f2621a549b2f5da47e25f44f8..9cb379bcf7fadf30aa80d1f9113bb01fcf3947fb 100644 --- a/paddle/pten/kernels/cuda/math.h +++ b/paddle/pten/kernels/cuda/math.h @@ -32,8 +32,6 @@ void Mean(const CUDAContext& dev_ctx, const std::vector& dims, bool keep_dim, bool reduce_all, - DataType in_dtype, - DataType out_dtype, DenseTensor* out); template @@ -70,7 +68,6 @@ void Sum(const CUDAContext& dev_ctx, const std::vector& dims, bool keep_dim, bool reduce_all, - DataType in_dtype, DataType out_dtype, DenseTensor* out); diff --git a/paddle/pten/tests/api/test_sum_api.cc b/paddle/pten/tests/api/test_sum_api.cc index d1b7ea33e8b76d49bbe8901bf4cb69ee9a57f9be..ff1609d3d4051b92a713aad5464a9cac1a446452 100644 --- a/paddle/pten/tests/api/test_sum_api.cc +++ b/paddle/pten/tests/api/test_sum_api.cc @@ -50,7 +50,7 @@ TEST(API, sum) { std::vector axis = {0, 1}; // 2. test API - auto out = paddle::experimental::sum(x, axis, false); + auto out = paddle::experimental::sum(x, axis, DataType::UNDEFINED, false); // 3. check result ASSERT_EQ(out.dims().size(), 1); ASSERT_EQ(out.dims()[0], 1); diff --git a/python/paddle/utils/code_gen/api.yaml b/python/paddle/utils/code_gen/api.yaml index e0ea80feebeba59ab271f17871c4c144e7904c83..5a4ebb0179ca8164cefd7b8d442742d9e68a3e37 100644 --- a/python/paddle/utils/code_gen/api.yaml +++ b/python/paddle/utils/code_gen/api.yaml @@ -79,13 +79,14 @@ func : matmul - api : mean - args : (const Tensor& x, const std::vector& axis, bool keep_dim) + args : (const Tensor& x, const std::vector& axis={}, bool keep_dim=false) output : Tensor infer_meta : func : ReduceInferMeta + param: [x, axis, keep_dim] kernel : func : mean - param : [x, axis, keep_dim, false, x.dtype(), DataType::UNDEFINED] + param : [x, axis, keep_dim, false] - api : multiply args : (const Tensor& x, const Tensor& y) @@ -130,13 +131,15 @@ param : [x, y, -1] - api : sum - args : (const Tensor& x, const std::vector& axis, bool keep_dim) + args : (const Tensor& x, const std::vector& axis={}, DataType dtype=DataType::UNDEFINED, bool keep_dim=false) output : Tensor infer_meta : func : ReduceInferMeta + param: [x, axis, keep_dim, dtype] kernel : func : sum - param : [x, axis, keep_dim, false, x.dtype(), DataType::UNDEFINED] + param : [x, axis, keep_dim, false, DataType::UNDEFINED] + data_type : x - api : zeros_like args : (const Tensor& x, DataType dtype=DataType::UNDEFINED, Backend place=Backend::UNDEFINED, DataLayout layout=DataLayout::UNDEFINED)