提交 cd7018bf 编写于 作者: S superjomn

refactor ParamTypeRecorder

上级 cae5a931
...@@ -34,8 +34,8 @@ bool ParamTypeRegistry::KeyCmp::operator()( ...@@ -34,8 +34,8 @@ bool ParamTypeRegistry::KeyCmp::operator()(
return a.kernel_type < b.kernel_type; return a.kernel_type < b.kernel_type;
else if (a.io != b.io) else if (a.io != b.io)
return a.io < b.io; return a.io < b.io;
else if (a.offset != b.offset) else if (a.arg_name != b.arg_name)
return a.offset < b.offset; return a.arg_name < b.arg_name;
else if (!(a.place == b.place)) { else if (!(a.place == b.place)) {
return a.place < b.place; return a.place < b.place;
} }
......
...@@ -94,27 +94,22 @@ struct ParamType { ...@@ -94,27 +94,22 @@ struct ParamType {
* The data types of kernel parameters. It is used to track the type of kernel's * The data types of kernel parameters. It is used to track the type of kernel's
* inputs and outputs. * inputs and outputs.
*/ */
struct ParamTypes { struct ParamTypeRecorder {
std::vector<std::vector<ParamType>> inputs; std::map<std::string, ParamType> inputs;
std::vector<std::vector<ParamType>> outputs; std::map<std::string, ParamType> outputs;
void RegisterInputType(int offset, const ParamType& type) { void RegisterInputType(const std::string& arg_name, const ParamType& type) {
Register(&inputs, offset, type); Register(&inputs, arg_name, type);
} }
void RegisterOutputType(int offset, const ParamType& type) { void RegisterOutputType(const std::string& arg_name, const ParamType& type) {
Register(&outputs, offset, type); Register(&outputs, arg_name, type);
} }
private: private:
void Register(std::vector<std::vector<ParamType>>* ts, int offset, void Register(std::map<std::string, ParamType>* ts,
ParamType type) { const std::string& arg_name, ParamType type) {
CHECK_GE(offset, 0) << "invalid offset"; (*ts)[arg_name] = type;
CHECK_GE(offset, 50) << "invalid offset";
for (size_t i = 0; i < offset - inputs.size() + 1; i++) {
ts->emplace_back();
}
ts->at(offset).emplace_back(type);
} }
}; };
...@@ -148,14 +143,16 @@ class ParamTypeRegistry { ...@@ -148,14 +143,16 @@ class ParamTypeRegistry {
explicit NewInstance(const std::string& kernel_type) explicit NewInstance(const std::string& kernel_type)
: kernel_type_(kernel_type) {} : kernel_type_(kernel_type) {}
NewInstance& BindInput(int offset, const ParamType& ptype) { NewInstance& BindInput(const std::string& arg_name,
const ParamType& ptype) {
ParamTypeRegistry::Global().Register<IO::kInput>( ParamTypeRegistry::Global().Register<IO::kInput>(
kernel_type_, Place{target, precision, layout}, offset, ptype); kernel_type_, Place{target, precision, layout}, arg_name, ptype);
return *this; return *this;
} }
NewInstance& BindOutput(int offset, const ParamType& ptype) { NewInstance& BindOutput(const std::string& arg_name,
const ParamType& ptype) {
ParamTypeRegistry::Global().Register<IO::kOutput>( ParamTypeRegistry::Global().Register<IO::kOutput>(
kernel_type_, Place{target, precision, layout}, offset, ptype); kernel_type_, Place{target, precision, layout}, arg_name, ptype);
return *this; return *this;
} }
...@@ -166,9 +163,9 @@ class ParamTypeRegistry { ...@@ -166,9 +163,9 @@ class ParamTypeRegistry {
}; };
template <IO io> template <IO io>
void Register(const std::string& kernel_type, const Place& place, int offset, void Register(const std::string& kernel_type, const Place& place,
ParamType data_type) { const std::string& arg_name, ParamType data_type) {
KernelIdTy key{kernel_type, place, io, offset}; KernelIdTy key{kernel_type, place, io, arg_name};
types_[key] = data_type; types_[key] = data_type;
} }
...@@ -188,7 +185,7 @@ class ParamTypeRegistry { ...@@ -188,7 +185,7 @@ class ParamTypeRegistry {
std::string kernel_type; std::string kernel_type;
Place place; Place place;
IO io; IO io;
int offset; std::string arg_name;
}; };
using key_t = KernelIdTy; using key_t = KernelIdTy;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册