未验证 提交 524389ee 编写于 作者: N niuliling123 提交者: GitHub

Add the transformop parameter in TensorReduceFunctorImpl (#38135)

* Add the transformop parameter in TensorReduceFunctorImpl
上级 be874c08
...@@ -18,29 +18,6 @@ limitations under the License. */ ...@@ -18,29 +18,6 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
template <typename Tx, typename Ty = Tx>
struct SquareTransformer {
HOSTDEVICE explicit inline SquareTransformer(int n) {}
HOSTDEVICE inline Ty operator()(const Tx& x) const {
return static_cast<Ty>(x) * static_cast<Ty>(x);
}
HOSTDEVICE inline Ty operator()(const Tx* x) const {
return static_cast<Ty>(x[0]) * static_cast<Ty>(x[0]);
}
};
template <typename Tx, typename Ty = Tx>
struct SquareSum {
using Transformer = SquareTransformer<Tx, Ty>;
inline Ty initial() { return static_cast<Ty>(0.0f); }
__device__ __forceinline__ Ty operator()(const Ty& a, const Ty& b) const {
return b + a;
}
};
template <> template <>
class ClipByNormKernel<platform::CUDADeviceContext, platform::float16> class ClipByNormKernel<platform::CUDADeviceContext, platform::float16>
...@@ -97,8 +74,10 @@ class ClipByNormKernel<platform::CUDADeviceContext, platform::float16> ...@@ -97,8 +74,10 @@ class ClipByNormKernel<platform::CUDADeviceContext, platform::float16>
} }
Tensor tmp = context.AllocateTmpTensor<float, platform::CUDADeviceContext>( Tensor tmp = context.AllocateTmpTensor<float, platform::CUDADeviceContext>(
{1}, dev_ctx); {1}, dev_ctx);
TensorReduceFunctorImpl<platform::float16, float, SquareSum>( TensorReduceFunctorImpl<platform::float16, float, kps::AddFunctor,
*input, &tmp, reduce_dims, dev_ctx.stream()); kps::SquareFunctor<platform::float16, float>>(
*input, &tmp, kps::SquareFunctor<platform::float16, float>(),
reduce_dims, dev_ctx.stream());
auto tmp_eigen = EigenVector<float>::Flatten(tmp); auto tmp_eigen = EigenVector<float>::Flatten(tmp);
auto x_norm = tmp_eigen.sqrt(); auto x_norm = tmp_eigen.sqrt();
......
...@@ -15,7 +15,6 @@ limitations under the License. */ ...@@ -15,7 +15,6 @@ limitations under the License. */
#include "paddle/fluid/framework/pten_utils.h" #include "paddle/fluid/framework/pten_utils.h"
#include "paddle/fluid/operators/elementwise/elementwise_add_op.h" #include "paddle/fluid/operators/elementwise/elementwise_add_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" #include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h" #include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#include "paddle/fluid/platform/complex.h" #include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
...@@ -91,7 +90,8 @@ default_elementwise_add_grad(const framework::ExecutionContext& ctx, ...@@ -91,7 +90,8 @@ default_elementwise_add_grad(const framework::ExecutionContext& ctx,
} }
std::vector<int> reduce_dims = GetReduceDim(x->dims(), out->dims(), axis); std::vector<int> reduce_dims = GetReduceDim(x->dims(), out->dims(), axis);
gpuStream_t stream = ctx.cuda_device_context().stream(); gpuStream_t stream = ctx.cuda_device_context().stream();
TensorReduceFunctorImpl<T, T, CustomSum>(*dout, dx, reduce_dims, stream); TensorReduceFunctorImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
*dout, dx, kps::IdentityFunctor<T>(), reduce_dims, stream);
} }
} }
// dy // dy
...@@ -106,7 +106,8 @@ default_elementwise_add_grad(const framework::ExecutionContext& ctx, ...@@ -106,7 +106,8 @@ default_elementwise_add_grad(const framework::ExecutionContext& ctx,
} else { } else {
std::vector<int> reduce_dims = GetReduceDim(y->dims(), out->dims(), axis); std::vector<int> reduce_dims = GetReduceDim(y->dims(), out->dims(), axis);
gpuStream_t stream = ctx.cuda_device_context().stream(); gpuStream_t stream = ctx.cuda_device_context().stream();
TensorReduceFunctorImpl<T, T, CustomSum>(*dout, dy, reduce_dims, stream); TensorReduceFunctorImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
*dout, dy, kps::IdentityFunctor<T>(), reduce_dims, stream);
} }
} }
} }
......
...@@ -14,7 +14,6 @@ limitations under the License. */ ...@@ -14,7 +14,6 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" #include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/elementwise/elementwise_sub_op.h" #include "paddle/fluid/operators/elementwise/elementwise_sub_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h" #include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#include "paddle/fluid/platform/complex.h" #include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
...@@ -69,7 +68,8 @@ default_elementwise_sub_grad(const framework::ExecutionContext& ctx, ...@@ -69,7 +68,8 @@ default_elementwise_sub_grad(const framework::ExecutionContext& ctx,
} }
std::vector<int> reduce_dims = GetReduceDim(x->dims(), out->dims(), axis); std::vector<int> reduce_dims = GetReduceDim(x->dims(), out->dims(), axis);
gpuStream_t stream = ctx.cuda_device_context().stream(); gpuStream_t stream = ctx.cuda_device_context().stream();
TensorReduceFunctorImpl<T, T, CustomSum>(*dout, dx, reduce_dims, stream); TensorReduceFunctorImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
*dout, dx, kps::IdentityFunctor<T>(), reduce_dims, stream);
} }
} }
// dy // dy
...@@ -90,7 +90,8 @@ default_elementwise_sub_grad(const framework::ExecutionContext& ctx, ...@@ -90,7 +90,8 @@ default_elementwise_sub_grad(const framework::ExecutionContext& ctx,
} else { } else {
std::vector<int> reduce_dims = GetReduceDim(y->dims(), out->dims(), axis); std::vector<int> reduce_dims = GetReduceDim(y->dims(), out->dims(), axis);
gpuStream_t stream = ctx.cuda_device_context().stream(); gpuStream_t stream = ctx.cuda_device_context().stream();
TensorReduceFunctorImpl<T, T, CustomSub>(*dout, dy, reduce_dims, stream); TensorReduceFunctorImpl<T, T, kps::AddFunctor, kps::InverseFunctor<T>>(
*dout, dy, kps::InverseFunctor<T>(), reduce_dims, stream);
} }
} }
} }
......
...@@ -16,11 +16,10 @@ limitations under the License. */ ...@@ -16,11 +16,10 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" #include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h" #include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h"
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h" #include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
// support gemm-nt and gemm-nn, which is used in fused_attention_op. // support gemm-nt and gemm-nn, which is used in fused_attention_op.
template <typename T> template <typename T>
class AttnMatMul { class AttnMatMul {
...@@ -165,8 +164,8 @@ class AttnMatMul { ...@@ -165,8 +164,8 @@ class AttnMatMul {
(input_dims[2] == output_dims[0])); (input_dims[2] == output_dims[0]));
if (support_case_1 || support_case_2) { if (support_case_1 || support_case_2) {
gpuStream_t stream = dev_ctx_.stream(); gpuStream_t stream = dev_ctx_.stream();
TensorReduceFunctorImpl<T, T, CustomSum>(*d_output, d_bias, {0, 1}, TensorReduceFunctorImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
stream); *d_output, d_bias, kps::IdentityFunctor<T>(), {0, 1}, stream);
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"Only support reduce when the input dims are [0,1,2,3,4] and " "Only support reduce when the input dims are [0,1,2,3,4] and "
......
...@@ -24,7 +24,6 @@ namespace cub = hipcub; ...@@ -24,7 +24,6 @@ namespace cub = hipcub;
#include "paddle/fluid/operators/margin_cross_entropy_op.h" #include "paddle/fluid/operators/margin_cross_entropy_op.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/softmax_impl.h" #include "paddle/fluid/operators/math/softmax_impl.h"
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.h" #include "paddle/fluid/operators/reduce_ops/reduce_op.h"
#include "paddle/fluid/string/string_helper.h" #include "paddle/fluid/string/string_helper.h"
...@@ -128,17 +127,6 @@ __global__ void AddMarginToPositiveLogitsKernel( ...@@ -128,17 +127,6 @@ __global__ void AddMarginToPositiveLogitsKernel(
} }
} }
template <typename Tx, typename Ty = Tx>
struct ExpAndSum {
using Transformer = kps::ExpFunctor<Tx>;
inline Ty initial() { return static_cast<Ty>(0.0f); }
__device__ __forceinline__ Ty operator()(const Ty& a, const Ty& b) const {
return b + a;
}
};
template <typename T> template <typename T>
__global__ void ScaleLogitKernel(T* logits, const float scale, const int64_t N, __global__ void ScaleLogitKernel(T* logits, const float scale, const int64_t N,
const int64_t D) { const int64_t D) {
...@@ -309,8 +297,9 @@ class MarginCrossEntropyOpCUDAKernel : public framework::OpKernel<T> { ...@@ -309,8 +297,9 @@ class MarginCrossEntropyOpCUDAKernel : public framework::OpKernel<T> {
logits_max = logits_max =
ctx.AllocateTmpTensor<T, platform::CUDADeviceContext>({N, 1}, dev_ctx); ctx.AllocateTmpTensor<T, platform::CUDADeviceContext>({N, 1}, dev_ctx);
T* logits_max_buff = logits_max.mutable_data<T>(place); T* logits_max_buff = logits_max.mutable_data<T>(place);
TensorReduceFunctorImpl<T, T, CustomMax>(softmax_2d, &logits_max, {1}, TensorReduceFunctorImpl<T, T, kps::MaxFunctor, kps::IdentityFunctor<T>>(
dev_ctx.stream()); softmax_2d, &logits_max, kps::IdentityFunctor<T>(), {1},
dev_ctx.stream());
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
if (nranks > 1) { if (nranks > 1) {
...@@ -330,8 +319,9 @@ class MarginCrossEntropyOpCUDAKernel : public framework::OpKernel<T> { ...@@ -330,8 +319,9 @@ class MarginCrossEntropyOpCUDAKernel : public framework::OpKernel<T> {
sum_exp_logits = sum_exp_logits =
ctx.AllocateTmpTensor<T, platform::CUDADeviceContext>({N, 1}, dev_ctx); ctx.AllocateTmpTensor<T, platform::CUDADeviceContext>({N, 1}, dev_ctx);
T* sum_exp_logits_buff = sum_exp_logits.mutable_data<T>(place); T* sum_exp_logits_buff = sum_exp_logits.mutable_data<T>(place);
TensorReduceFunctorImpl<T, T, ExpAndSum>(softmax_2d, &sum_exp_logits, {1}, TensorReduceFunctorImpl<T, T, kps::AddFunctor, kps::ExpFunctor<T>>(
dev_ctx.stream()); softmax_2d, &sum_exp_logits, kps::ExpFunctor<T>(), {1},
dev_ctx.stream());
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
if (nranks > 1) { if (nranks > 1) {
......
...@@ -59,28 +59,17 @@ __device__ __forceinline__ double inline_pow(double base, double exponent) { ...@@ -59,28 +59,17 @@ __device__ __forceinline__ double inline_pow(double base, double exponent) {
return pow(base, exponent); return pow(base, exponent);
} }
struct IdentityFunctor { template <typename T>
HOSTDEVICE explicit inline IdentityFunctor() {}
HOSTDEVICE explicit inline IdentityFunctor(int n) {}
template <typename T>
HOSTDEVICE inline T operator()(const T& x) const {
return static_cast<T>(x);
}
};
struct NonzeroFunctor { struct NonzeroFunctor {
HOSTDEVICE explicit inline NonzeroFunctor() {} HOSTDEVICE explicit inline NonzeroFunctor() {}
HOSTDEVICE explicit inline NonzeroFunctor(int n) {}
template <typename T>
HOSTDEVICE inline T operator()(const T& x) const { HOSTDEVICE inline T operator()(const T& x) const {
return static_cast<T>(static_cast<double>(x) != 0); return static_cast<T>(static_cast<double>(x) != 0);
} }
}; };
template <typename T>
struct AbsFunctor { struct AbsFunctor {
HOSTDEVICE explicit inline AbsFunctor() {} HOSTDEVICE explicit inline AbsFunctor() {}
HOSTDEVICE explicit inline AbsFunctor(int n) {}
template <typename T>
HOSTDEVICE inline T operator()(const T& x) const { HOSTDEVICE inline T operator()(const T& x) const {
return static_cast<T>(inline_abs(x)); return static_cast<T>(inline_abs(x));
} }
...@@ -106,48 +95,6 @@ struct PowFunctor { ...@@ -106,48 +95,6 @@ struct PowFunctor {
float porder; float porder;
}; };
template <typename Tx, typename Ty = Tx>
struct AbsAndMin {
using Transformer = AbsFunctor;
using MT = typename details::MPTypeTrait<Ty>::Type;
inline Ty initial() {
return static_cast<Ty>(std::numeric_limits<MT>::infinity());
}
__device__ __forceinline__ Ty operator()(const Ty& a, const Ty& b) const {
return (a < b) ? a : b;
}
};
template <typename Tx, typename Ty = Tx>
struct AbsAndMax {
using Transformer = AbsFunctor;
using MT = typename details::MPTypeTrait<Ty>::Type;
inline Ty initial() {
return static_cast<Ty>(-std::numeric_limits<MT>::infinity());
}
__device__ __forceinline__ Ty operator()(const Ty& a, const Ty& b) const {
return (a > b) ? a : b;
}
};
template <typename Tx, typename Ty = Tx>
struct NonzeroAndSum {
using Transformer = NonzeroFunctor;
inline Ty initial() { return static_cast<Ty>(0.0f); }
__device__ __forceinline__ Ty operator()(const Ty& a, const Ty& b) const {
return b + a;
}
};
template <typename Tx, typename Ty = Tx>
struct IdentityAndSum {
using Transformer = IdentityFunctor;
inline Ty initial() { return static_cast<Ty>(0.0f); }
__device__ __forceinline__ Ty operator()(const Ty& a, const Ty& b) const {
return b + a;
}
};
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class PnormCUDAKernel : public framework::OpKernel<T> { class PnormCUDAKernel : public framework::OpKernel<T> {
public: public:
...@@ -167,14 +114,14 @@ class PnormCUDAKernel : public framework::OpKernel<T> { ...@@ -167,14 +114,14 @@ class PnormCUDAKernel : public framework::OpKernel<T> {
using MT = typename details::MPTypeTrait<T>::Type; using MT = typename details::MPTypeTrait<T>::Type;
if (porder == 0) { if (porder == 0) {
TensorReduceFunctorImpl<T, T, NonzeroAndSum>(*in_x, out_norm, reduce_axis, TensorReduceFunctorImpl<T, T, kps::AddFunctor, NonzeroFunctor<T>>(
stream); *in_x, out_norm, NonzeroFunctor<T>(), reduce_axis, stream);
} else if (porder == INFINITY) { } else if (porder == INFINITY) {
TensorReduceFunctorImpl<T, T, AbsAndMax>(*in_x, out_norm, reduce_axis, TensorReduceFunctorImpl<T, T, kps::MaxFunctor, AbsFunctor<T>>(
stream); *in_x, out_norm, AbsFunctor<T>(), reduce_axis, stream);
} else if (porder == -INFINITY) { } else if (porder == -INFINITY) {
TensorReduceFunctorImpl<T, T, AbsAndMin>(*in_x, out_norm, reduce_axis, TensorReduceFunctorImpl<T, T, kps::MinFunctor, AbsFunctor<T>>(
stream); *in_x, out_norm, AbsFunctor<T>(), reduce_axis, stream);
} else { } else {
framework::Tensor tmp_x; framework::Tensor tmp_x;
tmp_x.mutable_data<T>(xdim, ctx.GetPlace()); tmp_x.mutable_data<T>(xdim, ctx.GetPlace());
...@@ -189,8 +136,8 @@ class PnormCUDAKernel : public framework::OpKernel<T> { ...@@ -189,8 +136,8 @@ class PnormCUDAKernel : public framework::OpKernel<T> {
cuda_ctx, ins, &outs, func); cuda_ctx, ins, &outs, func);
framework::Tensor tmp_y; framework::Tensor tmp_y;
tmp_y.mutable_data<T>(ndim, ctx.GetPlace()); tmp_y.mutable_data<T>(ndim, ctx.GetPlace());
TensorReduceFunctorImpl<T, T, IdentityAndSum>(tmp_x, &tmp_y, reduce_axis, TensorReduceFunctorImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
stream); tmp_x, &tmp_y, kps::IdentityFunctor<T>(), reduce_axis, stream);
const framework::Tensor* tmp_norm = &tmp_y; const framework::Tensor* tmp_norm = &tmp_y;
ins = {tmp_norm}; ins = {tmp_norm};
outs = {out_norm}; outs = {out_norm};
......
...@@ -23,7 +23,6 @@ limitations under the License. */ ...@@ -23,7 +23,6 @@ limitations under the License. */
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/pooling.h" #include "paddle/fluid/operators/math/pooling.h"
#if defined(__HIPCC__) || defined(__NVCC__) #if defined(__HIPCC__) || defined(__NVCC__)
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h" #include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#endif #endif
...@@ -203,13 +202,14 @@ class PoolKernel : public framework::OpKernel<T> { ...@@ -203,13 +202,14 @@ class PoolKernel : public framework::OpKernel<T> {
} else if (pooling_type == "avg") { } else if (pooling_type == "avg") {
std::vector<int> reduce_dim; std::vector<int> reduce_dim;
int reduce_num = getReduceNum(*in_x, out, data_format, &reduce_dim); int reduce_num = getReduceNum(*in_x, out, data_format, &reduce_dim);
if (reduce_num > 0 && if (reduce_num > 0 &&
adaptive) { // for adaptive_avg_pool2d && output_size == 1 adaptive) { // for adaptive_avg_pool2d && output_size == 1
#if defined(__HIPCC__) || defined(__NVCC__) #if defined(__HIPCC__) || defined(__NVCC__)
auto stream = dev_ctx.stream(); auto stream = dev_ctx.stream();
TensorReduceFunctorImpl<T, T, CustomMean>(*in_x, out, reduce_dim, TensorReduceFunctorImpl<T, T, kps::AddFunctor,
stream); kps::DivideFunctor<T>>(
*in_x, out, kps::DivideFunctor<T>(reduce_num), reduce_dim,
stream);
#else // for cpu #else // for cpu
paddle::operators::math::Pool2dFunctor< paddle::operators::math::Pool2dFunctor<
DeviceContext, paddle::operators::math::AvgPool<T>, T> DeviceContext, paddle::operators::math::AvgPool<T>, T>
......
...@@ -15,7 +15,7 @@ limitations under the License. */ ...@@ -15,7 +15,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/prelu.h" #include "paddle/fluid/operators/math/prelu.h"
#include "paddle/fluid/operators/prelu_op.h" #include "paddle/fluid/operators/prelu_op.h"
#include "paddle/fluid/operators/reduce_ops/cub_reduce.h" #include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" #include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
namespace paddle { namespace paddle {
...@@ -123,13 +123,6 @@ class PreluOpGradFunctor { ...@@ -123,13 +123,6 @@ class PreluOpGradFunctor {
} }
}; };
struct IdentityFunctor {
template <typename T>
HOSTDEVICE inline T operator()(const T& x) const {
return x;
}
};
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class CUDAPReluGradKernel : public framework::OpKernel<T> { class CUDAPReluGradKernel : public framework::OpKernel<T> {
public: public:
...@@ -192,9 +185,8 @@ class CUDAPReluGradKernel : public framework::OpKernel<T> { ...@@ -192,9 +185,8 @@ class CUDAPReluGradKernel : public framework::OpKernel<T> {
reduce_dims.push_back(i); reduce_dims.push_back(i);
} }
TensorReduce<T, T, cub::Sum, IdentityFunctor>( TensorReduceFunctorImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
dalpha_tmp, dalpha, reduce_dims, static_cast<T>(0), cub::Sum(), dalpha_tmp, dalpha, kps::IdentityFunctor<T>(), reduce_dims, stream);
IdentityFunctor(), stream);
} }
}; };
......
...@@ -13,8 +13,7 @@ ...@@ -13,8 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/reduce_ops/reduce_all_op.h" #include "paddle/fluid/operators/reduce_ops/reduce_all_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
reduce_all, reduce_all,
ops::ReduceCudaKernel<bool, paddle::operators::CustomLogicalAnd>); ops::ReduceCudaKernel<bool, kps::LogicalAndFunctor, kps::IdentityFunctor>);
...@@ -13,9 +13,8 @@ ...@@ -13,9 +13,8 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/reduce_ops/reduce_any_op.h" #include "paddle/fluid/operators/reduce_ops/reduce_any_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.h" #include "paddle/fluid/operators/reduce_ops/reduce_op.h"
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
reduce_any, reduce_any,
ops::ReduceCudaKernel<bool, paddle::operators::CustomLogicalOr>); ops::ReduceCudaKernel<bool, kps::LogicalOrFunctor, kps::IdentityFunctor>);
...@@ -11,13 +11,13 @@ ...@@ -11,13 +11,13 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h" #include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.h" #include "paddle/fluid/operators/reduce_ops/reduce_op.h"
// reduce_max // reduce_max
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
reduce_max, ops::ReduceCudaKernel<float, paddle::operators::CustomMax>, reduce_max,
ops::ReduceCudaKernel<double, paddle::operators::CustomMax>, ops::ReduceCudaKernel<float, kps::MaxFunctor, kps::IdentityFunctor>,
ops::ReduceCudaKernel<int, paddle::operators::CustomMax>, ops::ReduceCudaKernel<double, kps::MaxFunctor, kps::IdentityFunctor>,
ops::ReduceCudaKernel<int64_t, paddle::operators::CustomMax>); ops::ReduceCudaKernel<int, kps::MaxFunctor, kps::IdentityFunctor>,
ops::ReduceCudaKernel<int64_t, kps::MaxFunctor, kps::IdentityFunctor>);
...@@ -13,11 +13,11 @@ ...@@ -13,11 +13,11 @@
// limitations under the License. // limitations under the License.
#include <vector> #include <vector>
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_mean_op.h" #include "paddle/fluid/operators/reduce_ops/reduce_mean_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.h" #include "paddle/fluid/operators/reduce_ops/reduce_op.h"
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
reduce_mean, ops::ReduceCudaKernel<bool, paddle::operators::CustomMean>, reduce_mean,
ops::ReduceCudaKernel<float, paddle::operators::CustomMean>, ops::ReduceCudaKernel<bool, kps::AddFunctor, kps::DivideFunctor>,
ops::ReduceCudaKernel<double, paddle::operators::CustomMean>); ops::ReduceCudaKernel<float, kps::AddFunctor, kps::DivideFunctor>,
ops::ReduceCudaKernel<double, kps::AddFunctor, kps::DivideFunctor>);
...@@ -11,13 +11,13 @@ ...@@ -11,13 +11,13 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h" #include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.h" #include "paddle/fluid/operators/reduce_ops/reduce_op.h"
// reduce_min // reduce_min
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
reduce_min, ops::ReduceCudaKernel<float, paddle::operators::CustomMin>, reduce_min,
ops::ReduceCudaKernel<double, paddle::operators::CustomMin>, ops::ReduceCudaKernel<float, kps::MinFunctor, kps::IdentityFunctor>,
ops::ReduceCudaKernel<int, paddle::operators::CustomMin>, ops::ReduceCudaKernel<double, kps::MinFunctor, kps::IdentityFunctor>,
ops::ReduceCudaKernel<int64_t, paddle::operators::CustomMin>); ops::ReduceCudaKernel<int, kps::MinFunctor, kps::IdentityFunctor>,
ops::ReduceCudaKernel<int64_t, kps::MinFunctor, kps::IdentityFunctor>);
...@@ -44,11 +44,11 @@ namespace cub = hipcub; ...@@ -44,11 +44,11 @@ namespace cub = hipcub;
#define REDUCE_SPLIT_BOUNDARY 512 #define REDUCE_SPLIT_BOUNDARY 512
#define REDUCE_VEC_SIZE 4 #define REDUCE_VEC_SIZE 4
namespace kps = paddle::operators::kernel_primitives;
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace kps = paddle::operators::kernel_primitives;
namespace details { namespace details {
static inline int GetLastPow2(int n) { static inline int GetLastPow2(int n) {
...@@ -722,12 +722,12 @@ __global__ void ReduceHigherDimKernel(const Tx* x, Ty* y, ReduceOp reducer, ...@@ -722,12 +722,12 @@ __global__ void ReduceHigherDimKernel(const Tx* x, Ty* y, ReduceOp reducer,
} }
} }
template <typename Tx, typename Ty, typename MPType, typename ReduceOp> template <typename Tx, typename Ty, typename MPType, typename ReduceOp,
typename TransformOp>
static void LaunchReduceKernel(const Tx* x_data, Ty* y_data, static void LaunchReduceKernel(const Tx* x_data, Ty* y_data,
const ReduceOp& reducer, MPType init, const ReduceOp& reducer,
const TransformOp& transform, MPType init,
gpuStream_t stream, ReduceConfig<Ty> config) { gpuStream_t stream, ReduceConfig<Ty> config) {
using TransformOp = typename ReduceOp::Transformer;
if (config.reduce_type == kReduceLastDim) { if (config.reduce_type == kReduceLastDim) {
int stride_reduce = 1; int stride_reduce = 1;
int stride_left = config.reduce_num; int stride_left = config.reduce_num;
...@@ -743,15 +743,15 @@ static void LaunchReduceKernel(const Tx* x_data, Ty* y_data, ...@@ -743,15 +743,15 @@ static void LaunchReduceKernel(const Tx* x_data, Ty* y_data,
#ifdef PADDLE_WITH_XPU2 #ifdef PADDLE_WITH_XPU2
ReduceAnyKernel<Tx, Ty, MPType, ReduceOp, TransformOp, ReduceAnyKernel<Tx, Ty, MPType, ReduceOp, TransformOp,
OneDimIndexCal><<<8, 128, stream>>>( OneDimIndexCal><<<8, 128, stream>>>(
x_data, config.output_data, reducer, TransformOp(config.reduce_num), x_data, config.output_data, reducer, transform, init, config.reduce_num,
init, config.reduce_num, config.left_num, config.reduce_last_dim, config.left_num, config.reduce_last_dim, reduce_index_calculator,
reduce_index_calculator, left_index_calculator, dim); left_index_calculator, dim);
#else #else
ReduceAnyKernel<Tx, Ty, MPType, ReduceOp, TransformOp, ReduceAnyKernel<Tx, Ty, MPType, ReduceOp, TransformOp,
OneDimIndexCal><<<config.grid, config.block, 0, stream>>>( OneDimIndexCal><<<config.grid, config.block, 0, stream>>>(
x_data, config.output_data, reducer, TransformOp(config.reduce_num), x_data, config.output_data, reducer, transform, init, config.reduce_num,
init, config.reduce_num, config.left_num, config.reduce_last_dim, config.left_num, config.reduce_last_dim, reduce_index_calculator,
reduce_index_calculator, left_index_calculator, dim); left_index_calculator, dim);
#endif #endif
} else { } else {
...@@ -771,15 +771,15 @@ static void LaunchReduceKernel(const Tx* x_data, Ty* y_data, ...@@ -771,15 +771,15 @@ static void LaunchReduceKernel(const Tx* x_data, Ty* y_data,
#ifdef PADDLE_WITH_XPU2 #ifdef PADDLE_WITH_XPU2
ReduceAnyKernel<Tx, Ty, MPType, ReduceOp, TransformOp, ReduceAnyKernel<Tx, Ty, MPType, ReduceOp, TransformOp,
IndexCalculator><<<8, 128, stream>>>( IndexCalculator><<<8, 128, stream>>>(
x_data, config.output_data, reducer, TransformOp(config.reduce_num), x_data, config.output_data, reducer, transform, init, config.reduce_num,
init, config.reduce_num, config.left_num, config.reduce_last_dim, config.left_num, config.reduce_last_dim, reduce_index_calculator,
reduce_index_calculator, left_index_calculator, dim); left_index_calculator, dim);
#else #else
ReduceAnyKernel<Tx, Ty, MPType, ReduceOp, TransformOp, ReduceAnyKernel<Tx, Ty, MPType, ReduceOp, TransformOp,
IndexCalculator><<<config.grid, config.block, 0, stream>>>( IndexCalculator><<<config.grid, config.block, 0, stream>>>(
x_data, config.output_data, reducer, TransformOp(config.reduce_num), x_data, config.output_data, reducer, transform, init, config.reduce_num,
init, config.reduce_num, config.left_num, config.reduce_last_dim, config.left_num, config.reduce_last_dim, reduce_index_calculator,
reduce_index_calculator, left_index_calculator, dim); left_index_calculator, dim);
#endif #endif
} }
...@@ -802,23 +802,22 @@ static void LaunchReduceKernel(const Tx* x_data, Ty* y_data, ...@@ -802,23 +802,22 @@ static void LaunchReduceKernel(const Tx* x_data, Ty* y_data,
#ifdef PADDLE_WITH_XPU2 #ifdef PADDLE_WITH_XPU2
ReduceHigherDimKernel<Ty, Ty, MPType, ReduceOp, ReduceHigherDimKernel<Ty, Ty, MPType, ReduceOp,
kps::IdentityFunctor<Ty, MPType>><<<8, 128, stream>>>( kps::IdentityFunctor<Ty, MPType>><<<8, 128, stream>>>(
config.output_data, y_data, reducer, config.output_data, y_data, reducer, kps::IdentityFunctor<Ty, MPType>(),
kps::IdentityFunctor<Ty, MPType>(config.grid.y), init, config.grid.y, init, config.grid.y, config.left_num, config.grid.y, dim);
config.left_num, config.grid.y, dim);
#else #else
ReduceHigherDimKernel< ReduceHigherDimKernel<
Ty, Ty, MPType, ReduceOp, Ty, Ty, MPType, ReduceOp,
kps::IdentityFunctor<Ty, MPType>><<<grid, block, 0, stream>>>( kps::IdentityFunctor<Ty, MPType>><<<grid, block, 0, stream>>>(
config.output_data, y_data, reducer, config.output_data, y_data, reducer, kps::IdentityFunctor<Ty, MPType>(),
kps::IdentityFunctor<Ty, MPType>(config.grid.y), init, config.grid.y, init, config.grid.y, config.left_num, config.grid.y, dim);
config.left_num, config.grid.y, dim);
#endif #endif
} }
} }
template <typename Tx, typename Ty, template <typename Tx, typename Ty, template <typename> class ReduceOp,
template <typename, typename> class ReduceOp> typename TransformOp>
void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y, void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y,
const TransformOp& transform,
std::vector<int> origin_reduce_dims, std::vector<int> origin_reduce_dims,
gpuStream_t stream) { gpuStream_t stream) {
auto x_dim = framework::vectorize<int>(x.dims()); auto x_dim = framework::vectorize<int>(x.dims());
...@@ -853,10 +852,9 @@ void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y, ...@@ -853,10 +852,9 @@ void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y,
(!std::is_same<Tx, paddle::platform::float16>::value); (!std::is_same<Tx, paddle::platform::float16>::value);
if (use_cub_reduce) { if (use_cub_reduce) {
// launch CUB::Reduce // launch CUB::Reduce
using TransformOp = typename ReduceOp<Tx, Ty>::Transformer; auto reducer = ReduceOp<Ty>();
auto reducer = ReduceOp<Tx, Ty>(); cub::TransformInputIterator<Ty, TransformOp, const Tx*> trans_x(x_data,
cub::TransformInputIterator<Ty, TransformOp, const Tx*> trans_x( transform);
x_data, TransformOp(config.reduce_num));
size_t temp_storage_bytes = 0; size_t temp_storage_bytes = 0;
cub::DeviceReduce::Reduce(nullptr, temp_storage_bytes, trans_x, y_data, cub::DeviceReduce::Reduce(nullptr, temp_storage_bytes, trans_x, y_data,
config.reduce_num, reducer, reducer.initial(), config.reduce_num, reducer, reducer.initial(),
...@@ -873,7 +871,7 @@ void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y, ...@@ -873,7 +871,7 @@ void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y,
} }
using MPType = typename details::MPTypeTrait<Ty>::Type; using MPType = typename details::MPTypeTrait<Ty>::Type;
auto reducer = ReduceOp<Tx, MPType>(); auto reducer = ReduceOp<MPType>();
// launch ReduceHigherDimKernel // launch ReduceHigherDimKernel
// when reduce_dim.size() == 1 and reduce_dim[0] != x_dim.size() - 1, this // when reduce_dim.size() == 1 and reduce_dim[0] != x_dim.size() - 1, this
// function will be used // function will be used
...@@ -882,7 +880,6 @@ void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y, ...@@ -882,7 +880,6 @@ void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y,
// 32 // 32
// else grid.z = 1, grid.y = ny / block_size, grid.x = nx /32 // else grid.z = 1, grid.y = ny / block_size, grid.x = nx /32
if (config.reduce_type == ReduceType::kReduceHigherDim) { if (config.reduce_type == ReduceType::kReduceHigherDim) {
using TransformOp = typename ReduceOp<Tx, MPType>::Transformer;
kps::DimConfig dim = kps::DimConfig dim =
kps::DimConfig(config.grid.x, config.grid.y, config.grid.z, kps::DimConfig(config.grid.x, config.grid.y, config.grid.z,
config.block.x, config.blocking_size, 0); config.block.x, config.blocking_size, 0);
...@@ -890,18 +887,16 @@ void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y, ...@@ -890,18 +887,16 @@ void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y,
config.reduce_num % config.blocking_size, 0); config.reduce_num % config.blocking_size, 0);
#ifdef PADDLE_WITH_XPU2 #ifdef PADDLE_WITH_XPU2
ReduceHigherDimKernel<Tx, Ty, MPType, ReduceOp<Tx, MPType>, ReduceHigherDimKernel<Tx, Ty, MPType, ReduceOp<MPType>,
TransformOp><<<8, 128, stream>>>( TransformOp><<<8, 128, stream>>>(
x_data, config.output_data, reducer, TransformOp(config.reduce_num), x_data, config.output_data, reducer, transform, reducer.initial(),
reducer.initial(), config.reduce_num, config.left_num, config.reduce_num, config.left_num, config.blocking_size, dim);
config.blocking_size, dim);
#else #else
ReduceHigherDimKernel< ReduceHigherDimKernel<
Tx, Ty, MPType, ReduceOp<Tx, MPType>, Tx, Ty, MPType, ReduceOp<MPType>,
TransformOp><<<config.grid, config.block, 0, stream>>>( TransformOp><<<config.grid, config.block, 0, stream>>>(
x_data, config.output_data, reducer, TransformOp(config.reduce_num), x_data, config.output_data, reducer, transform, reducer.initial(),
reducer.initial(), config.reduce_num, config.left_num, config.reduce_num, config.left_num, config.blocking_size, dim);
config.blocking_size, dim);
#endif #endif
if (config.should_reduce_again) { if (config.should_reduce_again) {
...@@ -913,14 +908,14 @@ void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y, ...@@ -913,14 +908,14 @@ void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y,
#ifdef PADDLE_WITH_XPU2 #ifdef PADDLE_WITH_XPU2
ReduceHigherDimKernel< ReduceHigherDimKernel<
Ty, Ty, MPType, ReduceOp<Tx, MPType>, Ty, Ty, MPType, ReduceOp<MPType>,
kps::IdentityFunctor<Ty, MPType>><<<8, 128, stream>>>( kps::IdentityFunctor<Ty, MPType>><<<8, 128, stream>>>(
config.output_data, y_data, reducer, config.output_data, y_data, reducer,
kps::IdentityFunctor<Ty, MPType>(config.grid.y), reducer.initial(), kps::IdentityFunctor<Ty, MPType>(config.grid.y), reducer.initial(),
config.grid.y, config.left_num, config.grid.y, dim2); config.grid.y, config.left_num, config.grid.y, dim2);
#else #else
ReduceHigherDimKernel< ReduceHigherDimKernel<
Ty, Ty, MPType, ReduceOp<Tx, MPType>, Ty, Ty, MPType, ReduceOp<MPType>,
kps::IdentityFunctor<Ty, MPType>><<<grid, block, 0, stream>>>( kps::IdentityFunctor<Ty, MPType>><<<grid, block, 0, stream>>>(
config.output_data, y_data, reducer, config.output_data, y_data, reducer,
kps::IdentityFunctor<Ty, MPType>(config.grid.y), reducer.initial(), kps::IdentityFunctor<Ty, MPType>(config.grid.y), reducer.initial(),
...@@ -933,23 +928,32 @@ void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y, ...@@ -933,23 +928,32 @@ void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y,
// when reduce_dim.size() == 1 and reduce_dim[0] == x_dim.size() - 1, or // when reduce_dim.size() == 1 and reduce_dim[0] == x_dim.size() - 1, or
// when reduce_dim.size() != 1 and reduce_dim.size() != x_dim.size(), this // when reduce_dim.size() != 1 and reduce_dim.size() != x_dim.size(), this
// function will be used // function will be used
LaunchReduceKernel<Tx, Ty, MPType, ReduceOp<Tx, MPType>>( LaunchReduceKernel<Tx, Ty, MPType, ReduceOp<MPType>, TransformOp>(
x_data, y_data, reducer, reducer.initial(), stream, config); x_data, y_data, reducer, transform, reducer.initial(), stream, config);
} }
template <typename Tx, template <typename, typename> class ReduceOp> template <typename Tx, template <typename> class ReduceOp,
template <typename, typename> class TransformOp>
struct TensorReduceFunc { struct TensorReduceFunc {
const framework::Tensor& x; const framework::Tensor& x;
framework::Tensor* y; framework::Tensor* y;
std::vector<int> origin_reduce_dims; std::vector<int> origin_reduce_dims;
gpuStream_t stream; gpuStream_t stream;
int reduce_num;
TensorReduceFunc(const framework::Tensor& x, framework::Tensor* y, TensorReduceFunc(const framework::Tensor& x, framework::Tensor* y,
std::vector<int> origin_reduce_dims, gpuStream_t stream) std::vector<int> origin_reduce_dims, int num_reduce,
: x(x), y(y), origin_reduce_dims(origin_reduce_dims), stream(stream) {} gpuStream_t stream)
: x(x),
y(y),
origin_reduce_dims(origin_reduce_dims),
reduce_num(num_reduce),
stream(stream) {}
template <typename Ty> template <typename Ty>
void apply() const { void apply() const {
TensorReduceFunctorImpl<Tx, Ty, ReduceOp>(x, y, origin_reduce_dims, stream); using MPType = typename details::MPTypeTrait<Ty>::Type;
TensorReduceFunctorImpl<Tx, Ty, ReduceOp, TransformOp<Tx, MPType>>(
x, y, TransformOp<Tx, MPType>(reduce_num), origin_reduce_dims, stream);
} }
}; };
......
...@@ -670,7 +670,8 @@ If reduce_all is true, just reduce along all dimensions and output a scalar. ...@@ -670,7 +670,8 @@ If reduce_all is true, just reduce along all dimensions and output a scalar.
}; };
#if defined(__HIPCC__) || defined(__NVCC__) #if defined(__HIPCC__) || defined(__NVCC__)
template <typename T, template <typename, typename> class ReduceOp> template <typename T, template <typename> class ReduceOp,
template <typename, typename> class TransformOp>
class ReduceCudaKernel : public framework::OpKernel<T> { class ReduceCudaKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
...@@ -682,15 +683,19 @@ class ReduceCudaKernel : public framework::OpKernel<T> { ...@@ -682,15 +683,19 @@ class ReduceCudaKernel : public framework::OpKernel<T> {
std::vector<int> reduce_dims = std::vector<int> reduce_dims =
GetReduceDim(dims, input->dims().size(), reduce_all); GetReduceDim(dims, input->dims().size(), reduce_all);
int reduce_num = 1;
for (int i = 0; i < input->dims().size(); i++) {
reduce_num *= (input->dims())[i];
}
gpuStream_t stream = context.cuda_device_context().stream(); gpuStream_t stream = context.cuda_device_context().stream();
if (out_dtype >= 0) { if (out_dtype >= 0) {
framework::VisitDataTypeSmall( framework::VisitDataTypeSmall(
static_cast<framework::proto::VarType::Type>(out_dtype), static_cast<framework::proto::VarType::Type>(out_dtype),
TensorReduceFunc<T, ReduceOp>(*input, output, reduce_dims, stream)); TensorReduceFunc<T, ReduceOp, TransformOp>(
*input, output, reduce_dims, reduce_num, stream));
} else { } else {
TensorReduceFunctorImpl<T, T, ReduceOp>(*input, output, reduce_dims, TensorReduceFunctorImpl<T, T, ReduceOp, TransformOp<T, T>>(
stream); *input, output, TransformOp<T, T>(reduce_num), reduce_dims, stream);
} }
} }
}; };
......
...@@ -12,12 +12,12 @@ ...@@ -12,12 +12,12 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h" #include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#include "paddle/fluid/operators/reduce_ops/reduce_prod_op.h" #include "paddle/fluid/operators/reduce_ops/reduce_prod_op.h"
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
reduce_prod, ops::ReduceCudaKernel<float, paddle::operators::CustomMul>, reduce_prod,
ops::ReduceCudaKernel<int, paddle::operators::CustomMul>, ops::ReduceCudaKernel<float, kps::MulFunctor, kps::IdentityFunctor>,
ops::ReduceCudaKernel<double, paddle::operators::CustomMul>, ops::ReduceCudaKernel<int, kps::MulFunctor, kps::IdentityFunctor>,
ops::ReduceCudaKernel<int64_t, paddle::operators::CustomMul>); ops::ReduceCudaKernel<double, kps::MulFunctor, kps::IdentityFunctor>,
ops::ReduceCudaKernel<int64_t, kps::MulFunctor, kps::IdentityFunctor>);
...@@ -11,18 +11,18 @@ ...@@ -11,18 +11,18 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.h" #include "paddle/fluid/operators/reduce_ops/reduce_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_sum_op.h" #include "paddle/fluid/operators/reduce_ops/reduce_sum_op.h"
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
reduce_sum, ops::ReduceCudaKernel<bool, paddle::operators::CustomSum>, reduce_sum,
ops::ReduceCudaKernel<float, paddle::operators::CustomSum>, ops::ReduceCudaKernel<bool, kps::AddFunctor, kps::IdentityFunctor>,
ops::ReduceCudaKernel<double, paddle::operators::CustomSum>, ops::ReduceCudaKernel<float, kps::AddFunctor, kps::IdentityFunctor>,
ops::ReduceCudaKernel<paddle::platform::float16, ops::ReduceCudaKernel<double, kps::AddFunctor, kps::IdentityFunctor>,
paddle::operators::CustomSum>, ops::ReduceCudaKernel<paddle::platform::float16, kps::AddFunctor,
ops::ReduceCudaKernel<int, paddle::operators::CustomSum>, kps::IdentityFunctor>,
ops::ReduceCudaKernel<int64_t, paddle::operators::CustomSum>, ops::ReduceCudaKernel<int, kps::AddFunctor, kps::IdentityFunctor>,
ops::ReduceCudaKernel<paddle::platform::complex<float>, ops::ReduceCudaKernel<int64_t, kps::AddFunctor, kps::IdentityFunctor>,
paddle::operators::CustomSum>, ops::ReduceCudaKernel<paddle::platform::complex<float>, kps::AddFunctor,
ops::ReduceCudaKernel<paddle::platform::complex<double>, kps::IdentityFunctor>,
paddle::operators::CustomSum>); ops::ReduceCudaKernel<paddle::platform::complex<double>, kps::AddFunctor,
kps::IdentityFunctor>);
...@@ -15,21 +15,12 @@ ...@@ -15,21 +15,12 @@
#include <thrust/device_vector.h> #include <thrust/device_vector.h>
#include <thrust/host_vector.h> #include <thrust/host_vector.h>
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/reduce_ops/cub_reduce.h" #include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#include "paddle/fluid/operators/trace_op.h" #include "paddle/fluid/operators/trace_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
struct IdentityFunctor {
HOSTDEVICE explicit inline IdentityFunctor() {}
template <typename U>
HOSTDEVICE inline U operator()(const U& x) const {
return x;
}
};
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class TraceCUDAKernel : public framework::OpKernel<T> { class TraceCUDAKernel : public framework::OpKernel<T> {
public: public:
...@@ -48,9 +39,8 @@ class TraceCUDAKernel : public framework::OpKernel<T> { ...@@ -48,9 +39,8 @@ class TraceCUDAKernel : public framework::OpKernel<T> {
auto stream = context.cuda_device_context().stream(); auto stream = context.cuda_device_context().stream();
std::vector<int> reduce_dims; std::vector<int> reduce_dims;
reduce_dims.push_back(out->dims().size()); reduce_dims.push_back(out->dims().size());
TensorReduce<T, T, cub::Sum, IdentityFunctor>( TensorReduceFunctorImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
diag, out, reduce_dims, static_cast<T>(0), cub::Sum(), diag, out, kps::IdentityFunctor<T>(), reduce_dims, stream);
IdentityFunctor(), stream);
} else { } else {
math::SetConstant<DeviceContext, T> functor; math::SetConstant<DeviceContext, T> functor;
functor(context.device_context<DeviceContext>(), out, static_cast<T>(0)); functor(context.device_context<DeviceContext>(), out, static_cast<T>(0));
......
...@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.h" #include "paddle/fluid/operators/reduce_ops/reduce_op.h"
#include "paddle/fluid/operators/triangular_solve_op.h" #include "paddle/fluid/operators/triangular_solve_op.h"
...@@ -44,7 +43,8 @@ struct MatrixReduceSumFunctor<platform::CUDADeviceContext, T> { ...@@ -44,7 +43,8 @@ struct MatrixReduceSumFunctor<platform::CUDADeviceContext, T> {
} }
} }
gpuStream_t stream = ctx.cuda_device_context().stream(); gpuStream_t stream = ctx.cuda_device_context().stream();
TensorReduceFunctorImpl<T, T, CustomSum>(in, out, out_reduce_dims, stream); TensorReduceFunctorImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
in, out, kps::IdentityFunctor<T>(), out_reduce_dims, stream);
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册