提交 b4314f6d 编写于 作者: C ceci3

update combine

上级 464f6123
......@@ -60,10 +60,10 @@ class CombineSearchSpace(object):
Combine init tokens.
"""
tokens = []
self.token = []
self.single_token_num = []
for space in self.spaces:
tokens.extend(space.init_tokens())
self.token.append(space.init_tokens())
self.single_token_num.append(len(space.init_tokens()))
return tokens
def range_table(self):
......@@ -80,10 +80,19 @@ class CombineSearchSpace(object):
Combine model arch
"""
if tokens is None:
self.init_tokens()
tokens = self.init_tokens()
token_list = []
start_idx = 0
end_idx = 0
for i in range(len(self.single_token_num)):
end_idx += self.single_token_num[i]
token_list.append(tokens[start_idx:end_idx])
start_idx = end_idx
model_archs = []
for space, token in zip(self.spaces, self.token):
for space, token in zip(self.spaces, token_list):
model_archs.append(space.token2arch(token))
return model_archs
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册