diff --git a/paddle/fluid/lite/arm/math/scale.cc b/paddle/fluid/lite/arm/math/scale.cc index 40b91e6979f6f330f96f4c086fe1856707d9b189..ce969358f689ef7713efb435ce58ba72471d282b 100644 --- a/paddle/fluid/lite/arm/math/scale.cc +++ b/paddle/fluid/lite/arm/math/scale.cc @@ -58,6 +58,111 @@ void scale(const float* din, float* dout, int num, float scale, } } +template <> +void scale(const float* din, float* dout, int outer_dim, int scale_dim, + int inner_dim, const float* scale_data, + const float* bias_data) { + int cnt = inner_dim >> 4; + int remain = inner_dim % 16; + int size = inner_dim * scale_dim; + for (int n = 0; n < outer_dim; n++) { + const float* din_ptr_n = din + n * size; + float* dout_ptr_n = dout + n * size; +#pragma omp parallel for + for (int i = 0; i < scale_dim; i++) { + const float* din_ptr = din_ptr_n + i * inner_dim; + float* dout_ptr = dout_ptr_n + i * inner_dim; + float scale = scale_data[i]; + float32x4_t vscale = vdupq_n_f32(scale); + float bias = bias_data[i]; + float32x4_t vbias = vdupq_n_f32(bias); + for (int j = 0; j < cnt; j++) { + float32x4_t din0 = vld1q_f32(din_ptr); + float32x4_t din1 = vld1q_f32(din_ptr + 4); + float32x4_t din2 = vld1q_f32(din_ptr + 8); + float32x4_t din3 = vld1q_f32(din_ptr + 12); + + float32x4_t vsum1 = vmlaq_f32(vbias, din0, vscale); + float32x4_t vsum2 = vmlaq_f32(vbias, din1, vscale); + float32x4_t vsum3 = vmlaq_f32(vbias, din2, vscale); + float32x4_t vsum4 = vmlaq_f32(vbias, din3, vscale); + + din_ptr += 16; + vst1q_f32(dout_ptr, vsum1); + vst1q_f32(dout_ptr + 4, vsum2); + vst1q_f32(dout_ptr + 8, vsum3); + vst1q_f32(dout_ptr + 12, vsum4); + + dout_ptr += 16; + } + for (int j = 0; j < remain; j++) { + *dout_ptr = *din_ptr * scale + bias; + dout_ptr++; + din_ptr++; + } + } + } +} + +template <> +void scale(const float* din, float* dout, int outer_dim, int scale_dim, + const float* scale_data, const float* bias_data) { + int cnt = scale_dim >> 4; + int remain = scale_dim % 16; + for (int n = 0; n < outer_dim; n++) { + const float* din_ptr_n = din + n * scale_dim; + float* dout_ptr_n = dout + n * scale_dim; +#pragma omp parallel for + for (int i = 0; i < cnt; i++) { + int idx = i << 4; + const float* din_ptr = din_ptr_n + idx; + const float* scale_ptr = scale_data + idx; + const float* bias_ptr = bias_data + idx; + float* dout_ptr = dout_ptr_n + idx; + + float32x4_t din0 = vld1q_f32(din_ptr); + float32x4_t vscale0 = vld1q_f32(scale_ptr); + float32x4_t vbias0 = vld1q_f32(bias_ptr); + + float32x4_t din1 = vld1q_f32(din_ptr + 4); + float32x4_t vscale1 = vld1q_f32(scale_ptr + 4); + float32x4_t vbias1 = vld1q_f32(bias_ptr + 4); + + float32x4_t din2 = vld1q_f32(din_ptr + 8); + float32x4_t vscale2 = vld1q_f32(scale_ptr + 8); + float32x4_t vbias2 = vld1q_f32(bias_ptr + 8); + + float32x4_t vsum1 = vmlaq_f32(vbias0, din0, vscale0); + float32x4_t vsum2 = vmlaq_f32(vbias1, din1, vscale1); + + float32x4_t din3 = vld1q_f32(din_ptr + 12); + float32x4_t vscale3 = vld1q_f32(scale_ptr + 12); + float32x4_t vbias3 = vld1q_f32(bias_ptr + 12); + + vst1q_f32(dout_ptr, vsum1); + vst1q_f32(dout_ptr + 4, vsum2); + + float32x4_t vsum3 = vmlaq_f32(vbias2, din2, vscale2); + float32x4_t vsum4 = vmlaq_f32(vbias3, din3, vscale3); + + vst1q_f32(dout_ptr + 8, vsum3); + vst1q_f32(dout_ptr + 12, vsum4); + } + int idx = cnt << 4; + const float* din_ptr = din_ptr_n + idx; + float* dout_ptr = dout_ptr_n + idx; + const float* scale_ptr = scale_data + idx; + const float* bias_ptr = bias_data + idx; + for (int j = 0; j < remain; j++) { + *dout_ptr = *din_ptr * (*scale_ptr) + (*bias_ptr); + dout_ptr++; + din_ptr++; + scale_ptr++; + bias_ptr++; + } + } +} + } // namespace math } // namespace arm } // namespace lite diff --git a/paddle/fluid/lite/arm/math/scale.h b/paddle/fluid/lite/arm/math/scale.h index 97a5f79fc6bfabee5e38854e2ba89ce388648aac..2274dd23d2f4f486e39b97ad5040bde47af8a042 100644 --- a/paddle/fluid/lite/arm/math/scale.h +++ b/paddle/fluid/lite/arm/math/scale.h @@ -22,6 +22,14 @@ namespace math { template void scale(const T* din, T* dout, int num, float scale, float bias); +template +void scale(const T* din, T* dout, int outer_dim, int scale_dim, int inner_dim, + const float* scale_data, const float* bias_data); + +template +void scale(const T* din, T* dout, int outer_dim, int scale_dim, + const float* scale_data, const float* bias_data); + } // namespace math } // namespace arm } // namespace lite diff --git a/paddle/fluid/lite/arm/math/split.cc b/paddle/fluid/lite/arm/math/split.cc index 6dd6de6242e806947dfc630fd8f2a4dd03c89335..bf8d50590ff89c451347e33a289391b8d929e5b6 100644 --- a/paddle/fluid/lite/arm/math/split.cc +++ b/paddle/fluid/lite/arm/math/split.cc @@ -52,10 +52,10 @@ void split_cpy(const float* din, float* dout, int num) { } template <> -void split(const float* din, std::vector* dout, +void split(const float* din, const std::vector& dout, const int axis, const std::vector& in_strides) { int input_offset = 0; - for (auto out : *dout) { + for (auto out : dout) { auto out_dim = out->dims(); std::vector out_strides(out_dim.size()); out_strides[out_dim.size() - 1] = out_dim[out_dim.size() - 1]; diff --git a/paddle/fluid/lite/arm/math/split.h b/paddle/fluid/lite/arm/math/split.h index 9b5651d81ffa75362fcc39db82157c56548917c0..643214e174c3ede02f430ee4ded7cee097ba0afc 100644 --- a/paddle/fluid/lite/arm/math/split.h +++ b/paddle/fluid/lite/arm/math/split.h @@ -26,7 +26,7 @@ template void split_cpy(const T* din, T* dout, int num); template -void split(const T* din, std::vector* dout, const int axis, +void split(const T* din, const std::vector& dout, const int axis, const std::vector& in_strides); } // namespace math diff --git a/paddle/fluid/lite/core/cpu_info.cc b/paddle/fluid/lite/core/cpu_info.cc index df80f1c857688fd6fb76350e720effef0f3c15f6..ab1968295813006d5d11fc4fbf416b4f9c3a3215 100644 --- a/paddle/fluid/lite/core/cpu_info.cc +++ b/paddle/fluid/lite/core/cpu_info.cc @@ -54,15 +54,15 @@ void DeviceInfo::InitInternal(DeviceInfo* dev) { << ", cluster ID: " << dev->cluster_ids_[dev->core_ids_[i]] << ", CPU ARCH: A" << dev->archs_[i]; } - LOG(INFO) << "L1 DataCache size is: "; + VLOG(1) << "L1 DataCache size is: "; for (int i = 0; i < dev->compute_core_num_; ++i) { - LOG(INFO) << dev->L1_cache_[i] / 1024 << " KB"; + VLOG(1) << dev->L1_cache_[i] / 1024 << " KB"; } - LOG(INFO) << "L2 Cache size is: "; + VLOG(1) << "L2 Cache size is: "; for (int i = 0; i < dev->compute_core_num_; ++i) { - LOG(INFO) << dev->L2_cache_[i] / 1024 << " KB"; + VLOG(1) << dev->L2_cache_[i] / 1024 << " KB"; } - LOG(INFO) << "Total memory: " << dev->max_memory_ << "KB"; + VLOG(1) << "Total memory: " << dev->max_memory_ << "KB"; dev->max_freq_ = max_freq[0]; for (int j = 1; j < dev->compute_core_num_; ++j) { diff --git a/paddle/fluid/lite/kernels/arm/CMakeLists.txt b/paddle/fluid/lite/kernels/arm/CMakeLists.txt index 1cf66b0d266b3edf0b0d271ceb5e375f01f652c3..565c4a8a81de7990f58418f60d4c0e234fba5554 100644 --- a/paddle/fluid/lite/kernels/arm/CMakeLists.txt +++ b/paddle/fluid/lite/kernels/arm/CMakeLists.txt @@ -6,10 +6,11 @@ message(STATUS "compile with lite ARM kernels") cc_library(fc_compute_arm SRCS fc_compute.cc DEPS ${lite_kernel_deps} math_arm) cc_library(relu_compute_arm SRCS relu_compute.cc DEPS ${lite_kernel_deps}) -cc_library(mul_compute_arm SRCS mul_compute.cc DEPS ${lite_kernel_deps} eigen3) +cc_library(mul_compute_arm SRCS mul_compute.cc DEPS ${lite_kernel_deps} math_arm) cc_library(scale_compute_arm SRCS scale_compute.cc DEPS ${lite_kernel_deps} math_arm) cc_library(softmax_compute_arm SRCS softmax_compute.cc DEPS ${lite_kernel_deps} math_arm) cc_library(conv_compute_arm SRCS conv_compute.cc DEPS ${lite_kernel_deps} math_arm) +cc_library(batch_norm_compute_arm SRCS batch_norm_compute.cc DEPS ${lite_kernel_deps} math_arm) cc_library(elementwise_add_compute_arm SRCS elementwise_add_compute.cc DEPS ${lite_kernel_deps} math_arm) cc_library(pool_compute_arm SRCS pool_compute.cc DEPS ${lite_kernel_deps} math_arm) cc_library(split_compute_arm SRCS split_compute.cc DEPS ${lite_kernel_deps} math_arm) @@ -18,8 +19,10 @@ lite_cc_test(test_fc_compute_arm SRCS fc_compute_test.cc DEPS fc_compute_arm mat lite_cc_test(test_scale_compute_arm SRCS scale_compute_test.cc DEPS scale_compute_arm) lite_cc_test(test_softmax_compute_arm SRCS softmax_compute_test.cc DEPS softmax_compute_arm) lite_cc_test(test_conv_compute_arm SRCS conv_compute_test.cc DEPS conv_compute_arm) +lite_cc_test(test_batch_norm_compute_arm SRCS batch_norm_compute_test.cc DEPS batch_norm_compute_arm) lite_cc_test(test_elementwise_add_compute_arm SRCS elementwise_add_compute_test.cc DEPS elementwise_add_compute_arm) lite_cc_test(test_pool_compute_arm SRCS pool_compute_test.cc DEPS pool_compute_arm) +lite_cc_test(test_mul_compute_arm SRCS mul_compute_test.cc DEPS mul_compute_arm) lite_cc_test(test_split_compute_arm SRCS split_compute_test.cc DEPS split_compute_arm) set(arm_kernels @@ -29,6 +32,7 @@ set(arm_kernels scale_compute_arm softmax_compute_arm conv_compute_arm + batch_norm_compute_arm elementwise_add_compute_arm pool_compute_arm split_compute_arm diff --git a/paddle/fluid/lite/kernels/arm/batch_norm_compute.cc b/paddle/fluid/lite/kernels/arm/batch_norm_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..0cb43dd5e0430092cb4e3edb13226ca30de61e4d --- /dev/null +++ b/paddle/fluid/lite/kernels/arm/batch_norm_compute.cc @@ -0,0 +1,114 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/kernels/arm/batch_norm_compute.h" +#include "paddle/fluid/lite/arm/math/funcs.h" +#include "paddle/fluid/lite/core/op_registry.h" +#include "paddle/fluid/lite/core/type_system.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +void BatchNormCompute::PrepareForRun() { + auto& param = this->Param(); + auto x_dims = param.x->dims(); + bool global_stats = param.is_test || param.use_global_stats; + if (global_stats) { + int64_t channel_size = 0; + switch (param.data_layout) { + case DATALAYOUT(kNCHW): + channel_size = x_dims[1]; + break; + // case DATALAYOUT(kNHWC): + // channel_size = x_dims[x_dims.size() - 1]; + // break; + default: + LOG(FATAL) << "Unknown storage order: " + << DataLayoutToStr(param.data_layout); + break; + } + new_scale.Resize({channel_size}); + new_bias.Resize({channel_size}); + auto* scale_data = param.scale->mutable_data(); + auto* bias_data = param.bias->mutable_data(); + auto* mean_data = param.mean->mutable_data(); + auto* variance_data = param.variance->mutable_data(); + auto* new_scale_data = new_scale.mutable_data(); + auto* new_bias_data = new_bias.mutable_data(); + for (int c = 0; c < channel_size; c++) { + float inv_scale = 1.f / (std::sqrt(variance_data[c] + param.epsilon)); + new_bias_data[c] = + bias_data[c] - inv_scale * scale_data[c] * mean_data[c]; + new_scale_data[c] = inv_scale * scale_data[c]; + } + } +} + +void BatchNormCompute::Run() { + auto& param = this->Param(); + auto x_dims = param.x->dims(); + auto x_data = param.x->mutable_data(); + auto y_data = param.y->mutable_data(); + bool global_stats = param.is_test || param.use_global_stats; + if (global_stats) { + auto* new_scale_data = new_scale.mutable_data(); + auto* new_bias_data = new_bias.mutable_data(); + int64_t outer_size = 0; + int64_t channel_size = 0; + int64_t inner_size = 0; + switch (param.data_layout) { + case DATALAYOUT(kNCHW): + outer_size = x_dims[0]; + channel_size = x_dims[1]; + inner_size = x_dims.Slice(2, x_dims.size()).production(); + lite::arm::math::scale(x_data, y_data, outer_size, channel_size, + inner_size, new_scale_data, new_bias_data); + break; + // case DATALAYOUT(kNHWC): + // outer_size = x_dims.Slice(0, x_dims.size() - 1).production(); + // channel_size = x_dims[x_dims.size() - 1]; + // lite::arm::math::scale(x_data, y_data, outer_size, channel_size, + // new_scale_data, new_bias_data); + // break; + default: + LOG(FATAL) << "Unknown storage order: " + << DataLayoutToStr(param.data_layout); + break; + } + } else { + // TODO(hong19860320) calculate mean_out, variance_out, saved_mean and + // saved_variance + } +} + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(batch_norm, kARM, kFloat, kNCHW, + paddle::lite::kernels::arm::BatchNormCompute, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("Scale", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("Mean", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("Variance", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("Y", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("MeanOut", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("VarianceOut", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("SavedMean", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("SavedVariance", {LiteType::GetTensorTy(TARGET(kARM))}) + .Finalize(); diff --git a/paddle/fluid/lite/kernels/arm/batch_norm_compute.h b/paddle/fluid/lite/kernels/arm/batch_norm_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..cf3ad3accded0db9a95d0f0794c863b4f7b1cd8e --- /dev/null +++ b/paddle/fluid/lite/kernels/arm/batch_norm_compute.h @@ -0,0 +1,42 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/fluid/lite/core/kernel.h" +#include "paddle/fluid/lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +class BatchNormCompute : public KernelLite { + public: + using param_t = operators::BatchNormParam; + + void PrepareForRun() override; + + void Run() override; + + virtual ~BatchNormCompute() = default; + + private: + Tensor new_scale; + Tensor new_bias; +}; + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/kernels/arm/batch_norm_compute_test.cc b/paddle/fluid/lite/kernels/arm/batch_norm_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..3ca1a0b599b3448fe2dbed08fb37ccc9dae3450c --- /dev/null +++ b/paddle/fluid/lite/kernels/arm/batch_norm_compute_test.cc @@ -0,0 +1,221 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/kernels/arm/batch_norm_compute.h" +#include +#include +#include +#include +#include "paddle/fluid/lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +template +void batch_norm_compute_ref(const operators::BatchNormParam& param) { + DDim x_dims = param.x->dims(); + auto x_data = param.x->mutable_data(); + auto scale_data = param.scale->mutable_data(); + auto bias_data = param.bias->mutable_data(); + auto mean_data = param.mean->mutable_data(); + auto variance_data = param.variance->mutable_data(); + auto y_data = param.y->mutable_data(); + float epsilon = param.epsilon; + float momentum = param.momentum; + DataLayoutType data_layout = param.data_layout; + + bool global_stats = param.is_test || param.use_global_stats; + if (global_stats) { + int64_t outer_size = 0; + int64_t channel_size = 0; + int64_t inner_size = 0; + switch (data_layout) { + case DATALAYOUT(kNCHW): + outer_size = x_dims[0]; + channel_size = x_dims[1]; + inner_size = x_dims.Slice(2, x_dims.size()).production(); + break; + // case DATALAYOUT(kNHWC): + // outer_size = x_dims.Slice(0, x_dims.size() - 1).production(); + // channel_size = x_dims[x_dims.size() - 1]; + // inner_size = 1; + // break; + default: + LOG(FATAL) << "Unknown storage order: " << DataLayoutToStr(data_layout); + break; + } + auto x_ptr = x_data; + auto y_ptr = y_data; + for (int o = 0; o < outer_size; o++) { + for (int c = 0; c < channel_size; c++) { + for (int i = 0; i < inner_size; i++) { + dtype norm_x = + (*x_ptr - mean_data[c]) / std::sqrt(variance_data[c] + epsilon); + *y_ptr = norm_x * scale_data[c] + bias_data[c]; + x_ptr++; + y_ptr++; + } + } + } + } else { + // TODO(hong19860320) calculate mean_out, variance_out, saved_mean and + // saved_variance + } +} + +TEST(batch_norm_arm, retrive_op) { + auto batch_norm = + KernelRegistry::Global().Create( + "batch_norm"); + ASSERT_FALSE(batch_norm.empty()); + ASSERT_TRUE(batch_norm.front()); +} + +TEST(batch_norm_arm, init) { + BatchNormCompute batch_norm; + ASSERT_EQ(batch_norm.precision(), PRECISION(kFloat)); + ASSERT_EQ(batch_norm.target(), TARGET(kARM)); +} + +TEST(batch_norm_arm, compute) { + DeviceInfo::Init(); + for (auto n : {1, 2}) { + for (auto c : {6, 32 /*, 128*/}) { + for (auto h : {9, 18 /*, 56 , 112, 224, 512*/}) { + for (auto w : {9, 18 /*, 56, 112, 224, 512*/}) { + for (auto is_test : {/*false, */ true}) { + for (auto use_global_stats : {false, true}) { + for (auto epsilon : {1e-4f, 1e-5f}) { + for (auto momentum : {0.9f, 0.99f}) { + for (auto data_layout : + {DATALAYOUT(kNCHW) /*, DATALAYOUT(kNHWC)*/}) { + Tensor x; + Tensor scale; + Tensor bias; + Tensor mean; + Tensor variance; + Tensor y; + Tensor mean_out; + Tensor variance_out; + Tensor saved_mean; + Tensor saved_variance; + Tensor y_ref; + Tensor mean_out_ref; + Tensor variance_out_ref; + Tensor saved_mean_ref; + Tensor saved_variance_ref; + // set the dims of input, output, ref output tensors + std::vector in_out_shape; + switch (data_layout) { + case DATALAYOUT(kNCHW): + in_out_shape = {n, c, h, w}; + break; + // case DATALAYOUT(kNHWC): + // in_out_shape = {n, h, w, c}; + // break; + default: + LOG(FATAL) << "Unknown storage order: " + << DataLayoutToStr(data_layout); + break; + } + x.Resize(in_out_shape); + scale.Resize({c}); + bias.Resize({c}); + mean.Resize({c}); + variance.Resize({c}); + y.Resize(in_out_shape); + mean_out.Resize({c}); + variance_out.Resize({c}); + saved_mean.Resize({c}); + saved_variance.Resize({c}); + y_ref.Resize(in_out_shape); + mean_out_ref.Resize({c}); + variance_out_ref.Resize({c}); + saved_mean_ref.Resize({c}); + saved_variance_ref.Resize({c}); + // initialize the data of input tensors + auto* x_data = x.mutable_data(); + auto* scale_data = scale.mutable_data(); + auto* bias_data = bias.mutable_data(); + auto* mean_data = mean.mutable_data(); + auto* variance_data = variance.mutable_data(); + auto* y_data = y.mutable_data(); + for (int i = 0; i < x.dims().production(); i++) { + x_data[i] = static_cast(i % 64); + } + for (int i = 0; i < scale.dims().production(); i++) { + scale_data[i] = static_cast(i) * 0.01f + 0.03f; + } + for (int i = 0; i < bias.dims().production(); i++) { + bias_data[i] = static_cast(i) * 0.065f + 0.1f; + } + for (int i = 0; i < mean.dims().production(); i++) { + mean_data[i] = static_cast(i) * 0.0565f; + } + for (int i = 0; i < variance.dims().production(); i++) { + variance_data[i] = static_cast(i) * 2.08f + 1.5f; + } + // prepare kernel params and run + BatchNormCompute batch_norm; + std::unique_ptr ctx(new KernelContext); + ctx->As(); + batch_norm.SetContext(std::move(ctx)); + operators::BatchNormParam param; + param.x = &x; + param.scale = &scale; + param.bias = &bias; + param.mean = &mean; + param.variance = &variance; + param.is_test = is_test; + param.use_global_stats = use_global_stats; + param.epsilon = epsilon; + param.momentum = momentum; + param.data_layout = data_layout; + param.y = &y; + param.mean_out = &mean_out; + param.variance_out = &variance_out; + param.saved_mean = &saved_mean; + param.saved_variance = &saved_variance; + batch_norm.SetParam(param); + batch_norm.Launch(); + // invoking ref implementation and compare results + param.y = &y_ref; + param.mean_out = &mean_out_ref; + param.variance_out = &variance_out_ref; + param.saved_mean = &saved_mean_ref; + param.saved_variance = &saved_variance_ref; + batch_norm_compute_ref(param); + auto* y_ref_data = y_ref.mutable_data(); + for (int i = 0; i < y.dims().production(); i++) { + EXPECT_NEAR(y_data[i], y_ref_data[i], 1e-5); + } + } + } + } + } + } + } + } + } + } +} + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle + +USE_LITE_KERNEL(batch_norm, kARM, kFloat, kNCHW, def); diff --git a/paddle/fluid/lite/kernels/arm/conv_compute_test.cc b/paddle/fluid/lite/kernels/arm/conv_compute_test.cc index 7950788d3009f2398b2f1e95c3cbef3f6398a084..e4d80265d7728fa0eeea97fd070a982a8888ec7e 100644 --- a/paddle/fluid/lite/kernels/arm/conv_compute_test.cc +++ b/paddle/fluid/lite/kernels/arm/conv_compute_test.cc @@ -124,7 +124,20 @@ TEST(conv_arm, init) { TEST(conv_arm, compute) { DeviceInfo::Init(); -#if 0 +#if 1 + for (auto n : {2}) { + for (auto ic : {6}) { + for (auto oc : {6}) { + for (auto ih : {9}) { + for (auto iw : {9}) { + for (auto flag_bias : {false, true}) { + for (auto flag_relu : {false, true}) { + for (auto depthwise : {false, true}) { + for (auto dilation : {1}) { + for (auto stride : {1, 2}) { + for (auto padding : {0, 1, 2}) { + for (auto ks : {1, 3, 5}) { +#else for (auto n : {1, 2}) { for (auto ic : {6, 32 /*, 128*/}) { for (auto oc : {6, 32 /*, 128*/}) { @@ -137,19 +150,6 @@ TEST(conv_arm, compute) { for (auto stride : {1, 2}) { for (auto padding : {0, 1, 2}) { for (auto ks : {1, 3, 5}) { -#else - for (auto n : {1}) { - for (auto ic : {6}) { - for (auto oc : {6}) { - for (auto ih : {9}) { - for (auto iw : {9}) { - for (auto flag_bias : {false, true}) { - for (auto flag_relu : {false, true}) { - for (auto depthwise : {false, true}) { - for (auto dilation : {1}) { - for (auto stride : {1}) { - for (auto padding : {0, 1}) { - for (auto ks : {1, 3, 5}) { #endif int group = 1; if (depthwise) { // depthwise convolution ? diff --git a/paddle/fluid/lite/kernels/arm/fc_compute.cc b/paddle/fluid/lite/kernels/arm/fc_compute.cc index b26551e0533a5ae68c930cc1b9512ba0ca13253a..efd98008e7324eb1f884d1b1cad20b3ed1b0419e 100644 --- a/paddle/fluid/lite/kernels/arm/fc_compute.cc +++ b/paddle/fluid/lite/kernels/arm/fc_compute.cc @@ -22,6 +22,10 @@ namespace lite { namespace kernels { namespace arm { +void FcCompute::PrepareForRun() { + // TODO(TJ): transpose weight +} + void FcCompute::Run() { auto& param = this->Param(); auto x_dims = param.input->dims(); @@ -48,22 +52,16 @@ void FcCompute::Run() { &ctx); lite::arm::math::sgemm_prepack(packed_in, w_data, b_data, o_data, x_h, n, x_w, false, false, false, &ctx); - if (param.bias) { CHECK_EQ(param.bias->numel(), n); lite::arm::math::fill_bias_fc(o_data, b_data, x_h, n); } } else { - // use sgemmv - // sgemv((const float*)weights, (const float*)din, (float*)dout, - // false, n, x_w, _param->_flag_bias, (float*)bias, false); + lite::arm::math::sgemv(w_data, i_data, o_data, false, n, x_w, + b_data != nullptr, b_data, false); } } -TargetType FcCompute::target() const { return TARGET(kARM); } - -PrecisionType FcCompute::precision() const { return PRECISION(kFloat); } - } // namespace arm } // namespace kernels } // namespace lite diff --git a/paddle/fluid/lite/kernels/arm/fc_compute.h b/paddle/fluid/lite/kernels/arm/fc_compute.h index 414517843354f638ed37f54ef596dc6db53193ce..459d23194d8c50f593ebc92da2d5342fb449d110 100644 --- a/paddle/fluid/lite/kernels/arm/fc_compute.h +++ b/paddle/fluid/lite/kernels/arm/fc_compute.h @@ -25,10 +25,9 @@ class FcCompute : public KernelLite { public: using param_t = operators::FcParam; - void Run() override; + void PrepareForRun() override; - TargetType target() const override; - PrecisionType precision() const override; + void Run() override; virtual ~FcCompute() = default; }; diff --git a/paddle/fluid/lite/kernels/arm/mul_compute.cc b/paddle/fluid/lite/kernels/arm/mul_compute.cc index ff12b236031896cfd8503903327ab1141b5171ae..269e4842252c2a88f33c8faf6666d139e36e49f3 100644 --- a/paddle/fluid/lite/kernels/arm/mul_compute.cc +++ b/paddle/fluid/lite/kernels/arm/mul_compute.cc @@ -12,57 +12,57 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include -#include "paddle/fluid/lite/core/kernel.h" +#include "paddle/fluid/lite/kernels/arm/mul_compute.h" +#include "paddle/fluid/lite/arm/math/funcs.h" #include "paddle/fluid/lite/core/op_registry.h" -#include "paddle/fluid/lite/core/types.h" +#include "paddle/fluid/lite/core/type_system.h" namespace paddle { namespace lite { namespace kernels { namespace arm { -template -void mul_compute_eigen(const T* x, int x_h, int x_w, const T* y, int y_h, - int y_w, T* out) { - using matrix_t = - Eigen::Matrix; +void MulCompute::PrepareForRun() { + // TODO(TJ): transpose x or y if necessary +} - Eigen::Map X(x, x_h, x_w); - Eigen::Map Y(y, y_h, y_w); - Eigen::Map Out(out, x_h, y_w); +void MulCompute::Run() { + auto& param = Param(); - Out = X * Y; -} + const auto* x_data = param.x->data(); + const auto* y_data = param.y->data(); + auto* o_data = param.output->mutable_data(); -class MulCompute : public KernelLite { - public: - using param_t = operators::MulParam; + int m = static_cast( + param.x->dims().Slice(0, param.x_num_col_dims).production()); + int x_w = + static_cast(param.x->dims() + .Slice(param.x_num_col_dims, param.x->dims().size()) + .production()); + int y_h = static_cast( + param.y->dims().Slice(0, param.y_num_col_dims).production()); + int n = + static_cast(param.y->dims() + .Slice(param.y_num_col_dims, param.y->dims().size()) + .production()); - void Run() override { - auto& param = Param(); - core::dim2 x_shape( - {static_cast( - param.x->dims().Slice(0, param.x_num_col_dims).production()), - static_cast( - param.x->dims() - .Slice(param.x_num_col_dims, param.x->dims().size()) - .production())}); - core::dim2 y_shape( - {static_cast( - param.y->dims().Slice(0, param.y_num_col_dims).production()), - static_cast( - param.y->dims() - .Slice(param.y_num_col_dims, param.y->dims().size()) - .production())}); + CHECK_EQ(x_w, y_h) << "x_w must be equal with y_h"; + auto k = x_w; + if (n == 1) { + lite::arm::math::sgemv(x_data, y_data, o_data, false, m, k, false, nullptr, + false); - mul_compute_eigen(param.x->data(), x_shape.x, x_shape.y, // - param.y->data(), y_shape.x, y_shape.y, // - param.output->mutable_data()); - } + } else { + constexpr bool is_tranposed_y = false; + auto& ctx = this->ctx_->template As(); - virtual ~MulCompute() = default; -}; + float* packed_x = static_cast(ctx.workspace_data()) + + ctx.l2_cache_size() / sizeof(float); + lite::arm::math::prepackA(packed_x, x_data, k, 0, m, 0, k, false, &ctx); + lite::arm::math::sgemm_prepack(packed_x, y_data, nullptr, o_data, m, n, k, + false, false, is_tranposed_y, &ctx); + } +} } // namespace arm } // namespace kernels diff --git a/paddle/fluid/lite/kernels/arm/mul_compute.h b/paddle/fluid/lite/kernels/arm/mul_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..c18995e5a5c3cceb749465382b284c0a52c188a4 --- /dev/null +++ b/paddle/fluid/lite/kernels/arm/mul_compute.h @@ -0,0 +1,39 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/fluid/lite/core/kernel.h" +#include "paddle/fluid/lite/core/op_registry.h" +#include "paddle/fluid/lite/core/types.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +class MulCompute : public KernelLite { + public: + using param_t = operators::MulParam; + + void PrepareForRun() override; + + void Run() override; + + virtual ~MulCompute() = default; +}; + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/kernels/arm/mul_compute_test.cc b/paddle/fluid/lite/kernels/arm/mul_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..5e3d17ec93ae9d73028343b3d4dd1e77a0fe86f0 --- /dev/null +++ b/paddle/fluid/lite/kernels/arm/mul_compute_test.cc @@ -0,0 +1,152 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/kernels/arm/mul_compute.h" +#include +#include +#include +#include +#include +#include +#include +#include "paddle/fluid/lite/arm/math/funcs.h" +#include "paddle/fluid/lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +template +void FillData(T* a, const int n, const T lower = static_cast(-2.f), + const T upper = static_cast(2.f)) { + static unsigned int seed = 100; + std::mt19937 rng(seed++); + std::uniform_real_distribution uniform_dist(0, 1); + for (int i = 0; i < n; ++i) { + a[i] = static_cast(uniform_dist(rng) * (upper - lower) + lower); + } +} + +TEST(mul_arm, retrive_op) { + auto mul = + KernelRegistry::Global().Create("mul"); + ASSERT_FALSE(mul.empty()); + ASSERT_TRUE(mul.front()); +} + +TEST(mul_arm, init) { + MulCompute mul; + ASSERT_EQ(mul.precision(), PRECISION(kFloat)); + ASSERT_EQ(mul.target(), TARGET(kARM)); +} + +TEST(mul_arm, compare_test) { + using T = float; + + for (int m : {1, 2, 3, 4}) { + for (int n : {1, 2, 3, 4}) { + for (int k : {1, 2, 3, 4}) { + VLOG(3) << "m: " << m << ", n: " << n << ", k: " << k; + lite::Tensor x, y, out, ref; + x.Resize({m, k}); + y.Resize({k, n}); + out.Resize({m, n}); + ref.Resize({m, n}); + + auto* x_data = x.mutable_data(); + auto* y_data = y.mutable_data(); + auto* out_data = out.mutable_data(); + auto* ref_data = ref.mutable_data(); + + FillData(x_data, x.dims().production()); + FillData(y_data, y.dims().production()); + FillData(out_data, out.dims().production(), 0, 0); + FillData(ref_data, ref.dims().production(), 0, 0); + + MulCompute mul; + operators::MulParam param; + + param.x = &x; + param.y = &y; + param.output = &out; + + DeviceInfo::Init(); + std::unique_ptr ctx(new KernelContext); + ctx->As(); + mul.SetParam(param); + mul.SetContext(std::move(ctx)); + mul.PrepareForRun(); + + mul.Run(); + + lite::arm::math::mul_compute_eigen(x_data, m, k, y_data, k, n, + ref_data); + for (int i = 0; i < out.dims().production(); i++) { + EXPECT_NEAR(out_data[i], ref_data[i], 1e-3); + } + } + } + } +} + +TEST(mul_arm, num_col_dims) { + using T = float; + + lite::Tensor x, y, out, ref; + x.Resize({2, 3, 4}); + y.Resize({3, 4, 5}); + out.Resize({2, 5}); + ref.Resize({2, 5}); + + auto* x_data = x.mutable_data(); + auto* y_data = y.mutable_data(); + auto* out_data = out.mutable_data(); + auto* ref_data = ref.mutable_data(); + + FillData(x_data, x.dims().production()); + FillData(y_data, y.dims().production()); + FillData(out_data, out.dims().production()); + FillData(ref_data, out.dims().production()); + + MulCompute mul; + operators::MulParam param; + + param.x = &x; + param.y = &y; + param.output = &out; + param.x_num_col_dims = 1; + param.y_num_col_dims = 2; + + DeviceInfo::Init(); + std::unique_ptr ctx(new KernelContext); + ctx->As(); + mul.SetParam(param); + mul.SetContext(std::move(ctx)); + mul.PrepareForRun(); + + mul.Run(); + + lite::arm::math::mul_compute_eigen(x_data, 2, 12, y_data, 12, 5, ref_data); + for (int i = 0; i < out.dims().production(); i++) { + EXPECT_NEAR(out_data[i], ref_data[i], 1e-3); + } +} + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle + +USE_LITE_KERNEL(mul, kARM, kFloat, kNCHW, def); diff --git a/paddle/fluid/lite/kernels/arm/pool_compute_test.cc b/paddle/fluid/lite/kernels/arm/pool_compute_test.cc index 35873a9d2cc3fa922f48cc87e8e2c4191ac8ee60..b024ccef9d526d56bcf52c1600940ff0804eaf1f 100644 --- a/paddle/fluid/lite/kernels/arm/pool_compute_test.cc +++ b/paddle/fluid/lite/kernels/arm/pool_compute_test.cc @@ -182,7 +182,7 @@ TEST(pool_arm, compute) { for (auto stride : {2}) { for (auto pad : {0}) { for (auto n : {1, 3, 4, 11}) { - for (auto c : {1, 3, 11, 4, 1024}) { + for (auto c : {1, 3, 11 /* ,1024 */}) { // speedup for ci for (auto h : {3, 1, 11, 4, 1}) { for (auto w : {1, 3, 4, 12, 1}) { VLOG(3) << "n:" << n << " c:" << c << " h:" << h << " w:" << w diff --git a/paddle/fluid/lite/kernels/arm/scale_compute_test.cc b/paddle/fluid/lite/kernels/arm/scale_compute_test.cc index fee47d7eb7a6c093524bb0af617c60d069add01a..b1277792286429b666b3479c0655bb211a69db30 100644 --- a/paddle/fluid/lite/kernels/arm/scale_compute_test.cc +++ b/paddle/fluid/lite/kernels/arm/scale_compute_test.cc @@ -54,6 +54,15 @@ TEST(scale_arm, compute) { lite::Tensor output; lite::Tensor output_ref; +#if 1 // for ci speedup + for (auto n : {1, 3}) { + for (auto c : {1, 3}) { + for (auto h : {3, 4}) { + for (auto w : {4, 3}) { + for (auto bias_after_scale : {true, false}) { + for (auto s : {-1.0f, 0.13f}) { + for (auto b : {-15.f, 0.11234f}) { +#else for (auto n : {1, 3, 4, 11}) { for (auto c : {1, 3, 11, 4}) { for (auto h : {3, 1, 11, 4}) { @@ -61,6 +70,8 @@ TEST(scale_arm, compute) { for (auto bias_after_scale : {true, false}) { for (auto s : {-100.25f, -1.0f, 0.13f, 3840.975f}) { for (auto b : {-3075.495f, -15.f, 0.11234f, 128.15f}) { +#endif + x.Resize(DDim(std::vector({n, c, h, w}))); output.Resize(DDim(std::vector({n, c, h, w}))); output_ref.Resize(DDim(std::vector({n, c, h, w}))); diff --git a/paddle/fluid/lite/kernels/arm/split_compute.cc b/paddle/fluid/lite/kernels/arm/split_compute.cc index 9da69894592e146c9191eb9da38d8d481cf287a7..3c2416bd6907199e6e83baf65c428b675462f271 100644 --- a/paddle/fluid/lite/kernels/arm/split_compute.cc +++ b/paddle/fluid/lite/kernels/arm/split_compute.cc @@ -24,7 +24,7 @@ namespace arm { void SplitCompute::Run() { auto& param = Param(); const float* din = param.x->data(); - auto* dout = param.output; + auto& dout = param.output; auto in_dim = param.x->dims(); std::vector in_strides(in_dim.size()); in_strides[in_dim.size() - 1] = in_dim[in_dim.size() - 1]; diff --git a/paddle/fluid/lite/kernels/arm/split_compute_test.cc b/paddle/fluid/lite/kernels/arm/split_compute_test.cc index 808a1e2cdb7724042ffcd1324cf0dc2c5e28f2fc..39632bee8decfe875f0adb3c2717d58e593c400b 100644 --- a/paddle/fluid/lite/kernels/arm/split_compute_test.cc +++ b/paddle/fluid/lite/kernels/arm/split_compute_test.cc @@ -24,20 +24,10 @@ namespace kernels { namespace arm { void splite_resize_out(const lite::Tensor* din, - std::vector* dout, int axis, int num, - const std::vector& sections) { - for (auto out : *dout) delete out; - dout->clear(); + const std::vector& dout, int axis, + int num, const std::vector& sections) { auto in_dims = din->dims(); - int outs_number; - if (num > 0) { - outs_number = num; - } else { - outs_number = sections.size(); - } - for (int i = 0; i < outs_number; i++) { - dout->push_back(new lite::Tensor); - } + int outs_number = dout.size(); std::vector outs_dims; outs_dims.reserve(outs_number); @@ -58,7 +48,7 @@ void splite_resize_out(const lite::Tensor* din, } for (int j = 0; j < outs_dims.size(); ++j) { - (*dout)[j]->Resize(outs_dims[j]); + dout[j]->Resize(outs_dims[j]); } } @@ -75,7 +65,7 @@ void split_compute_ref(const operators::SplitParam& param) { } int input_offset = 0; - for (auto out : *dout) { + for (auto out : dout) { auto out_dim = out->dims(); std::vector out_strides(out_dim.size()); out_strides[out_dim.size() - 1] = out_dim[out_dim.size() - 1]; @@ -128,16 +118,31 @@ TEST(split_arm, compute) { for (int i = 0; i < x.dims().production(); i++) { x_data[i] = i; } - splite_resize_out(&x, &output, axis, num, sections); - splite_resize_out(&x, &output_ref, axis, num, sections); + for (auto out : output) delete out; + for (auto out : output_ref) delete out; + output.clear(); + output_ref.clear(); + + int outs_number; + if (num > 0) { + outs_number = num; + } else { + outs_number = sections.size(); + } + for (int i = 0; i < outs_number; i++) { + output.push_back(new lite::Tensor); + output_ref.push_back(new lite::Tensor); + } + splite_resize_out(&x, output, axis, num, sections); + splite_resize_out(&x, output_ref, axis, num, sections); param.x = &x; param.axis = axis; param.num = num; - param.sections = §ions; - param.output = &output; + param.sections = sections; + param.output = output; split.SetParam(param); split.Run(); - param.output = &output_ref; + param.output = output_ref; split_compute_ref(param); for (int i = 0; i < output.size(); i++) { float* output_data = output[i]->mutable_data(); diff --git a/paddle/fluid/lite/operators/CMakeLists.txt b/paddle/fluid/lite/operators/CMakeLists.txt index 9a90666420e94bdb585feeac689d9227fc6a2104..e996c5a29b56a9884e87b8e89b388d3ae03ee560 100644 --- a/paddle/fluid/lite/operators/CMakeLists.txt +++ b/paddle/fluid/lite/operators/CMakeLists.txt @@ -8,6 +8,7 @@ cc_library(mul_op_lite SRCS mul_op.cc DEPS ${op_DEPS}) cc_library(scale_op_lite SRCS scale_op.cc DEPS ${op_DEPS}) cc_library(softmax_op_lite SRCS softmax_op.cc DEPS ${op_DEPS}) cc_library(reshape_op_lite SRCS reshape_op.cc DEPS ${op_DEPS} ) +cc_library(batch_norm_op_lite SRCS batch_norm_op.cc DEPS ${op_DEPS}) cc_library(feed_op_lite SRCS feed_op.cc DEPS ${op_DEPS}) cc_library(fetch_op_lite SRCS fetch_op.cc DEPS ${op_DEPS}) cc_library(io_copy_op_lite SRCS io_copy_op.cc DEPS ${op_DEPS}) @@ -30,6 +31,7 @@ set(ops_lite scale_op_lite softmax_op_lite reshape_op_lite + batch_norm_op_lite feed_op_lite fetch_op_lite io_copy_op_lite @@ -52,4 +54,5 @@ lite_cc_test(test_pool_op_lite SRCS pool_op_test.cc lite_cc_test(test_scale_op_lite SRCS scale_op_test.cc DEPS scale_op_lite memory_lite) lite_cc_test(test_softmax_op_lite SRCS softmax_op_test.cc DEPS softmax_op_lite memory_lite) lite_cc_test(test_reshape_op_lite SRCS reshape_op_test.cc DEPS reshape_op_lite memory_lite) +lite_cc_test(test_batch_norm_op_lite SRCS batch_norm_op_test.cc DEPS batch_norm_op_lite memory_lite) lite_cc_test(test_concat_op_lite SRCS concat_op_test.cc DEPS concat_op_lite memory_lite) diff --git a/paddle/fluid/lite/operators/batch_norm_op.cc b/paddle/fluid/lite/operators/batch_norm_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..e974d0134dad93a2241c265687a190b10d5ff85d --- /dev/null +++ b/paddle/fluid/lite/operators/batch_norm_op.cc @@ -0,0 +1,110 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/operators/batch_norm_op.h" +#include "paddle/fluid/lite/core/op_registry.h" +namespace paddle { +namespace lite { +namespace operators { + +bool BatchNormOp::CheckShape() const { + CHECK_OR_FALSE(param_.x); + CHECK_OR_FALSE(param_.bias); + CHECK_OR_FALSE(param_.scale); + CHECK_OR_FALSE(param_.mean); + CHECK_OR_FALSE(param_.variance); + CHECK_OR_FALSE(param_.y); + if (!param_.is_test) { + CHECK_OR_FALSE(param_.mean_out); + CHECK_OR_FALSE(param_.variance_out); + CHECK_OR_FALSE(param_.saved_mean); + CHECK_OR_FALSE(param_.saved_variance); + } + auto x_dims = param_.x->dims(); + auto scale_dims = param_.scale->dims(); + auto bias_dims = param_.bias->dims(); + auto mean_dims = param_.mean->dims(); + auto variance_dims = param_.variance->dims(); + CHECK(x_dims.size() >= 2 && x_dims.size() <= 5) + << "Input X must have 2 to 5 dimensions."; + CHECK_EQ(scale_dims.size(), 1UL) << "Input Scale must have 1 dimensions."; + CHECK_EQ(bias_dims.size(), 1UL) << "Input Bias must have 1 dimensions."; + CHECK_EQ(mean_dims.size(), 1UL) << "Input Mean must have 1 dimensions."; + CHECK_EQ(variance_dims.size(), 1UL) + << "Input Variance must have 1 dimensions."; + return true; +} + +bool BatchNormOp::InferShape() const { + auto x_dims = param_.x->dims(); + int64_t channel_size = 0; + switch (param_.data_layout) { + case DATALAYOUT(kNCHW): + channel_size = x_dims[1]; + break; + // case DATALAYOUT(kNHWC): + // channel_size = x_dims[x_dims.size() - 1]; + // break; + default: + LOG(FATAL) << "Unknown storage order: " + << DataLayoutToStr(param_.data_layout); + break; + } + if (!param_.is_test) { + param_.mean_out->Resize({channel_size}); + param_.variance_out->Resize({channel_size}); + param_.saved_mean->Resize({channel_size}); + param_.saved_variance->Resize({channel_size}); + } + param_.y->Resize(x_dims); + return true; +} + +bool BatchNormOp::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) { + param_.x = scope->FindVar(op_desc.Input("X").front())->GetMutable(); + param_.bias = + scope->FindVar(op_desc.Input("Bias").front())->GetMutable(); + param_.scale = + scope->FindVar(op_desc.Input("Scale").front())->GetMutable(); + param_.mean = + scope->FindVar(op_desc.Input("Mean").front())->GetMutable(); + param_.variance = + scope->FindVar(op_desc.Input("Variance").front())->GetMutable(); + param_.y = scope->FindVar(op_desc.Output("Y").front())->GetMutable(); + param_.is_test = op_desc.GetAttr("is_test"); + param_.use_global_stats = op_desc.GetAttr("use_global_stats"); + if (!param_.is_test) { + param_.mean_out = + scope->FindVar(op_desc.Output("MeanOut").front())->GetMutable(); + param_.variance_out = scope->FindVar(op_desc.Output("VarianceOut").front()) + ->GetMutable(); + param_.saved_mean = scope->FindVar(op_desc.Output("SavedMean").front()) + ->GetMutable(); + param_.saved_variance = + scope->FindVar(op_desc.Output("SavedVariance").front()) + ->GetMutable(); + } + param_.epsilon = op_desc.GetAttr("epsilon"); + param_.momentum = op_desc.GetAttr("momentum"); + std::string data_layout = op_desc.GetAttr("data_layout"); + CHECK_EQ(data_layout, "NCHW") << "TODO(hong19860320): Only support NCHW."; + // param_.data_layout = StringToDataLayout(data_layout); + return true; +} + +} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(batch_norm, paddle::lite::operators::BatchNormOp); diff --git a/paddle/fluid/lite/operators/batch_norm_op.h b/paddle/fluid/lite/operators/batch_norm_op.h new file mode 100644 index 0000000000000000000000000000000000000000..30e8747319b1575b0c63e4b2919ed1363ad10bef --- /dev/null +++ b/paddle/fluid/lite/operators/batch_norm_op.h @@ -0,0 +1,46 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "paddle/fluid/lite/core/op_lite.h" +#include "paddle/fluid/lite/core/scope.h" +#include "paddle/fluid/lite/utils/all.h" + +namespace paddle { +namespace lite { +namespace operators { + +class BatchNormOp : public OpLite { + public: + BatchNormOp() {} + explicit BatchNormOp(const std::string &op_type) : OpLite(op_type) {} + + bool CheckShape() const override; + + bool InferShape() const override; + + bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + std::string DebugString() const override { return "batch_norm"; } + + private: + mutable BatchNormParam param_; +}; + +} // namespace operators +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/operators/batch_norm_op_test.cc b/paddle/fluid/lite/operators/batch_norm_op_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..f5449e7ec19f1bfdcc3f11a7fdac65dcfbc9af17 --- /dev/null +++ b/paddle/fluid/lite/operators/batch_norm_op_test.cc @@ -0,0 +1,139 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/operators/batch_norm_op.h" +#include +#include "paddle/fluid/lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators { + +TEST(batch_norm_op_lite, test) { + // prepare variables + Scope scope; + auto* x = scope.Var("x")->GetMutable(); + auto* scale = scope.Var("scale")->GetMutable(); + auto* bias = scope.Var("bias")->GetMutable(); + auto* mean = scope.Var("mean")->GetMutable(); + auto* variance = scope.Var("variance")->GetMutable(); + auto* y = scope.Var("y")->GetMutable(); + x->Resize({2, 32, 10, 20}); + auto x_dims = x->dims(); + const int64_t channel_size = x_dims[1]; // NCHW + scale->Resize({channel_size}); + bias->Resize({channel_size}); + mean->Resize({channel_size}); + variance->Resize(DDim({channel_size})); + + // prepare op desc + cpp::OpDesc desc; + desc.SetType("batch_norm"); + desc.SetInput("X", {"x"}); + desc.SetInput("Scale", {"scale"}); + desc.SetInput("Bias", {"bias"}); + desc.SetInput("Mean", {"mean"}); + desc.SetInput("Variance", {"variance"}); + desc.SetOutput("Y", {"y"}); + desc.SetAttr("is_test", true); + desc.SetAttr("use_global_stats", false); + desc.SetAttr("epsilon", 1e-5f); + desc.SetAttr("momentum", 0.9f); + desc.SetAttr("data_layout", std::string("NCHW")); + + BatchNormOp batch_norm("batch_norm"); + + batch_norm.SetValidPlaces({Place{TARGET(kHost), PRECISION(kFloat)}}); + batch_norm.Attach(desc, &scope); + batch_norm.CheckShape(); + batch_norm.InferShape(); + + // check output dims + auto y_dims = y->dims(); + CHECK_EQ(y_dims.size(), x_dims.size()); + for (size_t i = 0; i < y_dims.size(); i++) { + CHECK_EQ(y_dims[i], x_dims[i]); + } +} + +TEST(batch_norm_op_lite, test_enable_is_test) { + // prepare variables + Scope scope; + auto* x = scope.Var("x")->GetMutable(); + auto* scale = scope.Var("scale")->GetMutable(); + auto* bias = scope.Var("bias")->GetMutable(); + auto* mean = scope.Var("mean")->GetMutable(); + auto* variance = scope.Var("variance")->GetMutable(); + auto* y = scope.Var("y")->GetMutable(); + auto* mean_out = scope.Var("mean_out")->GetMutable(); + auto* variance_out = scope.Var("variance_out")->GetMutable(); + auto* saved_mean = scope.Var("saved_mean")->GetMutable(); + auto* saved_variance = scope.Var("saved_variance")->GetMutable(); + x->Resize({2, 32, 10, 20}); + auto x_dims = x->dims(); + const int64_t channel_size = x_dims[1]; // NCHW + scale->Resize({channel_size}); + bias->Resize({channel_size}); + mean->Resize({channel_size}); + variance->Resize({channel_size}); + + // prepare op desc + cpp::OpDesc desc; + desc.SetType("batch_norm"); + desc.SetInput("X", {"x"}); + desc.SetInput("Scale", {"scale"}); + desc.SetInput("Bias", {"bias"}); + desc.SetInput("Mean", {"mean"}); + desc.SetInput("Variance", {"variance"}); + desc.SetOutput("Y", {"y"}); + desc.SetOutput("MeanOut", {"mean_out"}); + desc.SetOutput("VarianceOut", {"variance_out"}); + desc.SetOutput("SavedMean", {"saved_mean"}); + desc.SetOutput("SavedVariance", {"saved_variance"}); + desc.SetAttr("is_test", false); + desc.SetAttr("use_global_stats", false); + desc.SetAttr("epsilon", 1e-5f); + desc.SetAttr("momentum", 0.9f); + desc.SetAttr("data_layout", std::string("NCHW")); + + BatchNormOp batch_norm("batch_norm"); + + batch_norm.SetValidPlaces({Place{TARGET(kHost), PRECISION(kFloat)}}); + batch_norm.Attach(desc, &scope); + batch_norm.CheckShape(); + batch_norm.InferShape(); + + // check output dims + auto y_dims = y->dims(); + CHECK_EQ(y_dims.size(), x_dims.size()); + for (size_t i = 0; i < y_dims.size(); i++) { + CHECK_EQ(y_dims[i], x_dims[i]); + } + auto mean_out_dims = mean_out->dims(); + auto variance_out_dims = variance_out->dims(); + auto saved_mean_dims = saved_mean->dims(); + auto saved_variance_dims = saved_variance->dims(); + CHECK_EQ(mean_out_dims.size(), 1UL); + CHECK_EQ(variance_out_dims.size(), 1UL); + CHECK_EQ(saved_mean_dims.size(), 1UL); + CHECK_EQ(saved_variance_dims.size(), 1UL); + CHECK_EQ(mean_out_dims[0], channel_size); + CHECK_EQ(variance_out_dims[0], channel_size); + CHECK_EQ(saved_mean_dims[0], channel_size); + CHECK_EQ(saved_variance_dims[0], channel_size); +} + +} // namespace operators +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/operators/op_params.h b/paddle/fluid/lite/operators/op_params.h index eee0d90dba2f3aad86a94983e0ac8fd67127b420..91a6067959854f608e31a6151a4e63e26df7eb64 100644 --- a/paddle/fluid/lite/operators/op_params.h +++ b/paddle/fluid/lite/operators/op_params.h @@ -57,6 +57,7 @@ struct FcParam { lite::Tensor* output{}; lite::DDim in_mat_dims; int in_num_col_dims{1}; + bool weight_transposed{false}; }; struct ReluParam { @@ -145,6 +146,25 @@ struct ConvParam { std::string data_format{"Anylayout"}; }; +// For BatchNorm op +struct BatchNormParam { + lite::Tensor* x{}; + lite::Tensor* bias{}; + lite::Tensor* scale{}; + lite::Tensor* mean{}; + lite::Tensor* variance{}; + lite::Tensor* y{}; + lite::Tensor* mean_out{}; + lite::Tensor* variance_out{}; + lite::Tensor* saved_mean{}; + lite::Tensor* saved_variance{}; + bool is_test{true}; + bool use_global_stats{false}; + float epsilon; + float momentum; + DataLayoutType data_layout{DATALAYOUT(kNCHW)}; +}; + // For Pooling op struct PoolParam { lite::Tensor* x{}; @@ -177,10 +197,10 @@ struct DropoutParam { // For Split op struct SplitParam { lite::Tensor* x{}; - std::vector* output{}; + std::vector output{}; int axis{-1}; int num{0}; - std::vector* sections; + std::vector sections; }; /// ----------------------- element wise operators ---------------------- diff --git a/paddle/fluid/lite/operators/split_op.cc b/paddle/fluid/lite/operators/split_op.cc index c788e9cf9546a8c058398d71fde7aa4295fe8fbc..9b4b7662ab7ba7228ee215bf051601150e2b6bb7 100644 --- a/paddle/fluid/lite/operators/split_op.cc +++ b/paddle/fluid/lite/operators/split_op.cc @@ -21,7 +21,7 @@ namespace operators { bool SplitOp::CheckShape() const { CHECK_OR_FALSE(param_.x); - CHECK_OR_FALSE(param_.output); + CHECK_GT_OR_FALSE(param_.output.size(), 1UL); auto x_dims = param_.x->dims(); auto x_rank = x_dims.size(); CHECK_OR_FALSE(param_.axis >= -static_cast(x_rank) && @@ -31,7 +31,7 @@ bool SplitOp::CheckShape() const { bool SplitOp::InferShape() const { const auto &outs = param_.output; - auto in_dims = param_.x.dims(); + auto in_dims = param_.x->dims(); int axis = param_.axis; int num = param_.num; const auto §ions = param_.sections; @@ -68,7 +68,7 @@ bool SplitOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { param_.sections = opdesc.GetAttr>("sections"); param_.x = const_cast( &scope->FindVar(opdesc.Input("X").front())->Get()); - auto outs = op_desc.Output("Out"); + auto outs = opdesc.Output("Out"); for (auto var : outs) { param_.output.push_back(scope->FindVar(var)->GetMutable()); } @@ -79,4 +79,4 @@ bool SplitOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { } // namespace lite } // namespace paddle -REGISTER_LITE_OP(softmax, paddle::lite::operators::SoftmaxOp); +REGISTER_LITE_OP(split, paddle::lite::operators::SplitOp); diff --git a/paddle/fluid/lite/operators/split_op.h b/paddle/fluid/lite/operators/split_op.h index 177c44171e6e67214f820f04e801be6c01df01cc..20dc4b1028c27f4efab558694285e44d46182ef8 100644 --- a/paddle/fluid/lite/operators/split_op.h +++ b/paddle/fluid/lite/operators/split_op.h @@ -23,7 +23,7 @@ namespace paddle { namespace lite { namespace operators { -class SoftmaxOp : public OpLite { +class SplitOp : public OpLite { public: SplitOp() {} explicit SplitOp(const std::string &op_type) : OpLite(op_type) {} diff --git a/paddle/fluid/lite/tools/build.sh b/paddle/fluid/lite/tools/build.sh index 54d938098749415ecf3d5a1ccf4ca72cd1b2b3e9..96657217cef0ca2f02a8626da7834953ea7ae358 100755 --- a/paddle/fluid/lite/tools/build.sh +++ b/paddle/fluid/lite/tools/build.sh @@ -59,11 +59,15 @@ function cmake_arm { -DARM_TARGET_OS=$1 -DARM_TARGET_ARCH_ABI=$2 } +function build_single { + #make $1 -j$(expr $(nproc) - 2) + make $1 -j8 +} + function build { file=$1 for _test in $(cat $file); do - #make $_test -j$(expr $(nproc) - 2) - make $_test -j8 + build_single $_test done } @@ -81,39 +85,6 @@ function test_lite { done } -port_armv8=5554 -port_armv7=5556 - -# Run test on android -function test_lite_android { - local file=$1 - local adb_abi=$2 - local port= - if [[ ${adb_abi} == "armeabi-v7a" ]]; then - port=${port_armv7} - fi - - if [[ ${adb_abi} == "arm64-v8a" ]]; then - port=${port_armv8} - fi - if [[ "${port}x" == "x" ]]; then - echo "Port can not be empty" - exit 1 - fi - - echo "file: ${file}" - # push all to adb and test - adb_work_dir="/data/local/tmp" - skip_list="test_model_parser_lite" - for _test in $(cat $file); do - [[ $skip_list =~ (^|[[:space:]])$_test($|[[:space:]]) ]] && continue || echo 'skip $_test' - testpath=$(find ./paddle/fluid -name ${_test}) - adb -s emulator-${port} push ${testpath} ${adb_work_dir} - adb -s emulator-${port} shell chmod +x "${adb_work_dir}/${_test}" - adb -s emulator-${port} shell "./${adb_work_dir}/${_test}" - done -} - # Build the code and run lite server tests. This is executed in the CI system. function build_test_server { mkdir -p ./build @@ -126,8 +97,34 @@ function build_test_server { build $LIBS_FILE } -# Build the code and run lite server tests. This is executed in the CI system. +# test_arm_android +function test_arm_android { + test_name=$1 + port=$2 + if [[ "${test_name}x" == "x" ]]; then + echo "test_name can not be empty" + exit 1 + fi + if [[ "${port}x" == "x" ]]; then + echo "Port can not be empty" + exit 1 + fi + + echo "test name: ${test_name}" + adb_work_dir="/data/local/tmp" + skip_list="test_model_parser_lite" # add more with space + [[ $skip_list =~ (^|[[:space:]])$test_name($|[[:space:]]) ]] && continue || echo 'skip $test_name' + testpath=$(find ./paddle/fluid -name ${test_name}) + adb -s emulator-${port} push ${testpath} ${adb_work_dir} + adb -s emulator-${port} shell chmod +x "${adb_work_dir}/${test_name}" + adb -s emulator-${port} shell "./${adb_work_dir}/${test_name}" +} + +# Build the code and run lite arm tests. This is executed in the CI system. function build_test_arm { + port_armv8=5554 + port_armv7=5556 + adb kill-server adb devices | grep emulator | cut -f1 | while read line; do adb -s $line emu kill; done # start android arm64-v8a armeabi-v7a emulators first @@ -140,6 +137,7 @@ function build_test_arm { for os in "android" "armlinux" ; do for abi in "arm64-v8a" "armeabi-v7a" "armeabi-v7a-hf" ; do + # TODO(TJ): enable compile on v7-hf on andorid and all v7 on armlinux if [[ ${abi} == "armeabi-v7a-hf" ]]; then echo "armeabi-v7a-hf is not supported on both android and armlinux" continue @@ -156,17 +154,30 @@ function build_test_arm { cmake_arm ${os} ${abi} build $TESTS_FILE + # armlinux need in another docker + # TODO(TJ): enable test with armlinux if [[ ${os} == "android" ]]; then adb_abi=${abi} if [[ ${adb_abi} == "armeabi-v7a-hf" ]]; then adb_abi="armeabi-v7a" fi if [[ ${adb_abi} == "armeabi-v7a" ]]; then - # skip v7 tests + # skip all armv7 tests + # TODO(TJ): enable test with armv7 continue fi - test_lite_android $TESTS_FILE ${adb_abi} - # armlinux need in another docker + local port= + if [[ ${adb_abi} == "armeabi-v7a" ]]; then + port=${port_armv7} + fi + + if [[ ${adb_abi} == "arm64-v8a" ]]; then + port=${port_armv8} + fi + echo "test file: ${TESTS_FILE}" + for _test in $(cat $TESTS_FILE); do + test_arm_android $_test $port + done fi cd - done @@ -182,12 +193,13 @@ function print_usage { echo "----------------------------------------" echo -e "cmake_x86: run cmake with X86 mode" echo -e "cmake_cuda: run cmake with CUDA mode" - echo -e "cmake_arm: run cmake with ARM mode" + echo -e "--arm_os= --arm_abi= cmake_arm: run cmake with ARM mode" echo echo -e "build: compile the tests" + echo -e "--test_name= build_single: compile single test" echo echo -e "test_server: run server tests" - echo -e "test_mobile: run mobile tests" + echo -e "--test_name= --adb_port_number= test_arm_android: run arm test" echo "----------------------------------------" echo } @@ -200,11 +212,31 @@ function main { TESTS_FILE="${i#*=}" shift ;; + --test_name=*) + TEST_NAME="${i#*=}" + shift + ;; + --arm_os=*) + ARM_OS="${i#*=}" + shift + ;; + --arm_abi=*) + ARM_ABI="${i#*=}" + shift + ;; + --arm_port=*) + ARM_PORT="${i#*=}" + shift + ;; build) build $TESTS_FILE build $LIBS_FILE shift ;; + build_single) + build_single $TEST_NAME + shift + ;; cmake_x86) cmake_x86 shift @@ -214,15 +246,15 @@ function main { shift ;; cmake_arm) - cmake_arm $2 $3 + cmake_arm $ARM_OS $ARM_ABI shift ;; test_server) test_lite $TESTS_FILE shift ;; - test_mobile) - test_lite $TESTS_FILE + test_arm_android) + test_arm_android $TEST_NAME $ARM_PORT shift ;; build_test_server) @@ -250,6 +282,4 @@ function main { done } -print_usage - main $@