提交 0a320081 编写于 作者: Y Yu Yang 提交者: GitHub

Merge pull request #2815 from reyoung/feature/op_kernel

OperatorWithKernel
...@@ -15,7 +15,6 @@ if(Boost_FOUND) ...@@ -15,7 +15,6 @@ if(Boost_FOUND)
add_subdirectory(memory) add_subdirectory(memory)
add_subdirectory(platform) add_subdirectory(platform)
add_subdirectory(framework) add_subdirectory(framework)
add_subdirectory(operators)
add_subdirectory(pybind) add_subdirectory(pybind)
endif() endif()
......
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "paddle/framework/operator.h"
#include "paddle/operators/demo_op.h"
using namespace paddle::framework; using namespace paddle::framework;
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class CosineOp : public OperatorWithKernel { class CosineOp : public OperatorBase {
public: public:
void Run(const OpRunContext* context) const override { void Run(const std::shared_ptr<Scope>& scope,
printf("%s\n", DebugString().c_str()); const platform::DeviceContext& dev_ctx) const override {}
} void InferShape(const std::shared_ptr<Scope>& scope) const override {}
}; };
class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
...@@ -30,12 +28,13 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { ...@@ -30,12 +28,13 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
REGISTER_OP(CosineOp, CosineOpProtoAndCheckerMaker, cos_sim) REGISTER_OP(CosineOp, CosineOpProtoAndCheckerMaker, cos_sim)
class MyTestOp : public OperatorWithKernel { class MyTestOp : public OperatorBase {
public:
void InferShape(const std::shared_ptr<Scope>& scope) const override {}
void Run(const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& dev_ctx) const override {}
public: public:
void Run(const OpRunContext* ctx) const override {
printf("%s\n", DebugString().c_str());
printf("test_attr = %d\n", ctx->op_->GetAttr<int>("test_attr"));
}
}; };
class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
...@@ -73,8 +72,8 @@ TEST(OpRegistry, CreateOp) { ...@@ -73,8 +72,8 @@ TEST(OpRegistry, CreateOp) {
paddle::framework::OperatorBase* op = paddle::framework::OperatorBase* op =
paddle::framework::OpRegistry::CreateOp(op_desc); paddle::framework::OpRegistry::CreateOp(op_desc);
auto scope = std::make_shared<Scope>(); auto scope = std::make_shared<Scope>();
auto dev_ctx = DeviceContext(); paddle::platform::CPUDeviceContext dev_ctx;
op->Run(scope, &dev_ctx); op->Run(scope, dev_ctx);
float scale_get = op->GetAttr<float>("scale"); float scale_get = op->GetAttr<float>("scale");
ASSERT_EQ(scale_get, scale); ASSERT_EQ(scale_get, scale);
} }
...@@ -116,8 +115,8 @@ TEST(OpRegistry, DefaultValue) { ...@@ -116,8 +115,8 @@ TEST(OpRegistry, DefaultValue) {
paddle::framework::OperatorBase* op = paddle::framework::OperatorBase* op =
paddle::framework::OpRegistry::CreateOp(op_desc); paddle::framework::OpRegistry::CreateOp(op_desc);
auto scope = std::make_shared<Scope>(); auto scope = std::make_shared<Scope>();
auto dev_ctx = DeviceContext(); paddle::platform::CPUDeviceContext dev_ctx;
op->Run(scope, &dev_ctx); op->Run(scope, dev_ctx);
ASSERT_EQ(op->GetAttr<float>("scale"), 1.0); ASSERT_EQ(op->GetAttr<float>("scale"), 1.0);
} }
...@@ -169,9 +168,9 @@ TEST(OpRegistry, CustomChecker) { ...@@ -169,9 +168,9 @@ TEST(OpRegistry, CustomChecker) {
attr->set_i(4); attr->set_i(4);
paddle::framework::OperatorBase* op = paddle::framework::OperatorBase* op =
paddle::framework::OpRegistry::CreateOp(op_desc); paddle::framework::OpRegistry::CreateOp(op_desc);
auto dev_ctx = DeviceContext(); paddle::platform::CPUDeviceContext dev_ctx;
auto scope = std::make_shared<Scope>(); auto scope = std::make_shared<Scope>();
op->Run(scope, &dev_ctx); op->Run(scope, dev_ctx);
int test_attr = op->GetAttr<int>("test_attr"); int test_attr = op->GetAttr<int>("test_attr");
ASSERT_EQ(test_attr, 4); ASSERT_EQ(test_attr, 4);
} }
......
...@@ -39,13 +39,5 @@ std::string OperatorBase::DebugString() const { ...@@ -39,13 +39,5 @@ std::string OperatorBase::DebugString() const {
return ss.str(); return ss.str();
} }
const Variable* OpRunContext::Input(int index) const {
return scope_->GetVariable(op_->inputs_[index]);
}
Variable* OpRunContext::Output(int index) const {
return scope_->GetVariable(op_->outputs_[index]);
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
\ No newline at end of file
...@@ -14,44 +14,22 @@ limitations under the License. */ ...@@ -14,44 +14,22 @@ limitations under the License. */
#pragma once #pragma once
#include <paddle/framework/attr_checker.h>
#include <paddle/framework/op_desc.pb.h>
#include <paddle/framework/scope.h>
#include <paddle/platform/device_context.h>
#include <paddle/platform/place.h>
#include <paddle/utils/Error.h>
#include <boost/variant.hpp> #include <boost/variant.hpp>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "paddle/framework/attr_checker.h"
#include "paddle/framework/op_desc.pb.h"
#include "paddle/framework/scope.h"
#include "paddle/utils/Error.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class OperatorBase; class OperatorBase;
class DeviceContext {};
/**
* OpRunContext is the only parameter of Operator's Run function.
* Run will get input/output variables, state such as momentum and
* device resource such as CUDA stream, cublas handle, etc. from
* OpRunContext. User should construct it before run the Operator.
*/
class OpRunContext {
public:
OpRunContext(const OperatorBase* op, const std::shared_ptr<Scope> scope,
const DeviceContext* device_context)
: op_(op), scope_(scope), device_context_(device_context) {}
const Variable* Input(int index) const;
Variable* Output(int index) const;
public:
const OperatorBase* op_;
const std::shared_ptr<Scope> scope_;
const DeviceContext* device_context_;
};
/** /**
* OperatorBase has the basic element that Net will call to do computation. * OperatorBase has the basic element that Net will call to do computation.
* Only CreateOperator from OpRegistry will new Operator directly. User * Only CreateOperator from OpRegistry will new Operator directly. User
...@@ -77,7 +55,10 @@ class OperatorBase { ...@@ -77,7 +55,10 @@ class OperatorBase {
/// Net will call this function to Run an op. /// Net will call this function to Run an op.
virtual void Run(const std::shared_ptr<Scope>& scope, virtual void Run(const std::shared_ptr<Scope>& scope,
const DeviceContext* dev_ctx) const = 0; const platform::DeviceContext& dev_ctx) const = 0;
protected:
std::string Type() const { return desc_.type(); }
public: public:
OpDesc desc_; OpDesc desc_;
...@@ -86,22 +67,84 @@ class OperatorBase { ...@@ -86,22 +67,84 @@ class OperatorBase {
AttributeMap attrs_; AttributeMap attrs_;
}; };
class OpKernel {
public:
/**
* KernelContext is the only parameter of Kernel Run function.
* Run will get input/output variables, state such as momentum and
* device resource such as CUDA stream, cublas handle, etc. from
* KernelContext. User should construct it before run the Operator.
*/
class KernelContext {
public:
KernelContext(const OperatorBase* op, const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& device_context)
: op_(*op), scope_(scope), device_context_(device_context) {}
const Variable* Input(int index) const {
return scope_->GetVariable(op_.inputs_[index]);
}
Variable* Output(int index) const {
return scope_->GetVariable(op_.outputs_[index]);
}
const OperatorBase& op_;
const std::shared_ptr<Scope>& scope_;
const platform::DeviceContext& device_context_;
};
virtual void Compute(const KernelContext& context) const = 0;
virtual ~OpKernel() {}
};
class OperatorWithKernel : public OperatorBase { class OperatorWithKernel : public OperatorBase {
public: public:
virtual ~OperatorWithKernel() {} struct OpKernelKey {
platform::Place place_;
virtual void InferShape(const std::shared_ptr<Scope>& scope) const {} OpKernelKey() = default;
OpKernelKey(const platform::DeviceContext& dev_ctx) {
place_ = dev_ctx.GetPlace();
}
bool operator==(const OpKernelKey& o) const { return place_ == o.place_; }
};
struct OpKernelHash {
std::hash<bool> hash_;
size_t operator()(const OpKernelKey& key) const {
return hash_(platform::is_gpu_place(key.place_));
}
};
using OpKernelMap =
std::unordered_map<OpKernelKey, std::unique_ptr<OpKernel>, OpKernelHash>;
void Run(const std::shared_ptr<Scope>& scope, void Run(const std::shared_ptr<Scope>& scope,
const DeviceContext* dev_ctx) const { const platform::DeviceContext& dev_ctx) const final {
OpRunContext op_ctx(this, scope, dev_ctx); auto& opKernel = AllOpKernels().at(Type()).at(OpKernelKey(dev_ctx));
Run(&op_ctx); opKernel->Compute(OpKernel::KernelContext(this, scope, dev_ctx));
} }
/// when implement an Op, your should implement this function. static std::unordered_map<std::string /* op_type */, OpKernelMap>&
/// this function should be moved to OpKernel later AllOpKernels() {
virtual void Run(const OpRunContext* context) const = 0; static std::unordered_map<std::string, OpKernelMap> g_all_op_kernels;
return g_all_op_kernels;
};
}; };
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
#define REGISTER_OP_KERNEL(type, PlaceType, KernelType) \
struct __op_kernel_register__##type##__ { \
__op_kernel_register__##type##__() { \
::paddle::framework::OperatorWithKernel::OpKernelKey key; \
key.place_ = PlaceType(); \
::paddle::framework::OperatorWithKernel::AllOpKernels()[#type][key] \
.reset(new KernelType()); \
} \
}; \
static __op_kernel_register__##type##__ __reg_kernel_##type##__
...@@ -19,17 +19,15 @@ limitations under the License. */ ...@@ -19,17 +19,15 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class OperatorTest : public OperatorWithKernel { class OperatorTest : public OperatorBase {
public: public:
void Run(const OpRunContext* ctx) const override { void InferShape(const std::shared_ptr<Scope>& scope) const override {}
float scale = ctx->op_->GetAttr<float>("scale"); void Run(const std::shared_ptr<Scope>& scope,
PADDLE_ENFORCE(ctx->Input(0) == nullptr, "Input(0) should not initialized"); const platform::DeviceContext& dev_ctx) const override {
PADDLE_ENFORCE(ctx->Output(0) == nullptr, float scale = GetAttr<float>("scale");
"Output(1) should not initialized"); ASSERT_NEAR(scale, 3.14, 1e-5);
auto output1 = ctx->scope_->CreateVariable("output1"); ASSERT_EQ(scope->GetVariable(inputs_[0]), nullptr);
PADDLE_ENFORCE(output1 != nullptr, "should create output1 from scope"); ASSERT_NE(scope->GetVariable(outputs_[0]), nullptr);
printf("get attr %s = %f\n", "scale", scale);
printf("%s\n", DebugString().c_str());
} }
}; };
...@@ -49,31 +47,26 @@ class OperatorTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker { ...@@ -49,31 +47,26 @@ class OperatorTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
REGISTER_OP(OperatorTest, OperatorTestProtoAndCheckerMaker, test_operator) REGISTER_OP(OperatorTest, OperatorTestProtoAndCheckerMaker, test_operator)
TEST(OperatorBase, DebugString) { TEST(OperatorBase, all) {
OpDesc op_desc; OpDesc op_desc;
op_desc.set_type("test_operator"); op_desc.set_type("test_operator");
std::vector<std::string> inputs = {"IN1", "IN2"}; *op_desc.mutable_inputs()->Add() = "IN1";
for (auto& input : inputs) { *op_desc.mutable_outputs()->Add() = "OUT1";
op_desc.add_inputs(input);
}
std::vector<std::string> outputs = {"OUT1", "OUT2"};
for (auto& output : outputs) {
op_desc.add_outputs(output);
}
auto attr = op_desc.mutable_attrs()->Add(); auto attr = op_desc.mutable_attrs()->Add();
attr->set_name("scale"); attr->set_name("scale");
attr->set_type(paddle::framework::AttrType::FLOAT); attr->set_type(paddle::framework::AttrType::FLOAT);
float scale = 3.14; float scale = 3.14;
attr->set_f(scale); attr->set_f(scale);
DeviceContext device_context; platform::CPUDeviceContext device_context;
auto scope = std::make_shared<Scope>(); auto scope = std::make_shared<Scope>();
OperatorBase* op = paddle::framework::OpRegistry::CreateOp(op_desc); OperatorBase* op = paddle::framework::OpRegistry::CreateOp(op_desc);
ASSERT_EQ(op->inputs_, inputs);
ASSERT_EQ(op->outputs_, outputs);
ASSERT_EQ(op->GetAttr<float>("scale"), scale); ASSERT_EQ(op->GetAttr<float>("scale"), scale);
op->Run(scope, &device_context); scope->CreateVariable("OUT1");
op->Run(scope, device_context);
std::cout << op->DebugString() << std::endl;
delete op;
} }
} // namespace framework } // namespace framework
......
---
Language: Cpp
BasedOnStyle: Google
Standard: Cpp11
...
#pragma once
#include "paddle/framework/op_registry.h"
using namespace paddle::framework;
namespace paddle {
namespace operators {
class CosineOp : public OperatorWithKernel {
public:
void Run(const OpRunContext *context) const override {
printf("%s\n", DebugString().c_str());
}
};
class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
public:
CosineOpProtoAndCheckerMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("input", "input of cosine op");
AddOutput("output", "output of cosine op");
AddAttr<float>("scale", "scale of cosine op")
.SetDefault(1.0)
.LargerThan(0.0);
AddType("cos");
AddComment("This is cos op");
}
};
REGISTER_OP(CosineOp, CosineOpProtoAndCheckerMaker, cos_sim)
class MyTestOp : public OperatorWithKernel {
public:
void Run(const OpRunContext *context) const override {
printf("%s\n", DebugString().c_str());
}
};
class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
public:
MyTestOpProtoAndCheckerMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("input", "input of cosine op");
AddOutput("output", "output of cosine op");
auto my_checker = [](int i) {
PADDLE_ENFORCE(i % 2 == 0, "'test_attr' must be even!");
};
AddAttr<int>("test_attr", "a simple test attribute")
.AddCustomChecker(my_checker);
AddType("my_test_op");
AddComment("This is my_test op");
}
};
REGISTER_OP(MyTestOp, MyTestOpProtoAndCheckerMaker, my_test_op)
} // namespace operators
} // namespace operators
...@@ -22,8 +22,8 @@ limitations under the License. */ ...@@ -22,8 +22,8 @@ limitations under the License. */
#include "paddle/platform/dynload/curand.h" #include "paddle/platform/dynload/curand.h"
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#endif #endif
#include "paddle/platform/place.h" #include <paddle/platform/place.h>
#include "unsupported/Eigen/CXX11/Tensor" #include <unsupported/Eigen/CXX11/Tensor>
namespace paddle { namespace paddle {
namespace platform { namespace platform {
...@@ -31,9 +31,16 @@ namespace platform { ...@@ -31,9 +31,16 @@ namespace platform {
class DeviceContext { class DeviceContext {
public: public:
virtual ~DeviceContext() {} virtual ~DeviceContext() {}
virtual Place GetPlace() const = 0;
}; };
class CPUDeviceContext : public DeviceContext {}; class CPUDeviceContext : public DeviceContext {
public:
Place GetPlace() const override {
Place retv = CPUPlace();
return retv;
}
};
#ifndef PADDLE_ONLY_CPU #ifndef PADDLE_ONLY_CPU
...@@ -61,6 +68,11 @@ class CUDADeviceContext : public DeviceContext { ...@@ -61,6 +68,11 @@ class CUDADeviceContext : public DeviceContext {
eigen_device_ = new Eigen::GpuDevice(eigen_stream_); eigen_device_ = new Eigen::GpuDevice(eigen_stream_);
} }
Place GetPlace() const override {
Place retv = GPUPlace();
return retv;
}
void Wait() { void Wait() {
paddle::platform::throw_on_error(cudaStreamSynchronize(stream_), paddle::platform::throw_on_error(cudaStreamSynchronize(stream_),
"cudaStreamSynchronize failed"); "cudaStreamSynchronize failed");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册