diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index 8f7adff8b3982e91a3d7f6d598cd62d5005d5f17..25d120c9a97b50b02ae37fa281adeda39d3ab80d 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -17,6 +17,18 @@ limitations under the License. */ namespace paddle { namespace framework { +template <> +DeviceType* KernelContext::get_eigen_device() { + return device_context_.get_eigen_device(); +} + +#ifndef PADDLE_ONLY_CPU +template <> +DeviceType* KernelContext::get_eigen_device() { + return device_context_.get_eigen_device(); +} +#endif + std::string OperatorBase::DebugString() const { std::stringstream ss; ss << "=================\n"; diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index d3c55e0ceb6737772be38b536a107900ee895b12..48cfeeb73144eb827b6e298ecf4df775d1793095 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -29,6 +29,21 @@ limitations under the License. */ namespace paddle { namespace framework { +template +struct EigenDeviceConverter; + +template <> +struct EigenDeviceConverter { + using EigenDeviceType = Eigen::DefaultDevice; +}; + +#ifndef PADDLE_ONLY_CPU +template <> +struct EigenDeviceConverter { + using EigenDeviceType = Eigen::GpuDevice; +}; +#endif + class OperatorBase; /** @@ -72,33 +87,39 @@ class OperatorBase { AttributeMap attrs_; }; -class OpKernel { +/** + * KernelContext is the only parameter of Kernel Run function. + * Run will get input/output variables, state such as momentum and + * device resource such as CUDA stream, cublas handle, etc. from + * KernelContext. User should construct it before run the Operator. + */ +class KernelContext { public: - /** - * KernelContext is the only parameter of Kernel Run function. - * Run will get input/output variables, state such as momentum and - * device resource such as CUDA stream, cublas handle, etc. from - * KernelContext. User should construct it before run the Operator. - */ - class KernelContext { - public: - KernelContext(const OperatorBase* op, const std::shared_ptr& scope, - const platform::DeviceContext& device_context) - : op_(*op), scope_(scope), device_context_(device_context) {} - - const Variable* Input(int index) const { - return scope_->GetVariable(op_.inputs_[index]); - } + KernelContext(const OperatorBase* op, const std::shared_ptr& scope, + const platform::DeviceContext& device_context) + : op_(*op), scope_(scope), device_context_(device_context) {} - Variable* Output(int index) const { - return scope_->GetVariable(op_.outputs_[index]); - } + const Variable* Input(int index) const { + return scope_->GetVariable(op_.inputs_[index]); + } - const OperatorBase& op_; - const std::shared_ptr& scope_; - const platform::DeviceContext& device_context_; - }; + Variable* Output(int index) const { + return scope_->GetVariable(op_.outputs_[index]); + } + + platform::DeviceContext& device_context() const { return device_context_; } + template ::EigenDeviceType> + DeviceType* get_eigen_device(); + + const OperatorBase& op_; + const std::shared_ptr& scope_; + const platform::DeviceContext& device_context_; +}; + +class OpKernel { + public: virtual void Compute(const KernelContext& context) const = 0; virtual ~OpKernel() {} diff --git a/paddle/framework/tensor.h b/paddle/framework/tensor.h index e14b75d0e0851be1d62d55b32520ad413688f334..01244f617c2fb2e86f2e4d960bc0c5b593fc34fa 100644 --- a/paddle/framework/tensor.h +++ b/paddle/framework/tensor.h @@ -35,7 +35,7 @@ class Tensor { template - T* data() const { + const T* data() const { PADDLE_ENFORCE( holder_ != nullptr, "Tenosr has not been initialized. Call Tensor::mutable_data first."); @@ -58,6 +58,20 @@ class Tensor { offset_); } + template ::value>::type* = nullptr> + T* mutable_data(paddle::platform::Place place) { + if (holder_ == nullptr || + !(holder_->Place() == + place) /* some versions of boost::variant don't have operator!= */ + || holder_->Size() < product(dims_) * sizeof(T) + offset_) { + holder_.reset(new PlaceholderImpl(place, product(dims_) * sizeof(T))); + offset_ = 0; + } + return reinterpret_cast(reinterpret_cast(holder_->Ptr()) + + offset_); + } + size_t NumElements() const { return product(dims_); } template diff --git a/paddle/operators/add_op.cc b/paddle/operators/add_op.cc index 2766f0bf258ed863a4297c1e4a2be4673cbf3044..ef39e426fde5a26e14f90201d16a8a1b5fad9099 100644 --- a/paddle/operators/add_op.cc +++ b/paddle/operators/add_op.cc @@ -1,6 +1,6 @@ -#include -#include -#include +#include "paddle/operators/add_op.h" +#include "paddle/framework/op_registry.h" +#include "paddle/framework/tensor.h" namespace paddle { namespace operators { @@ -36,9 +36,10 @@ The equation is: Out = X + Y )DOC"); } }; -} // namespace op +} // namespace operators } // namespace paddle REGISTER_OP(add_two, paddle::operators::AddOp, paddle::operators::AddOpMaker); REGISTER_OP_CPU_KERNEL( - add_two, ::paddle::operators::AddKernel<::paddle::platform::CPUPlace>); \ No newline at end of file + add_two, + ::paddle::operators::AddKernel<::paddle::platform::CPUPlace, float>); \ No newline at end of file diff --git a/paddle/operators/add_op.cu b/paddle/operators/add_op.cu index 5979345fffd68d71ba09dc96874d8ff9471bdbcc..f4a4fb16a6e39c1c15e613f86ad5c02703aa33d0 100644 --- a/paddle/operators/add_op.cu +++ b/paddle/operators/add_op.cu @@ -1,5 +1,7 @@ -#include -#include +#define EIGEN_USE_GPU + +#include "paddle/operators/add_op.h" +#include "paddle/framework/op_registry.h" REGISTER_OP_GPU_KERNEL(add_two, - paddle::operators::AddKernel); \ No newline at end of file + paddle::operators::AddKernel); \ No newline at end of file diff --git a/paddle/operators/add_op.h b/paddle/operators/add_op.h index 17d459dbc86af73fa86934a5ccce9da509c59c6b..27a477a3ac070174030c31683c922444abfe6e81 100644 --- a/paddle/operators/add_op.h +++ b/paddle/operators/add_op.h @@ -1,17 +1,26 @@ #pragma once -#include -#include +#include "glog/logging.h" +#include "paddle/framework/operator.h" +//#include "paddle/operators/add_op_functor.h" namespace paddle { namespace operators { -template +// Place can be CPUPlace or GPUPlace +template class AddKernel : public framework::OpKernel { public: - void Compute(const KernelContext &context) const override { - LOG(INFO) << "Add kernel in " << typeid(Place).name(); + void Compute(const KernelContext& context) const override { + auto* input0 = context.Input(0); + auto* input1 = context.Input(1); + + auto* output = context.Output(0); + output->mutable_data(Place()); + + output->flat().device(*(context.get_eigen_device())) = + input0->flat() + input1->flat(); } }; -} // namespace op +} // namespace operators } // namespace paddle