From 11f9f5fb172f620d5221c93fe26196ebd244df79 Mon Sep 17 00:00:00 2001 From: dongzhihong Date: Wed, 26 Jul 2017 00:40:37 +0800 Subject: [PATCH] "fix const dependency hell" --- paddle/framework/operator.cc | 4 +-- paddle/framework/operator.h | 14 ++++----- paddle/operators/random_op.h | 49 ++++++++++++++++---------------- paddle/platform/device_context.h | 4 +-- 4 files changed, 36 insertions(+), 35 deletions(-) diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index 18e327089fe..0a317dffa96 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -22,14 +22,14 @@ namespace framework { template <> Eigen::DefaultDevice* KernelContext::GetEigenDevice< platform::CPUPlace, Eigen::DefaultDevice>() const { - return device_context_.get_eigen_device(); + return device_context_->get_eigen_device(); } #ifndef PADDLE_ONLY_CPU template <> Eigen::GpuDevice* KernelContext::GetEigenDevice() const { - return device_context_.get_eigen_device(); + return device_context_->get_eigen_device(); } #endif diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index f59314f8288..5db041ea329 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -88,7 +88,7 @@ class OperatorBase { /// Net will call this function to Run an op. virtual void Run(const std::shared_ptr& scope, - const platform::DeviceContext& dev_ctx) const = 0; + platform::DeviceContext& dev_ctx) const = 0; // Get a input with argument's name described in `op_proto` const std::string& Input(const std::string& name) const; @@ -113,8 +113,8 @@ class OperatorBase { class KernelContext { public: KernelContext(const OperatorBase* op, const std::shared_ptr& scope, - const platform::DeviceContext& device_context) - : op_(*op), scope_(scope), device_context_(device_context) {} + platform::DeviceContext& device_context) + : op_(*op), scope_(scope), device_context_(&device_context) {} const Variable* Input(int index) const { return scope_->GetVariable(op_.inputs_[index]); @@ -155,11 +155,11 @@ class KernelContext { typename EigenDeviceConverter::EigenDeviceType> DeviceType* GetEigenDevice() const; - platform::Place GetPlace() const { return device_context_.GetPlace(); } + platform::Place GetPlace() const { return device_context_->GetPlace(); } const OperatorBase& op_; - const std::shared_ptr& scope_; - const platform::DeviceContext& device_context_; + const std::shared_ptr scope_; + platform::DeviceContext* device_context_; }; class OpKernel { @@ -213,7 +213,7 @@ class OperatorWithKernel : public OperatorBase { std::unordered_map, OpKernelHash>; void Run(const std::shared_ptr& scope, - const platform::DeviceContext& dev_ctx) const final { + platform::DeviceContext& dev_ctx) const final { auto& opKernel = AllOpKernels().at(type_).at(OpKernelKey(dev_ctx)); opKernel->Compute(KernelContext(this, scope, dev_ctx)); } diff --git a/paddle/operators/random_op.h b/paddle/operators/random_op.h index f8e1a90a1d1..8231b6b6134 100644 --- a/paddle/operators/random_op.h +++ b/paddle/operators/random_op.h @@ -7,25 +7,15 @@ namespace paddle { namespace operators { -template -bool Gaussian(DeviceContext& ctx, - framework::Tensor* output, - const int size, - const T& mean, - const T& std, - const T& seed) { - return false; -} - template -bool Gaussian(platform::CPUDeviceContext& ctx, - framework::Tensor* output, +bool Gaussian(platform::CPUDeviceContext* ctx, + T* output, const int size, const T& mean, const T& std, const T& seed) { - auto g = ctx.RandGenerator(seed); - std::normal_distribution distribution(mean, std); + auto g = ctx->RandGenerator(seed); + std::normal_distribution distribution(mean, std); for (int i = 0; i < size; ++i) { output[i] = distribution(g); } @@ -34,13 +24,13 @@ bool Gaussian(platform::CPUDeviceContext& ctx, #ifndef PADDLE_ONLY_CPU template -bool Gaussian(platform::CUDADeviceContext& ctx, - framework::Tensor* output, +bool Gaussian(platform::CUDADeviceContext* ctx, + T* output, const int size, const T& mean, const T& std, const T& seed) { - auto g = RandGenerator(seed); + auto g = ctx->RandGenerator(seed); return curandGenerateNormal(g, output, size, mean, std); } #endif @@ -53,13 +43,24 @@ public: auto std = context.op_.GetAttr("std"); auto seed = context.op_.GetAttr("seed"); auto* output = context.Output(0)->GetMutable(); - output->mutable_data(context.GetPlace()); - Gaussian(context.device_context_, - output, - framework::product(output->dims()), - mean, - std, - seed); + auto place = context.GetPlace(); + if (platform::is_cpu_place(place)) { + Gaussian( + dynamic_cast(context.device_context_), + output->mutable_data(context.GetPlace()), + framework::product(output->dims()), + mean, + std, + seed); + } else { + Gaussian( + dynamic_cast(context.device_context_), + output->mutable_data(context.GetPlace()), + framework::product(output->dims()), + mean, + std, + seed); + } } }; diff --git a/paddle/platform/device_context.h b/paddle/platform/device_context.h index b8af4abd7f9..7bc34bd5458 100644 --- a/paddle/platform/device_context.h +++ b/paddle/platform/device_context.h @@ -49,7 +49,7 @@ class CPUDeviceContext : public DeviceContext { return retv; } - const random_generator_type& RandGenerator(const int seed) { + random_generator_type& RandGenerator(const int seed) { if (!rand_generator_) { random_seed_ = seed; rand_generator_.reset(new random_generator_type(random_seed_)); @@ -98,7 +98,7 @@ class CUDADeviceContext : public DeviceContext { "cudaStreamSynchronize failed"); } - const curandGenerator_t RandGenerator(const int seed) { + curandGenerator_t RandGenerator(const int seed) { if (!rand_generator_) { random_seed_ = seed; GPUPlaceGuard guard(gpu_place_); -- GitLab