提交 d649dbf4 编写于 作者: Q qijun

implement add_op kernel

上级 bac1426d
......@@ -18,13 +18,15 @@ namespace paddle {
namespace framework {
template <>
DeviceType* KernelContext::get_eigen_device<CPUPlace>() {
return device_context_.get_eigen_device<DeviceType>();
Eigen::DefaultDevice* OpKernel::KernelContext::get_eigen_device<
platform::CPUPlace, Eigen::DefaultDevice>() const {
return device_context_.get_eigen_device<Eigen::DefaultDevice>();
}
#ifndef PADDLE_ONLY_CPU
template <>
DeviceType* KernelContext::get_eigen_device<GPUPlace>() {
DeviceType* OpKernel::KernelContext::get_eigen_device<platform::GPUPlace>()
const {
return device_context_.get_eigen_device<DeviceType>();
}
#endif
......
......@@ -33,13 +33,13 @@ template <typename T>
struct EigenDeviceConverter;
template <>
struct EigenDeviceConverter<CPUPlace> {
struct EigenDeviceConverter<platform::CPUPlace> {
using EigenDeviceType = Eigen::DefaultDevice;
};
#ifndef PADDLE_ONLY_CPU
template <>
struct EigenDeviceConverter<GPUPlace> {
struct EigenDeviceConverter<platform::GPUPlace> {
using EigenDeviceType = Eigen::GpuDevice;
};
#endif
......@@ -87,13 +87,15 @@ class OperatorBase {
AttributeMap attrs_;
};
/**
class OpKernel {
public:
/**
* KernelContext is the only parameter of Kernel Run function.
* Run will get input/output variables, state such as momentum and
* device resource such as CUDA stream, cublas handle, etc. from
* KernelContext. User should construct it before run the Operator.
*/
class KernelContext {
class KernelContext {
public:
KernelContext(const OperatorBase* op, const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& device_context)
......@@ -107,19 +109,16 @@ class KernelContext {
return scope_->GetVariable(op_.outputs_[index]);
}
platform::DeviceContext& device_context() const { return device_context_; }
template <typename PlaceType, typename DeviceType = EigenDeviceConverter<
PlaceType>::EigenDeviceType>
DeviceType* get_eigen_device();
template <typename PlaceType,
typename DeviceType =
typename EigenDeviceConverter<PlaceType>::EigenDeviceType>
DeviceType* get_eigen_device() const;
const OperatorBase& op_;
const std::shared_ptr<Scope>& scope_;
const platform::DeviceContext& device_context_;
};
};
class OpKernel {
public:
virtual void Compute(const KernelContext& context) const = 0;
virtual ~OpKernel() {}
......
......@@ -35,7 +35,7 @@ class Tensor {
template <typename T>
const T* data() const {
T* data() const {
PADDLE_ENFORCE(
holder_ != nullptr,
"Tenosr has not been initialized. Call Tensor::mutable_data first.");
......@@ -90,7 +90,7 @@ class Tensor {
// flat to rank = 1
template <typename T>
typename TTypes<T>::Flat flat() {
return shaped<T, 1>({NumElements()});
return shaped<T, 1>(make_ddim({static_cast<int>(NumElements())}));
}
// to TensorType Vec
......@@ -114,7 +114,7 @@ class Tensor {
template <typename T>
typename TTypes<T>::ConstFlat flat() const {
return shaped<T, 1>({NumElements()});
return shaped<T, 1>(make_ddim({static_cast<int>(NumElements())}));
}
template <typename T>
......
......@@ -40,6 +40,6 @@ The equation is: Out = X + Y
} // namespace paddle
REGISTER_OP(add_two, paddle::operators::AddOp, paddle::operators::AddOpMaker);
REGISTER_OP_CPU_KERNEL(
add_two,
::paddle::operators::AddKernel<::paddle::platform::CPUPlace, float>);
\ No newline at end of file
typedef paddle::operators::AddKernel<::paddle::platform::CPUPlace, float>
AddKernel_CPU_float;
REGISTER_OP_CPU_KERNEL(add_two, AddKernel_CPU_float);
\ No newline at end of file
#define EIGEN_USE_GPU
#include "paddle/operators/add_op.h"
#include "paddle/framework/op_registry.h"
typedef paddle::operators::AddKernel<::paddle::platform::GPUPlace, float> AddKernel_GPU_float;
REGISTER_OP_GPU_KERNEL(add_two,
paddle::operators::AddKernel<paddle::platform::GPUPlace, float>);
\ No newline at end of file
AddKernel_GPU_float);
\ No newline at end of file
......@@ -6,19 +6,18 @@
namespace paddle {
namespace operators {
// Place can be CPUPlace or GPUPlace
template <typename Place, typename DataType>
template <typename Place, typename T>
class AddKernel : public framework::OpKernel {
public:
void Compute(const KernelContext& context) const override {
auto* input0 = context.Input(0);
auto* input1 = context.Input(1);
auto input0 = context.Input(0)->Get<framework::Tensor>();
auto input1 = context.Input(1)->Get<framework::Tensor>();
auto* output = context.Output(0)->GetMutable<framework::Tensor>();
auto* output = context.Output(0);
output->mutable_data<DataType>(Place());
output->mutable_data<T>(Place());
output->flat<T>().device(*(context.get_eigen_device<Place>())) =
input0->flat<T>() + input1->flat<T>();
input0.flat<T>() + input1.flat<T>();
}
};
......
......@@ -15,14 +15,15 @@ namespace paddle {
namespace platform {
template <>
Eigen::DefaultDevice* DeviceContext::get_eigen_device<Eigen::DefaultDevice>() {
return reinterpret_cast<CPUDeviceContext*>(this)->eigen_device();
Eigen::DefaultDevice* DeviceContext::get_eigen_device<Eigen::DefaultDevice>()
const {
return reinterpret_cast<const CPUDeviceContext*>(this)->eigen_device();
}
#ifndef PADDLE_ONLY_CPU
template <>
Eigen::GpuDevice* DeviceContext::get_eigen_device<Eigen::GpuDevice>() {
return reinterpret_cast<CUDADeviceContext*>(this)->eigen_device();
Eigen::GpuDevice* DeviceContext::get_eigen_device<Eigen::GpuDevice>() const {
return reinterpret_cast<const CUDADeviceContext*>(this)->eigen_device();
}
#endif
......
......@@ -32,17 +32,14 @@ class DeviceContext {
virtual Place GetPlace() const = 0;
template <typename DeviceType>
DeviceType* get_eigen_device();
DeviceType* get_eigen_device() const;
};
class CPUDeviceContext : public DeviceContext {
public:
Eigen::DefaultDevice* eigen_device() {
if (!eigen_device_) {
eigen_device_.reset(new Eigen::DefaultDevice());
}
return eigen_device_.get();
}
CPUDeviceContext() { eigen_device_.reset(new Eigen::DefaultDevice()); }
Eigen::DefaultDevice* eigen_device() const { return eigen_device_.get(); }
Place GetPlace() const override {
Place retv = CPUPlace();
......@@ -91,7 +88,7 @@ class CUDADeviceContext : public DeviceContext {
cudaStream_t stream() { return stream_; }
Eigen::GpuDevice* eigen_device() { return eigen_device_.get(); }
Eigen::GpuDevice* eigen_device() const { return eigen_device_.get(); }
cublasHandle_t cublas_handle() {
if (!blas_handle_) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册