未验证 提交 d216cdc9 编写于 作者: C ceci3 提交者: GitHub

Fix ofa (#536)

* fix

* fix embedding
Co-authored-by: NBai Yifan <me@ethanbai.com>
上级 5590daf9
......@@ -528,7 +528,7 @@ class Convert:
getattr(self.context, 'channel', None) != None):
attr_dict = layer.__dict__
key = attr_dict['_full_name']
new_attr_name = ['padding_idx', ]
new_attr_name = []
if pd_ver == 185:
new_attr_name += [
'size', 'is_sparse', 'is_distributed', 'param_attr',
......@@ -584,6 +584,9 @@ class Convert:
for attr in new_attr_name:
new_attr_dict[attr] = attr_dict['_' + attr]
new_attr_dict['padding_idx'] = None if attr_dict[
'_padding_idx'] == -1 else attr_dict['_padding_idx']
del layer, attr_dict
layer = Block(SuperEmbedding(**new_attr_dict), key=key)
......
......@@ -937,7 +937,8 @@ class SuperEmbedding(nn.Embedding):
weight_attr=None,
name=None):
super(SuperEmbedding, self).__init__(num_embeddings, embedding_dim,
sparse, weight_attr, name)
padding_idx, sparse, weight_attr,
name)
self.candidate_config = candidate_config
self.expand_ratio = candidate_config[
'expand_ratio'] if 'expand_ratio' in candidate_config else None
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册