diff --git a/examples/erniesage/models/ernie_model/ernie.py b/examples/erniesage/models/ernie_model/ernie.py index e2c55485cbae0aef6d1672787321650d31902165..3a8b4650bd9a77882e458677215525940a3d7df2 100644 --- a/examples/erniesage/models/ernie_model/ernie.py +++ b/examples/erniesage/models/ernie_model/ernie.py @@ -104,7 +104,7 @@ class ErnieModel(object): zero = L.fill_constant([1], dtype='int64', value=0) input_mask = L.logical_not(L.equal(src_ids, zero)) # assume pad id == 0 - input_mask = L.cast(input_mask, 'float') + input_mask = L.cast(input_mask, 'float32') input_mask.stop_gradient = True return input_mask