提交 d649dbf4 编写于 作者: Q qijun

implement add_op kernel

上级 bac1426d
...@@ -18,13 +18,15 @@ namespace paddle { ...@@ -18,13 +18,15 @@ namespace paddle {
namespace framework { namespace framework {
template <> template <>
DeviceType* KernelContext::get_eigen_device<CPUPlace>() { Eigen::DefaultDevice* OpKernel::KernelContext::get_eigen_device<
return device_context_.get_eigen_device<DeviceType>(); platform::CPUPlace, Eigen::DefaultDevice>() const {
return device_context_.get_eigen_device<Eigen::DefaultDevice>();
} }
#ifndef PADDLE_ONLY_CPU #ifndef PADDLE_ONLY_CPU
template <> template <>
DeviceType* KernelContext::get_eigen_device<GPUPlace>() { DeviceType* OpKernel::KernelContext::get_eigen_device<platform::GPUPlace>()
const {
return device_context_.get_eigen_device<DeviceType>(); return device_context_.get_eigen_device<DeviceType>();
} }
#endif #endif
......
...@@ -33,13 +33,13 @@ template <typename T> ...@@ -33,13 +33,13 @@ template <typename T>
struct EigenDeviceConverter; struct EigenDeviceConverter;
template <> template <>
struct EigenDeviceConverter<CPUPlace> { struct EigenDeviceConverter<platform::CPUPlace> {
using EigenDeviceType = Eigen::DefaultDevice; using EigenDeviceType = Eigen::DefaultDevice;
}; };
#ifndef PADDLE_ONLY_CPU #ifndef PADDLE_ONLY_CPU
template <> template <>
struct EigenDeviceConverter<GPUPlace> { struct EigenDeviceConverter<platform::GPUPlace> {
using EigenDeviceType = Eigen::GpuDevice; using EigenDeviceType = Eigen::GpuDevice;
}; };
#endif #endif
...@@ -87,13 +87,15 @@ class OperatorBase { ...@@ -87,13 +87,15 @@ class OperatorBase {
AttributeMap attrs_; AttributeMap attrs_;
}; };
/** class OpKernel {
public:
/**
* KernelContext is the only parameter of Kernel Run function. * KernelContext is the only parameter of Kernel Run function.
* Run will get input/output variables, state such as momentum and * Run will get input/output variables, state such as momentum and
* device resource such as CUDA stream, cublas handle, etc. from * device resource such as CUDA stream, cublas handle, etc. from
* KernelContext. User should construct it before run the Operator. * KernelContext. User should construct it before run the Operator.
*/ */
class KernelContext { class KernelContext {
public: public:
KernelContext(const OperatorBase* op, const std::shared_ptr<Scope>& scope, KernelContext(const OperatorBase* op, const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& device_context) const platform::DeviceContext& device_context)
...@@ -107,19 +109,16 @@ class KernelContext { ...@@ -107,19 +109,16 @@ class KernelContext {
return scope_->GetVariable(op_.outputs_[index]); return scope_->GetVariable(op_.outputs_[index]);
} }
platform::DeviceContext& device_context() const { return device_context_; } template <typename PlaceType,
typename DeviceType =
template <typename PlaceType, typename DeviceType = EigenDeviceConverter< typename EigenDeviceConverter<PlaceType>::EigenDeviceType>
PlaceType>::EigenDeviceType> DeviceType* get_eigen_device() const;
DeviceType* get_eigen_device();
const OperatorBase& op_; const OperatorBase& op_;
const std::shared_ptr<Scope>& scope_; const std::shared_ptr<Scope>& scope_;
const platform::DeviceContext& device_context_; const platform::DeviceContext& device_context_;
}; };
class OpKernel {
public:
virtual void Compute(const KernelContext& context) const = 0; virtual void Compute(const KernelContext& context) const = 0;
virtual ~OpKernel() {} virtual ~OpKernel() {}
......
...@@ -35,7 +35,7 @@ class Tensor { ...@@ -35,7 +35,7 @@ class Tensor {
template <typename T> template <typename T>
const T* data() const { T* data() const {
PADDLE_ENFORCE( PADDLE_ENFORCE(
holder_ != nullptr, holder_ != nullptr,
"Tenosr has not been initialized. Call Tensor::mutable_data first."); "Tenosr has not been initialized. Call Tensor::mutable_data first.");
...@@ -90,7 +90,7 @@ class Tensor { ...@@ -90,7 +90,7 @@ class Tensor {
// flat to rank = 1 // flat to rank = 1
template <typename T> template <typename T>
typename TTypes<T>::Flat flat() { typename TTypes<T>::Flat flat() {
return shaped<T, 1>({NumElements()}); return shaped<T, 1>(make_ddim({static_cast<int>(NumElements())}));
} }
// to TensorType Vec // to TensorType Vec
...@@ -114,7 +114,7 @@ class Tensor { ...@@ -114,7 +114,7 @@ class Tensor {
template <typename T> template <typename T>
typename TTypes<T>::ConstFlat flat() const { typename TTypes<T>::ConstFlat flat() const {
return shaped<T, 1>({NumElements()}); return shaped<T, 1>(make_ddim({static_cast<int>(NumElements())}));
} }
template <typename T> template <typename T>
......
...@@ -40,6 +40,6 @@ The equation is: Out = X + Y ...@@ -40,6 +40,6 @@ The equation is: Out = X + Y
} // namespace paddle } // namespace paddle
REGISTER_OP(add_two, paddle::operators::AddOp, paddle::operators::AddOpMaker); REGISTER_OP(add_two, paddle::operators::AddOp, paddle::operators::AddOpMaker);
REGISTER_OP_CPU_KERNEL( typedef paddle::operators::AddKernel<::paddle::platform::CPUPlace, float>
add_two, AddKernel_CPU_float;
::paddle::operators::AddKernel<::paddle::platform::CPUPlace, float>); REGISTER_OP_CPU_KERNEL(add_two, AddKernel_CPU_float);
\ No newline at end of file \ No newline at end of file
#define EIGEN_USE_GPU
#include "paddle/operators/add_op.h" #include "paddle/operators/add_op.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
typedef paddle::operators::AddKernel<::paddle::platform::GPUPlace, float> AddKernel_GPU_float;
REGISTER_OP_GPU_KERNEL(add_two, REGISTER_OP_GPU_KERNEL(add_two,
paddle::operators::AddKernel<paddle::platform::GPUPlace, float>); AddKernel_GPU_float);
\ No newline at end of file \ No newline at end of file
...@@ -6,19 +6,18 @@ ...@@ -6,19 +6,18 @@
namespace paddle { namespace paddle {
namespace operators { namespace operators {
// Place can be CPUPlace or GPUPlace template <typename Place, typename T>
template <typename Place, typename DataType>
class AddKernel : public framework::OpKernel { class AddKernel : public framework::OpKernel {
public: public:
void Compute(const KernelContext& context) const override { void Compute(const KernelContext& context) const override {
auto* input0 = context.Input(0); auto input0 = context.Input(0)->Get<framework::Tensor>();
auto* input1 = context.Input(1); auto input1 = context.Input(1)->Get<framework::Tensor>();
auto* output = context.Output(0)->GetMutable<framework::Tensor>();
auto* output = context.Output(0); output->mutable_data<T>(Place());
output->mutable_data<DataType>(Place());
output->flat<T>().device(*(context.get_eigen_device<Place>())) = output->flat<T>().device(*(context.get_eigen_device<Place>())) =
input0->flat<T>() + input1->flat<T>(); input0.flat<T>() + input1.flat<T>();
} }
}; };
......
...@@ -15,14 +15,15 @@ namespace paddle { ...@@ -15,14 +15,15 @@ namespace paddle {
namespace platform { namespace platform {
template <> template <>
Eigen::DefaultDevice* DeviceContext::get_eigen_device<Eigen::DefaultDevice>() { Eigen::DefaultDevice* DeviceContext::get_eigen_device<Eigen::DefaultDevice>()
return reinterpret_cast<CPUDeviceContext*>(this)->eigen_device(); const {
return reinterpret_cast<const CPUDeviceContext*>(this)->eigen_device();
} }
#ifndef PADDLE_ONLY_CPU #ifndef PADDLE_ONLY_CPU
template <> template <>
Eigen::GpuDevice* DeviceContext::get_eigen_device<Eigen::GpuDevice>() { Eigen::GpuDevice* DeviceContext::get_eigen_device<Eigen::GpuDevice>() const {
return reinterpret_cast<CUDADeviceContext*>(this)->eigen_device(); return reinterpret_cast<const CUDADeviceContext*>(this)->eigen_device();
} }
#endif #endif
......
...@@ -32,17 +32,14 @@ class DeviceContext { ...@@ -32,17 +32,14 @@ class DeviceContext {
virtual Place GetPlace() const = 0; virtual Place GetPlace() const = 0;
template <typename DeviceType> template <typename DeviceType>
DeviceType* get_eigen_device(); DeviceType* get_eigen_device() const;
}; };
class CPUDeviceContext : public DeviceContext { class CPUDeviceContext : public DeviceContext {
public: public:
Eigen::DefaultDevice* eigen_device() { CPUDeviceContext() { eigen_device_.reset(new Eigen::DefaultDevice()); }
if (!eigen_device_) {
eigen_device_.reset(new Eigen::DefaultDevice()); Eigen::DefaultDevice* eigen_device() const { return eigen_device_.get(); }
}
return eigen_device_.get();
}
Place GetPlace() const override { Place GetPlace() const override {
Place retv = CPUPlace(); Place retv = CPUPlace();
...@@ -91,7 +88,7 @@ class CUDADeviceContext : public DeviceContext { ...@@ -91,7 +88,7 @@ class CUDADeviceContext : public DeviceContext {
cudaStream_t stream() { return stream_; } cudaStream_t stream() { return stream_; }
Eigen::GpuDevice* eigen_device() { return eigen_device_.get(); } Eigen::GpuDevice* eigen_device() const { return eigen_device_.get(); }
cublasHandle_t cublas_handle() { cublasHandle_t cublas_handle() {
if (!blas_handle_) { if (!blas_handle_) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册