提交 10ba176f 编写于 作者: B baiyfbupt

fix conflicts

......@@ -89,6 +89,7 @@ def main():
# whether use multi-gpus
device_num = fluid.dygraph.parallel.Env().nranks
use_data_parallel = device_num > 1
place = fluid.CUDAPlace(fluid.dygraph.parallel.Env(
).dev_id) if use_data_parallel else fluid.CUDAPlace(0)
......
......@@ -108,6 +108,10 @@ class BertModelLayer(Layer):
return self._src_emb.parameters() + self._pos_emb.parameters(
) + self._sent_emb.parameters()
def emb_names(self):
return self._src_emb.parameters() + self._pos_emb.parameters(
) + self._sent_emb.parameters()
def max_flops(self):
return self._encoder.max_flops
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册