未验证 提交 eaa2363e 编写于 作者: C chentianyu03 提交者: GitHub

[pten] modify reduce_sum reduce_mean args (#38216)

* modify sum mean args

* add GetExpectedPtenKernelArgs for redcue_op

* modify kernel args number

* modify kernel args number
上级 843435ff
......@@ -546,6 +546,25 @@ class ReduceOp : public framework::OperatorWithKernel {
}
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
framework::KernelSignature GetExpectedPtenKernelArgs(
const framework::ExecutionContext& ctx) const override {
if (Type() == "reduce_sum") {
if (ctx.InputVar("X")->IsType<framework::LoDTensor>()) {
return framework::KernelSignature(
"sum", {"X"}, {"dim", "keep_dim", "reduce_all", "out_dtype"},
{"Out"});
}
}
if (Type() == "reduce_mean") {
if (ctx.InputVar("X")->IsType<framework::LoDTensor>()) {
return framework::KernelSignature(
"mean", {"X"}, {"dim", "keep_dim", "reduce_all"}, {"Out"});
}
}
// TODO(chentianyu03): support other cases after selected rows added
return framework::KernelSignature("reduce.unregistered", {}, {}, {});
}
};
class ReduceOpUseInputPlace : public ReduceOp {
......
......@@ -71,8 +71,6 @@ using mean_kernel = void (*)(const DeviceContext&,
const std::vector<int64_t>&,
bool,
bool,
DataType,
DataType,
DenseTensor*);
using multiply_kernel = void (*)(const DeviceContext&,
......@@ -99,7 +97,6 @@ using sum_kernel = void (*)(const DeviceContext&,
bool,
bool,
DataType,
DataType,
DenseTensor*);
using subtract_kernel = void (*)(const DeviceContext&,
......
......@@ -45,9 +45,7 @@ DenseTensor Mean(const ContextT& dev_ctx,
dev_ctx.GetPlace()),
std::move(out_meta));
bool reduce_all = false;
DataType out_dtype = pten::DataType::UNDEFINED;
Mean<T>(
dev_ctx, x, axis, keep_dim, reduce_all, x.dtype(), out_dtype, &dense_out);
Mean<T>(dev_ctx, x, axis, keep_dim, reduce_all, &dense_out);
return dense_out;
}
......@@ -57,7 +55,7 @@ DenseTensor Sum(const ContextT& dev_ctx,
const std::vector<int64_t>& axis,
DataType dtype,
bool keep_dim) {
auto out_meta = ReduceInferMeta(x.meta(), axis, keep_dim);
auto out_meta = ReduceInferMeta(x.meta(), axis, keep_dim, dtype);
pten::DenseTensor dense_out(
pten::make_intrusive<paddle::experimental::SharedStorage>(
dev_ctx.GetPlace()),
......@@ -67,12 +65,7 @@ DenseTensor Sum(const ContextT& dev_ctx,
// so use default value(false) is OK.
bool reduce_all = false;
if (x.dtype() == pten::DataType::BOOL || x.dtype() == pten::DataType::INT32 ||
x.dtype() == pten::DataType::INT64) {
dtype = pten::DataType::INT64;
}
Sum<T>(dev_ctx, x, axis, keep_dim, reduce_all, x.dtype(), dtype, &dense_out);
Sum<T>(dev_ctx, x, axis, keep_dim, reduce_all, out_meta.dtype, &dense_out);
return dense_out;
}
......
......@@ -234,7 +234,8 @@ DenseTensorMeta ReshapeInferMeta(const DenseTensorMeta& x_meta,
DenseTensorMeta ReduceInferMeta(const DenseTensorMeta& x_meta,
const std::vector<int64_t>& axis,
bool keep_dim) {
bool keep_dim,
DataType dtype) {
bool reduce_all = true;
std::set<int64_t> dims_set(axis.begin(), axis.end());
for (int64_t i = 0; i < x_meta.dims.size(); ++i) {
......@@ -268,10 +269,16 @@ DenseTensorMeta ReduceInferMeta(const DenseTensorMeta& x_meta,
}
DDim out_dim = paddle::framework::make_ddim(out_dim_vector);
DataType out_dtype = x_meta.dtype;
if (x_meta.dtype == DataType::BOOL || x_meta.dtype == DataType::INT32 ||
x_meta.dtype == DataType::INT64) {
out_dtype = DataType::INT64;
DataType out_dtype;
if (dtype != DataType::UNDEFINED) {
out_dtype = dtype;
} else {
if (x_meta.dtype == DataType::BOOL || x_meta.dtype == DataType::INT32 ||
x_meta.dtype == DataType::INT64) {
out_dtype = DataType::INT64;
} else {
out_dtype = x_meta.dtype;
}
}
DenseTensorMeta return_meta(out_dtype, out_dim, x_meta.layout);
......
......@@ -56,5 +56,6 @@ DenseTensorMeta ReshapeInferMeta(const DenseTensorMeta& x_meta,
DenseTensorMeta ReduceInferMeta(const DenseTensorMeta& x_meta,
const std::vector<int64_t>& axis,
bool keep_dim);
bool keep_dim,
DataType dtype = DataType::UNDEFINED);
} // namespace pten
......@@ -39,9 +39,8 @@ void Mean(const CPUContext& dev_ctx,
const std::vector<int64_t>& dims,
bool keep_dim,
bool reduce_all,
DataType in_dtype,
DataType out_dtype,
DenseTensor* out) {
auto out_dtype = x.dtype();
pten::general::Reduce<CPUContext, T, pten::eigen::MeanFunctor>(
dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out);
}
......@@ -76,7 +75,6 @@ void Sum(const CPUContext& dev_ctx,
const std::vector<int64_t>& dims,
bool keep_dim,
bool reduce_all,
DataType in_dtype,
DataType out_dtype,
DenseTensor* out) {
pten::general::Reduce<CPUContext, T, pten::eigen::SumFunctor>(
......
......@@ -30,8 +30,6 @@ void Mean(const CPUContext& dev_ctx,
const std::vector<int64_t>& dims,
bool keep_dim,
bool reduce_all,
DataType in_dtype,
DataType out_dtype,
DenseTensor* out);
template <typename T>
......@@ -67,7 +65,6 @@ void Sum(const CPUContext& dev_ctx,
const std::vector<int64_t>& dims,
bool keep_dim,
bool reduce_all,
DataType in_dtype,
DataType out_dtype,
DenseTensor* out);
......
......@@ -68,9 +68,8 @@ void Mean(const CUDAContext& dev_ctx,
const std::vector<int64_t>& dims,
bool keep_dim,
bool reduce_all,
DataType in_dtype,
DataType out_dtype,
DenseTensor* out) {
auto out_dtype = x.dtype();
pten::Reduce<T, paddle::operators::CustomMean>(
dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out);
}
......@@ -90,7 +89,6 @@ void Sum(const CUDAContext& dev_ctx,
const std::vector<int64_t>& dims,
bool keep_dim,
bool reduce_all,
DataType in_dtype,
DataType out_dtype,
DenseTensor* out) {
pten::Reduce<T, paddle::operators::CustomSum>(
......
......@@ -32,8 +32,6 @@ void Mean(const CUDAContext& dev_ctx,
const std::vector<int64_t>& dims,
bool keep_dim,
bool reduce_all,
DataType in_dtype,
DataType out_dtype,
DenseTensor* out);
template <typename T>
......@@ -70,7 +68,6 @@ void Sum(const CUDAContext& dev_ctx,
const std::vector<int64_t>& dims,
bool keep_dim,
bool reduce_all,
DataType in_dtype,
DataType out_dtype,
DenseTensor* out);
......
......@@ -50,7 +50,7 @@ TEST(API, sum) {
std::vector<int64_t> axis = {0, 1};
// 2. test API
auto out = paddle::experimental::sum(x, axis, false);
auto out = paddle::experimental::sum(x, axis, DataType::UNDEFINED, false);
// 3. check result
ASSERT_EQ(out.dims().size(), 1);
ASSERT_EQ(out.dims()[0], 1);
......
......@@ -79,13 +79,14 @@
func : matmul
- api : mean
args : (const Tensor& x, const std::vector<int64_t>& axis, bool keep_dim)
args : (const Tensor& x, const std::vector<int64_t>& axis={}, bool keep_dim=false)
output : Tensor
infer_meta :
func : ReduceInferMeta
param: [x, axis, keep_dim]
kernel :
func : mean
param : [x, axis, keep_dim, false, x.dtype(), DataType::UNDEFINED]
param : [x, axis, keep_dim, false]
- api : multiply
args : (const Tensor& x, const Tensor& y)
......@@ -130,13 +131,15 @@
param : [x, y, -1]
- api : sum
args : (const Tensor& x, const std::vector<int64_t>& axis, bool keep_dim)
args : (const Tensor& x, const std::vector<int64_t>& axis={}, DataType dtype=DataType::UNDEFINED, bool keep_dim=false)
output : Tensor
infer_meta :
func : ReduceInferMeta
param: [x, axis, keep_dim, dtype]
kernel :
func : sum
param : [x, axis, keep_dim, false, x.dtype(), DataType::UNDEFINED]
param : [x, axis, keep_dim, false, DataType::UNDEFINED]
data_type : x
- api : zeros_like
args : (const Tensor& x, DataType dtype=DataType::UNDEFINED, Backend place=Backend::UNDEFINED, DataLayout layout=DataLayout::UNDEFINED)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册