提交 29d24dbb 编写于 作者: M Megvii Engine Team

fix(mge/function): fix interpolate unsupport fp16 error

GitOrigin-RevId: 7fc6271986ce94b496e7be42b1d4e8500a3ca921
上级 36df3850
......@@ -582,7 +582,8 @@ def interpolate(
"nearest": "nearest",
"bicubic": "cubic",
}
if inp.dtype == np.float16:
inp = inp.astype("float32")
op = builtin.Resize(imode=mode_map[mode], format="NCHW")
shape = astensor1d(dsize, inp, dtype="int32", device=inp.device)
(ret,) = apply(op, inp, shape)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册