提交 11f9f5fb 编写于 作者: D dongzhihong

"fix const dependency hell"

上级 984225ec
...@@ -22,14 +22,14 @@ namespace framework { ...@@ -22,14 +22,14 @@ namespace framework {
template <> template <>
Eigen::DefaultDevice* KernelContext::GetEigenDevice< Eigen::DefaultDevice* KernelContext::GetEigenDevice<
platform::CPUPlace, Eigen::DefaultDevice>() const { platform::CPUPlace, Eigen::DefaultDevice>() const {
return device_context_.get_eigen_device<Eigen::DefaultDevice>(); return device_context_->get_eigen_device<Eigen::DefaultDevice>();
} }
#ifndef PADDLE_ONLY_CPU #ifndef PADDLE_ONLY_CPU
template <> template <>
Eigen::GpuDevice* Eigen::GpuDevice*
KernelContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const { KernelContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const {
return device_context_.get_eigen_device<Eigen::GpuDevice>(); return device_context_->get_eigen_device<Eigen::GpuDevice>();
} }
#endif #endif
......
...@@ -88,7 +88,7 @@ class OperatorBase { ...@@ -88,7 +88,7 @@ class OperatorBase {
/// Net will call this function to Run an op. /// Net will call this function to Run an op.
virtual void Run(const std::shared_ptr<Scope>& scope, virtual void Run(const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& dev_ctx) const = 0; platform::DeviceContext& dev_ctx) const = 0;
// Get a input with argument's name described in `op_proto` // Get a input with argument's name described in `op_proto`
const std::string& Input(const std::string& name) const; const std::string& Input(const std::string& name) const;
...@@ -113,8 +113,8 @@ class OperatorBase { ...@@ -113,8 +113,8 @@ class OperatorBase {
class KernelContext { class KernelContext {
public: public:
KernelContext(const OperatorBase* op, const std::shared_ptr<Scope>& scope, KernelContext(const OperatorBase* op, const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& device_context) platform::DeviceContext& device_context)
: op_(*op), scope_(scope), device_context_(device_context) {} : op_(*op), scope_(scope), device_context_(&device_context) {}
const Variable* Input(int index) const { const Variable* Input(int index) const {
return scope_->GetVariable(op_.inputs_[index]); return scope_->GetVariable(op_.inputs_[index]);
...@@ -155,11 +155,11 @@ class KernelContext { ...@@ -155,11 +155,11 @@ class KernelContext {
typename EigenDeviceConverter<PlaceType>::EigenDeviceType> typename EigenDeviceConverter<PlaceType>::EigenDeviceType>
DeviceType* GetEigenDevice() const; DeviceType* GetEigenDevice() const;
platform::Place GetPlace() const { return device_context_.GetPlace(); } platform::Place GetPlace() const { return device_context_->GetPlace(); }
const OperatorBase& op_; const OperatorBase& op_;
const std::shared_ptr<Scope>& scope_; const std::shared_ptr<Scope> scope_;
const platform::DeviceContext& device_context_; platform::DeviceContext* device_context_;
}; };
class OpKernel { class OpKernel {
...@@ -213,7 +213,7 @@ class OperatorWithKernel : public OperatorBase { ...@@ -213,7 +213,7 @@ class OperatorWithKernel : public OperatorBase {
std::unordered_map<OpKernelKey, std::unique_ptr<OpKernel>, OpKernelHash>; std::unordered_map<OpKernelKey, std::unique_ptr<OpKernel>, OpKernelHash>;
void Run(const std::shared_ptr<Scope>& scope, void Run(const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& dev_ctx) const final { platform::DeviceContext& dev_ctx) const final {
auto& opKernel = AllOpKernels().at(type_).at(OpKernelKey(dev_ctx)); auto& opKernel = AllOpKernels().at(type_).at(OpKernelKey(dev_ctx));
opKernel->Compute(KernelContext(this, scope, dev_ctx)); opKernel->Compute(KernelContext(this, scope, dev_ctx));
} }
......
...@@ -7,25 +7,15 @@ ...@@ -7,25 +7,15 @@
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename T, typename DeviceContext>
bool Gaussian(DeviceContext& ctx,
framework::Tensor* output,
const int size,
const T& mean,
const T& std,
const T& seed) {
return false;
}
template <typename T> template <typename T>
bool Gaussian(platform::CPUDeviceContext& ctx, bool Gaussian(platform::CPUDeviceContext* ctx,
framework::Tensor* output, T* output,
const int size, const int size,
const T& mean, const T& mean,
const T& std, const T& std,
const T& seed) { const T& seed) {
auto g = ctx.RandGenerator(seed); auto g = ctx->RandGenerator(seed);
std::normal_distribution<double> distribution(mean, std); std::normal_distribution<T> distribution(mean, std);
for (int i = 0; i < size; ++i) { for (int i = 0; i < size; ++i) {
output[i] = distribution(g); output[i] = distribution(g);
} }
...@@ -34,13 +24,13 @@ bool Gaussian(platform::CPUDeviceContext& ctx, ...@@ -34,13 +24,13 @@ bool Gaussian(platform::CPUDeviceContext& ctx,
#ifndef PADDLE_ONLY_CPU #ifndef PADDLE_ONLY_CPU
template <typename T> template <typename T>
bool Gaussian(platform::CUDADeviceContext& ctx, bool Gaussian(platform::CUDADeviceContext* ctx,
framework::Tensor* output, T* output,
const int size, const int size,
const T& mean, const T& mean,
const T& std, const T& std,
const T& seed) { const T& seed) {
auto g = RandGenerator(seed); auto g = ctx->RandGenerator(seed);
return curandGenerateNormal(g, output, size, mean, std); return curandGenerateNormal(g, output, size, mean, std);
} }
#endif #endif
...@@ -53,13 +43,24 @@ public: ...@@ -53,13 +43,24 @@ public:
auto std = context.op_.GetAttr<T>("std"); auto std = context.op_.GetAttr<T>("std");
auto seed = context.op_.GetAttr<T>("seed"); auto seed = context.op_.GetAttr<T>("seed");
auto* output = context.Output(0)->GetMutable<framework::Tensor>(); auto* output = context.Output(0)->GetMutable<framework::Tensor>();
output->mutable_data<T>(context.GetPlace()); auto place = context.GetPlace();
Gaussian(context.device_context_, if (platform::is_cpu_place(place)) {
output, Gaussian(
dynamic_cast<platform::CPUDeviceContext*>(context.device_context_),
output->mutable_data<T>(context.GetPlace()),
framework::product(output->dims()), framework::product(output->dims()),
mean, mean,
std, std,
seed); seed);
} else {
Gaussian(
dynamic_cast<platform::CUDADeviceContext*>(context.device_context_),
output->mutable_data<T>(context.GetPlace()),
framework::product(output->dims()),
mean,
std,
seed);
}
} }
}; };
......
...@@ -49,7 +49,7 @@ class CPUDeviceContext : public DeviceContext { ...@@ -49,7 +49,7 @@ class CPUDeviceContext : public DeviceContext {
return retv; return retv;
} }
const random_generator_type& RandGenerator(const int seed) { random_generator_type& RandGenerator(const int seed) {
if (!rand_generator_) { if (!rand_generator_) {
random_seed_ = seed; random_seed_ = seed;
rand_generator_.reset(new random_generator_type(random_seed_)); rand_generator_.reset(new random_generator_type(random_seed_));
...@@ -98,7 +98,7 @@ class CUDADeviceContext : public DeviceContext { ...@@ -98,7 +98,7 @@ class CUDADeviceContext : public DeviceContext {
"cudaStreamSynchronize failed"); "cudaStreamSynchronize failed");
} }
const curandGenerator_t RandGenerator(const int seed) { curandGenerator_t RandGenerator(const int seed) {
if (!rand_generator_) { if (!rand_generator_) {
random_seed_ = seed; random_seed_ = seed;
GPUPlaceGuard guard(gpu_place_); GPUPlaceGuard guard(gpu_place_);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册