From f86073c4414c31a5ac764eaa4d6477e5c2046a9d Mon Sep 17 00:00:00 2001 From: Leo Chen Date: Thu, 17 Feb 2022 21:34:09 +0800 Subject: [PATCH] [pten] move bernoulli kernel to pten (#39590) * move bernoulli kernel to pten * follow comments --- paddle/fluid/operators/bernoulli_op.cc | 28 ------- paddle/fluid/operators/bernoulli_op.cu | 84 --------------------- paddle/fluid/platform/transform.h | 47 ++++++++++++ paddle/pten/kernels/bernoulli_kernel.h | 27 +++++++ paddle/pten/kernels/cpu/bernoulli_kernel.cc | 55 ++++++++++++++ paddle/pten/kernels/gpu/bernoulli_kernel.cu | 77 +++++++++++++++++++ 6 files changed, 206 insertions(+), 112 deletions(-) delete mode 100644 paddle/fluid/operators/bernoulli_op.cu create mode 100644 paddle/pten/kernels/bernoulli_kernel.h create mode 100644 paddle/pten/kernels/cpu/bernoulli_kernel.cc create mode 100644 paddle/pten/kernels/gpu/bernoulli_kernel.cu diff --git a/paddle/fluid/operators/bernoulli_op.cc b/paddle/fluid/operators/bernoulli_op.cc index 79c4e2c2bba..ffb0173c463 100644 --- a/paddle/fluid/operators/bernoulli_op.cc +++ b/paddle/fluid/operators/bernoulli_op.cc @@ -49,30 +49,6 @@ class BernoulliOp : public framework::OperatorWithKernel { } }; -// It seems that Eigen::Tensor::random in GPU will SEGFAULT. -// Use std::random and thrust::random(thrust is a std library in CUDA) to -// implement uniform random. -template -class BernoulliOpKernel - : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const override { - const auto x = ctx.Input("X"); - auto out = ctx.Output("Out"); - auto *in_data = x->data(); - auto *out_data = out->mutable_data(ctx.GetPlace()); - - int64_t size = x->numel(); - std::uniform_real_distribution dist(0.0, 1.0); - auto gen_ptr = framework::DefaultCPUGenerator(); - auto engine = gen_ptr->GetCPUEngine(); - - for (int64_t i = 0; i < size; ++i) { - out_data[i] = BernoulliFunctor(in_data[i], dist(*engine)); - } - } -}; // namespace operators - } // namespace operators } // namespace paddle @@ -82,7 +58,3 @@ REGISTER_OPERATOR( bernoulli, ops::BernoulliOp, ops::BernoulliOpMaker, paddle::framework::EmptyGradOpMaker, paddle::framework::EmptyGradOpMaker); - -REGISTER_OP_CPU_KERNEL(bernoulli, - ops::BernoulliOpKernel, - ops::BernoulliOpKernel); diff --git a/paddle/fluid/operators/bernoulli_op.cu b/paddle/fluid/operators/bernoulli_op.cu deleted file mode 100644 index 030f7cb7d7c..00000000000 --- a/paddle/fluid/operators/bernoulli_op.cu +++ /dev/null @@ -1,84 +0,0 @@ -/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include -#include -#include - -#include "paddle/fluid/framework/generator.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" -#include "paddle/fluid/operators/bernoulli_op.h" -#include "paddle/fluid/platform/transform.h" - -namespace paddle { -namespace operators { -// it can be consistent with cpu when CUDAGenerator is provided. -template -struct BernoulliCudaFunctor { - unsigned int seed_; - unsigned int offset_; - __host__ __device__ BernoulliCudaFunctor(unsigned int seed, - unsigned int offset) - : seed_(seed), offset_(offset) {} - - __host__ __device__ T operator()(const unsigned int n, const T p) const { - // NOTE(zhiqiu): currently, PADDLE_ENFORCE in cuda kernel may print several - // lines of error messages if, and it should be refined. - PADDLE_ENFORCE(p >= 0.0 && p <= 1.0, - "The probability should be >=0 and <= 1, but got %f", p); - thrust::minstd_rand rng; - rng.seed(seed_); - thrust::uniform_real_distribution dist(0.0, 1.0); - rng.discard(n + offset_); - return static_cast(dist(rng) < p); - } -}; - -template -class BernoulliOpKernel - : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - const auto x = ctx.Input("X"); - auto out = ctx.Output("Out"); - auto* in_data = x->data(); - auto* out_data = out->mutable_data(ctx.GetPlace()); - int64_t size = x->numel(); - - int device_id = ctx.GetPlace().GetDeviceId(); - auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id); - auto seed_offset = gen_cuda->IncrementOffset(1); - int64_t gen_offset = size * seed_offset.second; - platform::Transform trans; - thrust::counting_iterator index_sequence_begin(0); - auto* context = - static_cast(&ctx.device_context()); - - trans(*context, index_sequence_begin, index_sequence_begin + size, in_data, - out_data, - BernoulliCudaFunctor(static_cast(seed_offset.first), - static_cast(gen_offset))); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -namespace plat = paddle::platform; - -REGISTER_OP_CUDA_KERNEL( - bernoulli, ops::BernoulliOpKernel, - ops::BernoulliOpKernel); diff --git a/paddle/fluid/platform/transform.h b/paddle/fluid/platform/transform.h index e3a39146287..be051ff9219 100644 --- a/paddle/fluid/platform/transform.h +++ b/paddle/fluid/platform/transform.h @@ -141,6 +141,53 @@ struct Transform { #endif } }; + +template <> +struct Transform { + template + void operator()(const pten::GPUContext& context, InputIter first, + InputIter last, OutputIter result, UnaryOperation op) { + auto place = context.GetPlace(); + PADDLE_ENFORCE_EQ(is_gpu_place(place), true, + platform::errors::PreconditionNotMet( + "The CUDA Transform must be used in GPU place.")); +#ifdef __HIPCC__ + thrust::transform(thrust::hip::par.on(context.stream()), + details::CastToCUDATransformIterator(first), + details::CastToCUDATransformIterator(last), + details::CastToCUDATransformIterator(result), op); +#else + thrust::transform(thrust::cuda::par.on(context.stream()), + details::CastToCUDATransformIterator(first), + details::CastToCUDATransformIterator(last), + details::CastToCUDATransformIterator(result), op); +#endif + } + + template + void operator()(const pten::GPUContext& context, InputIter1 first1, + InputIter1 last1, InputIter2 first2, OutputIter result, + BinaryOperation op) { + auto place = context.GetPlace(); + PADDLE_ENFORCE_EQ(is_gpu_place(place), true, + platform::errors::PreconditionNotMet( + "The CUDA Transform must be used in GPU place.")); +#ifdef __HIPCC__ + thrust::transform(thrust::hip::par.on(context.stream()), + details::CastToCUDATransformIterator(first1), + details::CastToCUDATransformIterator(last1), + details::CastToCUDATransformIterator(first2), + details::CastToCUDATransformIterator(result), op); +#else + thrust::transform(thrust::cuda::par.on(context.stream()), + details::CastToCUDATransformIterator(first1), + details::CastToCUDATransformIterator(last1), + details::CastToCUDATransformIterator(first2), + details::CastToCUDATransformIterator(result), op); +#endif + } +}; #endif } // namespace platform diff --git a/paddle/pten/kernels/bernoulli_kernel.h b/paddle/pten/kernels/bernoulli_kernel.h new file mode 100644 index 00000000000..f2d15e54cb1 --- /dev/null +++ b/paddle/pten/kernels/bernoulli_kernel.h @@ -0,0 +1,27 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/pten/core/dense_tensor.h" +#include "paddle/pten/core/device_context.h" + +namespace pten { + +template +void BernoulliKernel(const Context& ctx, + const DenseTensor& x, + DenseTensor* out); + +} // namespace pten diff --git a/paddle/pten/kernels/cpu/bernoulli_kernel.cc b/paddle/pten/kernels/cpu/bernoulli_kernel.cc new file mode 100644 index 00000000000..8e8e0ef0406 --- /dev/null +++ b/paddle/pten/kernels/cpu/bernoulli_kernel.cc @@ -0,0 +1,55 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/pten/kernels/bernoulli_kernel.h" +#include +#include "paddle/pten/backends/cpu/cpu_context.h" +#include "paddle/pten/core/kernel_registry.h" + +namespace pten { + +template +inline T BernoulliFunctor(T p, T rand) { + PADDLE_ENFORCE_LE(p, + 1.0, + pten::errors::OutOfRange( + "The probability should be <= 1, but got %f", p)); + PADDLE_ENFORCE_GE(p, + 0.0, + pten::errors::OutOfRange( + "The probability should be >= 0, but got %f", p)); + return static_cast(rand < p); +} + +template +void BernoulliKernel(const Context& ctx, + const DenseTensor& x, + DenseTensor* out) { + auto numel = x.numel(); + auto* x_data = x.data(); + T* out_data = ctx.template Alloc(out); + + std::uniform_real_distribution dist(0.0, 1.0); + auto gen_ptr = ctx.GetGenerator(); + auto engine = gen_ptr->GetCPUEngine(); + + for (int64_t i = 0; i < numel; ++i) { + out_data[i] = BernoulliFunctor(x_data[i], dist(*engine)); + } +} + +} // namespace pten + +PT_REGISTER_KERNEL( + bernoulli, CPU, ALL_LAYOUT, pten::BernoulliKernel, float, double) {} diff --git a/paddle/pten/kernels/gpu/bernoulli_kernel.cu b/paddle/pten/kernels/gpu/bernoulli_kernel.cu new file mode 100644 index 00000000000..e759eae1033 --- /dev/null +++ b/paddle/pten/kernels/gpu/bernoulli_kernel.cu @@ -0,0 +1,77 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include "paddle/pten/backends/gpu/gpu_context.h" +#include "paddle/pten/core/dense_tensor.h" +#include "paddle/pten/core/kernel_registry.h" +#include "paddle/pten/kernels/bernoulli_kernel.h" + +// See Note [ Why still include the fluid headers? ] +#include "paddle/fluid/platform/transform.h" + +namespace pten { + +template +struct BernoulliCudaFunctor { + unsigned int seed_; + unsigned int offset_; + __host__ __device__ BernoulliCudaFunctor(unsigned int seed, + unsigned int offset) + : seed_(seed), offset_(offset) {} + + __host__ __device__ T operator()(const unsigned int n, const T p) const { + // NOTE(zhiqiu): currently, PADDLE_ENFORCE in cuda kernel may print several + // lines of error messages if, and it should be refined. + PADDLE_ENFORCE(p >= 0.0 && p <= 1.0, + "The probability should be >=0 and <= 1, but got %f", + p); + thrust::minstd_rand rng; + rng.seed(seed_); + thrust::uniform_real_distribution dist(0.0, 1.0); + rng.discard(n + offset_); + return static_cast(dist(rng) < p); + } +}; + +template +void BernoulliKernel(const Context& ctx, + const DenseTensor& x, + DenseTensor* out) { + auto numel = x.numel(); + auto* x_data = x.data(); + T* out_data = ctx.template Alloc(out); + + auto gen_cuda = ctx.GetGenerator(); + auto seed_offset = gen_cuda->IncrementOffset(1); + int64_t gen_offset = numel * seed_offset.second; + paddle::platform::Transform trans; + thrust::counting_iterator index_sequence_begin(0); + trans(ctx, + index_sequence_begin, + index_sequence_begin + numel, + x_data, + out_data, + BernoulliCudaFunctor(static_cast(seed_offset.first), + static_cast(gen_offset))); +} + +} // namespace pten + +PT_REGISTER_KERNEL( + bernoulli, GPU, ALL_LAYOUT, pten::BernoulliKernel, float, double) {} -- GitLab