From 672d56d7b09638f276d97e34e92da18356b88c5f Mon Sep 17 00:00:00 2001 From: lijianshe02 Date: Thu, 20 Jun 2019 09:19:45 +0000 Subject: [PATCH] add server batch_norm kernel and unitest --- paddle/fluid/lite/kernels/x86/CMakeLists.txt | 3 + .../lite/kernels/x86/batch_norm_compute.cc | 30 ++++ .../lite/kernels/x86/batch_norm_compute.h | 158 ++++++++++++++++++ .../kernels/x86/batch_norm_compute_test.cc | 139 +++++++++++++++ 4 files changed, 330 insertions(+) create mode 100644 paddle/fluid/lite/kernels/x86/batch_norm_compute.cc create mode 100644 paddle/fluid/lite/kernels/x86/batch_norm_compute.h create mode 100644 paddle/fluid/lite/kernels/x86/batch_norm_compute_test.cc diff --git a/paddle/fluid/lite/kernels/x86/CMakeLists.txt b/paddle/fluid/lite/kernels/x86/CMakeLists.txt index 35c61376153..f66818b2e9d 100644 --- a/paddle/fluid/lite/kernels/x86/CMakeLists.txt +++ b/paddle/fluid/lite/kernels/x86/CMakeLists.txt @@ -17,6 +17,7 @@ cc_library(dropout_compute_x86 SRCS dropout_compute.cc DEPS ${lite_kernel_deps} cc_library(concat_compute_x86 SRCS concat_compute.cc DEPS ${lite_kernel_deps} ) cc_library(conv_compute_x86 SRCS conv_compute.cc DEPS ${lite_kernel_deps} blas im2col vol2col) cc_library(pool_compute_x86 SRCS pool_compute.cc DEPS ${lite_kernel_deps} pooling) +cc_library(batch_norm_compute_x86 SRCS batch_norm_compute.cc DEPS ${lite_kernel_deps}) lite_cc_test(test_fc_compute_x86 SRCS fc_compute_test.cc DEPS fc_compute_x86) lite_cc_test(test_conv2d_compute_x86 SRCS conv_compute_test.cc DEPS conv_compute_x86) @@ -28,6 +29,7 @@ lite_cc_test(test_relu_compute_x86 SRCS relu_compute_test.cc DEPS relu_compute_x lite_cc_test(test_mul_compute_x86 SRCS mul_compute_test.cc DEPS mul_compute_x86 operator) lite_cc_test(test_scale_compute_x86 SRCS scale_compute_test.cc DEPS scale_compute_x86) lite_cc_test(test_dropout_compute_x86 SRCS dropout_compute_test.cc DEPS dropout_compute_x86) +lite_cc_test(test_batch_norm_compute_x86 SRCS batch_norm_compute_test.cc DEPS batch_norm_compute_x86) set(x86_kernels @@ -44,6 +46,7 @@ set(x86_kernels concat_compute_x86 conv_compute_x86 pool_compute_x86 + batch_norm_compute_x86 ) set(x86_kernels "${x86_kernels}" CACHE INTERNAL "x86 kernels") diff --git a/paddle/fluid/lite/kernels/x86/batch_norm_compute.cc b/paddle/fluid/lite/kernels/x86/batch_norm_compute.cc new file mode 100644 index 00000000000..008d2398014 --- /dev/null +++ b/paddle/fluid/lite/kernels/x86/batch_norm_compute.cc @@ -0,0 +1,30 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/kernels/x86/batch_norm_compute.h" + +REGISTER_LITE_KERNEL(batch_norm, kX86, kFloat, kNCHW, + paddle::lite::kernels::x86::BatchNormCompute, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindInput("Scale", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindInput("Mean", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindInput("Variance", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput("Y", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput("MeanOut", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput("VarianceOut", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput("MeanOut", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput("SavedMean", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput("SavedVariance", {LiteType::GetTensorTy(TARGET(kX86))}) + .Finalize(); diff --git a/paddle/fluid/lite/kernels/x86/batch_norm_compute.h b/paddle/fluid/lite/kernels/x86/batch_norm_compute.h new file mode 100644 index 00000000000..e9cf55d208d --- /dev/null +++ b/paddle/fluid/lite/kernels/x86/batch_norm_compute.h @@ -0,0 +1,158 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/lite/core/kernel.h" +#include "paddle/fluid/lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +template +using EigenArrayMap = + Eigen::Map>; +template +using ConstEigenArrayMap = + Eigen::Map>; +template +using EigenVectorArrayMap = Eigen::Map>; +template +using ConstEigenVectorArrayMap = + Eigen::Map>; + +template +class BatchNormCompute : public KernelLite { + public: + using param_t = operators::BatchNormParam; + void Run() override { + auto ¶m = *param_.get_mutable(); + bool global_stats = param.is_test || param.use_global_stats; + + const auto *x = param.x; + const auto &x_dims = x->dims(); + CHECK(x_dims.size() >= 2 && x_dims.size() <= 5); + const int N = x_dims[0]; + const int C = param.data_layout == DATALAYOUT(kNCHW) + ? x_dims[1] + : x_dims[x_dims.size() - 1]; + const int sample_size = x->dims().production() / N / C; + + // alloc memory + param.y->template mutable_data(); + param.mean_out->template mutable_data(); + param.variance_out->template mutable_data(); + param.saved_mean->template mutable_data(); + param.saved_variance->template mutable_data(); + + if (!global_stats) { + // saved_xx is use just in this batch of data + EigenVectorArrayMap saved_mean_e(param.saved_mean->mutable_data(), + C); + EigenVectorArrayMap saved_variance_e( + param.saved_variance->mutable_data(), C); + saved_mean_e.setZero(); + saved_variance_e.setZero(); + + EigenVectorArrayMap running_mean_arr(param.mean_out->mutable_data(), + C); + EigenVectorArrayMap running_var_arr( + param.variance_out->mutable_data(), C); + + if ((N * sample_size) == 1) { + LOG(WARNING) << "Only 1 element in normalization dimension, " + << "we skip the batch norm calculation, let y = x."; + framework::TensorCopy(x->raw_tensor(), platform::CPUPlace(), + ¶m.y->raw_tensor()); + return; + } + + switch (param.data_layout) { + case DATALAYOUT(kNCHW): { + ConstEigenArrayMap x_arr(x->data(), sample_size, N * C); + for (int nc = 0; nc < N * C; ++nc) { + saved_mean_e(nc % C) += x_arr.col(nc).sum(); + } + saved_mean_e /= N * sample_size; + for (int nc = 0; nc < N * C; ++nc) { + saved_variance_e(nc % C) += + (x_arr.col(nc) - saved_mean_e(nc % C)).matrix().squaredNorm(); + } + saved_variance_e /= N * sample_size; + break; + } + default: + LOG(FATAL) << "Unknown storage order: " + << DataLayoutToStr(param.data_layout); + break; + } + running_mean_arr = running_mean_arr * param.momentum + + saved_mean_e * (1. - param.momentum); + running_var_arr = running_var_arr * param.momentum + + saved_variance_e * (1. - param.momentum); + } + + // use SavedMean and SavedVariance to do normalize + Eigen::Array inv_std(C); + if (global_stats) { + ConstEigenVectorArrayMap var_arr(param.variance->data(), C); + inv_std = (var_arr + param.epsilon).sqrt().inverse(); + } else { + EigenVectorArrayMap saved_inv_std( + param.saved_variance->mutable_data(), C); + // inverse SavedVariance first, gradient will use it too. + saved_inv_std = (saved_inv_std + param.epsilon).inverse().sqrt(); + inv_std = saved_inv_std; + } + + ConstEigenVectorArrayMap mean_arr( + global_stats ? param.mean->data() : param.saved_mean->data(), C); + + // ((x - est_mean) * (inv_var) * scale + bias + // formula transform ====> + // (x * inv_var * scale) + (bias - est_mean * inv_var * scale) + + ConstEigenVectorArrayMap scale_arr(param.scale->data(), C); + ConstEigenVectorArrayMap bias_arr(param.bias->data(), C); + Eigen::Array new_scale = inv_std * scale_arr; + Eigen::Array new_bias = + bias_arr - mean_arr * inv_std * scale_arr; + + switch (param.data_layout) { + case DATALAYOUT(kNCHW): { + EigenArrayMap y_arr(param.y->mutable_data(), sample_size, N * C); + ConstEigenArrayMap x_arr(x->data(), sample_size, N * C); + for (int nc = 0; nc < N * C; ++nc) { + y_arr.col(nc) = x_arr.col(nc) * new_scale(nc % C) + new_bias(nc % C); + } + break; + } + default: + LOG(FATAL) << "Unknown storage order: " + << DataLayoutToStr(param.data_layout); + break; + } + } + virtual ~BatchNormCompute() = default; +}; + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/kernels/x86/batch_norm_compute_test.cc b/paddle/fluid/lite/kernels/x86/batch_norm_compute_test.cc new file mode 100644 index 00000000000..d9c53035db1 --- /dev/null +++ b/paddle/fluid/lite/kernels/x86/batch_norm_compute_test.cc @@ -0,0 +1,139 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/kernels/x86/batch_norm_compute.h" +#include +#include +#include +#include "paddle/fluid/lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +TEST(batch_norm_x86, retrive_op) { + auto batch_norm = + KernelRegistry::Global().Create( + "batch_norm"); + ASSERT_FALSE(batch_norm.empty()); + ASSERT_TRUE(batch_norm.front()); +} + +TEST(batch_norm_x86, init) { + BatchNormCompute batch_norm; + ASSERT_EQ(batch_norm.precision(), PRECISION(kFloat)); + ASSERT_EQ(batch_norm.target(), TARGET(kX86)); +} + +TEST(batch_norm_x86, run_test) { + lite::Tensor x, scale, bias, mean, variance, y, mean_out, variance_out, + saved_mean, saved_variance; + constexpr int batch_size = 2; + std::vector x_shape{batch_size, 3, 64, 64}; + x.Resize(lite::DDim(x_shape)); + + std::vector scale_shape{3}; + scale.Resize(lite::DDim(scale_shape)); + + std::vector bias_shape{3}; + bias.Resize(lite::DDim(bias_shape)); + + std::vector mean_shape{3}; + mean.Resize(lite::DDim(mean_shape)); + + std::vector variance_shape{3}; + variance.Resize(lite::DDim(variance_shape)); + + std::vector y_shape{batch_size, 3, 64, 64}; + y.Resize(lite::DDim(y_shape)); + + std::vector mean_out_shape{3}; + mean_out.Resize(lite::DDim(mean_out_shape)); + + std::vector variance_out_shape{3}; + variance_out.Resize(lite::DDim(variance_out_shape)); + + std::vector saved_mean_shape{3}; + saved_mean.Resize(lite::DDim(saved_mean_shape)); + + std::vector saved_variance_shape{3}; + saved_variance.Resize(lite::DDim(saved_variance_shape)); + + auto x_data = x.mutable_data(); + auto scale_data = scale.mutable_data(); + auto bias_data = bias.mutable_data(); + auto mean_data = mean.mutable_data(); + auto variance_data = variance.mutable_data(); + y.mutable_data(); + mean_out.mutable_data(); + variance_out.mutable_data(); + saved_mean.mutable_data(); + saved_variance.mutable_data(); + + for (int64_t i = 0; i < x.dims().production(); i++) { + x_data[i] = static_cast(i); + } + for (int i = 0; i < scale.dims().production(); i++) { + scale_data[i] = static_cast(i) * 0.01f + 0.03f; + } + for (int i = 0; i < bias.dims().production(); i++) { + bias_data[i] = static_cast(i) * 0.065f + 0.1f; + } + for (int i = 0; i < mean.dims().production(); i++) { + mean_data[i] = static_cast(i) * 0.0565f; + } + for (int i = 0; i < variance.dims().production(); i++) { + variance_data[i] = static_cast(i) * 2.08f + 1.5f; + } + // BatchNormCompute batch_norm; + BatchNormCompute batch_norm; + operators::BatchNormParam param; + + param.x = &x; + param.is_test = false; + param.scale = &scale; + param.bias = &bias; + param.mean = &mean; + param.variance = &variance; + param.use_global_stats = false; + param.epsilon = 1e-4f; + param.momentum = 0.9f; + param.y = &y; + param.mean_out = &mean_out; + param.variance_out = &variance_out; + param.saved_mean = &saved_mean; + param.saved_variance = &saved_variance; + + batch_norm.SetParam(param); + batch_norm.Run(); + + LOG(INFO) << "output: " << y; + LOG(INFO) << "mean_out: " << mean_out; + LOG(INFO) << "variance_out: " << mean_out; + LOG(INFO) << "saved_mean: " << saved_mean; + LOG(INFO) << "saved_variance: " << saved_variance; + + /*for (int i = 0; i < y.dims().production(); i++) { + if(i < 5 || i > y.dims().production() - 5) + LOG(INFO) << y_data[i]; + }*/ +} + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle + +USE_LITE_KERNEL(batch_norm, kX86, kFloat, kNCHW, def); -- GitLab