未验证 提交 5eec8cf5 编写于 作者: W wangchaochaohu 提交者: GitHub

fix the mean grad OP performance improvement test=develop (#21658)

上级 29f64c8c
...@@ -31,10 +31,11 @@ struct DivideFunctor { ...@@ -31,10 +31,11 @@ struct DivideFunctor {
}; };
template <typename T> template <typename T>
__global__ void MeanRunKernel(const T in_data, T* out_data, int N) { __global__ void MeanRunKernel(const T* in_data, T* out_data, int N) {
int idx = blockDim.x * blockIdx.x + threadIdx.x; int idx = blockDim.x * blockIdx.x + threadIdx.x;
T data = in_data[0];
for (; idx < N; idx += blockDim.x * gridDim.x) { for (; idx < N; idx += blockDim.x * gridDim.x) {
out_data[idx] = in_data / (static_cast<T>(N)); out_data[idx] = data / (static_cast<T>(N));
} }
} }
...@@ -85,7 +86,7 @@ class MeanCUDAGradKernel : public framework::OpKernel<T> { ...@@ -85,7 +86,7 @@ class MeanCUDAGradKernel : public framework::OpKernel<T> {
auto IG = context.Output<Tensor>(framework::GradVarName("X")); auto IG = context.Output<Tensor>(framework::GradVarName("X"));
IG->mutable_data<T>(context.GetPlace()); IG->mutable_data<T>(context.GetPlace());
T in_data = OG[0]; auto in_data = OG->data<T>();
auto size_prob = IG->numel(); auto size_prob = IG->numel();
auto out_data = IG->data<T>(); auto out_data = IG->data<T>();
int threads = 512; int threads = 512;
...@@ -105,6 +106,8 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -105,6 +106,8 @@ REGISTER_OP_CUDA_KERNEL(
ops::MeanCUDAKernel<paddle::platform::CUDADeviceContext, double>, ops::MeanCUDAKernel<paddle::platform::CUDADeviceContext, double>,
ops::MeanCUDAKernel<paddle::platform::CUDADeviceContext, plat::float16>); ops::MeanCUDAKernel<paddle::platform::CUDADeviceContext, plat::float16>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
mean_grad, ops::MeanGradKernel<paddle::platform::CUDADeviceContext, float>, mean_grad,
ops::MeanGradKernel<paddle::platform::CUDADeviceContext, double>, ops::MeanCUDAGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::MeanGradKernel<paddle::platform::CUDADeviceContext, plat::float16>); ops::MeanCUDAGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::MeanCUDAGradKernel<paddle::platform::CUDADeviceContext,
plat::float16>);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册