From 6f80b5f1df2b4d77857338f44c3159388602457b Mon Sep 17 00:00:00 2001 From: dongzhihong Date: Tue, 25 Jul 2017 12:00:47 +0800 Subject: [PATCH] "move to template function" --- paddle/operators/random_op.cc | 34 ++++++++++++++++++----- paddle/operators/random_op.cu | 7 ++--- paddle/operators/random_op.h | 28 +++++++++++++------ paddle/platform/device_context.h | 46 ++++++++++++++++++++------------ 4 files changed, 81 insertions(+), 34 deletions(-) diff --git a/paddle/operators/random_op.cc b/paddle/operators/random_op.cc index b85ff842207..a536ee74b40 100644 --- a/paddle/operators/random_op.cc +++ b/paddle/operators/random_op.cc @@ -1,13 +1,33 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + #include "paddle/operators/random_op.h" #include "paddle/framework/op_registry.h" namespace paddle { namespace operators { -using paddle::platform::GPUPlace; -template -bool Gaussian( - Generator g, T* output, const int size, const T& mean, const T& std) { +// using paddle::platform::CPUPlace; +// template +template +bool Gaussian(platform::CPUDeviceContext& ctx, + framework::Tensor* output, + const int size, + const T& mean, + const T& std, + const T& seed) { + auto g = ctx.RandGenerator(seed); std::normal_distribution distribution(mean, std); for (int i = 0; i < size; ++i) { output[i] = distribution(g()); @@ -24,7 +44,9 @@ protected: PADDLE_ENFORCE(outputs.size() == 1, "Output size of RandomOp must be one."); PADDLE_ENFORCE(inputs[0] != nullptr && outputs[0] != nullptr, "Inputs/Outputs of RandomOp must all be set."); - outputs[0]->set_dims(context.op_.attrs_.at("shape")); + outputs[0]->Resize( + framework::make_ddim(this->GetAttr>("shape"))); + // outputs[0]->set_dims(context.op_.attrs_.at("shape")); } }; @@ -32,7 +54,7 @@ class RandomOpMaker : public framework::OpProtoAndCheckerMaker { public: RandomOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) : framework::OpProtoAndCheckerMaker(proto, op_checker) { - AddAttr>("Shape", "The shape of matrix to be randomized"); + AddAttr>("shape", "The shape of matrix to be randomized"); AddAttr("seed", "random seed generator.").SetDefault(1337); AddAttr("mean", "mean value of random.").SetDefault(.0); AddAttr("std", "minimum value of random value") diff --git a/paddle/operators/random_op.cu b/paddle/operators/random_op.cu index ea1096aeb97..40b642d8a19 100644 --- a/paddle/operators/random_op.cu +++ b/paddle/operators/random_op.cu @@ -4,9 +4,10 @@ namespace paddle { namespace operators { -using paddle::platform::GPUPlace; -template -bool Gaussian(Generator g, T* output, const int size, const T& mean, const T& std) { +template +bool Gaussian(platform::CUDADeviceContext &ctx, framework::Tensor* output, + const int size, const T& mean, const T& std, const T& seed) { + auto g = RandGenerator(seed); return curandGenerateNormal(g, output, size, mean, std); } diff --git a/paddle/operators/random_op.h b/paddle/operators/random_op.h index 1b5fb16de1a..a82b3afec87 100644 --- a/paddle/operators/random_op.h +++ b/paddle/operators/random_op.h @@ -6,21 +6,33 @@ namespace paddle { namespace operators { -template -bool Gaussian( - Generator g, T* output, const int size, const T& mean, const T& std); +template +bool Gaussian(DeviceContext& ctx, + framework::Tensor* output, + const int size, + const T& mean, + const T& std, + const T& seed); template class RandomOpKernel : public framework::OpKernel { public: void Compute(const framework::KernelContext& context) const override { - auto mean = context.op_.attrs_.at("mean"); - auto std = context.op_.attrs_.at("std"); - auto seed = context.op_.attrs_.at("seed"); + auto mean = context.op_.GetAttr("mean"); + auto std = context.op_.GetAttr("std"); + auto seed = context.op_.GetAttr("seed"); auto* output = context.Output(0)->GetMutable(); output->mutable_data(context.GetPlace()); - - Gaussian(, output, output->size(), mean, std) : + Gaussian(context.device_context_, + output, + framework::product(output->dims()), + mean, + std, + seed); + // Gaussian(context.device_context_, + // output, + // framework::product(output->dims()), + // mean, std, seed); // std::default_random_engine generator(seed); // std::normal_distribution distribution(mean, std); diff --git a/paddle/platform/device_context.h b/paddle/platform/device_context.h index fe6f13e399a..b8af4abd7f9 100644 --- a/paddle/platform/device_context.h +++ b/paddle/platform/device_context.h @@ -39,6 +39,7 @@ class DeviceContext { class CPUDeviceContext : public DeviceContext { public: + typedef std::mt19937 random_generator_type; CPUDeviceContext() { eigen_device_.reset(new Eigen::DefaultDevice()); } Eigen::DefaultDevice* eigen_device() const { return eigen_device_.get(); } @@ -48,7 +49,17 @@ class CPUDeviceContext : public DeviceContext { return retv; } + const random_generator_type& RandGenerator(const int seed) { + if (!rand_generator_) { + random_seed_ = seed; + rand_generator_.reset(new random_generator_type(random_seed_)); + } + return *rand_generator_.get(); + } + private: + int random_seed_; + std::unique_ptr rand_generator_; std::unique_ptr eigen_device_; }; @@ -87,6 +98,24 @@ class CUDADeviceContext : public DeviceContext { "cudaStreamSynchronize failed"); } + const curandGenerator_t RandGenerator(const int seed) { + if (!rand_generator_) { + random_seed_ = seed; + GPUPlaceGuard guard(gpu_place_); + PADDLE_ENFORCE(paddle::platform::dynload::curandCreateGenerator( + &rand_generator_, CURAND_RNG_PSEUDO_DEFAULT), + "curandCreateGenerator failed"); + PADDLE_ENFORCE( + paddle::platform::dynload::curandSetPseudoRandomGeneratorSeed( + rand_generator_, random_seed_), + "curandSetPseudoRandomGeneratorSeed failed"); + PADDLE_ENFORCE( + paddle::platform::dynload::curandSetStream(rand_generator_, stream_), + "curandSetStream failed"); + } + return rand_generator_; + } + cudaStream_t stream() { return stream_; } Eigen::GpuDevice* eigen_device() const { return eigen_device_.get(); } @@ -115,23 +144,6 @@ class CUDADeviceContext : public DeviceContext { return dnn_handle_; } - curandGenerator_t curand_generator() { - if (!rand_generator_) { - GPUPlaceGuard guard(gpu_place_); - PADDLE_ENFORCE(paddle::platform::dynload::curandCreateGenerator( - &rand_generator_, CURAND_RNG_PSEUDO_DEFAULT), - "curandCreateGenerator failed"); - PADDLE_ENFORCE( - paddle::platform::dynload::curandSetPseudoRandomGeneratorSeed( - rand_generator_, random_seed_), - "curandSetPseudoRandomGeneratorSeed failed"); - PADDLE_ENFORCE( - paddle::platform::dynload::curandSetStream(rand_generator_, stream_), - "curandSetStream failed"); - } - return rand_generator_; - } - ~CUDADeviceContext() { Wait(); if (blas_handle_) { -- GitLab