未验证 提交 8305ba37 编写于 作者: T TTerror 提交者: GitHub

fix bn_infer and optimize momentum for kunlun (#35250)

上级 8ba58eb0
...@@ -35,7 +35,7 @@ ELSE () ...@@ -35,7 +35,7 @@ ELSE ()
ENDIF() ENDIF()
SET(XPU_BASE_URL_WITHOUT_DATE "https://baidu-kunlun-product.cdn.bcebos.com/KL-SDK/klsdk-dev") SET(XPU_BASE_URL_WITHOUT_DATE "https://baidu-kunlun-product.cdn.bcebos.com/KL-SDK/klsdk-dev")
SET(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20210826") SET(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20210830")
SET(XPU_XRE_URL "${XPU_BASE_URL}/${XPU_XRE_DIR_NAME}.tar.gz" CACHE STRING "" FORCE) SET(XPU_XRE_URL "${XPU_BASE_URL}/${XPU_XRE_DIR_NAME}.tar.gz" CACHE STRING "" FORCE)
SET(XPU_XDNN_URL "${XPU_BASE_URL}/${XPU_XDNN_DIR_NAME}.tar.gz" CACHE STRING "" FORCE) SET(XPU_XDNN_URL "${XPU_BASE_URL}/${XPU_XDNN_DIR_NAME}.tar.gz" CACHE STRING "" FORCE)
SET(XPU_XCCL_URL "${XPU_BASE_URL_WITHOUT_DATE}/20210623/${XPU_XCCL_DIR_NAME}.tar.gz" CACHE STRING "" FORCE) SET(XPU_XCCL_URL "${XPU_BASE_URL_WITHOUT_DATE}/20210623/${XPU_XCCL_DIR_NAME}.tar.gz" CACHE STRING "" FORCE)
......
...@@ -76,26 +76,25 @@ class BatchNormXPUKernel : public framework::OpKernel<T> { ...@@ -76,26 +76,25 @@ class BatchNormXPUKernel : public framework::OpKernel<T> {
W, epsilon, momentum, scale_data, bias_data, W, epsilon, momentum, scale_data, bias_data,
saved_mean_data, saved_variance_data, saved_mean_data, saved_variance_data,
mean_out_data, variance_out_data, true); mean_out_data, variance_out_data, true);
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS,
r, XPU_SUCCESS, platform::errors::External(
platform::errors::External("XPU API(batch_norm_train_forward) return " "The batch_norm XPU API return wrong value[%d %s]",
"wrong value[%d], please check whether " r, XPUAPIErrorMsg[r]));
"Baidu Kunlun Card is properly installed.",
r));
} else { } else {
const auto* mean = ctx.Input<Tensor>("Mean"); const auto* mean = ctx.Input<Tensor>("Mean");
const auto* variance = ctx.Input<Tensor>("Variance"); const auto* variance = ctx.Input<Tensor>("Variance");
const auto* mean_data = mean->data<T>(); const auto* mean_data = mean->data<float>();
const auto* variance_data = variance->data<T>(); const auto* variance_data = variance->data<float>();
int r = xpu::batch_norm_infer_forward( const auto* x_data = x->data<float>();
dev_ctx.x_context(), epsilon, N, C, H, W, x_data, y_data, scale_data, auto* y_data = y->mutable_data<float>(ctx.GetPlace());
bias_data, mean_data, variance_data); int r = xpu::batch_norm_infer(dev_ctx.x_context(), x_data, y_data, N, C,
H, W, epsilon, scale_data, bias_data,
mean_data, variance_data, true);
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS, r, xpu::Error_t::SUCCESS,
platform::errors::External("XPU API(batch_norm_infer_forward) return " platform::errors::External(
"wrong value[%d], please check whether " "The batch_norm_infer XPU API return wrong value[%d %s]", r,
"Baidu Kunlun Card is properly installed.", XPUAPIErrorMsg[r]));
r));
} }
} }
}; };
......
...@@ -44,10 +44,10 @@ class MomentumOpXPUKernel : public framework::OpKernel<T> { ...@@ -44,10 +44,10 @@ class MomentumOpXPUKernel : public framework::OpKernel<T> {
auto grad = ctx.Input<framework::Tensor>("Grad"); auto grad = ctx.Input<framework::Tensor>("Grad");
auto& dev_ctx = ctx.template device_context<DeviceContext>(); auto& dev_ctx = ctx.template device_context<DeviceContext>();
int r = xpu::momentum( int r = xpu::momentum(dev_ctx.x_context(), param->data<float>(),
dev_ctx.x_context(), param->data<float>(), velocity->data<float>(), velocity->data<float>(), grad->data<float>(),
grad->data<float>(), lr, use_nesterov, mu, param_out->numel(), param_out->data<float>(), velocity_out->data<float>(),
param_out->data<float>(), velocity_out->data<float>()); param_out->numel(), lr, use_nesterov, mu);
if (r == xpu::Error_t::INVALID_PARAM) { if (r == xpu::Error_t::INVALID_PARAM) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
r, xpu::Error_t::SUCCESS, r, xpu::Error_t::SUCCESS,
......
...@@ -75,6 +75,10 @@ XPUOpMap& get_kl2_ops() { ...@@ -75,6 +75,10 @@ XPUOpMap& get_kl2_ops() {
{"elementwise_min_grad", {"elementwise_min_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})}, pOpKernelType(vartype::FP16, XPUPlace())})},
{"momentum", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"batch_norm", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"batch_norm_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
// AddMore // AddMore
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册