提交 9a9dac5d 编写于 作者: G gaotingquan 提交者: cuicheng01

fix: set dtype in paddle.to_tensor()

上级 69aad7ea
......@@ -39,7 +39,7 @@ def drop_path(x, drop_prob=0., training=False):
"""
if drop_prob == 0. or not training:
return x
keep_prob = paddle.to_tensor(1 - drop_prob)
keep_prob = paddle.to_tensor(1 - drop_prob, dtype=x.dtype)
shape = (paddle.shape(x)[0], ) + (1, ) * (x.ndim - 1)
random_tensor = keep_prob + paddle.rand(shape, dtype=x.dtype)
random_tensor = paddle.floor(random_tensor) # binarize
......
......@@ -45,7 +45,7 @@ def drop_path(x, drop_prob=0., training=False):
"""
if drop_prob == 0. or not training:
return x
keep_prob = paddle.to_tensor(1 - drop_prob)
keep_prob = paddle.to_tensor(1 - drop_prob, dtype=x.dtype)
shape = (paddle.shape(x)[0], ) + (1, ) * (x.ndim - 1)
random_tensor = paddle.add(keep_prob, paddle.rand(shape, dtype=x.dtype))
random_tensor = paddle.floor(random_tensor) # binarize
......
......@@ -42,7 +42,7 @@ def drop_path(x, drop_prob=0., training=False):
"""
if drop_prob == 0. or not training:
return x
keep_prob = paddle.to_tensor(1 - drop_prob)
keep_prob = paddle.to_tensor(1 - drop_prob, dtype=x.dtype)
shape = (paddle.shape(x)[0], ) + (1, ) * (x.ndim - 1)
random_tensor = keep_prob + paddle.rand(shape, dtype=x.dtype)
random_tensor = paddle.floor(random_tensor) # binarize
......
......@@ -60,7 +60,7 @@ def drop_path(x, drop_prob=0., training=False):
"""
if drop_prob == 0. or not training:
return x
keep_prob = paddle.to_tensor(1 - drop_prob)
keep_prob = paddle.to_tensor(1 - drop_prob, dtype=x.dtype)
shape = (paddle.shape(x)[0], ) + (1, ) * (x.ndim - 1)
random_tensor = keep_prob + paddle.rand(shape).astype(x.dtype)
random_tensor = paddle.floor(random_tensor) # binarize
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册