diff --git a/paddle/fluid/lite/core/op_lite.cc b/paddle/fluid/lite/core/op_lite.cc index e1bbcfc1a6e341d331463c5e88248475937a5f94..1493b577c5d6eabd9200d6f31d619172a513ff0e 100644 --- a/paddle/fluid/lite/core/op_lite.cc +++ b/paddle/fluid/lite/core/op_lite.cc @@ -25,9 +25,12 @@ std::vector> OpLite::CreateKernels( CHECK(!op_type_.empty()) << "op_type_ should be set first"; for (auto place : places) { - kernels.emplace_back(KernelRegistry::Global().Create( + auto ks = KernelRegistry::Global().Create( (kernel_type.empty() ? op_type_ : kernel_type), place.target, - place.precision)); + place.precision); + for (auto &&it : ks) { + kernels.emplace_back(std::move(it)); + } } return kernels; @@ -61,6 +64,20 @@ bool OpLite::Attach(const framework::OpDesc &opdesc, lite::Scope *scope) { return AttachImpl(opdesc, scope); } +const Tensor *OpLite::GetTensor(lite::Scope *scope, + const std::string &name) const { + auto *var = scope->FindVar(name); + CHECK(var) << "no variable called " << name << " found"; + return &var->Get(); +} + +Tensor *OpLite::GetMutableTensor(lite::Scope *scope, + const std::string &name) const { + auto *var = scope->FindVar(name); + CHECK(var) << "no variable called " << name << " found"; + return var->GetMutable(); +} + bool OpInfo::GetInputArgname(const std::string &value_name, std::string *out) { for (auto &item : input_argument_) { auto it = std::find(item.second.begin(), item.second.end(), value_name); diff --git a/paddle/fluid/lite/core/op_lite.h b/paddle/fluid/lite/core/op_lite.h index 3fd020b176a397e8a2bf02cdece68452ca1ce08f..54d973ebc7de363b136e10a043cc02739d4ecd79 100644 --- a/paddle/fluid/lite/core/op_lite.h +++ b/paddle/fluid/lite/core/op_lite.h @@ -119,6 +119,9 @@ class OpLite : public Registry { std::vector> CreateKernels( const std::vector &places, const std::string &kernel_type = ""); + const Tensor *GetTensor(lite::Scope *scope, const std::string &name) const; + Tensor *GetMutableTensor(lite::Scope *scope, const std::string &name) const; + friend class mir::Node; friend class mir::SSAGraph; diff --git a/paddle/fluid/lite/core/op_registry.cc b/paddle/fluid/lite/core/op_registry.cc index 38bc79aaba66fd9b8166c2a32421542ae4af4744..676cbe2dfcf33d71c699689e69791f8e51bb7a5e 100644 --- a/paddle/fluid/lite/core/op_registry.cc +++ b/paddle/fluid/lite/core/op_registry.cc @@ -17,9 +17,8 @@ namespace paddle { namespace lite { -std::unique_ptr KernelRegistry::Create(const std::string &op_type, - TargetType target, - PrecisionType precision) { +std::list> KernelRegistry::Create( + const std::string &op_type, TargetType target, PrecisionType precision) { #define CREATE_KERNEL(target__) \ switch (precision) { \ case PRECISION(kFloat): \ @@ -43,7 +42,7 @@ std::unique_ptr KernelRegistry::Create(const std::string &op_type, } #undef CREATE_KERNEL - return nullptr; + return std::list>(); } KernelRegistry::KernelRegistry() { diff --git a/paddle/fluid/lite/core/op_registry.h b/paddle/fluid/lite/core/op_registry.h index 04c19fbf873a9433eba27456b2e3ed2b1e07ca8d..749f786a71c0c06d2135ee7d71a68124d54e1247 100644 --- a/paddle/fluid/lite/core/op_registry.h +++ b/paddle/fluid/lite/core/op_registry.h @@ -52,8 +52,7 @@ class OpLiteRegistor : public Registor { template using KernelRegistryForTarget = - Factory, - std::unique_ptr>>; + Factory, std::unique_ptr>; class KernelRegistry final { public: @@ -80,16 +79,16 @@ class KernelRegistry final { } template - std::unique_ptr Create(const std::string &op_type) { + std::list> Create(const std::string &op_type) { using kernel_registor_t = KernelRegistryForTarget; return registries_[GetKernelOffset()] .template get() - ->Create(op_type); + ->Creates(op_type); } - std::unique_ptr Create(const std::string &op_type, - TargetType target, - PrecisionType precision); + std::list> Create(const std::string &op_type, + TargetType target, + PrecisionType precision); // Get a kernel registry offset in all the registries. template @@ -151,29 +150,36 @@ class KernelRegistor : public lite::Registor { // Kernel registry #define LITE_KERNEL_REGISTER(op_type__, target__, precision__) \ op_type__##__##target__##__##precision__##__registor__ -#define LITE_KERNEL_REGISTER_INSTANCE(op_type__, target__, precision__) \ - op_type__##__##target__##__##precision__##__registor__instance__ -#define LITE_KERNEL_REGISTER_FAKE(op_type__, target__, precision__) \ - LITE_KERNEL_REGISTER_INSTANCE(op_type__, target__, precision__) - -#define REGISTER_LITE_KERNEL(op_type__, target__, precision__, KernelClass) \ - static paddle::lite::KernelRegistor \ - LITE_KERNEL_REGISTER_INSTANCE(op_type__, target__, \ - precision__)(#op_type__); \ - static KernelClass LITE_KERNEL_INSTANCE(op_type__, target__, precision__); \ - int touch_##op_type__##target__##precision__() { \ - LITE_KERNEL_INSTANCE(op_type__, target__, precision__).Touch(); \ - return 0; \ - } \ - static bool op_type__##target__##precision__##param_register \ - __attribute__((unused)) = paddle::lite::ParamTypeRegistry::NewInstance< \ - TARGET(target__), PRECISION(precision__)>(#op_type__) - -#define USE_LITE_KERNEL(op_type__, target__, precision__) \ - extern int touch_##op_type__##target__##precision__(); \ - int LITE_KERNEL_REGISTER_FAKE(op_type__, target__, precision__) \ - __attribute__((unused)) = touch_##op_type__##target__##precision__(); - -#define LITE_KERNEL_INSTANCE(op_type__, target__, precision__) \ - op_type__##target__##precision__ +#define LITE_KERNEL_REGISTER_INSTANCE(op_type__, target__, precision__, \ + alias__) \ + op_type__##__##target__##__##precision__##__registor__instance__##alias__ +#define LITE_KERNEL_REGISTER_FAKE(op_type__, target__, precision__, alias__) \ + LITE_KERNEL_REGISTER_INSTANCE(op_type__, target__, precision__, alias__) + +#define REGISTER_LITE_KERNEL(op_type__, target__, precision__, KernelClass, \ + alias__) \ + static paddle::lite::KernelRegistor \ + LITE_KERNEL_REGISTER_INSTANCE(op_type__, target__, precision__, \ + alias__)(#op_type__); \ + static KernelClass LITE_KERNEL_INSTANCE(op_type__, target__, precision__, \ + alias__); \ + int touch_##op_type__##target__##precision__##alias__() { \ + LITE_KERNEL_INSTANCE(op_type__, target__, precision__, alias__).Touch(); \ + return 0; \ + } \ + static bool LITE_KERNEL_PARAM_INSTANCE(op_type__, target__, precision__, \ + alias__) __attribute__((unused)) = \ + paddle::lite::ParamTypeRegistry::NewInstance( \ + #op_type__) + +#define USE_LITE_KERNEL(op_type__, target__, precision__, alias__) \ + extern int touch_##op_type__##target__##precision__##alias__(); \ + int op_type__##target__##precision__##alias__ __attribute__((unused)) = \ + touch_##op_type__##target__##precision__##alias__(); + +#define LITE_KERNEL_INSTANCE(op_type__, target__, precision__, alias__) \ + op_type__##target__##precision__##alias__ +#define LITE_KERNEL_PARAM_INSTANCE(op_type__, target__, precision__, alias__) \ + op_type__##target__##precision__##alias__##param_register diff --git a/paddle/fluid/lite/kernels/host/fc_compute.cc b/paddle/fluid/lite/kernels/host/fc_compute.cc index e81900bad9d6d5b32bb4331211c5aa7a11dfe89d..ac9a4ccc0a7bd054603e39d8ca4cbe029837554f 100644 --- a/paddle/fluid/lite/kernels/host/fc_compute.cc +++ b/paddle/fluid/lite/kernels/host/fc_compute.cc @@ -24,7 +24,7 @@ namespace host { // NOTE should use pure std C++ implementation. void FcCompute::Run() { - auto& param = this->param(); + auto& param = this->Param(); CHECK_GE(param.input->dims().size(), 2UL); CHECK_EQ(param.output->dims().size(), 2UL); @@ -51,7 +51,8 @@ void FcCompute::Run() { } // namespace lite } // namespace paddle -REGISTER_LITE_KERNEL(fc, kHost, kFloat, paddle::lite::kernels::host::FcCompute) +REGISTER_LITE_KERNEL(fc, kHost, kFloat, paddle::lite::kernels::host::FcCompute, + def) .BindInput("Input", {paddle::lite::Type::Get( TARGET(kHost))}) diff --git a/paddle/fluid/lite/kernels/host/feed_compute.cc b/paddle/fluid/lite/kernels/host/feed_compute.cc index 6a0f480a4d1c6592478abcb0e0bd7c01b31cd772..342d5d55573350b7b447da2ca017415b83b3fe7f 100644 --- a/paddle/fluid/lite/kernels/host/feed_compute.cc +++ b/paddle/fluid/lite/kernels/host/feed_compute.cc @@ -26,7 +26,7 @@ class FeedCompute : public OpKernel { using param_t = operators::FeedParam; void Run() override { - auto &theparam = param(); + auto &theparam = Param(); const Tensor &feed_item = theparam.feed_list->at(theparam.col); theparam.out->CopyDataFrom(feed_item); } @@ -38,7 +38,7 @@ class FeedCompute : public OpKernel { } // namespace paddle REGISTER_LITE_KERNEL(feed, kHost, kFloat, - paddle::lite::kernels::host::FeedCompute) + paddle::lite::kernels::host::FeedCompute, def) .BindInput("X", {paddle::lite::Type::Get( TARGET(kHost))}) .BindOutput("Out", {paddle::lite::Type::Get( diff --git a/paddle/fluid/lite/kernels/host/mul_compute.cc b/paddle/fluid/lite/kernels/host/mul_compute.cc index a9667dd8312b4436efa65bedca96a4a8d2e86643..b2062fa9308a1db241ea1dfefbacaa78fcdb7d30 100644 --- a/paddle/fluid/lite/kernels/host/mul_compute.cc +++ b/paddle/fluid/lite/kernels/host/mul_compute.cc @@ -40,7 +40,7 @@ class MulCompute : public OpKernel { using param_t = operators::MulParam; void Run() override { - auto& theparam = param(); + auto& theparam = Param(); core::dim2 x_shape( {product(theparam.x->dims().begin(), theparam.x->dims().begin() + theparam.x_num_col_dims), @@ -67,7 +67,7 @@ class MulCompute : public OpKernel { } // namespace paddle REGISTER_LITE_KERNEL(mul, kHost, kFloat, - paddle::lite::kernels::host::MulCompute) + paddle::lite::kernels::host::MulCompute, def) .BindInput("X", {paddle::lite::Type::Get( TARGET(kHost))}) .BindInput("Y", {paddle::lite::Type::Get( diff --git a/paddle/fluid/lite/kernels/host/relu_compute.h b/paddle/fluid/lite/kernels/host/relu_compute.h index 4ff3b8ca64cd80aa135920135b599f7641167aaf..b8176377dcc2c729a816a40ea2eb0769196425b7 100644 --- a/paddle/fluid/lite/kernels/host/relu_compute.h +++ b/paddle/fluid/lite/kernels/host/relu_compute.h @@ -24,7 +24,7 @@ namespace host { class ReluCompute : public OpKernel { public: void Run() override { - auto& theparam = param(); + auto& theparam = Param(); auto n = product(theparam.input->dims()); const float* input = theparam.input->data(); float* output = theparam.output->mutable_data(); @@ -43,5 +43,5 @@ class ReluCompute : public OpKernel { } // namespace paddle REGISTER_LITE_KERNEL(relu, kHost, kFloat, - paddle::lite::kernels::host::ReluCompute) + paddle::lite::kernels::host::ReluCompute, def) .Finalize(); diff --git a/paddle/fluid/lite/kernels/host/scale_compute.cc b/paddle/fluid/lite/kernels/host/scale_compute.cc index 38211a2c545dc954c4c4a6edc4dd0b8ab7634efd..490792be6aa84272bdec264ee4e252dff67f4280 100644 --- a/paddle/fluid/lite/kernels/host/scale_compute.cc +++ b/paddle/fluid/lite/kernels/host/scale_compute.cc @@ -36,7 +36,7 @@ class ScaleCompute : public OpKernel { using param_t = operators::MulParam; void Run() override { - auto& theparam = param(); + auto& theparam = Param(); scale_compute(theparam.x->data(), theparam.x->mutable_data(), product(theparam.x->dims()), theparam.scale, theparam.bias, theparam.bias_after_scale); @@ -51,5 +51,5 @@ class ScaleCompute : public OpKernel { } // namespace paddle REGISTER_LITE_KERNEL(scale, kHost, kFloat, - paddle::lite::kernels::host::ScaleCompute) + paddle::lite::kernels::host::ScaleCompute, def) .Finalize(); diff --git a/paddle/fluid/lite/operators/fc_op.cc b/paddle/fluid/lite/operators/fc_op.cc index 4d3d5506b10444fead07ccfc718a7d5d1b8c3a54..e4f6d336307b393d4cd830e5429dff4580f82864 100644 --- a/paddle/fluid/lite/operators/fc_op.cc +++ b/paddle/fluid/lite/operators/fc_op.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "fc_op.h" +#include "paddle/fluid/lite/operators/fc_op.h" #include "paddle/fluid/lite/core/op_registry.h" namespace paddle { diff --git a/paddle/fluid/lite/utils/factory.h b/paddle/fluid/lite/utils/factory.h index 395390b3b5b3a9b173914a114b81c377a7b4c0cf..cc00e42651b601422ad5d76035b11cb3da94771d 100644 --- a/paddle/fluid/lite/utils/factory.h +++ b/paddle/fluid/lite/utils/factory.h @@ -14,6 +14,7 @@ #pragma once #include +#include #include #include #include @@ -49,13 +50,21 @@ class Factory { void Register(const std::string& op_type, creator_t&& creator) { CHECK(!creators_.count(op_type)) << "The op " << op_type << " has already registered"; - creators_.emplace(op_type, std::move(creator)); + creators_[op_type].emplace_back(std::move(creator)); } item_ptr_t Create(const std::string& op_type) const { + return std::move(Creates(op_type).front()); + } + + std::list Creates(const std::string& op_type) const { auto it = creators_.find(op_type); CHECK(it != creators_.end()) << "no item called " << op_type; - return it->second(); + std::list res; + for (auto& c : it->second) { + res.emplace_back(c()); + } + return res; } std::string DebugString() const { @@ -67,7 +76,7 @@ class Factory { } protected: - std::unordered_map creators_; + std::unordered_map> creators_; }; /* A helper function to help run a lambda at the start.