diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index 25d120c9a97b50b02ae37fa281adeda39d3ab80d..3c6376c1503f3c3d816e0500b9ad79a99857ef20 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -18,13 +18,15 @@ namespace paddle { namespace framework { template <> -DeviceType* KernelContext::get_eigen_device() { - return device_context_.get_eigen_device(); +Eigen::DefaultDevice* OpKernel::KernelContext::get_eigen_device< + platform::CPUPlace, Eigen::DefaultDevice>() const { + return device_context_.get_eigen_device(); } #ifndef PADDLE_ONLY_CPU template <> -DeviceType* KernelContext::get_eigen_device() { +DeviceType* OpKernel::KernelContext::get_eigen_device() + const { return device_context_.get_eigen_device(); } #endif diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 48cfeeb73144eb827b6e298ecf4df775d1793095..558d4a0b6769ad11e775e6eeb5a48bcbf48067f3 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -33,13 +33,13 @@ template struct EigenDeviceConverter; template <> -struct EigenDeviceConverter { +struct EigenDeviceConverter { using EigenDeviceType = Eigen::DefaultDevice; }; #ifndef PADDLE_ONLY_CPU template <> -struct EigenDeviceConverter { +struct EigenDeviceConverter { using EigenDeviceType = Eigen::GpuDevice; }; #endif @@ -87,39 +87,38 @@ class OperatorBase { AttributeMap attrs_; }; -/** - * KernelContext is the only parameter of Kernel Run function. - * Run will get input/output variables, state such as momentum and - * device resource such as CUDA stream, cublas handle, etc. from - * KernelContext. User should construct it before run the Operator. - */ -class KernelContext { +class OpKernel { public: - KernelContext(const OperatorBase* op, const std::shared_ptr& scope, - const platform::DeviceContext& device_context) - : op_(*op), scope_(scope), device_context_(device_context) {} - - const Variable* Input(int index) const { - return scope_->GetVariable(op_.inputs_[index]); - } - - Variable* Output(int index) const { - return scope_->GetVariable(op_.outputs_[index]); - } + /** + * KernelContext is the only parameter of Kernel Run function. + * Run will get input/output variables, state such as momentum and + * device resource such as CUDA stream, cublas handle, etc. from + * KernelContext. User should construct it before run the Operator. + */ + class KernelContext { + public: + KernelContext(const OperatorBase* op, const std::shared_ptr& scope, + const platform::DeviceContext& device_context) + : op_(*op), scope_(scope), device_context_(device_context) {} + + const Variable* Input(int index) const { + return scope_->GetVariable(op_.inputs_[index]); + } - platform::DeviceContext& device_context() const { return device_context_; } + Variable* Output(int index) const { + return scope_->GetVariable(op_.outputs_[index]); + } - template ::EigenDeviceType> - DeviceType* get_eigen_device(); + template ::EigenDeviceType> + DeviceType* get_eigen_device() const; - const OperatorBase& op_; - const std::shared_ptr& scope_; - const platform::DeviceContext& device_context_; -}; + const OperatorBase& op_; + const std::shared_ptr& scope_; + const platform::DeviceContext& device_context_; + }; -class OpKernel { - public: virtual void Compute(const KernelContext& context) const = 0; virtual ~OpKernel() {} diff --git a/paddle/framework/tensor.h b/paddle/framework/tensor.h index 01244f617c2fb2e86f2e4d960bc0c5b593fc34fa..784d52cc426b92b509c320e86c0f8e0b6b0e2c13 100644 --- a/paddle/framework/tensor.h +++ b/paddle/framework/tensor.h @@ -35,7 +35,7 @@ class Tensor { template - const T* data() const { + T* data() const { PADDLE_ENFORCE( holder_ != nullptr, "Tenosr has not been initialized. Call Tensor::mutable_data first."); @@ -90,7 +90,7 @@ class Tensor { // flat to rank = 1 template typename TTypes::Flat flat() { - return shaped({NumElements()}); + return shaped(make_ddim({static_cast(NumElements())})); } // to TensorType Vec @@ -114,7 +114,7 @@ class Tensor { template typename TTypes::ConstFlat flat() const { - return shaped({NumElements()}); + return shaped(make_ddim({static_cast(NumElements())})); } template diff --git a/paddle/operators/add_op.cc b/paddle/operators/add_op.cc index ef39e426fde5a26e14f90201d16a8a1b5fad9099..7dc6414af2b3378c68b833568d7ac05251461a97 100644 --- a/paddle/operators/add_op.cc +++ b/paddle/operators/add_op.cc @@ -40,6 +40,6 @@ The equation is: Out = X + Y } // namespace paddle REGISTER_OP(add_two, paddle::operators::AddOp, paddle::operators::AddOpMaker); -REGISTER_OP_CPU_KERNEL( - add_two, - ::paddle::operators::AddKernel<::paddle::platform::CPUPlace, float>); \ No newline at end of file +typedef paddle::operators::AddKernel<::paddle::platform::CPUPlace, float> + AddKernel_CPU_float; +REGISTER_OP_CPU_KERNEL(add_two, AddKernel_CPU_float); \ No newline at end of file diff --git a/paddle/operators/add_op.cu b/paddle/operators/add_op.cu index f4a4fb16a6e39c1c15e613f86ad5c02703aa33d0..0edf142ee4e5f359ea14be02dbf3f7f8855f6db1 100644 --- a/paddle/operators/add_op.cu +++ b/paddle/operators/add_op.cu @@ -1,7 +1,6 @@ -#define EIGEN_USE_GPU - #include "paddle/operators/add_op.h" #include "paddle/framework/op_registry.h" +typedef paddle::operators::AddKernel<::paddle::platform::GPUPlace, float> AddKernel_GPU_float; REGISTER_OP_GPU_KERNEL(add_two, - paddle::operators::AddKernel); \ No newline at end of file + AddKernel_GPU_float); \ No newline at end of file diff --git a/paddle/operators/add_op.h b/paddle/operators/add_op.h index 27a477a3ac070174030c31683c922444abfe6e81..568cb19742cdeebf9752149706b37388c0ab3ad6 100644 --- a/paddle/operators/add_op.h +++ b/paddle/operators/add_op.h @@ -6,19 +6,18 @@ namespace paddle { namespace operators { -// Place can be CPUPlace or GPUPlace -template +template class AddKernel : public framework::OpKernel { public: void Compute(const KernelContext& context) const override { - auto* input0 = context.Input(0); - auto* input1 = context.Input(1); + auto input0 = context.Input(0)->Get(); + auto input1 = context.Input(1)->Get(); + auto* output = context.Output(0)->GetMutable(); - auto* output = context.Output(0); - output->mutable_data(Place()); + output->mutable_data(Place()); output->flat().device(*(context.get_eigen_device())) = - input0->flat() + input1->flat(); + input0.flat() + input1.flat(); } }; diff --git a/paddle/platform/device_context.cc b/paddle/platform/device_context.cc index 960ef0a5955bfe5f7d33b7c8e4524176b0dbfda6..9c1d94e9e703caf2db92ca4a8eac975317e6b945 100644 --- a/paddle/platform/device_context.cc +++ b/paddle/platform/device_context.cc @@ -15,14 +15,15 @@ namespace paddle { namespace platform { template <> -Eigen::DefaultDevice* DeviceContext::get_eigen_device() { - return reinterpret_cast(this)->eigen_device(); +Eigen::DefaultDevice* DeviceContext::get_eigen_device() + const { + return reinterpret_cast(this)->eigen_device(); } #ifndef PADDLE_ONLY_CPU template <> -Eigen::GpuDevice* DeviceContext::get_eigen_device() { - return reinterpret_cast(this)->eigen_device(); +Eigen::GpuDevice* DeviceContext::get_eigen_device() const { + return reinterpret_cast(this)->eigen_device(); } #endif diff --git a/paddle/platform/device_context.h b/paddle/platform/device_context.h index 7de07d06bed885d6529a884fb81fedbdaba78f4a..2ec7b055994b019cd81af191a6b9cf511bc83489 100644 --- a/paddle/platform/device_context.h +++ b/paddle/platform/device_context.h @@ -32,17 +32,14 @@ class DeviceContext { virtual Place GetPlace() const = 0; template - DeviceType* get_eigen_device(); + DeviceType* get_eigen_device() const; }; class CPUDeviceContext : public DeviceContext { public: - Eigen::DefaultDevice* eigen_device() { - if (!eigen_device_) { - eigen_device_.reset(new Eigen::DefaultDevice()); - } - return eigen_device_.get(); - } + CPUDeviceContext() { eigen_device_.reset(new Eigen::DefaultDevice()); } + + Eigen::DefaultDevice* eigen_device() const { return eigen_device_.get(); } Place GetPlace() const override { Place retv = CPUPlace(); @@ -91,7 +88,7 @@ class CUDADeviceContext : public DeviceContext { cudaStream_t stream() { return stream_; } - Eigen::GpuDevice* eigen_device() { return eigen_device_.get(); } + Eigen::GpuDevice* eigen_device() const { return eigen_device_.get(); } cublasHandle_t cublas_handle() { if (!blas_handle_) {