未验证 提交 ccab0e2a 编写于 作者: 傅剑寒 提交者: GitHub

fix uniform_rand_kernel FP16 support in dygraph mode (#46212)

上级 596d8209
...@@ -74,8 +74,12 @@ void UniformRandomRawKernel(const Context& dev_ctx, ...@@ -74,8 +74,12 @@ void UniformRandomRawKernel(const Context& dev_ctx,
funcs::distribution_and_transform<T>(dev_ctx, out, dist, trans); funcs::distribution_and_transform<T>(dev_ctx, out, dist, trans);
} else { } else {
// Use OP seed // Use OP seed
auto func = UniformGenerator<T>( auto func = UniformGenerator<T>(static_cast<T>(min.to<float>()),
min.to<float>(), max.to<float>(), seed, diag_num, diag_step, diag_val); static_cast<T>(max.to<float>()),
seed,
diag_num,
diag_step,
static_cast<T>(diag_val));
IndexKernel<T, UniformGenerator<T>>(dev_ctx, out, func); IndexKernel<T, UniformGenerator<T>>(dev_ctx, out, func);
} }
} }
...@@ -87,4 +91,5 @@ PD_REGISTER_KERNEL(uniform_random_raw, ...@@ -87,4 +91,5 @@ PD_REGISTER_KERNEL(uniform_random_raw,
ALL_LAYOUT, ALL_LAYOUT,
phi::UniformRandomRawKernel, phi::UniformRandomRawKernel,
float, float,
double) {} double,
phi::dtype::float16) {}
...@@ -51,8 +51,13 @@ PD_REGISTER_KERNEL(uniform_random, ...@@ -51,8 +51,13 @@ PD_REGISTER_KERNEL(uniform_random,
phi::dtype::bfloat16) {} phi::dtype::bfloat16) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(uniform_random,
uniform_random, GPU, ALL_LAYOUT, phi::UniformRandomKernel, float, double) {} GPU,
ALL_LAYOUT,
phi::UniformRandomKernel,
float,
double,
phi::dtype::float16) {}
#endif #endif
#ifdef PADDLE_WITH_XPU #ifdef PADDLE_WITH_XPU
......
...@@ -585,8 +585,17 @@ class TestUniformDtype(unittest.TestCase): ...@@ -585,8 +585,17 @@ class TestUniformDtype(unittest.TestCase):
out = paddle.tensor.random.uniform([2, 3]) out = paddle.tensor.random.uniform([2, 3])
self.assertEqual(out.dtype, fluid.core.VarDesc.VarType.FP64) self.assertEqual(out.dtype, fluid.core.VarDesc.VarType.FP64)
def test_dygraph_fp16():
if not paddle.is_compiled_with_cuda():
paddle.enable_static()
return
paddle.set_device('gpu')
out = paddle.uniform([2, 3], dtype=paddle.float16)
self.assertEqual(out.dtype, fluid.core.VarDesc.VarType.FP16)
test_default_fp64() test_default_fp64()
test_default_fp32() test_default_fp32()
test_dygraph_fp16()
paddle.enable_static() paddle.enable_static()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册