未验证 提交 ee694b68 编写于 作者: Z Zeng Jinle 提交者: GitHub

Merge pull request #1273 from sneaxiy/dam_develop

Fix DAM model bug
...@@ -128,7 +128,7 @@ def train(args): ...@@ -128,7 +128,7 @@ def train(args):
dev_count = fluid.core.get_cuda_device_count() dev_count = fluid.core.get_cuda_device_count()
else: else:
place = fluid.CPUPlace() place = fluid.CPUPlace()
dev_count = multiprocessing.cpu_count() dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
print("device count %d" % dev_count) print("device count %d" % dev_count)
......
...@@ -82,7 +82,10 @@ def dot_product_attention(query, ...@@ -82,7 +82,10 @@ def dot_product_attention(query,
else: else:
mask = fluid.layers.matmul(x=q_mask, y=k_mask, transpose_y=True) mask = fluid.layers.matmul(x=q_mask, y=k_mask, transpose_y=True)
another_mask = fluid.layers.scale( another_mask = fluid.layers.scale(
mask, scale=2**32 - 1, bias=-1, bias_after_scale=False) mask,
scale=float(2**32 - 1),
bias=float(-1),
bias_after_scale=False)
if mask_cache is not None: if mask_cache is not None:
if q_mask.name not in mask_cache: if q_mask.name not in mask_cache:
mask_cache[q_mask.name] = dict() mask_cache[q_mask.name] = dict()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册