未验证 提交 e23e51ff 编写于 作者: W Wenyu 提交者: GitHub

fix drop_prob type (#7682)

上级 6f384cb3
...@@ -32,7 +32,7 @@ def drop_path(x, drop_prob=0., training=False): ...@@ -32,7 +32,7 @@ def drop_path(x, drop_prob=0., training=False):
""" """
if drop_prob == 0. or not training: if drop_prob == 0. or not training:
return x return x
keep_prob = paddle.to_tensor(1 - drop_prob) keep_prob = paddle.to_tensor(1 - drop_prob, dtype=x.dtype)
shape = (paddle.shape(x)[0], ) + (1, ) * (x.ndim - 1) shape = (paddle.shape(x)[0], ) + (1, ) * (x.ndim - 1)
random_tensor = keep_prob + paddle.rand(shape, dtype=x.dtype) random_tensor = keep_prob + paddle.rand(shape, dtype=x.dtype)
random_tensor = paddle.floor(random_tensor) # binarize random_tensor = paddle.floor(random_tensor) # binarize
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册