From bac1426d47727a9ea101dd42135a0800c2c5e023 Mon Sep 17 00:00:00 2001 From: qijun Date: Fri, 14 Jul 2017 16:57:03 +0800 Subject: [PATCH] add_op kernel implementation --- paddle/framework/operator.cc | 12 +++++++ paddle/framework/operator.h | 67 +++++++++++++++++++++++------------- paddle/framework/tensor.h | 16 ++++++++- paddle/operators/add_op.cc | 11 +++--- paddle/operators/add_op.cu | 8 +++-- paddle/operators/add_op.h | 21 +++++++---- 6 files changed, 97 insertions(+), 38 deletions(-) diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index 8f7adff8b..25d120c9a 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -17,6 +17,18 @@ limitations under the License. */ namespace paddle { namespace framework { +template <> +DeviceType* KernelContext::get_eigen_device() { + return device_context_.get_eigen_device(); +} + +#ifndef PADDLE_ONLY_CPU +template <> +DeviceType* KernelContext::get_eigen_device() { + return device_context_.get_eigen_device(); +} +#endif + std::string OperatorBase::DebugString() const { std::stringstream ss; ss << "=================\n"; diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index d3c55e0ce..48cfeeb73 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -29,6 +29,21 @@ limitations under the License. */ namespace paddle { namespace framework { +template +struct EigenDeviceConverter; + +template <> +struct EigenDeviceConverter { + using EigenDeviceType = Eigen::DefaultDevice; +}; + +#ifndef PADDLE_ONLY_CPU +template <> +struct EigenDeviceConverter { + using EigenDeviceType = Eigen::GpuDevice; +}; +#endif + class OperatorBase; /** @@ -72,33 +87,39 @@ class OperatorBase { AttributeMap attrs_; }; -class OpKernel { +/** + * KernelContext is the only parameter of Kernel Run function. + * Run will get input/output variables, state such as momentum and + * device resource such as CUDA stream, cublas handle, etc. from + * KernelContext. User should construct it before run the Operator. + */ +class KernelContext { public: - /** - * KernelContext is the only parameter of Kernel Run function. - * Run will get input/output variables, state such as momentum and - * device resource such as CUDA stream, cublas handle, etc. from - * KernelContext. User should construct it before run the Operator. - */ - class KernelContext { - public: - KernelContext(const OperatorBase* op, const std::shared_ptr& scope, - const platform::DeviceContext& device_context) - : op_(*op), scope_(scope), device_context_(device_context) {} - - const Variable* Input(int index) const { - return scope_->GetVariable(op_.inputs_[index]); - } + KernelContext(const OperatorBase* op, const std::shared_ptr& scope, + const platform::DeviceContext& device_context) + : op_(*op), scope_(scope), device_context_(device_context) {} - Variable* Output(int index) const { - return scope_->GetVariable(op_.outputs_[index]); - } + const Variable* Input(int index) const { + return scope_->GetVariable(op_.inputs_[index]); + } - const OperatorBase& op_; - const std::shared_ptr& scope_; - const platform::DeviceContext& device_context_; - }; + Variable* Output(int index) const { + return scope_->GetVariable(op_.outputs_[index]); + } + + platform::DeviceContext& device_context() const { return device_context_; } + template ::EigenDeviceType> + DeviceType* get_eigen_device(); + + const OperatorBase& op_; + const std::shared_ptr& scope_; + const platform::DeviceContext& device_context_; +}; + +class OpKernel { + public: virtual void Compute(const KernelContext& context) const = 0; virtual ~OpKernel() {} diff --git a/paddle/framework/tensor.h b/paddle/framework/tensor.h index e14b75d0e..01244f617 100644 --- a/paddle/framework/tensor.h +++ b/paddle/framework/tensor.h @@ -35,7 +35,7 @@ class Tensor { template - T* data() const { + const T* data() const { PADDLE_ENFORCE( holder_ != nullptr, "Tenosr has not been initialized. Call Tensor::mutable_data first."); @@ -58,6 +58,20 @@ class Tensor { offset_); } + template ::value>::type* = nullptr> + T* mutable_data(paddle::platform::Place place) { + if (holder_ == nullptr || + !(holder_->Place() == + place) /* some versions of boost::variant don't have operator!= */ + || holder_->Size() < product(dims_) * sizeof(T) + offset_) { + holder_.reset(new PlaceholderImpl(place, product(dims_) * sizeof(T))); + offset_ = 0; + } + return reinterpret_cast(reinterpret_cast(holder_->Ptr()) + + offset_); + } + size_t NumElements() const { return product(dims_); } template diff --git a/paddle/operators/add_op.cc b/paddle/operators/add_op.cc index 2766f0bf2..ef39e426f 100644 --- a/paddle/operators/add_op.cc +++ b/paddle/operators/add_op.cc @@ -1,6 +1,6 @@ -#include -#include -#include +#include "paddle/operators/add_op.h" +#include "paddle/framework/op_registry.h" +#include "paddle/framework/tensor.h" namespace paddle { namespace operators { @@ -36,9 +36,10 @@ The equation is: Out = X + Y )DOC"); } }; -} // namespace op +} // namespace operators } // namespace paddle REGISTER_OP(add_two, paddle::operators::AddOp, paddle::operators::AddOpMaker); REGISTER_OP_CPU_KERNEL( - add_two, ::paddle::operators::AddKernel<::paddle::platform::CPUPlace>); \ No newline at end of file + add_two, + ::paddle::operators::AddKernel<::paddle::platform::CPUPlace, float>); \ No newline at end of file diff --git a/paddle/operators/add_op.cu b/paddle/operators/add_op.cu index 5979345ff..f4a4fb16a 100644 --- a/paddle/operators/add_op.cu +++ b/paddle/operators/add_op.cu @@ -1,5 +1,7 @@ -#include -#include +#define EIGEN_USE_GPU + +#include "paddle/operators/add_op.h" +#include "paddle/framework/op_registry.h" REGISTER_OP_GPU_KERNEL(add_two, - paddle::operators::AddKernel); \ No newline at end of file + paddle::operators::AddKernel); \ No newline at end of file diff --git a/paddle/operators/add_op.h b/paddle/operators/add_op.h index 17d459dbc..27a477a3a 100644 --- a/paddle/operators/add_op.h +++ b/paddle/operators/add_op.h @@ -1,17 +1,26 @@ #pragma once -#include -#include +#include "glog/logging.h" +#include "paddle/framework/operator.h" +//#include "paddle/operators/add_op_functor.h" namespace paddle { namespace operators { -template +// Place can be CPUPlace or GPUPlace +template class AddKernel : public framework::OpKernel { public: - void Compute(const KernelContext &context) const override { - LOG(INFO) << "Add kernel in " << typeid(Place).name(); + void Compute(const KernelContext& context) const override { + auto* input0 = context.Input(0); + auto* input1 = context.Input(1); + + auto* output = context.Output(0); + output->mutable_data(Place()); + + output->flat().device(*(context.get_eigen_device())) = + input0->flat() + input1->flat(); } }; -} // namespace op +} // namespace operators } // namespace paddle -- GitLab