未验证 提交 fad4744a 编写于 作者: T taixiurong 提交者: GitHub

fix crash in adam in xpu, *test=kunlun (#28433)

上级 6bba8e57
......@@ -74,7 +74,7 @@ class AdamOpXPUKernel : public framework::OpKernel<T> {
"output size is 1, but received "
"value is:%d.",
beta2_pow_out->numel()));
T beta1 = static_cast<T>(ctx.Attr<float>("beta1"));
if (ctx.HasInput("Beta1Tensor")) {
auto* beta1_tensor = ctx.Input<framework::Tensor>("Beta1Tensor");
......@@ -88,30 +88,53 @@ class AdamOpXPUKernel : public framework::OpKernel<T> {
if (grad_var->IsType<framework::LoDTensor>()) {
auto& grad = GET_DATA_SAFELY(ctx.Input<LoDTensor>("Grad"), "Input",
"Grad", "Adam");
auto& dev_ctx = ctx.template device_context<DeviceContext>();
const T* beta1_pow_ptr = beta1_pow.template data<T>();
const T* beta2_pow_ptr = beta2_pow.template data<T>();
Tensor xpu_beta1_pow;
Tensor xpu_beta2_pow;
if (beta1_pow.place() == platform::CPUPlace() &&
beta2_pow.place() == platform::CPUPlace()) {
TensorCopy(beta1_pow, ctx.GetPlace(), dev_ctx, &xpu_beta1_pow);
TensorCopy(beta2_pow, ctx.GetPlace(), dev_ctx, &xpu_beta2_pow);
dev_ctx.Wait();
beta1_pow_ptr = xpu_beta1_pow.template data<T>();
beta2_pow_ptr = xpu_beta2_pow.template data<T>();
}
int r = xpu::adam(
dev_ctx.x_context(), grad.template data<T>(), mom1.template data<T>(),
mom2.template data<T>(), param.template data<T>(),
beta1_pow.template data<T>(), beta2_pow.template data<T>(), beta1,
beta2, epsilon, lr.template data<T>(),
mom2.template data<T>(), param.template data<T>(), beta1_pow_ptr,
beta2_pow_ptr, beta1, beta2, epsilon, lr.template data<T>(),
mom1_out.template mutable_data<T>(ctx.GetPlace()),
mom2_out.template mutable_data<T>(ctx.GetPlace()),
param_out.template mutable_data<T>(ctx.GetPlace()), param.numel());
const float* ptr0 = beta1_pow.template data<T>();
float* ptr1 = beta1_pow_out->mutable_data<T>(ctx.GetPlace());
float cpudata;
xpu_memcpy(&cpudata, ptr0, sizeof(float), XPU_DEVICE_TO_HOST);
cpudata = cpudata * beta1;
xpu_memcpy(ptr1, &cpudata, sizeof(float), XPU_HOST_TO_DEVICE);
const float* ptr2 = beta2_pow.template data<T>();
float* ptr3 = beta2_pow_out->mutable_data<T>(ctx.GetPlace());
float cpudata1;
xpu_memcpy(&cpudata1, ptr2, sizeof(float), XPU_DEVICE_TO_HOST);
cpudata1 = cpudata1 * beta2;
xpu_memcpy(ptr3, &cpudata1, sizeof(float), XPU_HOST_TO_DEVICE);
//update in cpu and then copy to xpu
if (beta1_pow.place() == platform::CPUPlace() &&
beta2_pow.place() == platform::CPUPlace()) {
const T* beta1_pow_p = beta1_pow.template data<T>();
beta1_pow_out->mutable_data<T>(platform::CPUPlace())[0] =
beta1 * beta1_pow_p[0];
const T* beta2_pow_p = beta2_pow.template data<T>();
beta2_pow_out->mutable_data<T>(platform::CPUPlace())[0] =
beta2 * beta2_pow_p[0];
} else {
T cpu_beta1_pow_out_data;
T cpu_beta2_pow_out_data;
xpu_memcpy(&cpu_beta1_pow_out_data, beta1_pow_ptr, sizeof(T),
XPU_DEVICE_TO_HOST);
cpu_beta1_pow_out_data = cpu_beta1_pow_out_data * beta1;
xpu_memcpy(&cpu_beta2_pow_out_data, beta2_pow_ptr, sizeof(T),
XPU_DEVICE_TO_HOST);
cpu_beta2_pow_out_data = cpu_beta2_pow_out_data * beta2;
T* beta1_pow_out_p = beta1_pow_out->mutable_data<T>(ctx.GetPlace());
T* beta2_pow_out_p = beta2_pow_out->mutable_data<T>(ctx.GetPlace());
xpu_memcpy(beta1_pow_out_p, &cpu_beta1_pow_out_data, sizeof(T),
XPU_HOST_TO_DEVICE);
xpu_memcpy(beta2_pow_out_p, &cpu_beta2_pow_out_data, sizeof(T),
XPU_HOST_TO_DEVICE);
}
PADDLE_ENFORCE_EQ(r == xpu::Error_t::SUCCESS, true,
platform::errors::External(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册