From 572133400d3f4073d9a9206db5ed1ced3e39623d Mon Sep 17 00:00:00 2001 From: dongzhihong Date: Sun, 30 Jul 2017 22:13:26 +0800 Subject: [PATCH] "update the compute kernel" --- paddle/framework/operator.h | 8 ++--- paddle/operators/random_op.cc | 47 +++++++++++++++++++------- paddle/operators/random_op.cu | 25 +++++++++++++- paddle/operators/random_op.h | 57 ++------------------------------ paddle/platform/device_context.h | 19 +++++++---- 5 files changed, 77 insertions(+), 79 deletions(-) diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 5db041ea329..9ba661968cd 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -88,7 +88,7 @@ class OperatorBase { /// Net will call this function to Run an op. virtual void Run(const std::shared_ptr& scope, - platform::DeviceContext& dev_ctx) const = 0; + const platform::DeviceContext& dev_ctx) const = 0; // Get a input with argument's name described in `op_proto` const std::string& Input(const std::string& name) const; @@ -113,7 +113,7 @@ class OperatorBase { class KernelContext { public: KernelContext(const OperatorBase* op, const std::shared_ptr& scope, - platform::DeviceContext& device_context) + const platform::DeviceContext& device_context) : op_(*op), scope_(scope), device_context_(&device_context) {} const Variable* Input(int index) const { @@ -159,7 +159,7 @@ class KernelContext { const OperatorBase& op_; const std::shared_ptr scope_; - platform::DeviceContext* device_context_; + const platform::DeviceContext* device_context_; }; class OpKernel { @@ -213,7 +213,7 @@ class OperatorWithKernel : public OperatorBase { std::unordered_map, OpKernelHash>; void Run(const std::shared_ptr& scope, - platform::DeviceContext& dev_ctx) const final { + const platform::DeviceContext& dev_ctx) const final { auto& opKernel = AllOpKernels().at(type_).at(OpKernelKey(dev_ctx)); opKernel->Compute(KernelContext(this, scope, dev_ctx)); } diff --git a/paddle/operators/random_op.cc b/paddle/operators/random_op.cc index 726f6504e71..16e526dc4f6 100644 --- a/paddle/operators/random_op.cc +++ b/paddle/operators/random_op.cc @@ -19,7 +19,28 @@ namespace paddle { namespace operators { -class RandomOp : public framework::OperatorWithKernel { +template +class GaussianRandomOpKernel + : public framework::OpKernel { +public: + void Compute(const framework::KernelContext& context) const override { + auto mean = context.op_.GetAttr("mean"); + auto std = context.op_.GetAttr("std"); + // auto seed = context.op_.GetAttr("seed"); + auto* output = context.Output(0)->GetMutable(); + T* r = output->mutable_data(context.GetPlace()); + auto ctx = + static_cast(context.device_context_); + // generator need to modify context + auto g = const_cast(ctx)->RandGenerator(); + std::normal_distribution distribution(mean, std); + for (int i = 0; i < framework::product(output->dims()); ++i) { + r[i] = distribution(g); + } + } +}; + +class GaussianRandomOp : public framework::OperatorWithKernel { protected: void InferShape( const std::vector& inputs, @@ -33,20 +54,21 @@ protected: } }; -class RandomOpMaker : public framework::OpProtoAndCheckerMaker { +class GaussianRandomOpMaker : public framework::OpProtoAndCheckerMaker { public: - RandomOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) + GaussianRandomOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker) : framework::OpProtoAndCheckerMaker(proto, op_checker) { AddAttr>("shape", "The shape of matrix to be randomized"); - AddAttr("seed", "random seed generator.").SetDefault(1337); + // AddAttr("seed", "random seed generator.").SetDefault(1337); AddAttr("mean", "mean value of random.").SetDefault(.0); AddAttr("std", "minimum value of random value") .SetDefault(1.0) .LargerThan(.0); AddOutput("Out", "output matrix of random op"); AddComment(R"DOC( -Random Operator fill a matrix in normal distribution. -The eqution : Out = Random(Shape=(d0, d1, ...), Dtype, mean, std) +GaussianRandom Operator fill a matrix in normal distribution. +The eqution : Out = GaussianRandom(Shape=(d0, d1, ...), Dtype, mean, std) )DOC"); } }; @@ -54,10 +76,11 @@ The eqution : Out = Random(Shape=(d0, d1, ...), Dtype, mean, std) } // namespace operators } // namespace paddle -REGISTER_OP(random, - paddle::operators::RandomOp, - paddle::operators::RandomOpMaker); +REGISTER_OP(gaussian_random, + paddle::operators::GaussianRandomOp, + paddle::operators::GaussianRandomOpMaker); -typedef paddle::operators::RandomOpKernel - RandomOpKernel_CPU_float; -REGISTER_OP_CPU_KERNEL(random, RandomOpKernel_CPU_float); +typedef paddle::operators::GaussianRandomOpKernel + GaussianRandomOpKernel_CPU_float; +REGISTER_OP_CPU_KERNEL(gaussian_random, GaussianRandomOpKernel_CPU_float); diff --git a/paddle/operators/random_op.cu b/paddle/operators/random_op.cu index b417666c98e..78a00bc8990 100644 --- a/paddle/operators/random_op.cu +++ b/paddle/operators/random_op.cu @@ -1,7 +1,30 @@ #include "paddle/operators/random_op.h" #include "paddle/framework/op_registry.h" +namespace paddle { +namespace operators { + +template +class GaussianRandomOpKernel : public framework::OpKernel { +public: + void Compute(const framework::KernelContext& context) const override { + auto mean = context.op_.GetAttr("mean"); + auto std = context.op_.GetAttr("std"); + auto* output = context.Output(0)->GetMutable(); + T* r = output->mutable_data(context.GetPlace()); + auto ctx = static_cast + (context.device_context_); + // generator need to modify context + auto g = const_cast(ctx)->RandGenerator(); + curandGenerateNormal(g, r, framework::product(output->dims()), mean, std); -typedef paddle::operators::RandomOpKernel + } +}; + +} // namespace operators +} // namespace paddle + + +typedef paddle::operators::GaussianRandomOpKernel RandomOpKernel_GPU_float; REGISTER_OP_GPU_KERNEL(random, RandomOpKernel_GPU_float); \ No newline at end of file diff --git a/paddle/operators/random_op.h b/paddle/operators/random_op.h index 26dba130e43..b463a171d9c 100644 --- a/paddle/operators/random_op.h +++ b/paddle/operators/random_op.h @@ -7,63 +7,10 @@ namespace paddle { namespace operators { -template -bool Gaussian(platform::CPUDeviceContext* ctx, - T* output, - const int size, - const T& mean, - const T& std, - const T& seed) { - auto g = ctx->RandGenerator(seed); - std::normal_distribution distribution(mean, std); - for (int i = 0; i < size; ++i) { - output[i] = distribution(g); - } - return true; -} - -#ifndef PADDLE_ONLY_CPU -template -bool Gaussian(platform::CUDADeviceContext* ctx, - T* output, - const int size, - const T& mean, - const T& std, - const T& seed) { - auto g = ctx->RandGenerator(seed); - return curandGenerateNormal(g, output, size, mean, std); -} -#endif - template -class RandomOpKernel : public framework::OpKernel { +class GaussianRandomOpKernel : public framework::OpKernel { public: - void Compute(const framework::KernelContext& context) const override { - auto mean = context.op_.GetAttr("mean"); - auto std = context.op_.GetAttr("std"); - auto seed = context.op_.GetAttr("seed"); - auto* output = context.Output(0)->GetMutable(); - auto place = context.GetPlace(); - if (platform::is_cpu_place(place)) { - Gaussian( - dynamic_cast(context.device_context_), - output->mutable_data(context.GetPlace()), - framework::product(output->dims()), - mean, - std, - seed); - } else { -#ifndef PADDLE_ONLY_CPU - Gaussian( - dynamic_cast(context.device_context_), - output->mutable_data(context.GetPlace()), - framework::product(output->dims()), - mean, - std, - seed); -#endif - } - } + void Compute(const framework::KernelContext& context) const override {} }; } // namespace operators diff --git a/paddle/platform/device_context.h b/paddle/platform/device_context.h index 7bc34bd5458..239c25a90c4 100644 --- a/paddle/platform/device_context.h +++ b/paddle/platform/device_context.h @@ -21,6 +21,7 @@ limitations under the License. */ #include "paddle/platform/gpu_info.h" #define EIGEN_USE_GPU #endif +#include #include #include "paddle/platform/place.h" #include "unsupported/Eigen/CXX11/Tensor" @@ -40,7 +41,10 @@ class DeviceContext { class CPUDeviceContext : public DeviceContext { public: typedef std::mt19937 random_generator_type; - CPUDeviceContext() { eigen_device_.reset(new Eigen::DefaultDevice()); } + CPUDeviceContext() { + random_seed_ = std::chrono::system_clock::now().time_since_epoch().count(); + eigen_device_.reset(new Eigen::DefaultDevice()); + } Eigen::DefaultDevice* eigen_device() const { return eigen_device_.get(); } @@ -49,16 +53,15 @@ class CPUDeviceContext : public DeviceContext { return retv; } - random_generator_type& RandGenerator(const int seed) { + random_generator_type& RandGenerator() { if (!rand_generator_) { - random_seed_ = seed; rand_generator_.reset(new random_generator_type(random_seed_)); } return *rand_generator_.get(); } private: - int random_seed_; + unsigned random_seed_; std::unique_ptr rand_generator_; std::unique_ptr eigen_device_; }; @@ -81,6 +84,9 @@ class GPUPlaceGuard { class CUDADeviceContext : public DeviceContext { public: + CUDADeviceContext() { + random_seed_ = std::chrono::system_clock::now().time_since_epoch().count(); + } explicit CUDADeviceContext(const GPUPlace gpu_place) : gpu_place_(gpu_place) { GPUPlaceGuard guard(gpu_place_); PADDLE_ENFORCE(cudaStreamCreate(&stream_), "cudaStreamCreate failed"); @@ -98,9 +104,8 @@ class CUDADeviceContext : public DeviceContext { "cudaStreamSynchronize failed"); } - curandGenerator_t RandGenerator(const int seed) { + curandGenerator_t RandGenerator() { if (!rand_generator_) { - random_seed_ = seed; GPUPlaceGuard guard(gpu_place_); PADDLE_ENFORCE(paddle::platform::dynload::curandCreateGenerator( &rand_generator_, CURAND_RNG_PSEUDO_DEFAULT), @@ -177,7 +182,7 @@ class CUDADeviceContext : public DeviceContext { cudnnHandle_t dnn_handle_{nullptr}; - int random_seed_; + unsigned random_seed_; curandGenerator_t rand_generator_{nullptr}; }; -- GitLab