提交 6933c683 编写于 作者: Y Yelrose

fixed dtype check

上级 1ff69d40
...@@ -269,7 +269,7 @@ def masked_select(input, mask): ...@@ -269,7 +269,7 @@ def masked_select(input, mask):
def ensure_dtype(input, dtype): def ensure_dtype(input, dtype):
if input.dtype == dtype: if str(input.dtype) == dtype:
return input return input
else: else:
return fluid.layers.cast(input, dtype=dtype) return fluid.layers.cast(input, dtype=dtype)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册