未验证 提交 eaa2363e 编写于 作者: C chentianyu03 提交者: GitHub

[pten] modify reduce_sum reduce_mean args (#38216)

* modify sum mean args

* add GetExpectedPtenKernelArgs for redcue_op

* modify kernel args number

* modify kernel args number
上级 843435ff
...@@ -546,6 +546,25 @@ class ReduceOp : public framework::OperatorWithKernel { ...@@ -546,6 +546,25 @@ class ReduceOp : public framework::OperatorWithKernel {
} }
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
framework::KernelSignature GetExpectedPtenKernelArgs(
const framework::ExecutionContext& ctx) const override {
if (Type() == "reduce_sum") {
if (ctx.InputVar("X")->IsType<framework::LoDTensor>()) {
return framework::KernelSignature(
"sum", {"X"}, {"dim", "keep_dim", "reduce_all", "out_dtype"},
{"Out"});
}
}
if (Type() == "reduce_mean") {
if (ctx.InputVar("X")->IsType<framework::LoDTensor>()) {
return framework::KernelSignature(
"mean", {"X"}, {"dim", "keep_dim", "reduce_all"}, {"Out"});
}
}
// TODO(chentianyu03): support other cases after selected rows added
return framework::KernelSignature("reduce.unregistered", {}, {}, {});
}
}; };
class ReduceOpUseInputPlace : public ReduceOp { class ReduceOpUseInputPlace : public ReduceOp {
......
...@@ -71,8 +71,6 @@ using mean_kernel = void (*)(const DeviceContext&, ...@@ -71,8 +71,6 @@ using mean_kernel = void (*)(const DeviceContext&,
const std::vector<int64_t>&, const std::vector<int64_t>&,
bool, bool,
bool, bool,
DataType,
DataType,
DenseTensor*); DenseTensor*);
using multiply_kernel = void (*)(const DeviceContext&, using multiply_kernel = void (*)(const DeviceContext&,
...@@ -99,7 +97,6 @@ using sum_kernel = void (*)(const DeviceContext&, ...@@ -99,7 +97,6 @@ using sum_kernel = void (*)(const DeviceContext&,
bool, bool,
bool, bool,
DataType, DataType,
DataType,
DenseTensor*); DenseTensor*);
using subtract_kernel = void (*)(const DeviceContext&, using subtract_kernel = void (*)(const DeviceContext&,
......
...@@ -45,9 +45,7 @@ DenseTensor Mean(const ContextT& dev_ctx, ...@@ -45,9 +45,7 @@ DenseTensor Mean(const ContextT& dev_ctx,
dev_ctx.GetPlace()), dev_ctx.GetPlace()),
std::move(out_meta)); std::move(out_meta));
bool reduce_all = false; bool reduce_all = false;
DataType out_dtype = pten::DataType::UNDEFINED; Mean<T>(dev_ctx, x, axis, keep_dim, reduce_all, &dense_out);
Mean<T>(
dev_ctx, x, axis, keep_dim, reduce_all, x.dtype(), out_dtype, &dense_out);
return dense_out; return dense_out;
} }
...@@ -57,7 +55,7 @@ DenseTensor Sum(const ContextT& dev_ctx, ...@@ -57,7 +55,7 @@ DenseTensor Sum(const ContextT& dev_ctx,
const std::vector<int64_t>& axis, const std::vector<int64_t>& axis,
DataType dtype, DataType dtype,
bool keep_dim) { bool keep_dim) {
auto out_meta = ReduceInferMeta(x.meta(), axis, keep_dim); auto out_meta = ReduceInferMeta(x.meta(), axis, keep_dim, dtype);
pten::DenseTensor dense_out( pten::DenseTensor dense_out(
pten::make_intrusive<paddle::experimental::SharedStorage>( pten::make_intrusive<paddle::experimental::SharedStorage>(
dev_ctx.GetPlace()), dev_ctx.GetPlace()),
...@@ -67,12 +65,7 @@ DenseTensor Sum(const ContextT& dev_ctx, ...@@ -67,12 +65,7 @@ DenseTensor Sum(const ContextT& dev_ctx,
// so use default value(false) is OK. // so use default value(false) is OK.
bool reduce_all = false; bool reduce_all = false;
if (x.dtype() == pten::DataType::BOOL || x.dtype() == pten::DataType::INT32 || Sum<T>(dev_ctx, x, axis, keep_dim, reduce_all, out_meta.dtype, &dense_out);
x.dtype() == pten::DataType::INT64) {
dtype = pten::DataType::INT64;
}
Sum<T>(dev_ctx, x, axis, keep_dim, reduce_all, x.dtype(), dtype, &dense_out);
return dense_out; return dense_out;
} }
......
...@@ -234,7 +234,8 @@ DenseTensorMeta ReshapeInferMeta(const DenseTensorMeta& x_meta, ...@@ -234,7 +234,8 @@ DenseTensorMeta ReshapeInferMeta(const DenseTensorMeta& x_meta,
DenseTensorMeta ReduceInferMeta(const DenseTensorMeta& x_meta, DenseTensorMeta ReduceInferMeta(const DenseTensorMeta& x_meta,
const std::vector<int64_t>& axis, const std::vector<int64_t>& axis,
bool keep_dim) { bool keep_dim,
DataType dtype) {
bool reduce_all = true; bool reduce_all = true;
std::set<int64_t> dims_set(axis.begin(), axis.end()); std::set<int64_t> dims_set(axis.begin(), axis.end());
for (int64_t i = 0; i < x_meta.dims.size(); ++i) { for (int64_t i = 0; i < x_meta.dims.size(); ++i) {
...@@ -268,10 +269,16 @@ DenseTensorMeta ReduceInferMeta(const DenseTensorMeta& x_meta, ...@@ -268,10 +269,16 @@ DenseTensorMeta ReduceInferMeta(const DenseTensorMeta& x_meta,
} }
DDim out_dim = paddle::framework::make_ddim(out_dim_vector); DDim out_dim = paddle::framework::make_ddim(out_dim_vector);
DataType out_dtype = x_meta.dtype; DataType out_dtype;
if (x_meta.dtype == DataType::BOOL || x_meta.dtype == DataType::INT32 || if (dtype != DataType::UNDEFINED) {
x_meta.dtype == DataType::INT64) { out_dtype = dtype;
out_dtype = DataType::INT64; } else {
if (x_meta.dtype == DataType::BOOL || x_meta.dtype == DataType::INT32 ||
x_meta.dtype == DataType::INT64) {
out_dtype = DataType::INT64;
} else {
out_dtype = x_meta.dtype;
}
} }
DenseTensorMeta return_meta(out_dtype, out_dim, x_meta.layout); DenseTensorMeta return_meta(out_dtype, out_dim, x_meta.layout);
......
...@@ -56,5 +56,6 @@ DenseTensorMeta ReshapeInferMeta(const DenseTensorMeta& x_meta, ...@@ -56,5 +56,6 @@ DenseTensorMeta ReshapeInferMeta(const DenseTensorMeta& x_meta,
DenseTensorMeta ReduceInferMeta(const DenseTensorMeta& x_meta, DenseTensorMeta ReduceInferMeta(const DenseTensorMeta& x_meta,
const std::vector<int64_t>& axis, const std::vector<int64_t>& axis,
bool keep_dim); bool keep_dim,
DataType dtype = DataType::UNDEFINED);
} // namespace pten } // namespace pten
...@@ -39,9 +39,8 @@ void Mean(const CPUContext& dev_ctx, ...@@ -39,9 +39,8 @@ void Mean(const CPUContext& dev_ctx,
const std::vector<int64_t>& dims, const std::vector<int64_t>& dims,
bool keep_dim, bool keep_dim,
bool reduce_all, bool reduce_all,
DataType in_dtype,
DataType out_dtype,
DenseTensor* out) { DenseTensor* out) {
auto out_dtype = x.dtype();
pten::general::Reduce<CPUContext, T, pten::eigen::MeanFunctor>( pten::general::Reduce<CPUContext, T, pten::eigen::MeanFunctor>(
dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out); dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out);
} }
...@@ -76,7 +75,6 @@ void Sum(const CPUContext& dev_ctx, ...@@ -76,7 +75,6 @@ void Sum(const CPUContext& dev_ctx,
const std::vector<int64_t>& dims, const std::vector<int64_t>& dims,
bool keep_dim, bool keep_dim,
bool reduce_all, bool reduce_all,
DataType in_dtype,
DataType out_dtype, DataType out_dtype,
DenseTensor* out) { DenseTensor* out) {
pten::general::Reduce<CPUContext, T, pten::eigen::SumFunctor>( pten::general::Reduce<CPUContext, T, pten::eigen::SumFunctor>(
......
...@@ -30,8 +30,6 @@ void Mean(const CPUContext& dev_ctx, ...@@ -30,8 +30,6 @@ void Mean(const CPUContext& dev_ctx,
const std::vector<int64_t>& dims, const std::vector<int64_t>& dims,
bool keep_dim, bool keep_dim,
bool reduce_all, bool reduce_all,
DataType in_dtype,
DataType out_dtype,
DenseTensor* out); DenseTensor* out);
template <typename T> template <typename T>
...@@ -67,7 +65,6 @@ void Sum(const CPUContext& dev_ctx, ...@@ -67,7 +65,6 @@ void Sum(const CPUContext& dev_ctx,
const std::vector<int64_t>& dims, const std::vector<int64_t>& dims,
bool keep_dim, bool keep_dim,
bool reduce_all, bool reduce_all,
DataType in_dtype,
DataType out_dtype, DataType out_dtype,
DenseTensor* out); DenseTensor* out);
......
...@@ -68,9 +68,8 @@ void Mean(const CUDAContext& dev_ctx, ...@@ -68,9 +68,8 @@ void Mean(const CUDAContext& dev_ctx,
const std::vector<int64_t>& dims, const std::vector<int64_t>& dims,
bool keep_dim, bool keep_dim,
bool reduce_all, bool reduce_all,
DataType in_dtype,
DataType out_dtype,
DenseTensor* out) { DenseTensor* out) {
auto out_dtype = x.dtype();
pten::Reduce<T, paddle::operators::CustomMean>( pten::Reduce<T, paddle::operators::CustomMean>(
dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out); dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out);
} }
...@@ -90,7 +89,6 @@ void Sum(const CUDAContext& dev_ctx, ...@@ -90,7 +89,6 @@ void Sum(const CUDAContext& dev_ctx,
const std::vector<int64_t>& dims, const std::vector<int64_t>& dims,
bool keep_dim, bool keep_dim,
bool reduce_all, bool reduce_all,
DataType in_dtype,
DataType out_dtype, DataType out_dtype,
DenseTensor* out) { DenseTensor* out) {
pten::Reduce<T, paddle::operators::CustomSum>( pten::Reduce<T, paddle::operators::CustomSum>(
......
...@@ -32,8 +32,6 @@ void Mean(const CUDAContext& dev_ctx, ...@@ -32,8 +32,6 @@ void Mean(const CUDAContext& dev_ctx,
const std::vector<int64_t>& dims, const std::vector<int64_t>& dims,
bool keep_dim, bool keep_dim,
bool reduce_all, bool reduce_all,
DataType in_dtype,
DataType out_dtype,
DenseTensor* out); DenseTensor* out);
template <typename T> template <typename T>
...@@ -70,7 +68,6 @@ void Sum(const CUDAContext& dev_ctx, ...@@ -70,7 +68,6 @@ void Sum(const CUDAContext& dev_ctx,
const std::vector<int64_t>& dims, const std::vector<int64_t>& dims,
bool keep_dim, bool keep_dim,
bool reduce_all, bool reduce_all,
DataType in_dtype,
DataType out_dtype, DataType out_dtype,
DenseTensor* out); DenseTensor* out);
......
...@@ -50,7 +50,7 @@ TEST(API, sum) { ...@@ -50,7 +50,7 @@ TEST(API, sum) {
std::vector<int64_t> axis = {0, 1}; std::vector<int64_t> axis = {0, 1};
// 2. test API // 2. test API
auto out = paddle::experimental::sum(x, axis, false); auto out = paddle::experimental::sum(x, axis, DataType::UNDEFINED, false);
// 3. check result // 3. check result
ASSERT_EQ(out.dims().size(), 1); ASSERT_EQ(out.dims().size(), 1);
ASSERT_EQ(out.dims()[0], 1); ASSERT_EQ(out.dims()[0], 1);
......
...@@ -79,13 +79,14 @@ ...@@ -79,13 +79,14 @@
func : matmul func : matmul
- api : mean - api : mean
args : (const Tensor& x, const std::vector<int64_t>& axis, bool keep_dim) args : (const Tensor& x, const std::vector<int64_t>& axis={}, bool keep_dim=false)
output : Tensor output : Tensor
infer_meta : infer_meta :
func : ReduceInferMeta func : ReduceInferMeta
param: [x, axis, keep_dim]
kernel : kernel :
func : mean func : mean
param : [x, axis, keep_dim, false, x.dtype(), DataType::UNDEFINED] param : [x, axis, keep_dim, false]
- api : multiply - api : multiply
args : (const Tensor& x, const Tensor& y) args : (const Tensor& x, const Tensor& y)
...@@ -130,13 +131,15 @@ ...@@ -130,13 +131,15 @@
param : [x, y, -1] param : [x, y, -1]
- api : sum - api : sum
args : (const Tensor& x, const std::vector<int64_t>& axis, bool keep_dim) args : (const Tensor& x, const std::vector<int64_t>& axis={}, DataType dtype=DataType::UNDEFINED, bool keep_dim=false)
output : Tensor output : Tensor
infer_meta : infer_meta :
func : ReduceInferMeta func : ReduceInferMeta
param: [x, axis, keep_dim, dtype]
kernel : kernel :
func : sum func : sum
param : [x, axis, keep_dim, false, x.dtype(), DataType::UNDEFINED] param : [x, axis, keep_dim, false, DataType::UNDEFINED]
data_type : x
- api : zeros_like - api : zeros_like
args : (const Tensor& x, DataType dtype=DataType::UNDEFINED, Backend place=Backend::UNDEFINED, DataLayout layout=DataLayout::UNDEFINED) args : (const Tensor& x, DataType dtype=DataType::UNDEFINED, Backend place=Backend::UNDEFINED, DataLayout layout=DataLayout::UNDEFINED)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册