未验证 提交 ccab0e2a 编写于 作者: 傅剑寒 提交者: GitHub

fix uniform_rand_kernel FP16 support in dygraph mode (#46212)

上级 596d8209
......@@ -74,8 +74,12 @@ void UniformRandomRawKernel(const Context& dev_ctx,
funcs::distribution_and_transform<T>(dev_ctx, out, dist, trans);
} else {
// Use OP seed
auto func = UniformGenerator<T>(
min.to<float>(), max.to<float>(), seed, diag_num, diag_step, diag_val);
auto func = UniformGenerator<T>(static_cast<T>(min.to<float>()),
static_cast<T>(max.to<float>()),
seed,
diag_num,
diag_step,
static_cast<T>(diag_val));
IndexKernel<T, UniformGenerator<T>>(dev_ctx, out, func);
}
}
......@@ -87,4 +91,5 @@ PD_REGISTER_KERNEL(uniform_random_raw,
ALL_LAYOUT,
phi::UniformRandomRawKernel,
float,
double) {}
double,
phi::dtype::float16) {}
......@@ -51,8 +51,13 @@ PD_REGISTER_KERNEL(uniform_random,
phi::dtype::bfloat16) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL(
uniform_random, GPU, ALL_LAYOUT, phi::UniformRandomKernel, float, double) {}
PD_REGISTER_KERNEL(uniform_random,
GPU,
ALL_LAYOUT,
phi::UniformRandomKernel,
float,
double,
phi::dtype::float16) {}
#endif
#ifdef PADDLE_WITH_XPU
......
......@@ -585,8 +585,17 @@ class TestUniformDtype(unittest.TestCase):
out = paddle.tensor.random.uniform([2, 3])
self.assertEqual(out.dtype, fluid.core.VarDesc.VarType.FP64)
def test_dygraph_fp16():
if not paddle.is_compiled_with_cuda():
paddle.enable_static()
return
paddle.set_device('gpu')
out = paddle.uniform([2, 3], dtype=paddle.float16)
self.assertEqual(out.dtype, fluid.core.VarDesc.VarType.FP16)
test_default_fp64()
test_default_fp32()
test_dygraph_fp16()
paddle.enable_static()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册