提交 11f9f5fb 编写于 作者: D dongzhihong

"fix const dependency hell"

上级 984225ec
......@@ -22,14 +22,14 @@ namespace framework {
template <>
Eigen::DefaultDevice* KernelContext::GetEigenDevice<
platform::CPUPlace, Eigen::DefaultDevice>() const {
return device_context_.get_eigen_device<Eigen::DefaultDevice>();
return device_context_->get_eigen_device<Eigen::DefaultDevice>();
}
#ifndef PADDLE_ONLY_CPU
template <>
Eigen::GpuDevice*
KernelContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const {
return device_context_.get_eigen_device<Eigen::GpuDevice>();
return device_context_->get_eigen_device<Eigen::GpuDevice>();
}
#endif
......
......@@ -88,7 +88,7 @@ class OperatorBase {
/// Net will call this function to Run an op.
virtual void Run(const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& dev_ctx) const = 0;
platform::DeviceContext& dev_ctx) const = 0;
// Get a input with argument's name described in `op_proto`
const std::string& Input(const std::string& name) const;
......@@ -113,8 +113,8 @@ class OperatorBase {
class KernelContext {
public:
KernelContext(const OperatorBase* op, const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& device_context)
: op_(*op), scope_(scope), device_context_(device_context) {}
platform::DeviceContext& device_context)
: op_(*op), scope_(scope), device_context_(&device_context) {}
const Variable* Input(int index) const {
return scope_->GetVariable(op_.inputs_[index]);
......@@ -155,11 +155,11 @@ class KernelContext {
typename EigenDeviceConverter<PlaceType>::EigenDeviceType>
DeviceType* GetEigenDevice() const;
platform::Place GetPlace() const { return device_context_.GetPlace(); }
platform::Place GetPlace() const { return device_context_->GetPlace(); }
const OperatorBase& op_;
const std::shared_ptr<Scope>& scope_;
const platform::DeviceContext& device_context_;
const std::shared_ptr<Scope> scope_;
platform::DeviceContext* device_context_;
};
class OpKernel {
......@@ -213,7 +213,7 @@ class OperatorWithKernel : public OperatorBase {
std::unordered_map<OpKernelKey, std::unique_ptr<OpKernel>, OpKernelHash>;
void Run(const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& dev_ctx) const final {
platform::DeviceContext& dev_ctx) const final {
auto& opKernel = AllOpKernels().at(type_).at(OpKernelKey(dev_ctx));
opKernel->Compute(KernelContext(this, scope, dev_ctx));
}
......
......@@ -7,25 +7,15 @@
namespace paddle {
namespace operators {
template <typename T, typename DeviceContext>
bool Gaussian(DeviceContext& ctx,
framework::Tensor* output,
const int size,
const T& mean,
const T& std,
const T& seed) {
return false;
}
template <typename T>
bool Gaussian(platform::CPUDeviceContext& ctx,
framework::Tensor* output,
bool Gaussian(platform::CPUDeviceContext* ctx,
T* output,
const int size,
const T& mean,
const T& std,
const T& seed) {
auto g = ctx.RandGenerator(seed);
std::normal_distribution<double> distribution(mean, std);
auto g = ctx->RandGenerator(seed);
std::normal_distribution<T> distribution(mean, std);
for (int i = 0; i < size; ++i) {
output[i] = distribution(g);
}
......@@ -34,13 +24,13 @@ bool Gaussian(platform::CPUDeviceContext& ctx,
#ifndef PADDLE_ONLY_CPU
template <typename T>
bool Gaussian(platform::CUDADeviceContext& ctx,
framework::Tensor* output,
bool Gaussian(platform::CUDADeviceContext* ctx,
T* output,
const int size,
const T& mean,
const T& std,
const T& seed) {
auto g = RandGenerator(seed);
auto g = ctx->RandGenerator(seed);
return curandGenerateNormal(g, output, size, mean, std);
}
#endif
......@@ -53,13 +43,24 @@ public:
auto std = context.op_.GetAttr<T>("std");
auto seed = context.op_.GetAttr<T>("seed");
auto* output = context.Output(0)->GetMutable<framework::Tensor>();
output->mutable_data<T>(context.GetPlace());
Gaussian(context.device_context_,
output,
auto place = context.GetPlace();
if (platform::is_cpu_place(place)) {
Gaussian(
dynamic_cast<platform::CPUDeviceContext*>(context.device_context_),
output->mutable_data<T>(context.GetPlace()),
framework::product(output->dims()),
mean,
std,
seed);
} else {
Gaussian(
dynamic_cast<platform::CUDADeviceContext*>(context.device_context_),
output->mutable_data<T>(context.GetPlace()),
framework::product(output->dims()),
mean,
std,
seed);
}
}
};
......
......@@ -49,7 +49,7 @@ class CPUDeviceContext : public DeviceContext {
return retv;
}
const random_generator_type& RandGenerator(const int seed) {
random_generator_type& RandGenerator(const int seed) {
if (!rand_generator_) {
random_seed_ = seed;
rand_generator_.reset(new random_generator_type(random_seed_));
......@@ -98,7 +98,7 @@ class CUDADeviceContext : public DeviceContext {
"cudaStreamSynchronize failed");
}
const curandGenerator_t RandGenerator(const int seed) {
curandGenerator_t RandGenerator(const int seed) {
if (!rand_generator_) {
random_seed_ = seed;
GPUPlaceGuard guard(gpu_place_);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册