From 0ff819207230ac345efefc0a37a3883e81d43c02 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Wed, 12 Jul 2017 14:02:57 +0800 Subject: [PATCH] Add OperatorWithKernel class * User can register OpKernel to its Ops. The OpKernelMap saved in OperatorWithKernel. Each Op which inherits OperatorWithKernel will use `OpKernel::Compute` instead of Run. --- paddle/CMakeLists.txt | 1 - paddle/framework/op_registry_test.cc | 33 ++++---- paddle/framework/operator.cc | 8 -- paddle/framework/operator.h | 117 ++++++++++++++++++--------- paddle/framework/operator_test.cc | 39 ++++----- paddle/operators/.clang-format | 5 -- paddle/operators/CMakeLists.txt | 0 paddle/operators/demo_op.h | 59 -------------- paddle/platform/device_context.h | 18 ++++- 9 files changed, 127 insertions(+), 153 deletions(-) delete mode 100644 paddle/operators/.clang-format delete mode 100644 paddle/operators/CMakeLists.txt delete mode 100644 paddle/operators/demo_op.h diff --git a/paddle/CMakeLists.txt b/paddle/CMakeLists.txt index 2c1eb7521d8..58a35564f83 100644 --- a/paddle/CMakeLists.txt +++ b/paddle/CMakeLists.txt @@ -15,7 +15,6 @@ if(Boost_FOUND) add_subdirectory(memory) add_subdirectory(platform) add_subdirectory(framework) - add_subdirectory(operators) add_subdirectory(pybind) endif() diff --git a/paddle/framework/op_registry_test.cc b/paddle/framework/op_registry_test.cc index c4baafc2aeb..f5d45a80bb8 100644 --- a/paddle/framework/op_registry_test.cc +++ b/paddle/framework/op_registry_test.cc @@ -1,17 +1,15 @@ #include "paddle/framework/op_registry.h" #include -#include "paddle/framework/operator.h" -#include "paddle/operators/demo_op.h" using namespace paddle::framework; namespace paddle { namespace framework { -class CosineOp : public OperatorWithKernel { +class CosineOp : public OperatorBase { public: - void Run(const OpRunContext* context) const override { - printf("%s\n", DebugString().c_str()); - } + void Run(const std::shared_ptr& scope, + const platform::DeviceContext& dev_ctx) const override {} + void InferShape(const std::shared_ptr& scope) const override {} }; class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { @@ -30,12 +28,13 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { REGISTER_OP(CosineOp, CosineOpProtoAndCheckerMaker, cos_sim) -class MyTestOp : public OperatorWithKernel { +class MyTestOp : public OperatorBase { + public: + void InferShape(const std::shared_ptr& scope) const override {} + void Run(const std::shared_ptr& scope, + const platform::DeviceContext& dev_ctx) const override {} + public: - void Run(const OpRunContext* ctx) const override { - printf("%s\n", DebugString().c_str()); - printf("test_attr = %d\n", ctx->op_->GetAttr("test_attr")); - } }; class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { @@ -73,8 +72,8 @@ TEST(OpRegistry, CreateOp) { paddle::framework::OperatorBase* op = paddle::framework::OpRegistry::CreateOp(op_desc); auto scope = std::make_shared(); - auto dev_ctx = DeviceContext(); - op->Run(scope, &dev_ctx); + paddle::platform::CPUDeviceContext dev_ctx; + op->Run(scope, dev_ctx); float scale_get = op->GetAttr("scale"); ASSERT_EQ(scale_get, scale); } @@ -116,8 +115,8 @@ TEST(OpRegistry, DefaultValue) { paddle::framework::OperatorBase* op = paddle::framework::OpRegistry::CreateOp(op_desc); auto scope = std::make_shared(); - auto dev_ctx = DeviceContext(); - op->Run(scope, &dev_ctx); + paddle::platform::CPUDeviceContext dev_ctx; + op->Run(scope, dev_ctx); ASSERT_EQ(op->GetAttr("scale"), 1.0); } @@ -169,9 +168,9 @@ TEST(OpRegistry, CustomChecker) { attr->set_i(4); paddle::framework::OperatorBase* op = paddle::framework::OpRegistry::CreateOp(op_desc); - auto dev_ctx = DeviceContext(); + paddle::platform::CPUDeviceContext dev_ctx; auto scope = std::make_shared(); - op->Run(scope, &dev_ctx); + op->Run(scope, dev_ctx); int test_attr = op->GetAttr("test_attr"); ASSERT_EQ(test_attr, 4); } diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index 3db3706e47d..8f7adff8b39 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -39,13 +39,5 @@ std::string OperatorBase::DebugString() const { return ss.str(); } -const Variable* OpRunContext::Input(int index) const { - return scope_->GetVariable(op_->inputs_[index]); -} - -Variable* OpRunContext::Output(int index) const { - return scope_->GetVariable(op_->outputs_[index]); -} - } // namespace framework } // namespace paddle \ No newline at end of file diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 6570d586981..0ce422e007c 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -14,44 +14,22 @@ limitations under the License. */ #pragma once +#include +#include +#include +#include +#include +#include #include #include #include #include -#include "paddle/framework/attr_checker.h" -#include "paddle/framework/op_desc.pb.h" -#include "paddle/framework/scope.h" -#include "paddle/utils/Error.h" - namespace paddle { namespace framework { class OperatorBase; -class DeviceContext {}; - -/** - * OpRunContext is the only parameter of Operator's Run function. - * Run will get input/output variables, state such as momentum and - * device resource such as CUDA stream, cublas handle, etc. from - * OpRunContext. User should construct it before run the Operator. - */ -class OpRunContext { - public: - OpRunContext(const OperatorBase* op, const std::shared_ptr scope, - const DeviceContext* device_context) - : op_(op), scope_(scope), device_context_(device_context) {} - - const Variable* Input(int index) const; - Variable* Output(int index) const; - - public: - const OperatorBase* op_; - const std::shared_ptr scope_; - const DeviceContext* device_context_; -}; - /** * OperatorBase has the basic element that Net will call to do computation. * Only CreateOperator from OpRegistry will new Operator directly. User @@ -77,7 +55,10 @@ class OperatorBase { /// Net will call this function to Run an op. virtual void Run(const std::shared_ptr& scope, - const DeviceContext* dev_ctx) const = 0; + const platform::DeviceContext& dev_ctx) const = 0; + + protected: + std::string Type() const { return desc_.type(); } public: OpDesc desc_; @@ -86,22 +67,84 @@ class OperatorBase { AttributeMap attrs_; }; +class OpKernel { + public: + /** + * KernelContext is the only parameter of Kernel Run function. + * Run will get input/output variables, state such as momentum and + * device resource such as CUDA stream, cublas handle, etc. from + * KernelContext. User should construct it before run the Operator. + */ + class KernelContext { + public: + KernelContext(const OperatorBase* op, const std::shared_ptr& scope, + const platform::DeviceContext& device_context) + : op_(*op), scope_(scope), device_context_(device_context) {} + + const Variable* Input(int index) const { + return scope_->GetVariable(op_.inputs_[index]); + } + + Variable* Output(int index) const { + return scope_->GetVariable(op_.outputs_[index]); + } + + const OperatorBase& op_; + const std::shared_ptr& scope_; + const platform::DeviceContext& device_context_; + }; + + virtual void Compute(const KernelContext& context) const = 0; + + virtual ~OpKernel() {} +}; + class OperatorWithKernel : public OperatorBase { public: - virtual ~OperatorWithKernel() {} + struct OpKernelKey { + platform::Place place_; - virtual void InferShape(const std::shared_ptr& scope) const {} + OpKernelKey() = default; + OpKernelKey(const platform::DeviceContext& dev_ctx) { + place_ = dev_ctx.GetPlace(); + } + + bool operator==(const OpKernelKey& o) const { return place_ == o.place_; } + }; + + struct OpKernelHash { + std::hash hash_; + size_t operator()(const OpKernelKey& key) const { + return hash_(platform::is_gpu_place(key.place_)); + } + }; + + using OpKernelMap = + std::unordered_map, OpKernelHash>; void Run(const std::shared_ptr& scope, - const DeviceContext* dev_ctx) const { - OpRunContext op_ctx(this, scope, dev_ctx); - Run(&op_ctx); + const platform::DeviceContext& dev_ctx) const final { + auto& opKernel = AllOpKernels().at(Type()).at(OpKernelKey(dev_ctx)); + opKernel->Compute(OpKernel::KernelContext(this, scope, dev_ctx)); } - /// when implement an Op, your should implement this function. - /// this function should be moved to OpKernel later - virtual void Run(const OpRunContext* context) const = 0; + static std::unordered_map& + AllOpKernels() { + static std::unordered_map g_all_op_kernels; + return g_all_op_kernels; + }; }; } // namespace framework } // namespace paddle + +#define REGISTER_OP_KERNEL(type, PlaceType, KernelType) \ + struct __op_kernel_register__##type##__ { \ + __op_kernel_register__##type##__() { \ + ::paddle::framework::OperatorWithKernel::OpKernelKey key; \ + key.place_ = PlaceType(); \ + ::paddle::framework::OperatorWithKernel::AllOpKernels()[#type][key] \ + .reset(new KernelType()); \ + } \ + }; \ + static __op_kernel_register__##type##__ __reg_kernel_##type##__ diff --git a/paddle/framework/operator_test.cc b/paddle/framework/operator_test.cc index 48808dabb27..86f45f108a5 100644 --- a/paddle/framework/operator_test.cc +++ b/paddle/framework/operator_test.cc @@ -19,17 +19,15 @@ limitations under the License. */ namespace paddle { namespace framework { -class OperatorTest : public OperatorWithKernel { +class OperatorTest : public OperatorBase { public: - void Run(const OpRunContext* ctx) const override { - float scale = ctx->op_->GetAttr("scale"); - PADDLE_ENFORCE(ctx->Input(0) == nullptr, "Input(0) should not initialized"); - PADDLE_ENFORCE(ctx->Output(0) == nullptr, - "Output(1) should not initialized"); - auto output1 = ctx->scope_->CreateVariable("output1"); - PADDLE_ENFORCE(output1 != nullptr, "should create output1 from scope"); - printf("get attr %s = %f\n", "scale", scale); - printf("%s\n", DebugString().c_str()); + void InferShape(const std::shared_ptr& scope) const override {} + void Run(const std::shared_ptr& scope, + const platform::DeviceContext& dev_ctx) const override { + float scale = GetAttr("scale"); + ASSERT_NEAR(scale, 3.14, 1e-5); + ASSERT_EQ(scope->GetVariable(inputs_[0]), nullptr); + ASSERT_NE(scope->GetVariable(outputs_[0]), nullptr); } }; @@ -49,31 +47,26 @@ class OperatorTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker { REGISTER_OP(OperatorTest, OperatorTestProtoAndCheckerMaker, test_operator) -TEST(OperatorBase, DebugString) { +TEST(OperatorBase, all) { OpDesc op_desc; op_desc.set_type("test_operator"); - std::vector inputs = {"IN1", "IN2"}; - for (auto& input : inputs) { - op_desc.add_inputs(input); - } - std::vector outputs = {"OUT1", "OUT2"}; - for (auto& output : outputs) { - op_desc.add_outputs(output); - } + *op_desc.mutable_inputs()->Add() = "IN1"; + *op_desc.mutable_outputs()->Add() = "OUT1"; auto attr = op_desc.mutable_attrs()->Add(); attr->set_name("scale"); attr->set_type(paddle::framework::AttrType::FLOAT); float scale = 3.14; attr->set_f(scale); - DeviceContext device_context; + platform::CPUDeviceContext device_context; auto scope = std::make_shared(); OperatorBase* op = paddle::framework::OpRegistry::CreateOp(op_desc); - ASSERT_EQ(op->inputs_, inputs); - ASSERT_EQ(op->outputs_, outputs); ASSERT_EQ(op->GetAttr("scale"), scale); - op->Run(scope, &device_context); + scope->CreateVariable("OUT1"); + op->Run(scope, device_context); + std::cout << op->DebugString() << std::endl; + delete op; } } // namespace framework diff --git a/paddle/operators/.clang-format b/paddle/operators/.clang-format deleted file mode 100644 index 29282dc87e2..00000000000 --- a/paddle/operators/.clang-format +++ /dev/null @@ -1,5 +0,0 @@ ---- -Language: Cpp -BasedOnStyle: Google -Standard: Cpp11 -... diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/paddle/operators/demo_op.h b/paddle/operators/demo_op.h deleted file mode 100644 index d0b7420b4e2..00000000000 --- a/paddle/operators/demo_op.h +++ /dev/null @@ -1,59 +0,0 @@ -#pragma once - -#include "paddle/framework/op_registry.h" - -using namespace paddle::framework; - -namespace paddle { -namespace operators { - -class CosineOp : public OperatorWithKernel { - public: - void Run(const OpRunContext *context) const override { - printf("%s\n", DebugString().c_str()); - } -}; - -class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { - public: - CosineOpProtoAndCheckerMaker(OpProto *proto, OpAttrChecker *op_checker) - : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("input", "input of cosine op"); - AddOutput("output", "output of cosine op"); - AddAttr("scale", "scale of cosine op") - .SetDefault(1.0) - .LargerThan(0.0); - AddType("cos"); - AddComment("This is cos op"); - } -}; - -REGISTER_OP(CosineOp, CosineOpProtoAndCheckerMaker, cos_sim) - -class MyTestOp : public OperatorWithKernel { - public: - void Run(const OpRunContext *context) const override { - printf("%s\n", DebugString().c_str()); - } -}; - -class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { - public: - MyTestOpProtoAndCheckerMaker(OpProto *proto, OpAttrChecker *op_checker) - : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("input", "input of cosine op"); - AddOutput("output", "output of cosine op"); - auto my_checker = [](int i) { - PADDLE_ENFORCE(i % 2 == 0, "'test_attr' must be even!"); - }; - AddAttr("test_attr", "a simple test attribute") - .AddCustomChecker(my_checker); - AddType("my_test_op"); - AddComment("This is my_test op"); - } -}; - -REGISTER_OP(MyTestOp, MyTestOpProtoAndCheckerMaker, my_test_op) - -} // namespace operators -} // namespace operators diff --git a/paddle/platform/device_context.h b/paddle/platform/device_context.h index 160eb4e1206..e3c2cd2647f 100644 --- a/paddle/platform/device_context.h +++ b/paddle/platform/device_context.h @@ -22,8 +22,8 @@ limitations under the License. */ #include "paddle/platform/dynload/curand.h" #define EIGEN_USE_GPU #endif -#include "paddle/platform/place.h" -#include "unsupported/Eigen/CXX11/Tensor" +#include +#include namespace paddle { namespace platform { @@ -31,9 +31,16 @@ namespace platform { class DeviceContext { public: virtual ~DeviceContext() {} + virtual Place GetPlace() const = 0; }; -class CPUDeviceContext : public DeviceContext {}; +class CPUDeviceContext : public DeviceContext { + public: + Place GetPlace() const override { + Place retv = CPUPlace(); + return retv; + } +}; #ifndef PADDLE_ONLY_CPU @@ -61,6 +68,11 @@ class CUDADeviceContext : public DeviceContext { eigen_device_ = new Eigen::GpuDevice(eigen_stream_); } + Place GetPlace() const override { + Place retv = GPUPlace(); + return retv; + } + void Wait() { paddle::platform::throw_on_error(cudaStreamSynchronize(stream_), "cudaStreamSynchronize failed"); -- GitLab