提交 b4314f6d 编写于 作者: C ceci3

update combine

上级 464f6123
...@@ -60,10 +60,10 @@ class CombineSearchSpace(object): ...@@ -60,10 +60,10 @@ class CombineSearchSpace(object):
Combine init tokens. Combine init tokens.
""" """
tokens = [] tokens = []
self.token = [] self.single_token_num = []
for space in self.spaces: for space in self.spaces:
tokens.extend(space.init_tokens()) tokens.extend(space.init_tokens())
self.token.append(space.init_tokens()) self.single_token_num.append(len(space.init_tokens()))
return tokens return tokens
def range_table(self): def range_table(self):
...@@ -80,10 +80,19 @@ class CombineSearchSpace(object): ...@@ -80,10 +80,19 @@ class CombineSearchSpace(object):
Combine model arch Combine model arch
""" """
if tokens is None: if tokens is None:
self.init_tokens() tokens = self.init_tokens()
token_list = []
start_idx = 0
end_idx = 0
for i in range(len(self.single_token_num)):
end_idx += self.single_token_num[i]
token_list.append(tokens[start_idx:end_idx])
start_idx = end_idx
model_archs = [] model_archs = []
for space, token in zip(self.spaces, self.token): for space, token in zip(self.spaces, token_list):
model_archs.append(space.token2arch(token)) model_archs.append(space.token2arch(token))
return model_archs return model_archs
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册