diff --git a/paddle/fluid/operators/mean_op.cu b/paddle/fluid/operators/mean_op.cu index 921c2e1298906655767c1e7f30dc34b2c564c671..f0f895c08aac81252903ffd9a5ed1c931701adbb 100644 --- a/paddle/fluid/operators/mean_op.cu +++ b/paddle/fluid/operators/mean_op.cu @@ -11,15 +11,99 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "cub/cub.cuh" #include "paddle/fluid/operators/mean_op.h" +#include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/float16.h" +namespace paddle { +namespace operators { + +template +struct DivideFunctor { + HOSTDEVICE explicit inline DivideFunctor(int n) + : n_inv(static_cast(1.0 / n)) {} + + HOSTDEVICE inline T operator()(const T& x) const { return x * n_inv; } + + private: + T n_inv; +}; + +template +__global__ void MeanRunKernel(const T in_data, T* out_data, int N) { + int idx = blockDim.x * blockIdx.x + threadIdx.x; + for (; idx < N; idx += blockDim.x * gridDim.x) { + out_data[idx] = in_data / (static_cast(N)); + } +} + +template +class MeanCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* input = context.Input("X"); + auto* output = context.Output("Out"); + + output->mutable_data(context.GetPlace()); + auto size_prob = input->numel(); + const T* in_data = input->data(); + T* out_data = output->mutable_data(context.GetPlace()); + auto stream = context.cuda_device_context().stream(); + + DivideFunctor transformer(size_prob); + cub::TransformInputIterator, const T*> trans_x( + in_data, transformer); + size_t temp_storage_bytes = 0; + + auto err = cub::DeviceReduce::Sum(nullptr, temp_storage_bytes, trans_x, + out_data, size_prob, stream); + PADDLE_ENFORCE_CUDA_SUCCESS(err, + "MeanOP failed to get reduce workspace size", + cudaGetErrorString(err)); + framework::Tensor tmp; + auto* temp_storage = tmp.mutable_data( + framework::make_ddim({static_cast(temp_storage_bytes)}), + context.GetPlace()); + err = cub::DeviceReduce::Sum(temp_storage, temp_storage_bytes, trans_x, + out_data, size_prob, stream); + PADDLE_ENFORCE_CUDA_SUCCESS(err, "MeanOP failed to run reduce computation", + cudaGetErrorString(err)); + } +}; + +template +class MeanCUDAGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto OG = context.Input(framework::GradVarName("Out")); + PADDLE_ENFORCE_EQ( + OG->numel(), 1, + platform::errors::InvalidArgument( + "Mean Gradient Input Tensor len should be 1. But received %d", + OG->numel())); + auto IG = context.Output(framework::GradVarName("X")); + IG->mutable_data(context.GetPlace()); + + T in_data = OG[0]; + auto size_prob = IG->numel(); + auto out_data = IG->data(); + int threads = 512; + int grid = (size_prob + threads - 1) / threads; + auto stream = context.cuda_device_context().stream(); + MeanRunKernel<<>>(in_data, out_data, + size_prob); + } +}; +} // namespace operators +} // namespace paddle + namespace ops = paddle::operators; namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL( - mean, ops::MeanKernel, - ops::MeanKernel, - ops::MeanKernel); + mean, ops::MeanCUDAKernel, + ops::MeanCUDAKernel, + ops::MeanCUDAKernel); REGISTER_OP_CUDA_KERNEL( mean_grad, ops::MeanGradKernel, ops::MeanGradKernel,