未验证 提交 ee694b68 编写于 作者: Z Zeng Jinle 提交者: GitHub

Merge pull request #1273 from sneaxiy/dam_develop

Fix DAM model bug
......@@ -128,7 +128,7 @@ def train(args):
dev_count = fluid.core.get_cuda_device_count()
else:
place = fluid.CPUPlace()
dev_count = multiprocessing.cpu_count()
dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
print("device count %d" % dev_count)
......
......@@ -82,7 +82,10 @@ def dot_product_attention(query,
else:
mask = fluid.layers.matmul(x=q_mask, y=k_mask, transpose_y=True)
another_mask = fluid.layers.scale(
mask, scale=2**32 - 1, bias=-1, bias_after_scale=False)
mask,
scale=float(2**32 - 1),
bias=float(-1),
bias_after_scale=False)
if mask_cache is not None:
if q_mask.name not in mask_cache:
mask_cache[q_mask.name] = dict()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册