提交 25990d29 编写于 作者: S Superjomn

make kernel_registry support multiple kernels for single type

上级 e55a5cd9
...@@ -25,9 +25,12 @@ std::vector<std::unique_ptr<KernelBase>> OpLite::CreateKernels( ...@@ -25,9 +25,12 @@ std::vector<std::unique_ptr<KernelBase>> OpLite::CreateKernels(
CHECK(!op_type_.empty()) << "op_type_ should be set first"; CHECK(!op_type_.empty()) << "op_type_ should be set first";
for (auto place : places) { for (auto place : places) {
kernels.emplace_back(KernelRegistry::Global().Create( auto ks = KernelRegistry::Global().Create(
(kernel_type.empty() ? op_type_ : kernel_type), place.target, (kernel_type.empty() ? op_type_ : kernel_type), place.target,
place.precision)); place.precision);
for (auto &&it : ks) {
kernels.emplace_back(std::move(it));
}
} }
return kernels; return kernels;
...@@ -61,6 +64,20 @@ bool OpLite::Attach(const framework::OpDesc &opdesc, lite::Scope *scope) { ...@@ -61,6 +64,20 @@ bool OpLite::Attach(const framework::OpDesc &opdesc, lite::Scope *scope) {
return AttachImpl(opdesc, scope); return AttachImpl(opdesc, scope);
} }
const Tensor *OpLite::GetTensor(lite::Scope *scope,
const std::string &name) const {
auto *var = scope->FindVar(name);
CHECK(var) << "no variable called " << name << " found";
return &var->Get<lite::Tensor>();
}
Tensor *OpLite::GetMutableTensor(lite::Scope *scope,
const std::string &name) const {
auto *var = scope->FindVar(name);
CHECK(var) << "no variable called " << name << " found";
return var->GetMutable<lite::Tensor>();
}
bool OpInfo::GetInputArgname(const std::string &value_name, std::string *out) { bool OpInfo::GetInputArgname(const std::string &value_name, std::string *out) {
for (auto &item : input_argument_) { for (auto &item : input_argument_) {
auto it = std::find(item.second.begin(), item.second.end(), value_name); auto it = std::find(item.second.begin(), item.second.end(), value_name);
......
...@@ -119,6 +119,9 @@ class OpLite : public Registry { ...@@ -119,6 +119,9 @@ class OpLite : public Registry {
std::vector<std::unique_ptr<KernelBase>> CreateKernels( std::vector<std::unique_ptr<KernelBase>> CreateKernels(
const std::vector<Place> &places, const std::string &kernel_type = ""); const std::vector<Place> &places, const std::string &kernel_type = "");
const Tensor *GetTensor(lite::Scope *scope, const std::string &name) const;
Tensor *GetMutableTensor(lite::Scope *scope, const std::string &name) const;
friend class mir::Node; friend class mir::Node;
friend class mir::SSAGraph; friend class mir::SSAGraph;
......
...@@ -17,9 +17,8 @@ ...@@ -17,9 +17,8 @@
namespace paddle { namespace paddle {
namespace lite { namespace lite {
std::unique_ptr<KernelBase> KernelRegistry::Create(const std::string &op_type, std::list<std::unique_ptr<KernelBase>> KernelRegistry::Create(
TargetType target, const std::string &op_type, TargetType target, PrecisionType precision) {
PrecisionType precision) {
#define CREATE_KERNEL(target__) \ #define CREATE_KERNEL(target__) \
switch (precision) { \ switch (precision) { \
case PRECISION(kFloat): \ case PRECISION(kFloat): \
...@@ -43,7 +42,7 @@ std::unique_ptr<KernelBase> KernelRegistry::Create(const std::string &op_type, ...@@ -43,7 +42,7 @@ std::unique_ptr<KernelBase> KernelRegistry::Create(const std::string &op_type,
} }
#undef CREATE_KERNEL #undef CREATE_KERNEL
return nullptr; return std::list<std::unique_ptr<KernelBase>>();
} }
KernelRegistry::KernelRegistry() { KernelRegistry::KernelRegistry() {
......
...@@ -52,8 +52,7 @@ class OpLiteRegistor : public Registor<OpClass> { ...@@ -52,8 +52,7 @@ class OpLiteRegistor : public Registor<OpClass> {
template <TargetType Target, PrecisionType Precision> template <TargetType Target, PrecisionType Precision>
using KernelRegistryForTarget = using KernelRegistryForTarget =
Factory<OpKernel<Target, Precision>, Factory<OpKernel<Target, Precision>, std::unique_ptr<KernelBase>>;
std::unique_ptr<OpKernel<Target, Precision>>>;
class KernelRegistry final { class KernelRegistry final {
public: public:
...@@ -80,16 +79,16 @@ class KernelRegistry final { ...@@ -80,16 +79,16 @@ class KernelRegistry final {
} }
template <TargetType Target, PrecisionType Precision> template <TargetType Target, PrecisionType Precision>
std::unique_ptr<KernelBase> Create(const std::string &op_type) { std::list<std::unique_ptr<KernelBase>> Create(const std::string &op_type) {
using kernel_registor_t = KernelRegistryForTarget<Target, Precision>; using kernel_registor_t = KernelRegistryForTarget<Target, Precision>;
return registries_[GetKernelOffset<Target, Precision>()] return registries_[GetKernelOffset<Target, Precision>()]
.template get<kernel_registor_t *>() .template get<kernel_registor_t *>()
->Create(op_type); ->Creates(op_type);
} }
std::unique_ptr<KernelBase> Create(const std::string &op_type, std::list<std::unique_ptr<KernelBase>> Create(const std::string &op_type,
TargetType target, TargetType target,
PrecisionType precision); PrecisionType precision);
// Get a kernel registry offset in all the registries. // Get a kernel registry offset in all the registries.
template <TargetType Target, PrecisionType Precision> template <TargetType Target, PrecisionType Precision>
...@@ -151,29 +150,36 @@ class KernelRegistor : public lite::Registor<KernelType> { ...@@ -151,29 +150,36 @@ class KernelRegistor : public lite::Registor<KernelType> {
// Kernel registry // Kernel registry
#define LITE_KERNEL_REGISTER(op_type__, target__, precision__) \ #define LITE_KERNEL_REGISTER(op_type__, target__, precision__) \
op_type__##__##target__##__##precision__##__registor__ op_type__##__##target__##__##precision__##__registor__
#define LITE_KERNEL_REGISTER_INSTANCE(op_type__, target__, precision__) \ #define LITE_KERNEL_REGISTER_INSTANCE(op_type__, target__, precision__, \
op_type__##__##target__##__##precision__##__registor__instance__ alias__) \
#define LITE_KERNEL_REGISTER_FAKE(op_type__, target__, precision__) \ op_type__##__##target__##__##precision__##__registor__instance__##alias__
LITE_KERNEL_REGISTER_INSTANCE(op_type__, target__, precision__) #define LITE_KERNEL_REGISTER_FAKE(op_type__, target__, precision__, alias__) \
LITE_KERNEL_REGISTER_INSTANCE(op_type__, target__, precision__, alias__)
#define REGISTER_LITE_KERNEL(op_type__, target__, precision__, KernelClass) \
static paddle::lite::KernelRegistor<TARGET(target__), \ #define REGISTER_LITE_KERNEL(op_type__, target__, precision__, KernelClass, \
PRECISION(precision__), KernelClass> \ alias__) \
LITE_KERNEL_REGISTER_INSTANCE(op_type__, target__, \ static paddle::lite::KernelRegistor<TARGET(target__), \
precision__)(#op_type__); \ PRECISION(precision__), KernelClass> \
static KernelClass LITE_KERNEL_INSTANCE(op_type__, target__, precision__); \ LITE_KERNEL_REGISTER_INSTANCE(op_type__, target__, precision__, \
int touch_##op_type__##target__##precision__() { \ alias__)(#op_type__); \
LITE_KERNEL_INSTANCE(op_type__, target__, precision__).Touch(); \ static KernelClass LITE_KERNEL_INSTANCE(op_type__, target__, precision__, \
return 0; \ alias__); \
} \ int touch_##op_type__##target__##precision__##alias__() { \
static bool op_type__##target__##precision__##param_register \ LITE_KERNEL_INSTANCE(op_type__, target__, precision__, alias__).Touch(); \
__attribute__((unused)) = paddle::lite::ParamTypeRegistry::NewInstance< \ return 0; \
TARGET(target__), PRECISION(precision__)>(#op_type__) } \
static bool LITE_KERNEL_PARAM_INSTANCE(op_type__, target__, precision__, \
#define USE_LITE_KERNEL(op_type__, target__, precision__) \ alias__) __attribute__((unused)) = \
extern int touch_##op_type__##target__##precision__(); \ paddle::lite::ParamTypeRegistry::NewInstance<TARGET(target__), \
int LITE_KERNEL_REGISTER_FAKE(op_type__, target__, precision__) \ PRECISION(precision__)>( \
__attribute__((unused)) = touch_##op_type__##target__##precision__(); #op_type__)
#define LITE_KERNEL_INSTANCE(op_type__, target__, precision__) \ #define USE_LITE_KERNEL(op_type__, target__, precision__, alias__) \
op_type__##target__##precision__ extern int touch_##op_type__##target__##precision__##alias__(); \
int op_type__##target__##precision__##alias__ __attribute__((unused)) = \
touch_##op_type__##target__##precision__##alias__();
#define LITE_KERNEL_INSTANCE(op_type__, target__, precision__, alias__) \
op_type__##target__##precision__##alias__
#define LITE_KERNEL_PARAM_INSTANCE(op_type__, target__, precision__, alias__) \
op_type__##target__##precision__##alias__##param_register
...@@ -24,7 +24,7 @@ namespace host { ...@@ -24,7 +24,7 @@ namespace host {
// NOTE should use pure std C++ implementation. // NOTE should use pure std C++ implementation.
void FcCompute::Run() { void FcCompute::Run() {
auto& param = this->param<operators::FcParam>(); auto& param = this->Param<operators::FcParam>();
CHECK_GE(param.input->dims().size(), 2UL); CHECK_GE(param.input->dims().size(), 2UL);
CHECK_EQ(param.output->dims().size(), 2UL); CHECK_EQ(param.output->dims().size(), 2UL);
...@@ -51,7 +51,8 @@ void FcCompute::Run() { ...@@ -51,7 +51,8 @@ void FcCompute::Run() {
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_LITE_KERNEL(fc, kHost, kFloat, paddle::lite::kernels::host::FcCompute) REGISTER_LITE_KERNEL(fc, kHost, kFloat, paddle::lite::kernels::host::FcCompute,
def)
.BindInput("Input", .BindInput("Input",
{paddle::lite::Type::Get<paddle::lite::TensorFp32NCHWTy>( {paddle::lite::Type::Get<paddle::lite::TensorFp32NCHWTy>(
TARGET(kHost))}) TARGET(kHost))})
......
...@@ -26,7 +26,7 @@ class FeedCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> { ...@@ -26,7 +26,7 @@ class FeedCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
using param_t = operators::FeedParam; using param_t = operators::FeedParam;
void Run() override { void Run() override {
auto &theparam = param<operators::FeedParam>(); auto &theparam = Param<operators::FeedParam>();
const Tensor &feed_item = theparam.feed_list->at(theparam.col); const Tensor &feed_item = theparam.feed_list->at(theparam.col);
theparam.out->CopyDataFrom(feed_item); theparam.out->CopyDataFrom(feed_item);
} }
...@@ -38,7 +38,7 @@ class FeedCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> { ...@@ -38,7 +38,7 @@ class FeedCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
} // namespace paddle } // namespace paddle
REGISTER_LITE_KERNEL(feed, kHost, kFloat, REGISTER_LITE_KERNEL(feed, kHost, kFloat,
paddle::lite::kernels::host::FeedCompute) paddle::lite::kernels::host::FeedCompute, def)
.BindInput("X", {paddle::lite::Type::Get<paddle::lite::TensorFp32NCHWTy>( .BindInput("X", {paddle::lite::Type::Get<paddle::lite::TensorFp32NCHWTy>(
TARGET(kHost))}) TARGET(kHost))})
.BindOutput("Out", {paddle::lite::Type::Get<paddle::lite::TensorFp32NCHWTy>( .BindOutput("Out", {paddle::lite::Type::Get<paddle::lite::TensorFp32NCHWTy>(
......
...@@ -40,7 +40,7 @@ class MulCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> { ...@@ -40,7 +40,7 @@ class MulCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
using param_t = operators::MulParam; using param_t = operators::MulParam;
void Run() override { void Run() override {
auto& theparam = param<operators::MulParam>(); auto& theparam = Param<operators::MulParam>();
core::dim2 x_shape( core::dim2 x_shape(
{product(theparam.x->dims().begin(), {product(theparam.x->dims().begin(),
theparam.x->dims().begin() + theparam.x_num_col_dims), theparam.x->dims().begin() + theparam.x_num_col_dims),
...@@ -67,7 +67,7 @@ class MulCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> { ...@@ -67,7 +67,7 @@ class MulCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
} // namespace paddle } // namespace paddle
REGISTER_LITE_KERNEL(mul, kHost, kFloat, REGISTER_LITE_KERNEL(mul, kHost, kFloat,
paddle::lite::kernels::host::MulCompute) paddle::lite::kernels::host::MulCompute, def)
.BindInput("X", {paddle::lite::Type::Get<paddle::lite::TensorFp32NCHWTy>( .BindInput("X", {paddle::lite::Type::Get<paddle::lite::TensorFp32NCHWTy>(
TARGET(kHost))}) TARGET(kHost))})
.BindInput("Y", {paddle::lite::Type::Get<paddle::lite::TensorFp32NCHWTy>( .BindInput("Y", {paddle::lite::Type::Get<paddle::lite::TensorFp32NCHWTy>(
......
...@@ -24,7 +24,7 @@ namespace host { ...@@ -24,7 +24,7 @@ namespace host {
class ReluCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> { class ReluCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
public: public:
void Run() override { void Run() override {
auto& theparam = param<operators::ReluParam>(); auto& theparam = Param<operators::ReluParam>();
auto n = product(theparam.input->dims()); auto n = product(theparam.input->dims());
const float* input = theparam.input->data<float>(); const float* input = theparam.input->data<float>();
float* output = theparam.output->mutable_data<float>(); float* output = theparam.output->mutable_data<float>();
...@@ -43,5 +43,5 @@ class ReluCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> { ...@@ -43,5 +43,5 @@ class ReluCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
} // namespace paddle } // namespace paddle
REGISTER_LITE_KERNEL(relu, kHost, kFloat, REGISTER_LITE_KERNEL(relu, kHost, kFloat,
paddle::lite::kernels::host::ReluCompute) paddle::lite::kernels::host::ReluCompute, def)
.Finalize(); .Finalize();
...@@ -36,7 +36,7 @@ class ScaleCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> { ...@@ -36,7 +36,7 @@ class ScaleCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
using param_t = operators::MulParam; using param_t = operators::MulParam;
void Run() override { void Run() override {
auto& theparam = param<operators::ScaleParam>(); auto& theparam = Param<operators::ScaleParam>();
scale_compute(theparam.x->data<float>(), theparam.x->mutable_data<float>(), scale_compute(theparam.x->data<float>(), theparam.x->mutable_data<float>(),
product(theparam.x->dims()), theparam.scale, theparam.bias, product(theparam.x->dims()), theparam.scale, theparam.bias,
theparam.bias_after_scale); theparam.bias_after_scale);
...@@ -51,5 +51,5 @@ class ScaleCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> { ...@@ -51,5 +51,5 @@ class ScaleCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
} // namespace paddle } // namespace paddle
REGISTER_LITE_KERNEL(scale, kHost, kFloat, REGISTER_LITE_KERNEL(scale, kHost, kFloat,
paddle::lite::kernels::host::ScaleCompute) paddle::lite::kernels::host::ScaleCompute, def)
.Finalize(); .Finalize();
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "fc_op.h" #include "paddle/fluid/lite/operators/fc_op.h"
#include "paddle/fluid/lite/core/op_registry.h" #include "paddle/fluid/lite/core/op_registry.h"
namespace paddle { namespace paddle {
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#pragma once #pragma once
#include <iostream> #include <iostream>
#include <list>
#include <memory> #include <memory>
#include <sstream> #include <sstream>
#include <unordered_map> #include <unordered_map>
...@@ -49,13 +50,21 @@ class Factory { ...@@ -49,13 +50,21 @@ class Factory {
void Register(const std::string& op_type, creator_t&& creator) { void Register(const std::string& op_type, creator_t&& creator) {
CHECK(!creators_.count(op_type)) << "The op " << op_type CHECK(!creators_.count(op_type)) << "The op " << op_type
<< " has already registered"; << " has already registered";
creators_.emplace(op_type, std::move(creator)); creators_[op_type].emplace_back(std::move(creator));
} }
item_ptr_t Create(const std::string& op_type) const { item_ptr_t Create(const std::string& op_type) const {
return std::move(Creates(op_type).front());
}
std::list<item_ptr_t> Creates(const std::string& op_type) const {
auto it = creators_.find(op_type); auto it = creators_.find(op_type);
CHECK(it != creators_.end()) << "no item called " << op_type; CHECK(it != creators_.end()) << "no item called " << op_type;
return it->second(); std::list<item_ptr_t> res;
for (auto& c : it->second) {
res.emplace_back(c());
}
return res;
} }
std::string DebugString() const { std::string DebugString() const {
...@@ -67,7 +76,7 @@ class Factory { ...@@ -67,7 +76,7 @@ class Factory {
} }
protected: protected:
std::unordered_map<std::string, creator_t> creators_; std::unordered_map<std::string, std::list<creator_t>> creators_;
}; };
/* A helper function to help run a lambda at the start. /* A helper function to help run a lambda at the start.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册