未验证 提交 5756d3e5 编写于 作者: C chentianyu03 提交者: GitHub

modify to complex template types in reduce_sum OP and rewrite it's IdentityFunctor struct (#33164)

上级 481ee79f
......@@ -366,33 +366,32 @@ void TensorReduce(const framework::Tensor& x, framework::Tensor* y,
#undef CUB_BLOCK_DIM_CASE
}
template <typename Tx, typename ReduceOp, typename TransformOp>
template <typename Tx, typename ReduceOp,
template <typename, typename> class TransformOp>
struct TensorReduceFunctor {
const framework::Tensor& x;
framework::Tensor* y;
std::vector<int> origin_reduce_dims;
const double& init;
const ReduceOp& reducer;
const TransformOp& transformer;
gpuStream_t stream;
TensorReduceFunctor(const framework::Tensor& x, framework::Tensor* y,
std::vector<int> origin_reduce_dims, const double& init,
const ReduceOp& reducer, const TransformOp& transformer,
gpuStream_t stream)
const ReduceOp& reducer, gpuStream_t stream)
: x(x),
y(y),
origin_reduce_dims(origin_reduce_dims),
init(init),
reducer(reducer),
transformer(transformer),
stream(stream) {}
template <typename Ty>
void apply() const {
const Ty& init_cast = static_cast<Ty>(init);
TensorReduce<Tx, Ty, ReduceOp, TransformOp>(
x, y, origin_reduce_dims, init_cast, reducer, transformer, stream);
TensorReduce<Tx, Ty, ReduceOp, TransformOp<Tx, Ty>>(
x, y, origin_reduce_dims, init_cast, reducer, TransformOp<Tx, Ty>(),
stream);
}
};
......
......@@ -119,9 +119,9 @@ REGISTER_OP_CPU_KERNEL(
ops::ReduceKernel<paddle::platform::CPUDeviceContext, int64_t,
ops::SumFunctor>,
ops::ReduceKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64, ops::SumFunctor>,
paddle::platform::complex<float>, ops::SumFunctor>,
ops::ReduceKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128,
paddle::platform::complex<double>,
ops::SumFunctor>);
......@@ -130,10 +130,9 @@ using CPUReduceSumGradKernel =
ops::ReduceSumGradKernel<paddle::platform::CPUDeviceContext, T,
ops::SumGradFunctor, true>;
REGISTER_OP_CPU_KERNEL(reduce_sum_grad, CPUReduceSumGradKernel<bool>,
CPUReduceSumGradKernel<float>,
CPUReduceSumGradKernel<double>,
CPUReduceSumGradKernel<int>,
CPUReduceSumGradKernel<int64_t>,
CPUReduceSumGradKernel<paddle::platform::complex64>,
CPUReduceSumGradKernel<paddle::platform::complex128>);
REGISTER_OP_CPU_KERNEL(
reduce_sum_grad, CPUReduceSumGradKernel<bool>,
CPUReduceSumGradKernel<float>, CPUReduceSumGradKernel<double>,
CPUReduceSumGradKernel<int>, CPUReduceSumGradKernel<int64_t>,
CPUReduceSumGradKernel<paddle::platform::complex<float>>,
CPUReduceSumGradKernel<paddle::platform::complex<double>>);
......@@ -18,11 +18,13 @@
namespace paddle {
namespace operators {
template <typename T>
template <typename Tx, typename Ty = Tx>
struct IdentityFunctor {
HOSTDEVICE explicit inline IdentityFunctor() {}
HOSTDEVICE inline T operator()(const T& x) const { return x; }
HOSTDEVICE inline Ty operator()(const Tx& x) const {
return static_cast<Ty>(x);
}
};
template <typename T>
......@@ -56,13 +58,13 @@ class ReduceSumKernel : public framework::OpKernel<T> {
if (out_dtype >= 0) {
framework::VisitDataTypeSmall(
static_cast<framework::proto::VarType::Type>(out_dtype),
TensorReduceFunctor<T, cub::Sum, IdentityFunctor<T>>(
TensorReduceFunctor<T, cub::Sum, IdentityFunctor>(
*input, output, reduce_dims, static_cast<double>(0.0), cub::Sum(),
IdentityFunctor<T>(), stream));
stream));
} else {
TensorReduce<T, T, cub::Sum, IdentityFunctor<T>>(
TensorReduce<T, T, cub::Sum, IdentityFunctor<T, T>>(
*input, output, reduce_dims, static_cast<T>(0), cub::Sum(),
IdentityFunctor<T>(), stream);
IdentityFunctor<T, T>(), stream);
}
}
};
......@@ -70,9 +72,9 @@ class ReduceSumKernel : public framework::OpKernel<T> {
} // namespace operators
} // namespace paddle
REGISTER_OP_CUDA_KERNEL(reduce_sum, ops::ReduceSumKernel<bool>,
ops::ReduceSumKernel<float>,
REGISTER_OP_CUDA_KERNEL(
reduce_sum, ops::ReduceSumKernel<bool>, ops::ReduceSumKernel<float>,
ops::ReduceSumKernel<double>, ops::ReduceSumKernel<int>,
ops::ReduceSumKernel<int64_t>,
ops::ReduceSumKernel<paddle::platform::complex64>,
ops::ReduceSumKernel<paddle::platform::complex128>);
ops::ReduceSumKernel<paddle::platform::complex<float>>,
ops::ReduceSumKernel<paddle::platform::complex<double>>);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册