未验证 提交 5756d3e5 编写于 作者: C chentianyu03 提交者: GitHub

modify to complex template types in reduce_sum OP and rewrite it's IdentityFunctor struct (#33164)

上级 481ee79f
...@@ -366,33 +366,32 @@ void TensorReduce(const framework::Tensor& x, framework::Tensor* y, ...@@ -366,33 +366,32 @@ void TensorReduce(const framework::Tensor& x, framework::Tensor* y,
#undef CUB_BLOCK_DIM_CASE #undef CUB_BLOCK_DIM_CASE
} }
template <typename Tx, typename ReduceOp, typename TransformOp> template <typename Tx, typename ReduceOp,
template <typename, typename> class TransformOp>
struct TensorReduceFunctor { struct TensorReduceFunctor {
const framework::Tensor& x; const framework::Tensor& x;
framework::Tensor* y; framework::Tensor* y;
std::vector<int> origin_reduce_dims; std::vector<int> origin_reduce_dims;
const double& init; const double& init;
const ReduceOp& reducer; const ReduceOp& reducer;
const TransformOp& transformer;
gpuStream_t stream; gpuStream_t stream;
TensorReduceFunctor(const framework::Tensor& x, framework::Tensor* y, TensorReduceFunctor(const framework::Tensor& x, framework::Tensor* y,
std::vector<int> origin_reduce_dims, const double& init, std::vector<int> origin_reduce_dims, const double& init,
const ReduceOp& reducer, const TransformOp& transformer, const ReduceOp& reducer, gpuStream_t stream)
gpuStream_t stream)
: x(x), : x(x),
y(y), y(y),
origin_reduce_dims(origin_reduce_dims), origin_reduce_dims(origin_reduce_dims),
init(init), init(init),
reducer(reducer), reducer(reducer),
transformer(transformer),
stream(stream) {} stream(stream) {}
template <typename Ty> template <typename Ty>
void apply() const { void apply() const {
const Ty& init_cast = static_cast<Ty>(init); const Ty& init_cast = static_cast<Ty>(init);
TensorReduce<Tx, Ty, ReduceOp, TransformOp>( TensorReduce<Tx, Ty, ReduceOp, TransformOp<Tx, Ty>>(
x, y, origin_reduce_dims, init_cast, reducer, transformer, stream); x, y, origin_reduce_dims, init_cast, reducer, TransformOp<Tx, Ty>(),
stream);
} }
}; };
......
...@@ -119,9 +119,9 @@ REGISTER_OP_CPU_KERNEL( ...@@ -119,9 +119,9 @@ REGISTER_OP_CPU_KERNEL(
ops::ReduceKernel<paddle::platform::CPUDeviceContext, int64_t, ops::ReduceKernel<paddle::platform::CPUDeviceContext, int64_t,
ops::SumFunctor>, ops::SumFunctor>,
ops::ReduceKernel<paddle::platform::CPUDeviceContext, ops::ReduceKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64, ops::SumFunctor>, paddle::platform::complex<float>, ops::SumFunctor>,
ops::ReduceKernel<paddle::platform::CPUDeviceContext, ops::ReduceKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128, paddle::platform::complex<double>,
ops::SumFunctor>); ops::SumFunctor>);
...@@ -130,10 +130,9 @@ using CPUReduceSumGradKernel = ...@@ -130,10 +130,9 @@ using CPUReduceSumGradKernel =
ops::ReduceSumGradKernel<paddle::platform::CPUDeviceContext, T, ops::ReduceSumGradKernel<paddle::platform::CPUDeviceContext, T,
ops::SumGradFunctor, true>; ops::SumGradFunctor, true>;
REGISTER_OP_CPU_KERNEL(reduce_sum_grad, CPUReduceSumGradKernel<bool>, REGISTER_OP_CPU_KERNEL(
CPUReduceSumGradKernel<float>, reduce_sum_grad, CPUReduceSumGradKernel<bool>,
CPUReduceSumGradKernel<double>, CPUReduceSumGradKernel<float>, CPUReduceSumGradKernel<double>,
CPUReduceSumGradKernel<int>, CPUReduceSumGradKernel<int>, CPUReduceSumGradKernel<int64_t>,
CPUReduceSumGradKernel<int64_t>, CPUReduceSumGradKernel<paddle::platform::complex<float>>,
CPUReduceSumGradKernel<paddle::platform::complex64>, CPUReduceSumGradKernel<paddle::platform::complex<double>>);
CPUReduceSumGradKernel<paddle::platform::complex128>);
...@@ -18,11 +18,13 @@ ...@@ -18,11 +18,13 @@
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename T> template <typename Tx, typename Ty = Tx>
struct IdentityFunctor { struct IdentityFunctor {
HOSTDEVICE explicit inline IdentityFunctor() {} HOSTDEVICE explicit inline IdentityFunctor() {}
HOSTDEVICE inline T operator()(const T& x) const { return x; } HOSTDEVICE inline Ty operator()(const Tx& x) const {
return static_cast<Ty>(x);
}
}; };
template <typename T> template <typename T>
...@@ -56,13 +58,13 @@ class ReduceSumKernel : public framework::OpKernel<T> { ...@@ -56,13 +58,13 @@ class ReduceSumKernel : public framework::OpKernel<T> {
if (out_dtype >= 0) { if (out_dtype >= 0) {
framework::VisitDataTypeSmall( framework::VisitDataTypeSmall(
static_cast<framework::proto::VarType::Type>(out_dtype), static_cast<framework::proto::VarType::Type>(out_dtype),
TensorReduceFunctor<T, cub::Sum, IdentityFunctor<T>>( TensorReduceFunctor<T, cub::Sum, IdentityFunctor>(
*input, output, reduce_dims, static_cast<double>(0.0), cub::Sum(), *input, output, reduce_dims, static_cast<double>(0.0), cub::Sum(),
IdentityFunctor<T>(), stream)); stream));
} else { } else {
TensorReduce<T, T, cub::Sum, IdentityFunctor<T>>( TensorReduce<T, T, cub::Sum, IdentityFunctor<T, T>>(
*input, output, reduce_dims, static_cast<T>(0), cub::Sum(), *input, output, reduce_dims, static_cast<T>(0), cub::Sum(),
IdentityFunctor<T>(), stream); IdentityFunctor<T, T>(), stream);
} }
} }
}; };
...@@ -70,9 +72,9 @@ class ReduceSumKernel : public framework::OpKernel<T> { ...@@ -70,9 +72,9 @@ class ReduceSumKernel : public framework::OpKernel<T> {
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
REGISTER_OP_CUDA_KERNEL(reduce_sum, ops::ReduceSumKernel<bool>, REGISTER_OP_CUDA_KERNEL(
ops::ReduceSumKernel<float>, reduce_sum, ops::ReduceSumKernel<bool>, ops::ReduceSumKernel<float>,
ops::ReduceSumKernel<double>, ops::ReduceSumKernel<int>, ops::ReduceSumKernel<double>, ops::ReduceSumKernel<int>,
ops::ReduceSumKernel<int64_t>, ops::ReduceSumKernel<int64_t>,
ops::ReduceSumKernel<paddle::platform::complex64>, ops::ReduceSumKernel<paddle::platform::complex<float>>,
ops::ReduceSumKernel<paddle::platform::complex128>); ops::ReduceSumKernel<paddle::platform::complex<double>>);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册