提交 0ff81920 编写于 作者: Y Yu Yang

Add OperatorWithKernel class

* User can register OpKernel to its Ops. The OpKernelMap saved in
  OperatorWithKernel. Each Op which inherits OperatorWithKernel will
  use `OpKernel::Compute` instead of Run.
上级 2749b71f
......@@ -15,7 +15,6 @@ if(Boost_FOUND)
add_subdirectory(memory)
add_subdirectory(platform)
add_subdirectory(framework)
add_subdirectory(operators)
add_subdirectory(pybind)
endif()
......
#include "paddle/framework/op_registry.h"
#include <gtest/gtest.h>
#include "paddle/framework/operator.h"
#include "paddle/operators/demo_op.h"
using namespace paddle::framework;
namespace paddle {
namespace framework {
class CosineOp : public OperatorWithKernel {
class CosineOp : public OperatorBase {
public:
void Run(const OpRunContext* context) const override {
printf("%s\n", DebugString().c_str());
}
void Run(const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& dev_ctx) const override {}
void InferShape(const std::shared_ptr<Scope>& scope) const override {}
};
class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
......@@ -30,12 +28,13 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
REGISTER_OP(CosineOp, CosineOpProtoAndCheckerMaker, cos_sim)
class MyTestOp : public OperatorWithKernel {
class MyTestOp : public OperatorBase {
public:
void InferShape(const std::shared_ptr<Scope>& scope) const override {}
void Run(const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& dev_ctx) const override {}
public:
void Run(const OpRunContext* ctx) const override {
printf("%s\n", DebugString().c_str());
printf("test_attr = %d\n", ctx->op_->GetAttr<int>("test_attr"));
}
};
class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
......@@ -73,8 +72,8 @@ TEST(OpRegistry, CreateOp) {
paddle::framework::OperatorBase* op =
paddle::framework::OpRegistry::CreateOp(op_desc);
auto scope = std::make_shared<Scope>();
auto dev_ctx = DeviceContext();
op->Run(scope, &dev_ctx);
paddle::platform::CPUDeviceContext dev_ctx;
op->Run(scope, dev_ctx);
float scale_get = op->GetAttr<float>("scale");
ASSERT_EQ(scale_get, scale);
}
......@@ -116,8 +115,8 @@ TEST(OpRegistry, DefaultValue) {
paddle::framework::OperatorBase* op =
paddle::framework::OpRegistry::CreateOp(op_desc);
auto scope = std::make_shared<Scope>();
auto dev_ctx = DeviceContext();
op->Run(scope, &dev_ctx);
paddle::platform::CPUDeviceContext dev_ctx;
op->Run(scope, dev_ctx);
ASSERT_EQ(op->GetAttr<float>("scale"), 1.0);
}
......@@ -169,9 +168,9 @@ TEST(OpRegistry, CustomChecker) {
attr->set_i(4);
paddle::framework::OperatorBase* op =
paddle::framework::OpRegistry::CreateOp(op_desc);
auto dev_ctx = DeviceContext();
paddle::platform::CPUDeviceContext dev_ctx;
auto scope = std::make_shared<Scope>();
op->Run(scope, &dev_ctx);
op->Run(scope, dev_ctx);
int test_attr = op->GetAttr<int>("test_attr");
ASSERT_EQ(test_attr, 4);
}
......
......@@ -39,13 +39,5 @@ std::string OperatorBase::DebugString() const {
return ss.str();
}
const Variable* OpRunContext::Input(int index) const {
return scope_->GetVariable(op_->inputs_[index]);
}
Variable* OpRunContext::Output(int index) const {
return scope_->GetVariable(op_->outputs_[index]);
}
} // namespace framework
} // namespace paddle
\ No newline at end of file
......@@ -14,44 +14,22 @@ limitations under the License. */
#pragma once
#include <paddle/framework/attr_checker.h>
#include <paddle/framework/op_desc.pb.h>
#include <paddle/framework/scope.h>
#include <paddle/platform/device_context.h>
#include <paddle/platform/place.h>
#include <paddle/utils/Error.h>
#include <boost/variant.hpp>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/framework/attr_checker.h"
#include "paddle/framework/op_desc.pb.h"
#include "paddle/framework/scope.h"
#include "paddle/utils/Error.h"
namespace paddle {
namespace framework {
class OperatorBase;
class DeviceContext {};
/**
* OpRunContext is the only parameter of Operator's Run function.
* Run will get input/output variables, state such as momentum and
* device resource such as CUDA stream, cublas handle, etc. from
* OpRunContext. User should construct it before run the Operator.
*/
class OpRunContext {
public:
OpRunContext(const OperatorBase* op, const std::shared_ptr<Scope> scope,
const DeviceContext* device_context)
: op_(op), scope_(scope), device_context_(device_context) {}
const Variable* Input(int index) const;
Variable* Output(int index) const;
public:
const OperatorBase* op_;
const std::shared_ptr<Scope> scope_;
const DeviceContext* device_context_;
};
/**
* OperatorBase has the basic element that Net will call to do computation.
* Only CreateOperator from OpRegistry will new Operator directly. User
......@@ -77,7 +55,10 @@ class OperatorBase {
/// Net will call this function to Run an op.
virtual void Run(const std::shared_ptr<Scope>& scope,
const DeviceContext* dev_ctx) const = 0;
const platform::DeviceContext& dev_ctx) const = 0;
protected:
std::string Type() const { return desc_.type(); }
public:
OpDesc desc_;
......@@ -86,22 +67,84 @@ class OperatorBase {
AttributeMap attrs_;
};
class OpKernel {
public:
/**
* KernelContext is the only parameter of Kernel Run function.
* Run will get input/output variables, state such as momentum and
* device resource such as CUDA stream, cublas handle, etc. from
* KernelContext. User should construct it before run the Operator.
*/
class KernelContext {
public:
KernelContext(const OperatorBase* op, const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& device_context)
: op_(*op), scope_(scope), device_context_(device_context) {}
const Variable* Input(int index) const {
return scope_->GetVariable(op_.inputs_[index]);
}
Variable* Output(int index) const {
return scope_->GetVariable(op_.outputs_[index]);
}
const OperatorBase& op_;
const std::shared_ptr<Scope>& scope_;
const platform::DeviceContext& device_context_;
};
virtual void Compute(const KernelContext& context) const = 0;
virtual ~OpKernel() {}
};
class OperatorWithKernel : public OperatorBase {
public:
virtual ~OperatorWithKernel() {}
struct OpKernelKey {
platform::Place place_;
virtual void InferShape(const std::shared_ptr<Scope>& scope) const {}
OpKernelKey() = default;
OpKernelKey(const platform::DeviceContext& dev_ctx) {
place_ = dev_ctx.GetPlace();
}
bool operator==(const OpKernelKey& o) const { return place_ == o.place_; }
};
struct OpKernelHash {
std::hash<bool> hash_;
size_t operator()(const OpKernelKey& key) const {
return hash_(platform::is_gpu_place(key.place_));
}
};
using OpKernelMap =
std::unordered_map<OpKernelKey, std::unique_ptr<OpKernel>, OpKernelHash>;
void Run(const std::shared_ptr<Scope>& scope,
const DeviceContext* dev_ctx) const {
OpRunContext op_ctx(this, scope, dev_ctx);
Run(&op_ctx);
const platform::DeviceContext& dev_ctx) const final {
auto& opKernel = AllOpKernels().at(Type()).at(OpKernelKey(dev_ctx));
opKernel->Compute(OpKernel::KernelContext(this, scope, dev_ctx));
}
/// when implement an Op, your should implement this function.
/// this function should be moved to OpKernel later
virtual void Run(const OpRunContext* context) const = 0;
static std::unordered_map<std::string /* op_type */, OpKernelMap>&
AllOpKernels() {
static std::unordered_map<std::string, OpKernelMap> g_all_op_kernels;
return g_all_op_kernels;
};
};
} // namespace framework
} // namespace paddle
#define REGISTER_OP_KERNEL(type, PlaceType, KernelType) \
struct __op_kernel_register__##type##__ { \
__op_kernel_register__##type##__() { \
::paddle::framework::OperatorWithKernel::OpKernelKey key; \
key.place_ = PlaceType(); \
::paddle::framework::OperatorWithKernel::AllOpKernels()[#type][key] \
.reset(new KernelType()); \
} \
}; \
static __op_kernel_register__##type##__ __reg_kernel_##type##__
......@@ -19,17 +19,15 @@ limitations under the License. */
namespace paddle {
namespace framework {
class OperatorTest : public OperatorWithKernel {
class OperatorTest : public OperatorBase {
public:
void Run(const OpRunContext* ctx) const override {
float scale = ctx->op_->GetAttr<float>("scale");
PADDLE_ENFORCE(ctx->Input(0) == nullptr, "Input(0) should not initialized");
PADDLE_ENFORCE(ctx->Output(0) == nullptr,
"Output(1) should not initialized");
auto output1 = ctx->scope_->CreateVariable("output1");
PADDLE_ENFORCE(output1 != nullptr, "should create output1 from scope");
printf("get attr %s = %f\n", "scale", scale);
printf("%s\n", DebugString().c_str());
void InferShape(const std::shared_ptr<Scope>& scope) const override {}
void Run(const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& dev_ctx) const override {
float scale = GetAttr<float>("scale");
ASSERT_NEAR(scale, 3.14, 1e-5);
ASSERT_EQ(scope->GetVariable(inputs_[0]), nullptr);
ASSERT_NE(scope->GetVariable(outputs_[0]), nullptr);
}
};
......@@ -49,31 +47,26 @@ class OperatorTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
REGISTER_OP(OperatorTest, OperatorTestProtoAndCheckerMaker, test_operator)
TEST(OperatorBase, DebugString) {
TEST(OperatorBase, all) {
OpDesc op_desc;
op_desc.set_type("test_operator");
std::vector<std::string> inputs = {"IN1", "IN2"};
for (auto& input : inputs) {
op_desc.add_inputs(input);
}
std::vector<std::string> outputs = {"OUT1", "OUT2"};
for (auto& output : outputs) {
op_desc.add_outputs(output);
}
*op_desc.mutable_inputs()->Add() = "IN1";
*op_desc.mutable_outputs()->Add() = "OUT1";
auto attr = op_desc.mutable_attrs()->Add();
attr->set_name("scale");
attr->set_type(paddle::framework::AttrType::FLOAT);
float scale = 3.14;
attr->set_f(scale);
DeviceContext device_context;
platform::CPUDeviceContext device_context;
auto scope = std::make_shared<Scope>();
OperatorBase* op = paddle::framework::OpRegistry::CreateOp(op_desc);
ASSERT_EQ(op->inputs_, inputs);
ASSERT_EQ(op->outputs_, outputs);
ASSERT_EQ(op->GetAttr<float>("scale"), scale);
op->Run(scope, &device_context);
scope->CreateVariable("OUT1");
op->Run(scope, device_context);
std::cout << op->DebugString() << std::endl;
delete op;
}
} // namespace framework
......
---
Language: Cpp
BasedOnStyle: Google
Standard: Cpp11
...
#pragma once
#include "paddle/framework/op_registry.h"
using namespace paddle::framework;
namespace paddle {
namespace operators {
class CosineOp : public OperatorWithKernel {
public:
void Run(const OpRunContext *context) const override {
printf("%s\n", DebugString().c_str());
}
};
class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
public:
CosineOpProtoAndCheckerMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("input", "input of cosine op");
AddOutput("output", "output of cosine op");
AddAttr<float>("scale", "scale of cosine op")
.SetDefault(1.0)
.LargerThan(0.0);
AddType("cos");
AddComment("This is cos op");
}
};
REGISTER_OP(CosineOp, CosineOpProtoAndCheckerMaker, cos_sim)
class MyTestOp : public OperatorWithKernel {
public:
void Run(const OpRunContext *context) const override {
printf("%s\n", DebugString().c_str());
}
};
class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
public:
MyTestOpProtoAndCheckerMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("input", "input of cosine op");
AddOutput("output", "output of cosine op");
auto my_checker = [](int i) {
PADDLE_ENFORCE(i % 2 == 0, "'test_attr' must be even!");
};
AddAttr<int>("test_attr", "a simple test attribute")
.AddCustomChecker(my_checker);
AddType("my_test_op");
AddComment("This is my_test op");
}
};
REGISTER_OP(MyTestOp, MyTestOpProtoAndCheckerMaker, my_test_op)
} // namespace operators
} // namespace operators
......@@ -22,8 +22,8 @@ limitations under the License. */
#include "paddle/platform/dynload/curand.h"
#define EIGEN_USE_GPU
#endif
#include "paddle/platform/place.h"
#include "unsupported/Eigen/CXX11/Tensor"
#include <paddle/platform/place.h>
#include <unsupported/Eigen/CXX11/Tensor>
namespace paddle {
namespace platform {
......@@ -31,9 +31,16 @@ namespace platform {
class DeviceContext {
public:
virtual ~DeviceContext() {}
virtual Place GetPlace() const = 0;
};
class CPUDeviceContext : public DeviceContext {};
class CPUDeviceContext : public DeviceContext {
public:
Place GetPlace() const override {
Place retv = CPUPlace();
return retv;
}
};
#ifndef PADDLE_ONLY_CPU
......@@ -61,6 +68,11 @@ class CUDADeviceContext : public DeviceContext {
eigen_device_ = new Eigen::GpuDevice(eigen_stream_);
}
Place GetPlace() const override {
Place retv = GPUPlace();
return retv;
}
void Wait() {
paddle::platform::throw_on_error(cudaStreamSynchronize(stream_),
"cudaStreamSynchronize failed");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册