未验证 提交 569b018e 编写于 作者: S shentanyue 提交者: GitHub

xpu gaussian_random support fp16 (#50881)

上级 4a9b694b
......@@ -338,7 +338,8 @@ XPUOpMap& get_kl2_ops() {
phi::DataType::INT32,
phi::DataType::INT64,
phi::DataType::BOOL})},
{"gaussian_random", XPUKernelSet({phi::DataType::FLOAT32})},
{"gaussian_random",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"gelu_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"gelu", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
......
......@@ -29,7 +29,7 @@ void GaussianKernel(const Context& ctx,
int seed,
DataType dtype,
DenseTensor* out) {
std::normal_distribution<T> dist(mean, std);
std::normal_distribution<float> dist(mean, std);
int64_t size = out->numel();
ctx.template Alloc<T>(out);
auto* data = out->data();
......@@ -57,4 +57,9 @@ void GaussianKernel(const Context& ctx,
} // namespace phi
PD_REGISTER_KERNEL(gaussian, XPU, ALL_LAYOUT, phi::GaussianKernel, float) {}
PD_REGISTER_KERNEL(gaussian,
XPU,
ALL_LAYOUT,
phi::GaussianKernel,
float,
phi::dtype::float16) {}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册