From 25990d296c52b26a868d25863cecb33e67e45484 Mon Sep 17 00:00:00 2001 From: Superjomn Date: Sun, 21 Apr 2019 16:39:57 +0800 Subject: [PATCH] make kernel_registry support multiple kernels for single type --- paddle/fluid/lite/core/op_lite.cc | 21 +++++- paddle/fluid/lite/core/op_lite.h | 3 + paddle/fluid/lite/core/op_registry.cc | 7 +- paddle/fluid/lite/core/op_registry.h | 72 ++++++++++--------- paddle/fluid/lite/kernels/host/fc_compute.cc | 5 +- .../fluid/lite/kernels/host/feed_compute.cc | 4 +- paddle/fluid/lite/kernels/host/mul_compute.cc | 4 +- paddle/fluid/lite/kernels/host/relu_compute.h | 4 +- .../fluid/lite/kernels/host/scale_compute.cc | 4 +- paddle/fluid/lite/operators/fc_op.cc | 2 +- paddle/fluid/lite/utils/factory.h | 15 +++- 11 files changed, 88 insertions(+), 53 deletions(-) diff --git a/paddle/fluid/lite/core/op_lite.cc b/paddle/fluid/lite/core/op_lite.cc index e1bbcfc1a6e..1493b577c5d 100644 --- a/paddle/fluid/lite/core/op_lite.cc +++ b/paddle/fluid/lite/core/op_lite.cc @@ -25,9 +25,12 @@ std::vector> OpLite::CreateKernels( CHECK(!op_type_.empty()) << "op_type_ should be set first"; for (auto place : places) { - kernels.emplace_back(KernelRegistry::Global().Create( + auto ks = KernelRegistry::Global().Create( (kernel_type.empty() ? op_type_ : kernel_type), place.target, - place.precision)); + place.precision); + for (auto &&it : ks) { + kernels.emplace_back(std::move(it)); + } } return kernels; @@ -61,6 +64,20 @@ bool OpLite::Attach(const framework::OpDesc &opdesc, lite::Scope *scope) { return AttachImpl(opdesc, scope); } +const Tensor *OpLite::GetTensor(lite::Scope *scope, + const std::string &name) const { + auto *var = scope->FindVar(name); + CHECK(var) << "no variable called " << name << " found"; + return &var->Get(); +} + +Tensor *OpLite::GetMutableTensor(lite::Scope *scope, + const std::string &name) const { + auto *var = scope->FindVar(name); + CHECK(var) << "no variable called " << name << " found"; + return var->GetMutable(); +} + bool OpInfo::GetInputArgname(const std::string &value_name, std::string *out) { for (auto &item : input_argument_) { auto it = std::find(item.second.begin(), item.second.end(), value_name); diff --git a/paddle/fluid/lite/core/op_lite.h b/paddle/fluid/lite/core/op_lite.h index 3fd020b176a..54d973ebc7d 100644 --- a/paddle/fluid/lite/core/op_lite.h +++ b/paddle/fluid/lite/core/op_lite.h @@ -119,6 +119,9 @@ class OpLite : public Registry { std::vector> CreateKernels( const std::vector &places, const std::string &kernel_type = ""); + const Tensor *GetTensor(lite::Scope *scope, const std::string &name) const; + Tensor *GetMutableTensor(lite::Scope *scope, const std::string &name) const; + friend class mir::Node; friend class mir::SSAGraph; diff --git a/paddle/fluid/lite/core/op_registry.cc b/paddle/fluid/lite/core/op_registry.cc index 38bc79aaba6..676cbe2dfcf 100644 --- a/paddle/fluid/lite/core/op_registry.cc +++ b/paddle/fluid/lite/core/op_registry.cc @@ -17,9 +17,8 @@ namespace paddle { namespace lite { -std::unique_ptr KernelRegistry::Create(const std::string &op_type, - TargetType target, - PrecisionType precision) { +std::list> KernelRegistry::Create( + const std::string &op_type, TargetType target, PrecisionType precision) { #define CREATE_KERNEL(target__) \ switch (precision) { \ case PRECISION(kFloat): \ @@ -43,7 +42,7 @@ std::unique_ptr KernelRegistry::Create(const std::string &op_type, } #undef CREATE_KERNEL - return nullptr; + return std::list>(); } KernelRegistry::KernelRegistry() { diff --git a/paddle/fluid/lite/core/op_registry.h b/paddle/fluid/lite/core/op_registry.h index 04c19fbf873..749f786a71c 100644 --- a/paddle/fluid/lite/core/op_registry.h +++ b/paddle/fluid/lite/core/op_registry.h @@ -52,8 +52,7 @@ class OpLiteRegistor : public Registor { template using KernelRegistryForTarget = - Factory, - std::unique_ptr>>; + Factory, std::unique_ptr>; class KernelRegistry final { public: @@ -80,16 +79,16 @@ class KernelRegistry final { } template - std::unique_ptr Create(const std::string &op_type) { + std::list> Create(const std::string &op_type) { using kernel_registor_t = KernelRegistryForTarget; return registries_[GetKernelOffset()] .template get() - ->Create(op_type); + ->Creates(op_type); } - std::unique_ptr Create(const std::string &op_type, - TargetType target, - PrecisionType precision); + std::list> Create(const std::string &op_type, + TargetType target, + PrecisionType precision); // Get a kernel registry offset in all the registries. template @@ -151,29 +150,36 @@ class KernelRegistor : public lite::Registor { // Kernel registry #define LITE_KERNEL_REGISTER(op_type__, target__, precision__) \ op_type__##__##target__##__##precision__##__registor__ -#define LITE_KERNEL_REGISTER_INSTANCE(op_type__, target__, precision__) \ - op_type__##__##target__##__##precision__##__registor__instance__ -#define LITE_KERNEL_REGISTER_FAKE(op_type__, target__, precision__) \ - LITE_KERNEL_REGISTER_INSTANCE(op_type__, target__, precision__) - -#define REGISTER_LITE_KERNEL(op_type__, target__, precision__, KernelClass) \ - static paddle::lite::KernelRegistor \ - LITE_KERNEL_REGISTER_INSTANCE(op_type__, target__, \ - precision__)(#op_type__); \ - static KernelClass LITE_KERNEL_INSTANCE(op_type__, target__, precision__); \ - int touch_##op_type__##target__##precision__() { \ - LITE_KERNEL_INSTANCE(op_type__, target__, precision__).Touch(); \ - return 0; \ - } \ - static bool op_type__##target__##precision__##param_register \ - __attribute__((unused)) = paddle::lite::ParamTypeRegistry::NewInstance< \ - TARGET(target__), PRECISION(precision__)>(#op_type__) - -#define USE_LITE_KERNEL(op_type__, target__, precision__) \ - extern int touch_##op_type__##target__##precision__(); \ - int LITE_KERNEL_REGISTER_FAKE(op_type__, target__, precision__) \ - __attribute__((unused)) = touch_##op_type__##target__##precision__(); - -#define LITE_KERNEL_INSTANCE(op_type__, target__, precision__) \ - op_type__##target__##precision__ +#define LITE_KERNEL_REGISTER_INSTANCE(op_type__, target__, precision__, \ + alias__) \ + op_type__##__##target__##__##precision__##__registor__instance__##alias__ +#define LITE_KERNEL_REGISTER_FAKE(op_type__, target__, precision__, alias__) \ + LITE_KERNEL_REGISTER_INSTANCE(op_type__, target__, precision__, alias__) + +#define REGISTER_LITE_KERNEL(op_type__, target__, precision__, KernelClass, \ + alias__) \ + static paddle::lite::KernelRegistor \ + LITE_KERNEL_REGISTER_INSTANCE(op_type__, target__, precision__, \ + alias__)(#op_type__); \ + static KernelClass LITE_KERNEL_INSTANCE(op_type__, target__, precision__, \ + alias__); \ + int touch_##op_type__##target__##precision__##alias__() { \ + LITE_KERNEL_INSTANCE(op_type__, target__, precision__, alias__).Touch(); \ + return 0; \ + } \ + static bool LITE_KERNEL_PARAM_INSTANCE(op_type__, target__, precision__, \ + alias__) __attribute__((unused)) = \ + paddle::lite::ParamTypeRegistry::NewInstance( \ + #op_type__) + +#define USE_LITE_KERNEL(op_type__, target__, precision__, alias__) \ + extern int touch_##op_type__##target__##precision__##alias__(); \ + int op_type__##target__##precision__##alias__ __attribute__((unused)) = \ + touch_##op_type__##target__##precision__##alias__(); + +#define LITE_KERNEL_INSTANCE(op_type__, target__, precision__, alias__) \ + op_type__##target__##precision__##alias__ +#define LITE_KERNEL_PARAM_INSTANCE(op_type__, target__, precision__, alias__) \ + op_type__##target__##precision__##alias__##param_register diff --git a/paddle/fluid/lite/kernels/host/fc_compute.cc b/paddle/fluid/lite/kernels/host/fc_compute.cc index e81900bad9d..ac9a4ccc0a7 100644 --- a/paddle/fluid/lite/kernels/host/fc_compute.cc +++ b/paddle/fluid/lite/kernels/host/fc_compute.cc @@ -24,7 +24,7 @@ namespace host { // NOTE should use pure std C++ implementation. void FcCompute::Run() { - auto& param = this->param(); + auto& param = this->Param(); CHECK_GE(param.input->dims().size(), 2UL); CHECK_EQ(param.output->dims().size(), 2UL); @@ -51,7 +51,8 @@ void FcCompute::Run() { } // namespace lite } // namespace paddle -REGISTER_LITE_KERNEL(fc, kHost, kFloat, paddle::lite::kernels::host::FcCompute) +REGISTER_LITE_KERNEL(fc, kHost, kFloat, paddle::lite::kernels::host::FcCompute, + def) .BindInput("Input", {paddle::lite::Type::Get( TARGET(kHost))}) diff --git a/paddle/fluid/lite/kernels/host/feed_compute.cc b/paddle/fluid/lite/kernels/host/feed_compute.cc index 6a0f480a4d1..342d5d55573 100644 --- a/paddle/fluid/lite/kernels/host/feed_compute.cc +++ b/paddle/fluid/lite/kernels/host/feed_compute.cc @@ -26,7 +26,7 @@ class FeedCompute : public OpKernel { using param_t = operators::FeedParam; void Run() override { - auto &theparam = param(); + auto &theparam = Param(); const Tensor &feed_item = theparam.feed_list->at(theparam.col); theparam.out->CopyDataFrom(feed_item); } @@ -38,7 +38,7 @@ class FeedCompute : public OpKernel { } // namespace paddle REGISTER_LITE_KERNEL(feed, kHost, kFloat, - paddle::lite::kernels::host::FeedCompute) + paddle::lite::kernels::host::FeedCompute, def) .BindInput("X", {paddle::lite::Type::Get( TARGET(kHost))}) .BindOutput("Out", {paddle::lite::Type::Get( diff --git a/paddle/fluid/lite/kernels/host/mul_compute.cc b/paddle/fluid/lite/kernels/host/mul_compute.cc index a9667dd8312..b2062fa9308 100644 --- a/paddle/fluid/lite/kernels/host/mul_compute.cc +++ b/paddle/fluid/lite/kernels/host/mul_compute.cc @@ -40,7 +40,7 @@ class MulCompute : public OpKernel { using param_t = operators::MulParam; void Run() override { - auto& theparam = param(); + auto& theparam = Param(); core::dim2 x_shape( {product(theparam.x->dims().begin(), theparam.x->dims().begin() + theparam.x_num_col_dims), @@ -67,7 +67,7 @@ class MulCompute : public OpKernel { } // namespace paddle REGISTER_LITE_KERNEL(mul, kHost, kFloat, - paddle::lite::kernels::host::MulCompute) + paddle::lite::kernels::host::MulCompute, def) .BindInput("X", {paddle::lite::Type::Get( TARGET(kHost))}) .BindInput("Y", {paddle::lite::Type::Get( diff --git a/paddle/fluid/lite/kernels/host/relu_compute.h b/paddle/fluid/lite/kernels/host/relu_compute.h index 4ff3b8ca64c..b8176377dcc 100644 --- a/paddle/fluid/lite/kernels/host/relu_compute.h +++ b/paddle/fluid/lite/kernels/host/relu_compute.h @@ -24,7 +24,7 @@ namespace host { class ReluCompute : public OpKernel { public: void Run() override { - auto& theparam = param(); + auto& theparam = Param(); auto n = product(theparam.input->dims()); const float* input = theparam.input->data(); float* output = theparam.output->mutable_data(); @@ -43,5 +43,5 @@ class ReluCompute : public OpKernel { } // namespace paddle REGISTER_LITE_KERNEL(relu, kHost, kFloat, - paddle::lite::kernels::host::ReluCompute) + paddle::lite::kernels::host::ReluCompute, def) .Finalize(); diff --git a/paddle/fluid/lite/kernels/host/scale_compute.cc b/paddle/fluid/lite/kernels/host/scale_compute.cc index 38211a2c545..490792be6aa 100644 --- a/paddle/fluid/lite/kernels/host/scale_compute.cc +++ b/paddle/fluid/lite/kernels/host/scale_compute.cc @@ -36,7 +36,7 @@ class ScaleCompute : public OpKernel { using param_t = operators::MulParam; void Run() override { - auto& theparam = param(); + auto& theparam = Param(); scale_compute(theparam.x->data(), theparam.x->mutable_data(), product(theparam.x->dims()), theparam.scale, theparam.bias, theparam.bias_after_scale); @@ -51,5 +51,5 @@ class ScaleCompute : public OpKernel { } // namespace paddle REGISTER_LITE_KERNEL(scale, kHost, kFloat, - paddle::lite::kernels::host::ScaleCompute) + paddle::lite::kernels::host::ScaleCompute, def) .Finalize(); diff --git a/paddle/fluid/lite/operators/fc_op.cc b/paddle/fluid/lite/operators/fc_op.cc index 4d3d5506b10..e4f6d336307 100644 --- a/paddle/fluid/lite/operators/fc_op.cc +++ b/paddle/fluid/lite/operators/fc_op.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "fc_op.h" +#include "paddle/fluid/lite/operators/fc_op.h" #include "paddle/fluid/lite/core/op_registry.h" namespace paddle { diff --git a/paddle/fluid/lite/utils/factory.h b/paddle/fluid/lite/utils/factory.h index 395390b3b5b..cc00e42651b 100644 --- a/paddle/fluid/lite/utils/factory.h +++ b/paddle/fluid/lite/utils/factory.h @@ -14,6 +14,7 @@ #pragma once #include +#include #include #include #include @@ -49,13 +50,21 @@ class Factory { void Register(const std::string& op_type, creator_t&& creator) { CHECK(!creators_.count(op_type)) << "The op " << op_type << " has already registered"; - creators_.emplace(op_type, std::move(creator)); + creators_[op_type].emplace_back(std::move(creator)); } item_ptr_t Create(const std::string& op_type) const { + return std::move(Creates(op_type).front()); + } + + std::list Creates(const std::string& op_type) const { auto it = creators_.find(op_type); CHECK(it != creators_.end()) << "no item called " << op_type; - return it->second(); + std::list res; + for (auto& c : it->second) { + res.emplace_back(c()); + } + return res; } std::string DebugString() const { @@ -67,7 +76,7 @@ class Factory { } protected: - std::unordered_map creators_; + std::unordered_map> creators_; }; /* A helper function to help run a lambda at the start. -- GitLab