提交 a2e5f652 编写于 作者: Q Qiao Longfei 提交者: GitHub

add operator base (#2725)

Add OperatorBase.

issue: https://github.com/PaddlePaddle/Paddle/issues/2790

Paddle design the Operator with Kernel. OperatorBase has no type and device information when create, One operator can have multiple kernels, Operator will choose a kernel to run according to context. The kernel should be bind to Operator before or during Operator running.
上级 267f9a2c
...@@ -15,6 +15,7 @@ if(Boost_FOUND) ...@@ -15,6 +15,7 @@ if(Boost_FOUND)
add_subdirectory(memory) add_subdirectory(memory)
add_subdirectory(platform) add_subdirectory(platform)
add_subdirectory(framework) add_subdirectory(framework)
add_subdirectory(operators)
add_subdirectory(pybind) add_subdirectory(pybind)
endif() endif()
......
...@@ -11,8 +11,10 @@ proto_library(op_proto SRCS op_proto.proto DEPS attr_type) ...@@ -11,8 +11,10 @@ proto_library(op_proto SRCS op_proto.proto DEPS attr_type)
cc_test(op_proto_test SRCS op_proto_test.cc DEPS op_proto protobuf) cc_test(op_proto_test SRCS op_proto_test.cc DEPS op_proto protobuf)
proto_library(op_desc SRCS op_desc.proto DEPS attr_type) proto_library(op_desc SRCS op_desc.proto DEPS attr_type)
cc_test(op_desc_test SRCS op_desc_test.cc DEPS op_desc protobuf) cc_test(op_desc_test SRCS op_desc_test.cc DEPS op_desc protobuf)
cc_library(operator SRCS operator.cc DEPS op_desc protobuf)
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry)
cc_library(op_registry SRCS op_registry.cc DEPS op_proto op_desc) cc_library(op_registry SRCS op_registry.cc DEPS op_proto op_desc)
cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry) cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry operator)
py_proto_compile(framework_py_proto SRCS attr_type.proto op_proto.proto op_desc.proto) py_proto_compile(framework_py_proto SRCS attr_type.proto op_proto.proto op_desc.proto)
# Generate an empty __init__.py to make framework_py_proto as a valid python module. # Generate an empty __init__.py to make framework_py_proto as a valid python module.
add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py) add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py)
......
...@@ -32,5 +32,5 @@ template <> ...@@ -32,5 +32,5 @@ template <>
void AttrTypeHelper::SetAttrType<std::vector<std::string>>(AttrProto* attr) { void AttrTypeHelper::SetAttrType<std::vector<std::string>>(AttrProto* attr) {
attr->set_type(paddle::framework::AttrType::STRINGS); attr->set_type(paddle::framework::AttrType::STRINGS);
} }
} } // namespace framework
} } // namespace paddle
\ No newline at end of file \ No newline at end of file
#pragma once #pragma once
#include "paddle/framework/attr_checker.h"
//#include "paddle/framework/op_base.h"
#include <algorithm> #include <algorithm>
#include "paddle/framework/attr_checker.h"
#include "paddle/framework/op_desc.pb.h" #include "paddle/framework/op_desc.pb.h"
#include "paddle/framework/op_proto.pb.h" #include "paddle/framework/op_proto.pb.h"
#include "paddle/framework/operator.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
//==================For test================//
class OpBase {
public:
std::vector<std::string> inputs_;
std::vector<std::string> outputs_;
AttributeMap attr_map_;
virtual std::string Run() const = 0;
virtual ~OpBase() {}
};
//=========================================//
// helper class to set attribute type // helper class to set attribute type
struct AttrTypeHelper { struct AttrTypeHelper {
template <typename T> template <typename T>
...@@ -105,7 +92,7 @@ class OpProtoAndCheckerMaker { ...@@ -105,7 +92,7 @@ class OpProtoAndCheckerMaker {
}; };
class OpRegistry { class OpRegistry {
using OpCreator = std::function<OpBase*()>; using OpCreator = std::function<OperatorBase*()>;
public: public:
template <typename OpType, typename ProtoMakerType> template <typename OpType, typename ProtoMakerType>
...@@ -118,9 +105,10 @@ class OpRegistry { ...@@ -118,9 +105,10 @@ class OpRegistry {
"Fail to initialize %s's OpProto !", op_type); "Fail to initialize %s's OpProto !", op_type);
} }
static OpBase* CreateOp(const OpDesc& op_desc) { static OperatorBase* CreateOp(const OpDesc& op_desc) {
std::string op_type = op_desc.type(); std::string op_type = op_desc.type();
OpBase* op = creators().at(op_type)(); OperatorBase* op = creators().at(op_type)();
op->desc_ = op_desc;
op->inputs_.reserve((size_t)op_desc.inputs_size()); op->inputs_.reserve((size_t)op_desc.inputs_size());
std::copy(op_desc.inputs().begin(), op_desc.inputs().end(), std::copy(op_desc.inputs().begin(), op_desc.inputs().end(),
std::back_inserter(op->inputs_)); std::back_inserter(op->inputs_));
...@@ -128,9 +116,9 @@ class OpRegistry { ...@@ -128,9 +116,9 @@ class OpRegistry {
std::copy(op_desc.outputs().begin(), op_desc.outputs().end(), std::copy(op_desc.outputs().begin(), op_desc.outputs().end(),
std::back_inserter(op->outputs_)); std::back_inserter(op->outputs_));
for (auto& attr : op_desc.attrs()) { for (auto& attr : op_desc.attrs()) {
op->attr_map_[attr.name()] = AttrTypeHelper::GetAttrValue(attr); op->attrs_[attr.name()] = AttrTypeHelper::GetAttrValue(attr);
} }
op_checkers().at(op_type).Check(op->attr_map_); op_checkers().at(op_type).Check(op->attrs_);
return op; return op;
} }
......
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "paddle/framework/operator.h"
#include "paddle/operators/demo_op.h"
using namespace paddle::framework;
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class CosineOp : public OpBase { class CosineOp : public OperatorWithKernel {
public: public:
virtual std::string Run() const { void Run(const OpRunContext* context) const override {
std::string msg = "CosineOp runs! scale = " + printf("%s\n", DebugString().c_str());
std::to_string(boost::get<float>(attr_map_.at("scale")));
return msg;
} }
}; };
...@@ -28,13 +30,11 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { ...@@ -28,13 +30,11 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
REGISTER_OP(CosineOp, CosineOpProtoAndCheckerMaker, cos_sim) REGISTER_OP(CosineOp, CosineOpProtoAndCheckerMaker, cos_sim)
class MyTestOp : public OpBase { class MyTestOp : public OperatorWithKernel {
public: public:
virtual std::string Run() const { void Run(const OpRunContext* ctx) const override {
std::string msg = printf("%s\n", DebugString().c_str());
"MyTestOp runs! test_attr = " + printf("test_attr = %d\n", ctx->op_->GetAttr<int>("test_attr"));
std::to_string(boost::get<int>(attr_map_.at("test_attr")));
return msg;
} }
}; };
...@@ -64,19 +64,19 @@ TEST(OpRegistry, CreateOp) { ...@@ -64,19 +64,19 @@ TEST(OpRegistry, CreateOp) {
op_desc.add_inputs("aa"); op_desc.add_inputs("aa");
op_desc.add_outputs("bb"); op_desc.add_outputs("bb");
float scale = 3.3;
auto attr = op_desc.mutable_attrs()->Add(); auto attr = op_desc.mutable_attrs()->Add();
attr->set_name("scale"); attr->set_name("scale");
attr->set_type(paddle::framework::AttrType::FLOAT); attr->set_type(paddle::framework::AttrType::FLOAT);
attr->set_f(3.3); attr->set_f(scale);
paddle::framework::OpBase* op = paddle::framework::OperatorBase* op =
paddle::framework::OpRegistry::CreateOp(op_desc); paddle::framework::OpRegistry::CreateOp(op_desc);
std::string debug_str = op->Run(); auto scope = std::make_shared<Scope>();
std::string str = "CosineOp runs! scale = " + std::to_string(3.3); auto dev_ctx = DeviceContext();
ASSERT_EQ(str.size(), debug_str.size()); op->Run(scope, &dev_ctx);
for (size_t i = 0; i < debug_str.length(); ++i) { float scale_get = op->GetAttr<float>("scale");
ASSERT_EQ(debug_str[i], str[i]); ASSERT_EQ(scale_get, scale);
}
} }
TEST(OpRegistry, IllegalAttr) { TEST(OpRegistry, IllegalAttr) {
...@@ -92,7 +92,7 @@ TEST(OpRegistry, IllegalAttr) { ...@@ -92,7 +92,7 @@ TEST(OpRegistry, IllegalAttr) {
bool caught = false; bool caught = false;
try { try {
paddle::framework::OpBase* op __attribute__((unused)) = paddle::framework::OperatorBase* op __attribute__((unused)) =
paddle::framework::OpRegistry::CreateOp(op_desc); paddle::framework::OpRegistry::CreateOp(op_desc);
} catch (paddle::framework::EnforceNotMet err) { } catch (paddle::framework::EnforceNotMet err) {
caught = true; caught = true;
...@@ -111,15 +111,14 @@ TEST(OpRegistry, DefaultValue) { ...@@ -111,15 +111,14 @@ TEST(OpRegistry, DefaultValue) {
op_desc.add_inputs("aa"); op_desc.add_inputs("aa");
op_desc.add_outputs("bb"); op_desc.add_outputs("bb");
paddle::framework::OpBase* op = ASSERT_TRUE(op_desc.IsInitialized());
paddle::framework::OperatorBase* op =
paddle::framework::OpRegistry::CreateOp(op_desc); paddle::framework::OpRegistry::CreateOp(op_desc);
std::string debug_str = op->Run(); auto scope = std::make_shared<Scope>();
float default_value = 1.0; auto dev_ctx = DeviceContext();
std::string str = "CosineOp runs! scale = " + std::to_string(default_value); op->Run(scope, &dev_ctx);
ASSERT_EQ(str.size(), debug_str.size()); ASSERT_EQ(op->GetAttr<float>("scale"), 1.0);
for (size_t i = 0; i < debug_str.length(); ++i) {
ASSERT_EQ(debug_str[i], str[i]);
}
} }
TEST(OpRegistry, CustomChecker) { TEST(OpRegistry, CustomChecker) {
...@@ -131,7 +130,7 @@ TEST(OpRegistry, CustomChecker) { ...@@ -131,7 +130,7 @@ TEST(OpRegistry, CustomChecker) {
// attr 'test_attr' is not set // attr 'test_attr' is not set
bool caught = false; bool caught = false;
try { try {
paddle::framework::OpBase* op __attribute__((unused)) = paddle::framework::OperatorBase* op __attribute__((unused)) =
paddle::framework::OpRegistry::CreateOp(op_desc); paddle::framework::OpRegistry::CreateOp(op_desc);
} catch (paddle::framework::EnforceNotMet err) { } catch (paddle::framework::EnforceNotMet err) {
caught = true; caught = true;
...@@ -150,7 +149,7 @@ TEST(OpRegistry, CustomChecker) { ...@@ -150,7 +149,7 @@ TEST(OpRegistry, CustomChecker) {
attr->set_i(3); attr->set_i(3);
caught = false; caught = false;
try { try {
paddle::framework::OpBase* op __attribute__((unused)) = paddle::framework::OperatorBase* op __attribute__((unused)) =
paddle::framework::OpRegistry::CreateOp(op_desc); paddle::framework::OpRegistry::CreateOp(op_desc);
} catch (paddle::framework::EnforceNotMet err) { } catch (paddle::framework::EnforceNotMet err) {
caught = true; caught = true;
...@@ -168,14 +167,13 @@ TEST(OpRegistry, CustomChecker) { ...@@ -168,14 +167,13 @@ TEST(OpRegistry, CustomChecker) {
attr->set_name("test_attr"); attr->set_name("test_attr");
attr->set_type(paddle::framework::AttrType::INT); attr->set_type(paddle::framework::AttrType::INT);
attr->set_i(4); attr->set_i(4);
paddle::framework::OpBase* op = paddle::framework::OperatorBase* op =
paddle::framework::OpRegistry::CreateOp(op_desc); paddle::framework::OpRegistry::CreateOp(op_desc);
std::string debug_str = op->Run(); auto dev_ctx = DeviceContext();
std::string str = "MyTestOp runs! test_attr = " + std::to_string(4); auto scope = std::make_shared<Scope>();
ASSERT_EQ(str.size(), debug_str.size()); op->Run(scope, &dev_ctx);
for (size_t i = 0; i < debug_str.length(); ++i) { int test_attr = op->GetAttr<int>("test_attr");
ASSERT_EQ(debug_str[i], str[i]); ASSERT_EQ(test_attr, 4);
}
} }
int main(int argc, char** argv) { int main(int argc, char** argv) {
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/operator.h"
namespace paddle {
namespace framework {
std::string OperatorBase::DebugString() const {
std::stringstream ss;
ss << "=================\n";
ss << "type = " << desc_.type() << "\n";
ss << "inputs = [";
for (auto& ipt : inputs_) {
ss << ipt << ", ";
}
ss << "]\n";
ss << "outputs = [";
for (auto& opt : outputs_) {
ss << opt << ", ";
}
ss << "]\n";
ss << "attr_keys = [";
for (auto& attr : attrs_) {
ss << attr.first << ", ";
}
ss << "]\n";
return ss.str();
}
const Variable* OpRunContext::Input(int index) const {
return scope_->GetVariable(op_->inputs_[index]);
}
Variable* OpRunContext::Output(int index) const {
return scope_->GetVariable(op_->outputs_[index]);
}
} // namespace framework
} // namespace paddle
\ No newline at end of file
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <boost/variant.hpp>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/framework/attr_checker.h"
#include "paddle/framework/op_desc.pb.h"
#include "paddle/framework/scope.h"
#include "paddle/utils/Error.h"
namespace paddle {
namespace framework {
class OperatorBase;
class DeviceContext {};
/**
* OpRunContext is the only parameter of Operator's Run function.
* Run will get input/output variables, state such as momentum and
* device resource such as CUDA stream, cublas handle, etc. from
* OpRunContext. User should construct it before run the Operator.
*/
class OpRunContext {
public:
OpRunContext(const OperatorBase* op, const std::shared_ptr<Scope> scope,
const DeviceContext* device_context)
: op_(op), scope_(scope), device_context_(device_context) {}
const Variable* Input(int index) const;
Variable* Output(int index) const;
public:
const OperatorBase* op_;
const std::shared_ptr<Scope> scope_;
const DeviceContext* device_context_;
};
/**
* OperatorBase has the basic element that Net will call to do computation.
* Only CreateOperator from OpRegistry will new Operator directly. User
* should always construct a proto message OpDesc and call
* OpRegistry::CreateOp(op_desc) to get an Operator instance.
*/
class OperatorBase {
public:
virtual ~OperatorBase() {}
template <typename T>
inline const T& GetAttr(const std::string& name) const {
PADDLE_ENFORCE(attrs_.count(name) != 0, "%s should be in AttributeMap",
name);
return boost::get<T>(attrs_.at(name));
}
std::string DebugString() const;
/// InferShape infer the size of Variables used by this Operator with
/// information inside scope
virtual void InferShape(const std::shared_ptr<Scope>& scope) const = 0;
/// Net will call this function to Run an op.
virtual void Run(const std::shared_ptr<Scope>& scope,
const DeviceContext* dev_ctx) const = 0;
public:
OpDesc desc_;
std::vector<std::string> inputs_;
std::vector<std::string> outputs_;
AttributeMap attrs_;
};
class OperatorWithKernel : public OperatorBase {
public:
virtual ~OperatorWithKernel() {}
virtual void InferShape(const std::shared_ptr<Scope>& scope) const {}
void Run(const std::shared_ptr<Scope>& scope,
const DeviceContext* dev_ctx) const {
OpRunContext op_ctx(this, scope, dev_ctx);
Run(&op_ctx);
}
/// when implement an Op, your should implement this function.
/// this function should be moved to OpKernel later
virtual void Run(const OpRunContext* context) const = 0;
};
} // namespace framework
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/operator.h"
#include "gtest/gtest.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace framework {
class OperatorTest : public OperatorWithKernel {
public:
void Run(const OpRunContext* ctx) const override {
float scale = ctx->op_->GetAttr<float>("scale");
PADDLE_ENFORCE(ctx->Input(0) == nullptr, "Input(0) should not initialized");
PADDLE_ENFORCE(ctx->Output(0) == nullptr,
"Output(1) should not initialized");
auto output1 = ctx->scope_->CreateVariable("output1");
PADDLE_ENFORCE(output1 != nullptr, "should create output1 from scope");
printf("get attr %s = %f\n", "scale", scale);
printf("%s\n", DebugString().c_str());
}
};
class OperatorTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
public:
OperatorTestProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("input", "input of test op");
AddOutput("output", "output of test op");
AddAttr<float>("scale", "scale of cosine op")
.SetDefault(1.0)
.LargerThan(0.0);
AddType("test_operator");
AddComment("This is test op");
}
};
REGISTER_OP(OperatorTest, OperatorTestProtoAndCheckerMaker, test_operator)
TEST(OperatorBase, DebugString) {
OpDesc op_desc;
op_desc.set_type("test_operator");
std::vector<std::string> inputs = {"IN1", "IN2"};
for (auto& input : inputs) {
op_desc.add_inputs(input);
}
std::vector<std::string> outputs = {"OUT1", "OUT2"};
for (auto& output : outputs) {
op_desc.add_outputs(output);
}
auto attr = op_desc.mutable_attrs()->Add();
attr->set_name("scale");
attr->set_type(paddle::framework::AttrType::FLOAT);
float scale = 3.14;
attr->set_f(scale);
DeviceContext device_context;
auto scope = std::make_shared<Scope>();
OperatorBase* op = paddle::framework::OpRegistry::CreateOp(op_desc);
ASSERT_EQ(op->inputs_, inputs);
ASSERT_EQ(op->outputs_, outputs);
ASSERT_EQ(op->GetAttr<float>("scale"), scale);
op->Run(scope, &device_context);
}
} // namespace framework
} // namespace paddle
\ No newline at end of file
---
Language: Cpp
BasedOnStyle: Google
Standard: Cpp11
...
#pragma once
#include "paddle/framework/op_registry.h"
using namespace paddle::framework;
namespace paddle {
namespace operators {
class CosineOp : public OperatorWithKernel {
public:
void Run(const OpRunContext *context) const override {
printf("%s\n", DebugString().c_str());
}
};
class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
public:
CosineOpProtoAndCheckerMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("input", "input of cosine op");
AddOutput("output", "output of cosine op");
AddAttr<float>("scale", "scale of cosine op")
.SetDefault(1.0)
.LargerThan(0.0);
AddType("cos");
AddComment("This is cos op");
}
};
REGISTER_OP(CosineOp, CosineOpProtoAndCheckerMaker, cos_sim)
class MyTestOp : public OperatorWithKernel {
public:
void Run(const OpRunContext *context) const override {
printf("%s\n", DebugString().c_str());
}
};
class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
public:
MyTestOpProtoAndCheckerMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("input", "input of cosine op");
AddOutput("output", "output of cosine op");
auto my_checker = [](int i) {
PADDLE_ENFORCE(i % 2 == 0, "'test_attr' must be even!");
};
AddAttr<int>("test_attr", "a simple test attribute")
.AddCustomChecker(my_checker);
AddType("my_test_op");
AddComment("This is my_test op");
}
};
REGISTER_OP(MyTestOp, MyTestOpProtoAndCheckerMaker, my_test_op)
} // namespace operators
} // namespace operators
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册