提交 25990d29 编写于 作者: S Superjomn

make kernel_registry support multiple kernels for single type

上级 e55a5cd9
......@@ -25,9 +25,12 @@ std::vector<std::unique_ptr<KernelBase>> OpLite::CreateKernels(
CHECK(!op_type_.empty()) << "op_type_ should be set first";
for (auto place : places) {
kernels.emplace_back(KernelRegistry::Global().Create(
auto ks = KernelRegistry::Global().Create(
(kernel_type.empty() ? op_type_ : kernel_type), place.target,
place.precision));
place.precision);
for (auto &&it : ks) {
kernels.emplace_back(std::move(it));
}
}
return kernels;
......@@ -61,6 +64,20 @@ bool OpLite::Attach(const framework::OpDesc &opdesc, lite::Scope *scope) {
return AttachImpl(opdesc, scope);
}
const Tensor *OpLite::GetTensor(lite::Scope *scope,
const std::string &name) const {
auto *var = scope->FindVar(name);
CHECK(var) << "no variable called " << name << " found";
return &var->Get<lite::Tensor>();
}
Tensor *OpLite::GetMutableTensor(lite::Scope *scope,
const std::string &name) const {
auto *var = scope->FindVar(name);
CHECK(var) << "no variable called " << name << " found";
return var->GetMutable<lite::Tensor>();
}
bool OpInfo::GetInputArgname(const std::string &value_name, std::string *out) {
for (auto &item : input_argument_) {
auto it = std::find(item.second.begin(), item.second.end(), value_name);
......
......@@ -119,6 +119,9 @@ class OpLite : public Registry {
std::vector<std::unique_ptr<KernelBase>> CreateKernels(
const std::vector<Place> &places, const std::string &kernel_type = "");
const Tensor *GetTensor(lite::Scope *scope, const std::string &name) const;
Tensor *GetMutableTensor(lite::Scope *scope, const std::string &name) const;
friend class mir::Node;
friend class mir::SSAGraph;
......
......@@ -17,9 +17,8 @@
namespace paddle {
namespace lite {
std::unique_ptr<KernelBase> KernelRegistry::Create(const std::string &op_type,
TargetType target,
PrecisionType precision) {
std::list<std::unique_ptr<KernelBase>> KernelRegistry::Create(
const std::string &op_type, TargetType target, PrecisionType precision) {
#define CREATE_KERNEL(target__) \
switch (precision) { \
case PRECISION(kFloat): \
......@@ -43,7 +42,7 @@ std::unique_ptr<KernelBase> KernelRegistry::Create(const std::string &op_type,
}
#undef CREATE_KERNEL
return nullptr;
return std::list<std::unique_ptr<KernelBase>>();
}
KernelRegistry::KernelRegistry() {
......
......@@ -52,8 +52,7 @@ class OpLiteRegistor : public Registor<OpClass> {
template <TargetType Target, PrecisionType Precision>
using KernelRegistryForTarget =
Factory<OpKernel<Target, Precision>,
std::unique_ptr<OpKernel<Target, Precision>>>;
Factory<OpKernel<Target, Precision>, std::unique_ptr<KernelBase>>;
class KernelRegistry final {
public:
......@@ -80,16 +79,16 @@ class KernelRegistry final {
}
template <TargetType Target, PrecisionType Precision>
std::unique_ptr<KernelBase> Create(const std::string &op_type) {
std::list<std::unique_ptr<KernelBase>> Create(const std::string &op_type) {
using kernel_registor_t = KernelRegistryForTarget<Target, Precision>;
return registries_[GetKernelOffset<Target, Precision>()]
.template get<kernel_registor_t *>()
->Create(op_type);
->Creates(op_type);
}
std::unique_ptr<KernelBase> Create(const std::string &op_type,
TargetType target,
PrecisionType precision);
std::list<std::unique_ptr<KernelBase>> Create(const std::string &op_type,
TargetType target,
PrecisionType precision);
// Get a kernel registry offset in all the registries.
template <TargetType Target, PrecisionType Precision>
......@@ -151,29 +150,36 @@ class KernelRegistor : public lite::Registor<KernelType> {
// Kernel registry
#define LITE_KERNEL_REGISTER(op_type__, target__, precision__) \
op_type__##__##target__##__##precision__##__registor__
#define LITE_KERNEL_REGISTER_INSTANCE(op_type__, target__, precision__) \
op_type__##__##target__##__##precision__##__registor__instance__
#define LITE_KERNEL_REGISTER_FAKE(op_type__, target__, precision__) \
LITE_KERNEL_REGISTER_INSTANCE(op_type__, target__, precision__)
#define REGISTER_LITE_KERNEL(op_type__, target__, precision__, KernelClass) \
static paddle::lite::KernelRegistor<TARGET(target__), \
PRECISION(precision__), KernelClass> \
LITE_KERNEL_REGISTER_INSTANCE(op_type__, target__, \
precision__)(#op_type__); \
static KernelClass LITE_KERNEL_INSTANCE(op_type__, target__, precision__); \
int touch_##op_type__##target__##precision__() { \
LITE_KERNEL_INSTANCE(op_type__, target__, precision__).Touch(); \
return 0; \
} \
static bool op_type__##target__##precision__##param_register \
__attribute__((unused)) = paddle::lite::ParamTypeRegistry::NewInstance< \
TARGET(target__), PRECISION(precision__)>(#op_type__)
#define USE_LITE_KERNEL(op_type__, target__, precision__) \
extern int touch_##op_type__##target__##precision__(); \
int LITE_KERNEL_REGISTER_FAKE(op_type__, target__, precision__) \
__attribute__((unused)) = touch_##op_type__##target__##precision__();
#define LITE_KERNEL_INSTANCE(op_type__, target__, precision__) \
op_type__##target__##precision__
#define LITE_KERNEL_REGISTER_INSTANCE(op_type__, target__, precision__, \
alias__) \
op_type__##__##target__##__##precision__##__registor__instance__##alias__
#define LITE_KERNEL_REGISTER_FAKE(op_type__, target__, precision__, alias__) \
LITE_KERNEL_REGISTER_INSTANCE(op_type__, target__, precision__, alias__)
#define REGISTER_LITE_KERNEL(op_type__, target__, precision__, KernelClass, \
alias__) \
static paddle::lite::KernelRegistor<TARGET(target__), \
PRECISION(precision__), KernelClass> \
LITE_KERNEL_REGISTER_INSTANCE(op_type__, target__, precision__, \
alias__)(#op_type__); \
static KernelClass LITE_KERNEL_INSTANCE(op_type__, target__, precision__, \
alias__); \
int touch_##op_type__##target__##precision__##alias__() { \
LITE_KERNEL_INSTANCE(op_type__, target__, precision__, alias__).Touch(); \
return 0; \
} \
static bool LITE_KERNEL_PARAM_INSTANCE(op_type__, target__, precision__, \
alias__) __attribute__((unused)) = \
paddle::lite::ParamTypeRegistry::NewInstance<TARGET(target__), \
PRECISION(precision__)>( \
#op_type__)
#define USE_LITE_KERNEL(op_type__, target__, precision__, alias__) \
extern int touch_##op_type__##target__##precision__##alias__(); \
int op_type__##target__##precision__##alias__ __attribute__((unused)) = \
touch_##op_type__##target__##precision__##alias__();
#define LITE_KERNEL_INSTANCE(op_type__, target__, precision__, alias__) \
op_type__##target__##precision__##alias__
#define LITE_KERNEL_PARAM_INSTANCE(op_type__, target__, precision__, alias__) \
op_type__##target__##precision__##alias__##param_register
......@@ -24,7 +24,7 @@ namespace host {
// NOTE should use pure std C++ implementation.
void FcCompute::Run() {
auto& param = this->param<operators::FcParam>();
auto& param = this->Param<operators::FcParam>();
CHECK_GE(param.input->dims().size(), 2UL);
CHECK_EQ(param.output->dims().size(), 2UL);
......@@ -51,7 +51,8 @@ void FcCompute::Run() {
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(fc, kHost, kFloat, paddle::lite::kernels::host::FcCompute)
REGISTER_LITE_KERNEL(fc, kHost, kFloat, paddle::lite::kernels::host::FcCompute,
def)
.BindInput("Input",
{paddle::lite::Type::Get<paddle::lite::TensorFp32NCHWTy>(
TARGET(kHost))})
......
......@@ -26,7 +26,7 @@ class FeedCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
using param_t = operators::FeedParam;
void Run() override {
auto &theparam = param<operators::FeedParam>();
auto &theparam = Param<operators::FeedParam>();
const Tensor &feed_item = theparam.feed_list->at(theparam.col);
theparam.out->CopyDataFrom(feed_item);
}
......@@ -38,7 +38,7 @@ class FeedCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
} // namespace paddle
REGISTER_LITE_KERNEL(feed, kHost, kFloat,
paddle::lite::kernels::host::FeedCompute)
paddle::lite::kernels::host::FeedCompute, def)
.BindInput("X", {paddle::lite::Type::Get<paddle::lite::TensorFp32NCHWTy>(
TARGET(kHost))})
.BindOutput("Out", {paddle::lite::Type::Get<paddle::lite::TensorFp32NCHWTy>(
......
......@@ -40,7 +40,7 @@ class MulCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
using param_t = operators::MulParam;
void Run() override {
auto& theparam = param<operators::MulParam>();
auto& theparam = Param<operators::MulParam>();
core::dim2 x_shape(
{product(theparam.x->dims().begin(),
theparam.x->dims().begin() + theparam.x_num_col_dims),
......@@ -67,7 +67,7 @@ class MulCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
} // namespace paddle
REGISTER_LITE_KERNEL(mul, kHost, kFloat,
paddle::lite::kernels::host::MulCompute)
paddle::lite::kernels::host::MulCompute, def)
.BindInput("X", {paddle::lite::Type::Get<paddle::lite::TensorFp32NCHWTy>(
TARGET(kHost))})
.BindInput("Y", {paddle::lite::Type::Get<paddle::lite::TensorFp32NCHWTy>(
......
......@@ -24,7 +24,7 @@ namespace host {
class ReluCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
public:
void Run() override {
auto& theparam = param<operators::ReluParam>();
auto& theparam = Param<operators::ReluParam>();
auto n = product(theparam.input->dims());
const float* input = theparam.input->data<float>();
float* output = theparam.output->mutable_data<float>();
......@@ -43,5 +43,5 @@ class ReluCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
} // namespace paddle
REGISTER_LITE_KERNEL(relu, kHost, kFloat,
paddle::lite::kernels::host::ReluCompute)
paddle::lite::kernels::host::ReluCompute, def)
.Finalize();
......@@ -36,7 +36,7 @@ class ScaleCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
using param_t = operators::MulParam;
void Run() override {
auto& theparam = param<operators::ScaleParam>();
auto& theparam = Param<operators::ScaleParam>();
scale_compute(theparam.x->data<float>(), theparam.x->mutable_data<float>(),
product(theparam.x->dims()), theparam.scale, theparam.bias,
theparam.bias_after_scale);
......@@ -51,5 +51,5 @@ class ScaleCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
} // namespace paddle
REGISTER_LITE_KERNEL(scale, kHost, kFloat,
paddle::lite::kernels::host::ScaleCompute)
paddle::lite::kernels::host::ScaleCompute, def)
.Finalize();
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fc_op.h"
#include "paddle/fluid/lite/operators/fc_op.h"
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
......
......@@ -14,6 +14,7 @@
#pragma once
#include <iostream>
#include <list>
#include <memory>
#include <sstream>
#include <unordered_map>
......@@ -49,13 +50,21 @@ class Factory {
void Register(const std::string& op_type, creator_t&& creator) {
CHECK(!creators_.count(op_type)) << "The op " << op_type
<< " has already registered";
creators_.emplace(op_type, std::move(creator));
creators_[op_type].emplace_back(std::move(creator));
}
item_ptr_t Create(const std::string& op_type) const {
return std::move(Creates(op_type).front());
}
std::list<item_ptr_t> Creates(const std::string& op_type) const {
auto it = creators_.find(op_type);
CHECK(it != creators_.end()) << "no item called " << op_type;
return it->second();
std::list<item_ptr_t> res;
for (auto& c : it->second) {
res.emplace_back(c());
}
return res;
}
std::string DebugString() const {
......@@ -67,7 +76,7 @@ class Factory {
}
protected:
std::unordered_map<std::string, creator_t> creators_;
std::unordered_map<std::string, std::list<creator_t>> creators_;
};
/* A helper function to help run a lambda at the start.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册