From cd7018bf3c956e4e9a42b844ddcf3a9807f7cc15 Mon Sep 17 00:00:00 2001 From: superjomn Date: Thu, 18 Apr 2019 16:54:34 +0800 Subject: [PATCH] refactor ParamTypeRecorder --- paddle/fluid/lite/core/kernel.cc | 4 +- paddle/fluid/lite/core/kernel.h | 43 +++++++++---------- .../core/mir/variable_place_inference_pass.cc | 0 .../core/mir/variable_place_inference_pass.h | 1 + 4 files changed, 23 insertions(+), 25 deletions(-) create mode 100644 paddle/fluid/lite/core/mir/variable_place_inference_pass.cc create mode 100644 paddle/fluid/lite/core/mir/variable_place_inference_pass.h diff --git a/paddle/fluid/lite/core/kernel.cc b/paddle/fluid/lite/core/kernel.cc index 34e01982960..a4268c78375 100644 --- a/paddle/fluid/lite/core/kernel.cc +++ b/paddle/fluid/lite/core/kernel.cc @@ -34,8 +34,8 @@ bool ParamTypeRegistry::KeyCmp::operator()( return a.kernel_type < b.kernel_type; else if (a.io != b.io) return a.io < b.io; - else if (a.offset != b.offset) - return a.offset < b.offset; + else if (a.arg_name != b.arg_name) + return a.arg_name < b.arg_name; else if (!(a.place == b.place)) { return a.place < b.place; } diff --git a/paddle/fluid/lite/core/kernel.h b/paddle/fluid/lite/core/kernel.h index f83ae0a5222..f5e31233104 100644 --- a/paddle/fluid/lite/core/kernel.h +++ b/paddle/fluid/lite/core/kernel.h @@ -94,27 +94,22 @@ struct ParamType { * The data types of kernel parameters. It is used to track the type of kernel's * inputs and outputs. */ -struct ParamTypes { - std::vector> inputs; - std::vector> outputs; +struct ParamTypeRecorder { + std::map inputs; + std::map outputs; - void RegisterInputType(int offset, const ParamType& type) { - Register(&inputs, offset, type); + void RegisterInputType(const std::string& arg_name, const ParamType& type) { + Register(&inputs, arg_name, type); } - void RegisterOutputType(int offset, const ParamType& type) { - Register(&outputs, offset, type); + void RegisterOutputType(const std::string& arg_name, const ParamType& type) { + Register(&outputs, arg_name, type); } private: - void Register(std::vector>* ts, int offset, - ParamType type) { - CHECK_GE(offset, 0) << "invalid offset"; - CHECK_GE(offset, 50) << "invalid offset"; - for (size_t i = 0; i < offset - inputs.size() + 1; i++) { - ts->emplace_back(); - } - ts->at(offset).emplace_back(type); + void Register(std::map* ts, + const std::string& arg_name, ParamType type) { + (*ts)[arg_name] = type; } }; @@ -148,14 +143,16 @@ class ParamTypeRegistry { explicit NewInstance(const std::string& kernel_type) : kernel_type_(kernel_type) {} - NewInstance& BindInput(int offset, const ParamType& ptype) { + NewInstance& BindInput(const std::string& arg_name, + const ParamType& ptype) { ParamTypeRegistry::Global().Register( - kernel_type_, Place{target, precision, layout}, offset, ptype); + kernel_type_, Place{target, precision, layout}, arg_name, ptype); return *this; } - NewInstance& BindOutput(int offset, const ParamType& ptype) { + NewInstance& BindOutput(const std::string& arg_name, + const ParamType& ptype) { ParamTypeRegistry::Global().Register( - kernel_type_, Place{target, precision, layout}, offset, ptype); + kernel_type_, Place{target, precision, layout}, arg_name, ptype); return *this; } @@ -166,9 +163,9 @@ class ParamTypeRegistry { }; template - void Register(const std::string& kernel_type, const Place& place, int offset, - ParamType data_type) { - KernelIdTy key{kernel_type, place, io, offset}; + void Register(const std::string& kernel_type, const Place& place, + const std::string& arg_name, ParamType data_type) { + KernelIdTy key{kernel_type, place, io, arg_name}; types_[key] = data_type; } @@ -188,7 +185,7 @@ class ParamTypeRegistry { std::string kernel_type; Place place; IO io; - int offset; + std::string arg_name; }; using key_t = KernelIdTy; diff --git a/paddle/fluid/lite/core/mir/variable_place_inference_pass.cc b/paddle/fluid/lite/core/mir/variable_place_inference_pass.cc new file mode 100644 index 00000000000..e69de29bb2d diff --git a/paddle/fluid/lite/core/mir/variable_place_inference_pass.h b/paddle/fluid/lite/core/mir/variable_place_inference_pass.h new file mode 100644 index 00000000000..8b137891791 --- /dev/null +++ b/paddle/fluid/lite/core/mir/variable_place_inference_pass.h @@ -0,0 +1 @@ + -- GitLab