提交 08da20a6 编写于 作者: S suweiyue

1. fix stop_gradient, 2. reduce_sum keep_dim

上级 e68b8b25
...@@ -191,12 +191,12 @@ def all_gather(X): ...@@ -191,12 +191,12 @@ def all_gather(X):
for i in range(trainer_num): for i in range(trainer_num):
copy_X = X * 1 copy_X = X * 1
copy_X = L.collective._broadcast(copy_X, i, True) copy_X = L.collective._broadcast(copy_X, i, True)
copy_X.stop_gradients=True copy_X.stop_gradient=True
Xs.append(copy_X) Xs.append(copy_X)
if len(Xs) > 1: if len(Xs) > 1:
Xs=L.concat(Xs, 0) Xs=L.concat(Xs, 0)
Xs.stop_gradients=True Xs.stop_gradient=True
else: else:
Xs = Xs[0] Xs = Xs[0]
return Xs return Xs
......
...@@ -27,7 +27,7 @@ class ErnieSageV2(BaseNet): ...@@ -27,7 +27,7 @@ class ErnieSageV2(BaseNet):
src_position_ids = L.expand(src_position_ids, [src_batch, 1, 1]) # [B, slot_seqlen * num_b, 1] src_position_ids = L.expand(src_position_ids, [src_batch, 1, 1]) # [B, slot_seqlen * num_b, 1]
zero = L.fill_constant([1], dtype='int64', value=0) zero = L.fill_constant([1], dtype='int64', value=0)
input_mask = L.cast(L.equal(src_ids, zero), "int32") # assume pad id == 0 [B, slot_seqlen, 1] input_mask = L.cast(L.equal(src_ids, zero), "int32") # assume pad id == 0 [B, slot_seqlen, 1]
src_pad_len = L.reduce_sum(input_mask, 1) # [B, 1, 1] src_pad_len = L.reduce_sum(input_mask, 1, keep_dim=True) # [B, 1, 1]
dst_position_ids = L.reshape( dst_position_ids = L.reshape(
L.range( L.range(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册