未验证 提交 8305ba37 编写于 作者: T TTerror 提交者: GitHub

fix bn_infer and optimize momentum for kunlun (#35250)

上级 8ba58eb0
......@@ -35,7 +35,7 @@ ELSE ()
ENDIF()
SET(XPU_BASE_URL_WITHOUT_DATE "https://baidu-kunlun-product.cdn.bcebos.com/KL-SDK/klsdk-dev")
SET(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20210826")
SET(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20210830")
SET(XPU_XRE_URL "${XPU_BASE_URL}/${XPU_XRE_DIR_NAME}.tar.gz" CACHE STRING "" FORCE)
SET(XPU_XDNN_URL "${XPU_BASE_URL}/${XPU_XDNN_DIR_NAME}.tar.gz" CACHE STRING "" FORCE)
SET(XPU_XCCL_URL "${XPU_BASE_URL_WITHOUT_DATE}/20210623/${XPU_XCCL_DIR_NAME}.tar.gz" CACHE STRING "" FORCE)
......
......@@ -76,26 +76,25 @@ class BatchNormXPUKernel : public framework::OpKernel<T> {
W, epsilon, momentum, scale_data, bias_data,
saved_mean_data, saved_variance_data,
mean_out_data, variance_out_data, true);
PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS,
platform::errors::External("XPU API(batch_norm_train_forward) return "
"wrong value[%d], please check whether "
"Baidu Kunlun Card is properly installed.",
r));
PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS,
platform::errors::External(
"The batch_norm XPU API return wrong value[%d %s]",
r, XPUAPIErrorMsg[r]));
} else {
const auto* mean = ctx.Input<Tensor>("Mean");
const auto* variance = ctx.Input<Tensor>("Variance");
const auto* mean_data = mean->data<T>();
const auto* variance_data = variance->data<T>();
int r = xpu::batch_norm_infer_forward(
dev_ctx.x_context(), epsilon, N, C, H, W, x_data, y_data, scale_data,
bias_data, mean_data, variance_data);
const auto* mean_data = mean->data<float>();
const auto* variance_data = variance->data<float>();
const auto* x_data = x->data<float>();
auto* y_data = y->mutable_data<float>(ctx.GetPlace());
int r = xpu::batch_norm_infer(dev_ctx.x_context(), x_data, y_data, N, C,
H, W, epsilon, scale_data, bias_data,
mean_data, variance_data, true);
PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS,
platform::errors::External("XPU API(batch_norm_infer_forward) return "
"wrong value[%d], please check whether "
"Baidu Kunlun Card is properly installed.",
r));
r, xpu::Error_t::SUCCESS,
platform::errors::External(
"The batch_norm_infer XPU API return wrong value[%d %s]", r,
XPUAPIErrorMsg[r]));
}
}
};
......
......@@ -44,10 +44,10 @@ class MomentumOpXPUKernel : public framework::OpKernel<T> {
auto grad = ctx.Input<framework::Tensor>("Grad");
auto& dev_ctx = ctx.template device_context<DeviceContext>();
int r = xpu::momentum(
dev_ctx.x_context(), param->data<float>(), velocity->data<float>(),
grad->data<float>(), lr, use_nesterov, mu, param_out->numel(),
param_out->data<float>(), velocity_out->data<float>());
int r = xpu::momentum(dev_ctx.x_context(), param->data<float>(),
velocity->data<float>(), grad->data<float>(),
param_out->data<float>(), velocity_out->data<float>(),
param_out->numel(), lr, use_nesterov, mu);
if (r == xpu::Error_t::INVALID_PARAM) {
PADDLE_ENFORCE_EQ(
r, xpu::Error_t::SUCCESS,
......
......@@ -75,6 +75,10 @@ XPUOpMap& get_kl2_ops() {
{"elementwise_min_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"momentum", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"batch_norm", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"batch_norm_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
// AddMore
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册