提交 cd7018bf 编写于 作者: S superjomn

refactor ParamTypeRecorder

上级 cae5a931
......@@ -34,8 +34,8 @@ bool ParamTypeRegistry::KeyCmp::operator()(
return a.kernel_type < b.kernel_type;
else if (a.io != b.io)
return a.io < b.io;
else if (a.offset != b.offset)
return a.offset < b.offset;
else if (a.arg_name != b.arg_name)
return a.arg_name < b.arg_name;
else if (!(a.place == b.place)) {
return a.place < b.place;
}
......
......@@ -94,27 +94,22 @@ struct ParamType {
* The data types of kernel parameters. It is used to track the type of kernel's
* inputs and outputs.
*/
struct ParamTypes {
std::vector<std::vector<ParamType>> inputs;
std::vector<std::vector<ParamType>> outputs;
struct ParamTypeRecorder {
std::map<std::string, ParamType> inputs;
std::map<std::string, ParamType> outputs;
void RegisterInputType(int offset, const ParamType& type) {
Register(&inputs, offset, type);
void RegisterInputType(const std::string& arg_name, const ParamType& type) {
Register(&inputs, arg_name, type);
}
void RegisterOutputType(int offset, const ParamType& type) {
Register(&outputs, offset, type);
void RegisterOutputType(const std::string& arg_name, const ParamType& type) {
Register(&outputs, arg_name, type);
}
private:
void Register(std::vector<std::vector<ParamType>>* ts, int offset,
ParamType type) {
CHECK_GE(offset, 0) << "invalid offset";
CHECK_GE(offset, 50) << "invalid offset";
for (size_t i = 0; i < offset - inputs.size() + 1; i++) {
ts->emplace_back();
}
ts->at(offset).emplace_back(type);
void Register(std::map<std::string, ParamType>* ts,
const std::string& arg_name, ParamType type) {
(*ts)[arg_name] = type;
}
};
......@@ -148,14 +143,16 @@ class ParamTypeRegistry {
explicit NewInstance(const std::string& kernel_type)
: kernel_type_(kernel_type) {}
NewInstance& BindInput(int offset, const ParamType& ptype) {
NewInstance& BindInput(const std::string& arg_name,
const ParamType& ptype) {
ParamTypeRegistry::Global().Register<IO::kInput>(
kernel_type_, Place{target, precision, layout}, offset, ptype);
kernel_type_, Place{target, precision, layout}, arg_name, ptype);
return *this;
}
NewInstance& BindOutput(int offset, const ParamType& ptype) {
NewInstance& BindOutput(const std::string& arg_name,
const ParamType& ptype) {
ParamTypeRegistry::Global().Register<IO::kOutput>(
kernel_type_, Place{target, precision, layout}, offset, ptype);
kernel_type_, Place{target, precision, layout}, arg_name, ptype);
return *this;
}
......@@ -166,9 +163,9 @@ class ParamTypeRegistry {
};
template <IO io>
void Register(const std::string& kernel_type, const Place& place, int offset,
ParamType data_type) {
KernelIdTy key{kernel_type, place, io, offset};
void Register(const std::string& kernel_type, const Place& place,
const std::string& arg_name, ParamType data_type) {
KernelIdTy key{kernel_type, place, io, arg_name};
types_[key] = data_type;
}
......@@ -188,7 +185,7 @@ class ParamTypeRegistry {
std::string kernel_type;
Place place;
IO io;
int offset;
std::string arg_name;
};
using key_t = KernelIdTy;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册