diff --git a/examples/erniesage/models/base.py b/examples/erniesage/models/base.py index f59f54b36d5727ce7acec6b1b782d0724e3b4be6..e93fd5ffbb00de1da1612b0945cdba682909f164 100644 --- a/examples/erniesage/models/base.py +++ b/examples/erniesage/models/base.py @@ -191,12 +191,12 @@ def all_gather(X): for i in range(trainer_num): copy_X = X * 1 copy_X = L.collective._broadcast(copy_X, i, True) - copy_X.stop_gradients=True + copy_X.stop_gradient=True Xs.append(copy_X) if len(Xs) > 1: Xs=L.concat(Xs, 0) - Xs.stop_gradients=True + Xs.stop_gradient=True else: Xs = Xs[0] return Xs diff --git a/examples/erniesage/models/erniesage_v2.py b/examples/erniesage/models/erniesage_v2.py index fec39f2f927be7f86dc88b6050ae7c45b096c822..78fed26cc56b31c0e1e604e5d01b51657ee48fd2 100644 --- a/examples/erniesage/models/erniesage_v2.py +++ b/examples/erniesage/models/erniesage_v2.py @@ -27,7 +27,7 @@ class ErnieSageV2(BaseNet): src_position_ids = L.expand(src_position_ids, [src_batch, 1, 1]) # [B, slot_seqlen * num_b, 1] zero = L.fill_constant([1], dtype='int64', value=0) input_mask = L.cast(L.equal(src_ids, zero), "int32") # assume pad id == 0 [B, slot_seqlen, 1] - src_pad_len = L.reduce_sum(input_mask, 1) # [B, 1, 1] + src_pad_len = L.reduce_sum(input_mask, 1, keep_dim=True) # [B, 1, 1] dst_position_ids = L.reshape( L.range(