diff --git a/paddle/fluid/lite/arm/math/scale.cc b/paddle/fluid/lite/arm/math/scale.cc index 40b91e6979f6f330f96f4c086fe1856707d9b189..ce969358f689ef7713efb435ce58ba72471d282b 100644 --- a/paddle/fluid/lite/arm/math/scale.cc +++ b/paddle/fluid/lite/arm/math/scale.cc @@ -58,6 +58,111 @@ void scale(const float* din, float* dout, int num, float scale, } } +template <> +void scale(const float* din, float* dout, int outer_dim, int scale_dim, + int inner_dim, const float* scale_data, + const float* bias_data) { + int cnt = inner_dim >> 4; + int remain = inner_dim % 16; + int size = inner_dim * scale_dim; + for (int n = 0; n < outer_dim; n++) { + const float* din_ptr_n = din + n * size; + float* dout_ptr_n = dout + n * size; +#pragma omp parallel for + for (int i = 0; i < scale_dim; i++) { + const float* din_ptr = din_ptr_n + i * inner_dim; + float* dout_ptr = dout_ptr_n + i * inner_dim; + float scale = scale_data[i]; + float32x4_t vscale = vdupq_n_f32(scale); + float bias = bias_data[i]; + float32x4_t vbias = vdupq_n_f32(bias); + for (int j = 0; j < cnt; j++) { + float32x4_t din0 = vld1q_f32(din_ptr); + float32x4_t din1 = vld1q_f32(din_ptr + 4); + float32x4_t din2 = vld1q_f32(din_ptr + 8); + float32x4_t din3 = vld1q_f32(din_ptr + 12); + + float32x4_t vsum1 = vmlaq_f32(vbias, din0, vscale); + float32x4_t vsum2 = vmlaq_f32(vbias, din1, vscale); + float32x4_t vsum3 = vmlaq_f32(vbias, din2, vscale); + float32x4_t vsum4 = vmlaq_f32(vbias, din3, vscale); + + din_ptr += 16; + vst1q_f32(dout_ptr, vsum1); + vst1q_f32(dout_ptr + 4, vsum2); + vst1q_f32(dout_ptr + 8, vsum3); + vst1q_f32(dout_ptr + 12, vsum4); + + dout_ptr += 16; + } + for (int j = 0; j < remain; j++) { + *dout_ptr = *din_ptr * scale + bias; + dout_ptr++; + din_ptr++; + } + } + } +} + +template <> +void scale(const float* din, float* dout, int outer_dim, int scale_dim, + const float* scale_data, const float* bias_data) { + int cnt = scale_dim >> 4; + int remain = scale_dim % 16; + for (int n = 0; n < outer_dim; n++) { + const float* din_ptr_n = din + n * scale_dim; + float* dout_ptr_n = dout + n * scale_dim; +#pragma omp parallel for + for (int i = 0; i < cnt; i++) { + int idx = i << 4; + const float* din_ptr = din_ptr_n + idx; + const float* scale_ptr = scale_data + idx; + const float* bias_ptr = bias_data + idx; + float* dout_ptr = dout_ptr_n + idx; + + float32x4_t din0 = vld1q_f32(din_ptr); + float32x4_t vscale0 = vld1q_f32(scale_ptr); + float32x4_t vbias0 = vld1q_f32(bias_ptr); + + float32x4_t din1 = vld1q_f32(din_ptr + 4); + float32x4_t vscale1 = vld1q_f32(scale_ptr + 4); + float32x4_t vbias1 = vld1q_f32(bias_ptr + 4); + + float32x4_t din2 = vld1q_f32(din_ptr + 8); + float32x4_t vscale2 = vld1q_f32(scale_ptr + 8); + float32x4_t vbias2 = vld1q_f32(bias_ptr + 8); + + float32x4_t vsum1 = vmlaq_f32(vbias0, din0, vscale0); + float32x4_t vsum2 = vmlaq_f32(vbias1, din1, vscale1); + + float32x4_t din3 = vld1q_f32(din_ptr + 12); + float32x4_t vscale3 = vld1q_f32(scale_ptr + 12); + float32x4_t vbias3 = vld1q_f32(bias_ptr + 12); + + vst1q_f32(dout_ptr, vsum1); + vst1q_f32(dout_ptr + 4, vsum2); + + float32x4_t vsum3 = vmlaq_f32(vbias2, din2, vscale2); + float32x4_t vsum4 = vmlaq_f32(vbias3, din3, vscale3); + + vst1q_f32(dout_ptr + 8, vsum3); + vst1q_f32(dout_ptr + 12, vsum4); + } + int idx = cnt << 4; + const float* din_ptr = din_ptr_n + idx; + float* dout_ptr = dout_ptr_n + idx; + const float* scale_ptr = scale_data + idx; + const float* bias_ptr = bias_data + idx; + for (int j = 0; j < remain; j++) { + *dout_ptr = *din_ptr * (*scale_ptr) + (*bias_ptr); + dout_ptr++; + din_ptr++; + scale_ptr++; + bias_ptr++; + } + } +} + } // namespace math } // namespace arm } // namespace lite diff --git a/paddle/fluid/lite/arm/math/scale.h b/paddle/fluid/lite/arm/math/scale.h index 97a5f79fc6bfabee5e38854e2ba89ce388648aac..2274dd23d2f4f486e39b97ad5040bde47af8a042 100644 --- a/paddle/fluid/lite/arm/math/scale.h +++ b/paddle/fluid/lite/arm/math/scale.h @@ -22,6 +22,14 @@ namespace math { template void scale(const T* din, T* dout, int num, float scale, float bias); +template +void scale(const T* din, T* dout, int outer_dim, int scale_dim, int inner_dim, + const float* scale_data, const float* bias_data); + +template +void scale(const T* din, T* dout, int outer_dim, int scale_dim, + const float* scale_data, const float* bias_data); + } // namespace math } // namespace arm } // namespace lite diff --git a/paddle/fluid/lite/kernels/arm/CMakeLists.txt b/paddle/fluid/lite/kernels/arm/CMakeLists.txt index 604473bb4bc512ed7656f5c7cf6a9fd2f3b82647..565c4a8a81de7990f58418f60d4c0e234fba5554 100644 --- a/paddle/fluid/lite/kernels/arm/CMakeLists.txt +++ b/paddle/fluid/lite/kernels/arm/CMakeLists.txt @@ -10,6 +10,7 @@ cc_library(mul_compute_arm SRCS mul_compute.cc DEPS ${lite_kernel_deps} math_arm cc_library(scale_compute_arm SRCS scale_compute.cc DEPS ${lite_kernel_deps} math_arm) cc_library(softmax_compute_arm SRCS softmax_compute.cc DEPS ${lite_kernel_deps} math_arm) cc_library(conv_compute_arm SRCS conv_compute.cc DEPS ${lite_kernel_deps} math_arm) +cc_library(batch_norm_compute_arm SRCS batch_norm_compute.cc DEPS ${lite_kernel_deps} math_arm) cc_library(elementwise_add_compute_arm SRCS elementwise_add_compute.cc DEPS ${lite_kernel_deps} math_arm) cc_library(pool_compute_arm SRCS pool_compute.cc DEPS ${lite_kernel_deps} math_arm) cc_library(split_compute_arm SRCS split_compute.cc DEPS ${lite_kernel_deps} math_arm) @@ -18,6 +19,7 @@ lite_cc_test(test_fc_compute_arm SRCS fc_compute_test.cc DEPS fc_compute_arm mat lite_cc_test(test_scale_compute_arm SRCS scale_compute_test.cc DEPS scale_compute_arm) lite_cc_test(test_softmax_compute_arm SRCS softmax_compute_test.cc DEPS softmax_compute_arm) lite_cc_test(test_conv_compute_arm SRCS conv_compute_test.cc DEPS conv_compute_arm) +lite_cc_test(test_batch_norm_compute_arm SRCS batch_norm_compute_test.cc DEPS batch_norm_compute_arm) lite_cc_test(test_elementwise_add_compute_arm SRCS elementwise_add_compute_test.cc DEPS elementwise_add_compute_arm) lite_cc_test(test_pool_compute_arm SRCS pool_compute_test.cc DEPS pool_compute_arm) lite_cc_test(test_mul_compute_arm SRCS mul_compute_test.cc DEPS mul_compute_arm) @@ -30,6 +32,7 @@ set(arm_kernels scale_compute_arm softmax_compute_arm conv_compute_arm + batch_norm_compute_arm elementwise_add_compute_arm pool_compute_arm split_compute_arm diff --git a/paddle/fluid/lite/kernels/arm/batch_norm_compute.cc b/paddle/fluid/lite/kernels/arm/batch_norm_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..0cb43dd5e0430092cb4e3edb13226ca30de61e4d --- /dev/null +++ b/paddle/fluid/lite/kernels/arm/batch_norm_compute.cc @@ -0,0 +1,114 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/kernels/arm/batch_norm_compute.h" +#include "paddle/fluid/lite/arm/math/funcs.h" +#include "paddle/fluid/lite/core/op_registry.h" +#include "paddle/fluid/lite/core/type_system.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +void BatchNormCompute::PrepareForRun() { + auto& param = this->Param(); + auto x_dims = param.x->dims(); + bool global_stats = param.is_test || param.use_global_stats; + if (global_stats) { + int64_t channel_size = 0; + switch (param.data_layout) { + case DATALAYOUT(kNCHW): + channel_size = x_dims[1]; + break; + // case DATALAYOUT(kNHWC): + // channel_size = x_dims[x_dims.size() - 1]; + // break; + default: + LOG(FATAL) << "Unknown storage order: " + << DataLayoutToStr(param.data_layout); + break; + } + new_scale.Resize({channel_size}); + new_bias.Resize({channel_size}); + auto* scale_data = param.scale->mutable_data(); + auto* bias_data = param.bias->mutable_data(); + auto* mean_data = param.mean->mutable_data(); + auto* variance_data = param.variance->mutable_data(); + auto* new_scale_data = new_scale.mutable_data(); + auto* new_bias_data = new_bias.mutable_data(); + for (int c = 0; c < channel_size; c++) { + float inv_scale = 1.f / (std::sqrt(variance_data[c] + param.epsilon)); + new_bias_data[c] = + bias_data[c] - inv_scale * scale_data[c] * mean_data[c]; + new_scale_data[c] = inv_scale * scale_data[c]; + } + } +} + +void BatchNormCompute::Run() { + auto& param = this->Param(); + auto x_dims = param.x->dims(); + auto x_data = param.x->mutable_data(); + auto y_data = param.y->mutable_data(); + bool global_stats = param.is_test || param.use_global_stats; + if (global_stats) { + auto* new_scale_data = new_scale.mutable_data(); + auto* new_bias_data = new_bias.mutable_data(); + int64_t outer_size = 0; + int64_t channel_size = 0; + int64_t inner_size = 0; + switch (param.data_layout) { + case DATALAYOUT(kNCHW): + outer_size = x_dims[0]; + channel_size = x_dims[1]; + inner_size = x_dims.Slice(2, x_dims.size()).production(); + lite::arm::math::scale(x_data, y_data, outer_size, channel_size, + inner_size, new_scale_data, new_bias_data); + break; + // case DATALAYOUT(kNHWC): + // outer_size = x_dims.Slice(0, x_dims.size() - 1).production(); + // channel_size = x_dims[x_dims.size() - 1]; + // lite::arm::math::scale(x_data, y_data, outer_size, channel_size, + // new_scale_data, new_bias_data); + // break; + default: + LOG(FATAL) << "Unknown storage order: " + << DataLayoutToStr(param.data_layout); + break; + } + } else { + // TODO(hong19860320) calculate mean_out, variance_out, saved_mean and + // saved_variance + } +} + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(batch_norm, kARM, kFloat, kNCHW, + paddle::lite::kernels::arm::BatchNormCompute, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("Scale", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("Mean", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("Variance", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("Y", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("MeanOut", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("VarianceOut", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("SavedMean", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("SavedVariance", {LiteType::GetTensorTy(TARGET(kARM))}) + .Finalize(); diff --git a/paddle/fluid/lite/kernels/arm/batch_norm_compute.h b/paddle/fluid/lite/kernels/arm/batch_norm_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..cf3ad3accded0db9a95d0f0794c863b4f7b1cd8e --- /dev/null +++ b/paddle/fluid/lite/kernels/arm/batch_norm_compute.h @@ -0,0 +1,42 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/fluid/lite/core/kernel.h" +#include "paddle/fluid/lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +class BatchNormCompute : public KernelLite { + public: + using param_t = operators::BatchNormParam; + + void PrepareForRun() override; + + void Run() override; + + virtual ~BatchNormCompute() = default; + + private: + Tensor new_scale; + Tensor new_bias; +}; + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/kernels/arm/batch_norm_compute_test.cc b/paddle/fluid/lite/kernels/arm/batch_norm_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..3ca1a0b599b3448fe2dbed08fb37ccc9dae3450c --- /dev/null +++ b/paddle/fluid/lite/kernels/arm/batch_norm_compute_test.cc @@ -0,0 +1,221 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/kernels/arm/batch_norm_compute.h" +#include +#include +#include +#include +#include "paddle/fluid/lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +template +void batch_norm_compute_ref(const operators::BatchNormParam& param) { + DDim x_dims = param.x->dims(); + auto x_data = param.x->mutable_data(); + auto scale_data = param.scale->mutable_data(); + auto bias_data = param.bias->mutable_data(); + auto mean_data = param.mean->mutable_data(); + auto variance_data = param.variance->mutable_data(); + auto y_data = param.y->mutable_data(); + float epsilon = param.epsilon; + float momentum = param.momentum; + DataLayoutType data_layout = param.data_layout; + + bool global_stats = param.is_test || param.use_global_stats; + if (global_stats) { + int64_t outer_size = 0; + int64_t channel_size = 0; + int64_t inner_size = 0; + switch (data_layout) { + case DATALAYOUT(kNCHW): + outer_size = x_dims[0]; + channel_size = x_dims[1]; + inner_size = x_dims.Slice(2, x_dims.size()).production(); + break; + // case DATALAYOUT(kNHWC): + // outer_size = x_dims.Slice(0, x_dims.size() - 1).production(); + // channel_size = x_dims[x_dims.size() - 1]; + // inner_size = 1; + // break; + default: + LOG(FATAL) << "Unknown storage order: " << DataLayoutToStr(data_layout); + break; + } + auto x_ptr = x_data; + auto y_ptr = y_data; + for (int o = 0; o < outer_size; o++) { + for (int c = 0; c < channel_size; c++) { + for (int i = 0; i < inner_size; i++) { + dtype norm_x = + (*x_ptr - mean_data[c]) / std::sqrt(variance_data[c] + epsilon); + *y_ptr = norm_x * scale_data[c] + bias_data[c]; + x_ptr++; + y_ptr++; + } + } + } + } else { + // TODO(hong19860320) calculate mean_out, variance_out, saved_mean and + // saved_variance + } +} + +TEST(batch_norm_arm, retrive_op) { + auto batch_norm = + KernelRegistry::Global().Create( + "batch_norm"); + ASSERT_FALSE(batch_norm.empty()); + ASSERT_TRUE(batch_norm.front()); +} + +TEST(batch_norm_arm, init) { + BatchNormCompute batch_norm; + ASSERT_EQ(batch_norm.precision(), PRECISION(kFloat)); + ASSERT_EQ(batch_norm.target(), TARGET(kARM)); +} + +TEST(batch_norm_arm, compute) { + DeviceInfo::Init(); + for (auto n : {1, 2}) { + for (auto c : {6, 32 /*, 128*/}) { + for (auto h : {9, 18 /*, 56 , 112, 224, 512*/}) { + for (auto w : {9, 18 /*, 56, 112, 224, 512*/}) { + for (auto is_test : {/*false, */ true}) { + for (auto use_global_stats : {false, true}) { + for (auto epsilon : {1e-4f, 1e-5f}) { + for (auto momentum : {0.9f, 0.99f}) { + for (auto data_layout : + {DATALAYOUT(kNCHW) /*, DATALAYOUT(kNHWC)*/}) { + Tensor x; + Tensor scale; + Tensor bias; + Tensor mean; + Tensor variance; + Tensor y; + Tensor mean_out; + Tensor variance_out; + Tensor saved_mean; + Tensor saved_variance; + Tensor y_ref; + Tensor mean_out_ref; + Tensor variance_out_ref; + Tensor saved_mean_ref; + Tensor saved_variance_ref; + // set the dims of input, output, ref output tensors + std::vector in_out_shape; + switch (data_layout) { + case DATALAYOUT(kNCHW): + in_out_shape = {n, c, h, w}; + break; + // case DATALAYOUT(kNHWC): + // in_out_shape = {n, h, w, c}; + // break; + default: + LOG(FATAL) << "Unknown storage order: " + << DataLayoutToStr(data_layout); + break; + } + x.Resize(in_out_shape); + scale.Resize({c}); + bias.Resize({c}); + mean.Resize({c}); + variance.Resize({c}); + y.Resize(in_out_shape); + mean_out.Resize({c}); + variance_out.Resize({c}); + saved_mean.Resize({c}); + saved_variance.Resize({c}); + y_ref.Resize(in_out_shape); + mean_out_ref.Resize({c}); + variance_out_ref.Resize({c}); + saved_mean_ref.Resize({c}); + saved_variance_ref.Resize({c}); + // initialize the data of input tensors + auto* x_data = x.mutable_data(); + auto* scale_data = scale.mutable_data(); + auto* bias_data = bias.mutable_data(); + auto* mean_data = mean.mutable_data(); + auto* variance_data = variance.mutable_data(); + auto* y_data = y.mutable_data(); + for (int i = 0; i < x.dims().production(); i++) { + x_data[i] = static_cast(i % 64); + } + for (int i = 0; i < scale.dims().production(); i++) { + scale_data[i] = static_cast(i) * 0.01f + 0.03f; + } + for (int i = 0; i < bias.dims().production(); i++) { + bias_data[i] = static_cast(i) * 0.065f + 0.1f; + } + for (int i = 0; i < mean.dims().production(); i++) { + mean_data[i] = static_cast(i) * 0.0565f; + } + for (int i = 0; i < variance.dims().production(); i++) { + variance_data[i] = static_cast(i) * 2.08f + 1.5f; + } + // prepare kernel params and run + BatchNormCompute batch_norm; + std::unique_ptr ctx(new KernelContext); + ctx->As(); + batch_norm.SetContext(std::move(ctx)); + operators::BatchNormParam param; + param.x = &x; + param.scale = &scale; + param.bias = &bias; + param.mean = &mean; + param.variance = &variance; + param.is_test = is_test; + param.use_global_stats = use_global_stats; + param.epsilon = epsilon; + param.momentum = momentum; + param.data_layout = data_layout; + param.y = &y; + param.mean_out = &mean_out; + param.variance_out = &variance_out; + param.saved_mean = &saved_mean; + param.saved_variance = &saved_variance; + batch_norm.SetParam(param); + batch_norm.Launch(); + // invoking ref implementation and compare results + param.y = &y_ref; + param.mean_out = &mean_out_ref; + param.variance_out = &variance_out_ref; + param.saved_mean = &saved_mean_ref; + param.saved_variance = &saved_variance_ref; + batch_norm_compute_ref(param); + auto* y_ref_data = y_ref.mutable_data(); + for (int i = 0; i < y.dims().production(); i++) { + EXPECT_NEAR(y_data[i], y_ref_data[i], 1e-5); + } + } + } + } + } + } + } + } + } + } +} + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle + +USE_LITE_KERNEL(batch_norm, kARM, kFloat, kNCHW, def); diff --git a/paddle/fluid/lite/operators/CMakeLists.txt b/paddle/fluid/lite/operators/CMakeLists.txt index 9a90666420e94bdb585feeac689d9227fc6a2104..e996c5a29b56a9884e87b8e89b388d3ae03ee560 100644 --- a/paddle/fluid/lite/operators/CMakeLists.txt +++ b/paddle/fluid/lite/operators/CMakeLists.txt @@ -8,6 +8,7 @@ cc_library(mul_op_lite SRCS mul_op.cc DEPS ${op_DEPS}) cc_library(scale_op_lite SRCS scale_op.cc DEPS ${op_DEPS}) cc_library(softmax_op_lite SRCS softmax_op.cc DEPS ${op_DEPS}) cc_library(reshape_op_lite SRCS reshape_op.cc DEPS ${op_DEPS} ) +cc_library(batch_norm_op_lite SRCS batch_norm_op.cc DEPS ${op_DEPS}) cc_library(feed_op_lite SRCS feed_op.cc DEPS ${op_DEPS}) cc_library(fetch_op_lite SRCS fetch_op.cc DEPS ${op_DEPS}) cc_library(io_copy_op_lite SRCS io_copy_op.cc DEPS ${op_DEPS}) @@ -30,6 +31,7 @@ set(ops_lite scale_op_lite softmax_op_lite reshape_op_lite + batch_norm_op_lite feed_op_lite fetch_op_lite io_copy_op_lite @@ -52,4 +54,5 @@ lite_cc_test(test_pool_op_lite SRCS pool_op_test.cc lite_cc_test(test_scale_op_lite SRCS scale_op_test.cc DEPS scale_op_lite memory_lite) lite_cc_test(test_softmax_op_lite SRCS softmax_op_test.cc DEPS softmax_op_lite memory_lite) lite_cc_test(test_reshape_op_lite SRCS reshape_op_test.cc DEPS reshape_op_lite memory_lite) +lite_cc_test(test_batch_norm_op_lite SRCS batch_norm_op_test.cc DEPS batch_norm_op_lite memory_lite) lite_cc_test(test_concat_op_lite SRCS concat_op_test.cc DEPS concat_op_lite memory_lite) diff --git a/paddle/fluid/lite/operators/batch_norm_op.cc b/paddle/fluid/lite/operators/batch_norm_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..e974d0134dad93a2241c265687a190b10d5ff85d --- /dev/null +++ b/paddle/fluid/lite/operators/batch_norm_op.cc @@ -0,0 +1,110 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/operators/batch_norm_op.h" +#include "paddle/fluid/lite/core/op_registry.h" +namespace paddle { +namespace lite { +namespace operators { + +bool BatchNormOp::CheckShape() const { + CHECK_OR_FALSE(param_.x); + CHECK_OR_FALSE(param_.bias); + CHECK_OR_FALSE(param_.scale); + CHECK_OR_FALSE(param_.mean); + CHECK_OR_FALSE(param_.variance); + CHECK_OR_FALSE(param_.y); + if (!param_.is_test) { + CHECK_OR_FALSE(param_.mean_out); + CHECK_OR_FALSE(param_.variance_out); + CHECK_OR_FALSE(param_.saved_mean); + CHECK_OR_FALSE(param_.saved_variance); + } + auto x_dims = param_.x->dims(); + auto scale_dims = param_.scale->dims(); + auto bias_dims = param_.bias->dims(); + auto mean_dims = param_.mean->dims(); + auto variance_dims = param_.variance->dims(); + CHECK(x_dims.size() >= 2 && x_dims.size() <= 5) + << "Input X must have 2 to 5 dimensions."; + CHECK_EQ(scale_dims.size(), 1UL) << "Input Scale must have 1 dimensions."; + CHECK_EQ(bias_dims.size(), 1UL) << "Input Bias must have 1 dimensions."; + CHECK_EQ(mean_dims.size(), 1UL) << "Input Mean must have 1 dimensions."; + CHECK_EQ(variance_dims.size(), 1UL) + << "Input Variance must have 1 dimensions."; + return true; +} + +bool BatchNormOp::InferShape() const { + auto x_dims = param_.x->dims(); + int64_t channel_size = 0; + switch (param_.data_layout) { + case DATALAYOUT(kNCHW): + channel_size = x_dims[1]; + break; + // case DATALAYOUT(kNHWC): + // channel_size = x_dims[x_dims.size() - 1]; + // break; + default: + LOG(FATAL) << "Unknown storage order: " + << DataLayoutToStr(param_.data_layout); + break; + } + if (!param_.is_test) { + param_.mean_out->Resize({channel_size}); + param_.variance_out->Resize({channel_size}); + param_.saved_mean->Resize({channel_size}); + param_.saved_variance->Resize({channel_size}); + } + param_.y->Resize(x_dims); + return true; +} + +bool BatchNormOp::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) { + param_.x = scope->FindVar(op_desc.Input("X").front())->GetMutable(); + param_.bias = + scope->FindVar(op_desc.Input("Bias").front())->GetMutable(); + param_.scale = + scope->FindVar(op_desc.Input("Scale").front())->GetMutable(); + param_.mean = + scope->FindVar(op_desc.Input("Mean").front())->GetMutable(); + param_.variance = + scope->FindVar(op_desc.Input("Variance").front())->GetMutable(); + param_.y = scope->FindVar(op_desc.Output("Y").front())->GetMutable(); + param_.is_test = op_desc.GetAttr("is_test"); + param_.use_global_stats = op_desc.GetAttr("use_global_stats"); + if (!param_.is_test) { + param_.mean_out = + scope->FindVar(op_desc.Output("MeanOut").front())->GetMutable(); + param_.variance_out = scope->FindVar(op_desc.Output("VarianceOut").front()) + ->GetMutable(); + param_.saved_mean = scope->FindVar(op_desc.Output("SavedMean").front()) + ->GetMutable(); + param_.saved_variance = + scope->FindVar(op_desc.Output("SavedVariance").front()) + ->GetMutable(); + } + param_.epsilon = op_desc.GetAttr("epsilon"); + param_.momentum = op_desc.GetAttr("momentum"); + std::string data_layout = op_desc.GetAttr("data_layout"); + CHECK_EQ(data_layout, "NCHW") << "TODO(hong19860320): Only support NCHW."; + // param_.data_layout = StringToDataLayout(data_layout); + return true; +} + +} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(batch_norm, paddle::lite::operators::BatchNormOp); diff --git a/paddle/fluid/lite/operators/batch_norm_op.h b/paddle/fluid/lite/operators/batch_norm_op.h new file mode 100644 index 0000000000000000000000000000000000000000..30e8747319b1575b0c63e4b2919ed1363ad10bef --- /dev/null +++ b/paddle/fluid/lite/operators/batch_norm_op.h @@ -0,0 +1,46 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "paddle/fluid/lite/core/op_lite.h" +#include "paddle/fluid/lite/core/scope.h" +#include "paddle/fluid/lite/utils/all.h" + +namespace paddle { +namespace lite { +namespace operators { + +class BatchNormOp : public OpLite { + public: + BatchNormOp() {} + explicit BatchNormOp(const std::string &op_type) : OpLite(op_type) {} + + bool CheckShape() const override; + + bool InferShape() const override; + + bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + std::string DebugString() const override { return "batch_norm"; } + + private: + mutable BatchNormParam param_; +}; + +} // namespace operators +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/operators/batch_norm_op_test.cc b/paddle/fluid/lite/operators/batch_norm_op_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..f5449e7ec19f1bfdcc3f11a7fdac65dcfbc9af17 --- /dev/null +++ b/paddle/fluid/lite/operators/batch_norm_op_test.cc @@ -0,0 +1,139 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/operators/batch_norm_op.h" +#include +#include "paddle/fluid/lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators { + +TEST(batch_norm_op_lite, test) { + // prepare variables + Scope scope; + auto* x = scope.Var("x")->GetMutable(); + auto* scale = scope.Var("scale")->GetMutable(); + auto* bias = scope.Var("bias")->GetMutable(); + auto* mean = scope.Var("mean")->GetMutable(); + auto* variance = scope.Var("variance")->GetMutable(); + auto* y = scope.Var("y")->GetMutable(); + x->Resize({2, 32, 10, 20}); + auto x_dims = x->dims(); + const int64_t channel_size = x_dims[1]; // NCHW + scale->Resize({channel_size}); + bias->Resize({channel_size}); + mean->Resize({channel_size}); + variance->Resize(DDim({channel_size})); + + // prepare op desc + cpp::OpDesc desc; + desc.SetType("batch_norm"); + desc.SetInput("X", {"x"}); + desc.SetInput("Scale", {"scale"}); + desc.SetInput("Bias", {"bias"}); + desc.SetInput("Mean", {"mean"}); + desc.SetInput("Variance", {"variance"}); + desc.SetOutput("Y", {"y"}); + desc.SetAttr("is_test", true); + desc.SetAttr("use_global_stats", false); + desc.SetAttr("epsilon", 1e-5f); + desc.SetAttr("momentum", 0.9f); + desc.SetAttr("data_layout", std::string("NCHW")); + + BatchNormOp batch_norm("batch_norm"); + + batch_norm.SetValidPlaces({Place{TARGET(kHost), PRECISION(kFloat)}}); + batch_norm.Attach(desc, &scope); + batch_norm.CheckShape(); + batch_norm.InferShape(); + + // check output dims + auto y_dims = y->dims(); + CHECK_EQ(y_dims.size(), x_dims.size()); + for (size_t i = 0; i < y_dims.size(); i++) { + CHECK_EQ(y_dims[i], x_dims[i]); + } +} + +TEST(batch_norm_op_lite, test_enable_is_test) { + // prepare variables + Scope scope; + auto* x = scope.Var("x")->GetMutable(); + auto* scale = scope.Var("scale")->GetMutable(); + auto* bias = scope.Var("bias")->GetMutable(); + auto* mean = scope.Var("mean")->GetMutable(); + auto* variance = scope.Var("variance")->GetMutable(); + auto* y = scope.Var("y")->GetMutable(); + auto* mean_out = scope.Var("mean_out")->GetMutable(); + auto* variance_out = scope.Var("variance_out")->GetMutable(); + auto* saved_mean = scope.Var("saved_mean")->GetMutable(); + auto* saved_variance = scope.Var("saved_variance")->GetMutable(); + x->Resize({2, 32, 10, 20}); + auto x_dims = x->dims(); + const int64_t channel_size = x_dims[1]; // NCHW + scale->Resize({channel_size}); + bias->Resize({channel_size}); + mean->Resize({channel_size}); + variance->Resize({channel_size}); + + // prepare op desc + cpp::OpDesc desc; + desc.SetType("batch_norm"); + desc.SetInput("X", {"x"}); + desc.SetInput("Scale", {"scale"}); + desc.SetInput("Bias", {"bias"}); + desc.SetInput("Mean", {"mean"}); + desc.SetInput("Variance", {"variance"}); + desc.SetOutput("Y", {"y"}); + desc.SetOutput("MeanOut", {"mean_out"}); + desc.SetOutput("VarianceOut", {"variance_out"}); + desc.SetOutput("SavedMean", {"saved_mean"}); + desc.SetOutput("SavedVariance", {"saved_variance"}); + desc.SetAttr("is_test", false); + desc.SetAttr("use_global_stats", false); + desc.SetAttr("epsilon", 1e-5f); + desc.SetAttr("momentum", 0.9f); + desc.SetAttr("data_layout", std::string("NCHW")); + + BatchNormOp batch_norm("batch_norm"); + + batch_norm.SetValidPlaces({Place{TARGET(kHost), PRECISION(kFloat)}}); + batch_norm.Attach(desc, &scope); + batch_norm.CheckShape(); + batch_norm.InferShape(); + + // check output dims + auto y_dims = y->dims(); + CHECK_EQ(y_dims.size(), x_dims.size()); + for (size_t i = 0; i < y_dims.size(); i++) { + CHECK_EQ(y_dims[i], x_dims[i]); + } + auto mean_out_dims = mean_out->dims(); + auto variance_out_dims = variance_out->dims(); + auto saved_mean_dims = saved_mean->dims(); + auto saved_variance_dims = saved_variance->dims(); + CHECK_EQ(mean_out_dims.size(), 1UL); + CHECK_EQ(variance_out_dims.size(), 1UL); + CHECK_EQ(saved_mean_dims.size(), 1UL); + CHECK_EQ(saved_variance_dims.size(), 1UL); + CHECK_EQ(mean_out_dims[0], channel_size); + CHECK_EQ(variance_out_dims[0], channel_size); + CHECK_EQ(saved_mean_dims[0], channel_size); + CHECK_EQ(saved_variance_dims[0], channel_size); +} + +} // namespace operators +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/operators/op_params.h b/paddle/fluid/lite/operators/op_params.h index a81e590b5a34db70c0b90759b4bd18b7d8d27cad..91a6067959854f608e31a6151a4e63e26df7eb64 100644 --- a/paddle/fluid/lite/operators/op_params.h +++ b/paddle/fluid/lite/operators/op_params.h @@ -146,6 +146,25 @@ struct ConvParam { std::string data_format{"Anylayout"}; }; +// For BatchNorm op +struct BatchNormParam { + lite::Tensor* x{}; + lite::Tensor* bias{}; + lite::Tensor* scale{}; + lite::Tensor* mean{}; + lite::Tensor* variance{}; + lite::Tensor* y{}; + lite::Tensor* mean_out{}; + lite::Tensor* variance_out{}; + lite::Tensor* saved_mean{}; + lite::Tensor* saved_variance{}; + bool is_test{true}; + bool use_global_stats{false}; + float epsilon; + float momentum; + DataLayoutType data_layout{DATALAYOUT(kNCHW)}; +}; + // For Pooling op struct PoolParam { lite::Tensor* x{};