未验证 提交 060e4fab 编写于 作者: H houj04 提交者: GitHub

[XPU] using xpu::normal in gaussian kernel. (#54176)

上级 74ec1993
......@@ -8,7 +8,7 @@ set(XPU_API_LIB_NAME "libxpuapi.so")
set(XPU_RT_LIB_NAME "libxpurt.so")
set(XPU_XFT_LIB_NAME "libxft.so")
set(XPU_BASE_DATE "20230523")
set(XPU_BASE_DATE "20230529")
set(XPU_XCCL_BASE_VERSION "1.0.49.2")
set(XPU_XFT_BASE_VERSION "latest")
......
......@@ -15,7 +15,6 @@
#include "paddle/phi/kernels/gaussian_kernel.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/common/memory_utils.h"
#include "paddle/phi/core/generator.h"
#include "paddle/phi/core/kernel_registry.h"
......@@ -29,30 +28,19 @@ void GaussianKernel(const Context& ctx,
int seed,
DataType dtype,
DenseTensor* out) {
std::normal_distribution<float> dist(mean, std);
int64_t size = out->numel();
ctx.template Alloc<T>(out);
auto* data = out->data();
uint64_t seed_v = static_cast<uint64_t>(seed);
// TODO(pangyoki): implement GetXPURandomEngine to set different seeds on
// corresponding XPU device.
std::shared_ptr<std::mt19937_64> engine;
if (seed_v) {
engine = std::make_shared<std::mt19937_64>();
engine->seed(seed_v);
} else {
engine = ctx.GetGenerator()->GetCPUEngine();
}
out->Resize(phi::make_ddim(shape.GetData()));
T* data = ctx.template Alloc<T>(out);
using XPUType = typename XPUTypeTrait<T>::Type;
int64_t real_seed = seed != 0 ? seed : ctx.GetGenerator()->Random64();
std::unique_ptr<T[]> data_cpu(new T[size]);
for (int64_t i = 0; i < size; ++i) {
data_cpu[i] = dist(*engine);
}
memory_utils::Copy(ctx.GetPlace(),
data,
phi::CPUPlace(),
reinterpret_cast<void*>(data_cpu.get()),
size * sizeof(T));
// int normal(Context* ctx, T* x, T mean, T std, int64_t len, int64_t seed);
int r = xpu::normal<XPUType>(ctx.x_context(),
reinterpret_cast<XPUType*>(data),
mean,
std,
out->numel(),
real_seed);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "normal");
}
} // namespace phi
......
......@@ -14,6 +14,7 @@
#include "paddle/phi/kernels/gelu_kernel.h"
#include "glog/logging.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/backends/xpu/xpu_context.h"
#include "paddle/phi/common/float16.h"
......@@ -26,6 +27,9 @@ void GeluKernel(const Context& dev_ctx,
const DenseTensor& x,
bool approximate,
DenseTensor* out) {
if (approximate) {
LOG_FIRST_N(INFO, 1) << "XPU does not support gelu with approximate.";
}
using XPUType = typename XPUTypeTrait<T>::Type;
dev_ctx.template Alloc<T>(out);
int r = xpu::gelu<XPUType>(dev_ctx.x_context(),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册