提交 bac1426d 编写于 作者: Q qijun

add_op kernel implementation

上级 6f2eba3e
......@@ -17,6 +17,18 @@ limitations under the License. */
namespace paddle {
namespace framework {
template <>
DeviceType* KernelContext::get_eigen_device<CPUPlace>() {
return device_context_.get_eigen_device<DeviceType>();
}
#ifndef PADDLE_ONLY_CPU
template <>
DeviceType* KernelContext::get_eigen_device<GPUPlace>() {
return device_context_.get_eigen_device<DeviceType>();
}
#endif
std::string OperatorBase::DebugString() const {
std::stringstream ss;
ss << "=================\n";
......
......@@ -29,6 +29,21 @@ limitations under the License. */
namespace paddle {
namespace framework {
template <typename T>
struct EigenDeviceConverter;
template <>
struct EigenDeviceConverter<CPUPlace> {
using EigenDeviceType = Eigen::DefaultDevice;
};
#ifndef PADDLE_ONLY_CPU
template <>
struct EigenDeviceConverter<GPUPlace> {
using EigenDeviceType = Eigen::GpuDevice;
};
#endif
class OperatorBase;
/**
......@@ -72,33 +87,39 @@ class OperatorBase {
AttributeMap attrs_;
};
class OpKernel {
/**
* KernelContext is the only parameter of Kernel Run function.
* Run will get input/output variables, state such as momentum and
* device resource such as CUDA stream, cublas handle, etc. from
* KernelContext. User should construct it before run the Operator.
*/
class KernelContext {
public:
/**
* KernelContext is the only parameter of Kernel Run function.
* Run will get input/output variables, state such as momentum and
* device resource such as CUDA stream, cublas handle, etc. from
* KernelContext. User should construct it before run the Operator.
*/
class KernelContext {
public:
KernelContext(const OperatorBase* op, const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& device_context)
: op_(*op), scope_(scope), device_context_(device_context) {}
const Variable* Input(int index) const {
return scope_->GetVariable(op_.inputs_[index]);
}
KernelContext(const OperatorBase* op, const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& device_context)
: op_(*op), scope_(scope), device_context_(device_context) {}
Variable* Output(int index) const {
return scope_->GetVariable(op_.outputs_[index]);
}
const Variable* Input(int index) const {
return scope_->GetVariable(op_.inputs_[index]);
}
const OperatorBase& op_;
const std::shared_ptr<Scope>& scope_;
const platform::DeviceContext& device_context_;
};
Variable* Output(int index) const {
return scope_->GetVariable(op_.outputs_[index]);
}
platform::DeviceContext& device_context() const { return device_context_; }
template <typename PlaceType, typename DeviceType = EigenDeviceConverter<
PlaceType>::EigenDeviceType>
DeviceType* get_eigen_device();
const OperatorBase& op_;
const std::shared_ptr<Scope>& scope_;
const platform::DeviceContext& device_context_;
};
class OpKernel {
public:
virtual void Compute(const KernelContext& context) const = 0;
virtual ~OpKernel() {}
......
......@@ -35,7 +35,7 @@ class Tensor {
template <typename T>
T* data() const {
const T* data() const {
PADDLE_ENFORCE(
holder_ != nullptr,
"Tenosr has not been initialized. Call Tensor::mutable_data first.");
......@@ -58,6 +58,20 @@ class Tensor {
offset_);
}
template <typename T, // must be POD types
typename std::enable_if<std::is_pod<T>::value>::type* = nullptr>
T* mutable_data(paddle::platform::Place place) {
if (holder_ == nullptr ||
!(holder_->Place() ==
place) /* some versions of boost::variant don't have operator!= */
|| holder_->Size() < product(dims_) * sizeof(T) + offset_) {
holder_.reset(new PlaceholderImpl<T>(place, product(dims_) * sizeof(T)));
offset_ = 0;
}
return reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(holder_->Ptr()) +
offset_);
}
size_t NumElements() const { return product(dims_); }
template <typename T, size_t NDIMS>
......
#include <paddle/framework/op_registry.h>
#include <paddle/framework/tensor.h>
#include <paddle/operators/add_op.h>
#include "paddle/operators/add_op.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/tensor.h"
namespace paddle {
namespace operators {
......@@ -36,9 +36,10 @@ The equation is: Out = X + Y
)DOC");
}
};
} // namespace op
} // namespace operators
} // namespace paddle
REGISTER_OP(add_two, paddle::operators::AddOp, paddle::operators::AddOpMaker);
REGISTER_OP_CPU_KERNEL(
add_two, ::paddle::operators::AddKernel<::paddle::platform::CPUPlace>);
\ No newline at end of file
add_two,
::paddle::operators::AddKernel<::paddle::platform::CPUPlace, float>);
\ No newline at end of file
#include <paddle/operators/add_op.h>
#include <paddle/framework/op_registry.h>
#define EIGEN_USE_GPU
#include "paddle/operators/add_op.h"
#include "paddle/framework/op_registry.h"
REGISTER_OP_GPU_KERNEL(add_two,
paddle::operators::AddKernel<paddle::platform::GPUPlace>);
\ No newline at end of file
paddle::operators::AddKernel<paddle::platform::GPUPlace, float>);
\ No newline at end of file
#pragma once
#include <glog/logging.h>
#include <paddle/framework/operator.h>
#include "glog/logging.h"
#include "paddle/framework/operator.h"
//#include "paddle/operators/add_op_functor.h"
namespace paddle {
namespace operators {
template <typename Place>
// Place can be CPUPlace or GPUPlace
template <typename Place, typename DataType>
class AddKernel : public framework::OpKernel {
public:
void Compute(const KernelContext &context) const override {
LOG(INFO) << "Add kernel in " << typeid(Place).name();
void Compute(const KernelContext& context) const override {
auto* input0 = context.Input(0);
auto* input1 = context.Input(1);
auto* output = context.Output(0);
output->mutable_data<DataType>(Place());
output->flat<T>().device(*(context.get_eigen_device<Place>())) =
input0->flat<T>() + input1->flat<T>();
}
};
} // namespace op
} // namespace operators
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册