提交 bac1426d 编写于 作者: Q qijun

add_op kernel implementation

上级 6f2eba3e
...@@ -17,6 +17,18 @@ limitations under the License. */ ...@@ -17,6 +17,18 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
template <>
DeviceType* KernelContext::get_eigen_device<CPUPlace>() {
return device_context_.get_eigen_device<DeviceType>();
}
#ifndef PADDLE_ONLY_CPU
template <>
DeviceType* KernelContext::get_eigen_device<GPUPlace>() {
return device_context_.get_eigen_device<DeviceType>();
}
#endif
std::string OperatorBase::DebugString() const { std::string OperatorBase::DebugString() const {
std::stringstream ss; std::stringstream ss;
ss << "=================\n"; ss << "=================\n";
......
...@@ -29,6 +29,21 @@ limitations under the License. */ ...@@ -29,6 +29,21 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
template <typename T>
struct EigenDeviceConverter;
template <>
struct EigenDeviceConverter<CPUPlace> {
using EigenDeviceType = Eigen::DefaultDevice;
};
#ifndef PADDLE_ONLY_CPU
template <>
struct EigenDeviceConverter<GPUPlace> {
using EigenDeviceType = Eigen::GpuDevice;
};
#endif
class OperatorBase; class OperatorBase;
/** /**
...@@ -72,33 +87,39 @@ class OperatorBase { ...@@ -72,33 +87,39 @@ class OperatorBase {
AttributeMap attrs_; AttributeMap attrs_;
}; };
class OpKernel { /**
* KernelContext is the only parameter of Kernel Run function.
* Run will get input/output variables, state such as momentum and
* device resource such as CUDA stream, cublas handle, etc. from
* KernelContext. User should construct it before run the Operator.
*/
class KernelContext {
public: public:
/** KernelContext(const OperatorBase* op, const std::shared_ptr<Scope>& scope,
* KernelContext is the only parameter of Kernel Run function. const platform::DeviceContext& device_context)
* Run will get input/output variables, state such as momentum and : op_(*op), scope_(scope), device_context_(device_context) {}
* device resource such as CUDA stream, cublas handle, etc. from
* KernelContext. User should construct it before run the Operator.
*/
class KernelContext {
public:
KernelContext(const OperatorBase* op, const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& device_context)
: op_(*op), scope_(scope), device_context_(device_context) {}
const Variable* Input(int index) const {
return scope_->GetVariable(op_.inputs_[index]);
}
Variable* Output(int index) const { const Variable* Input(int index) const {
return scope_->GetVariable(op_.outputs_[index]); return scope_->GetVariable(op_.inputs_[index]);
} }
const OperatorBase& op_; Variable* Output(int index) const {
const std::shared_ptr<Scope>& scope_; return scope_->GetVariable(op_.outputs_[index]);
const platform::DeviceContext& device_context_; }
};
platform::DeviceContext& device_context() const { return device_context_; }
template <typename PlaceType, typename DeviceType = EigenDeviceConverter<
PlaceType>::EigenDeviceType>
DeviceType* get_eigen_device();
const OperatorBase& op_;
const std::shared_ptr<Scope>& scope_;
const platform::DeviceContext& device_context_;
};
class OpKernel {
public:
virtual void Compute(const KernelContext& context) const = 0; virtual void Compute(const KernelContext& context) const = 0;
virtual ~OpKernel() {} virtual ~OpKernel() {}
......
...@@ -35,7 +35,7 @@ class Tensor { ...@@ -35,7 +35,7 @@ class Tensor {
template <typename T> template <typename T>
T* data() const { const T* data() const {
PADDLE_ENFORCE( PADDLE_ENFORCE(
holder_ != nullptr, holder_ != nullptr,
"Tenosr has not been initialized. Call Tensor::mutable_data first."); "Tenosr has not been initialized. Call Tensor::mutable_data first.");
...@@ -58,6 +58,20 @@ class Tensor { ...@@ -58,6 +58,20 @@ class Tensor {
offset_); offset_);
} }
template <typename T, // must be POD types
typename std::enable_if<std::is_pod<T>::value>::type* = nullptr>
T* mutable_data(paddle::platform::Place place) {
if (holder_ == nullptr ||
!(holder_->Place() ==
place) /* some versions of boost::variant don't have operator!= */
|| holder_->Size() < product(dims_) * sizeof(T) + offset_) {
holder_.reset(new PlaceholderImpl<T>(place, product(dims_) * sizeof(T)));
offset_ = 0;
}
return reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(holder_->Ptr()) +
offset_);
}
size_t NumElements() const { return product(dims_); } size_t NumElements() const { return product(dims_); }
template <typename T, size_t NDIMS> template <typename T, size_t NDIMS>
......
#include <paddle/framework/op_registry.h> #include "paddle/operators/add_op.h"
#include <paddle/framework/tensor.h> #include "paddle/framework/op_registry.h"
#include <paddle/operators/add_op.h> #include "paddle/framework/tensor.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -36,9 +36,10 @@ The equation is: Out = X + Y ...@@ -36,9 +36,10 @@ The equation is: Out = X + Y
)DOC"); )DOC");
} }
}; };
} // namespace op } // namespace operators
} // namespace paddle } // namespace paddle
REGISTER_OP(add_two, paddle::operators::AddOp, paddle::operators::AddOpMaker); REGISTER_OP(add_two, paddle::operators::AddOp, paddle::operators::AddOpMaker);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
add_two, ::paddle::operators::AddKernel<::paddle::platform::CPUPlace>); add_two,
\ No newline at end of file ::paddle::operators::AddKernel<::paddle::platform::CPUPlace, float>);
\ No newline at end of file
#include <paddle/operators/add_op.h> #define EIGEN_USE_GPU
#include <paddle/framework/op_registry.h>
#include "paddle/operators/add_op.h"
#include "paddle/framework/op_registry.h"
REGISTER_OP_GPU_KERNEL(add_two, REGISTER_OP_GPU_KERNEL(add_two,
paddle::operators::AddKernel<paddle::platform::GPUPlace>); paddle::operators::AddKernel<paddle::platform::GPUPlace, float>);
\ No newline at end of file \ No newline at end of file
#pragma once #pragma once
#include <glog/logging.h> #include "glog/logging.h"
#include <paddle/framework/operator.h> #include "paddle/framework/operator.h"
//#include "paddle/operators/add_op_functor.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename Place> // Place can be CPUPlace or GPUPlace
template <typename Place, typename DataType>
class AddKernel : public framework::OpKernel { class AddKernel : public framework::OpKernel {
public: public:
void Compute(const KernelContext &context) const override { void Compute(const KernelContext& context) const override {
LOG(INFO) << "Add kernel in " << typeid(Place).name(); auto* input0 = context.Input(0);
auto* input1 = context.Input(1);
auto* output = context.Output(0);
output->mutable_data<DataType>(Place());
output->flat<T>().device(*(context.get_eigen_device<Place>())) =
input0->flat<T>() + input1->flat<T>();
} }
}; };
} // namespace op } // namespace operators
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册