提交 6933c683 编写于 作者: Y Yelrose

fixed dtype check

上级 1ff69d40
......@@ -269,7 +269,7 @@ def masked_select(input, mask):
def ensure_dtype(input, dtype):
if input.dtype == dtype:
if str(input.dtype) == dtype:
return input
else:
return fluid.layers.cast(input, dtype=dtype)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册