提交 c659d037 编写于 作者: T tensor-tang

Merge remote-tracking branch 'gitlab/develop' into incubate/lite

fix conflicts
...@@ -58,6 +58,111 @@ void scale<float>(const float* din, float* dout, int num, float scale, ...@@ -58,6 +58,111 @@ void scale<float>(const float* din, float* dout, int num, float scale,
} }
} }
template <>
void scale<float>(const float* din, float* dout, int outer_dim, int scale_dim,
int inner_dim, const float* scale_data,
const float* bias_data) {
int cnt = inner_dim >> 4;
int remain = inner_dim % 16;
int size = inner_dim * scale_dim;
for (int n = 0; n < outer_dim; n++) {
const float* din_ptr_n = din + n * size;
float* dout_ptr_n = dout + n * size;
#pragma omp parallel for
for (int i = 0; i < scale_dim; i++) {
const float* din_ptr = din_ptr_n + i * inner_dim;
float* dout_ptr = dout_ptr_n + i * inner_dim;
float scale = scale_data[i];
float32x4_t vscale = vdupq_n_f32(scale);
float bias = bias_data[i];
float32x4_t vbias = vdupq_n_f32(bias);
for (int j = 0; j < cnt; j++) {
float32x4_t din0 = vld1q_f32(din_ptr);
float32x4_t din1 = vld1q_f32(din_ptr + 4);
float32x4_t din2 = vld1q_f32(din_ptr + 8);
float32x4_t din3 = vld1q_f32(din_ptr + 12);
float32x4_t vsum1 = vmlaq_f32(vbias, din0, vscale);
float32x4_t vsum2 = vmlaq_f32(vbias, din1, vscale);
float32x4_t vsum3 = vmlaq_f32(vbias, din2, vscale);
float32x4_t vsum4 = vmlaq_f32(vbias, din3, vscale);
din_ptr += 16;
vst1q_f32(dout_ptr, vsum1);
vst1q_f32(dout_ptr + 4, vsum2);
vst1q_f32(dout_ptr + 8, vsum3);
vst1q_f32(dout_ptr + 12, vsum4);
dout_ptr += 16;
}
for (int j = 0; j < remain; j++) {
*dout_ptr = *din_ptr * scale + bias;
dout_ptr++;
din_ptr++;
}
}
}
}
template <>
void scale<float>(const float* din, float* dout, int outer_dim, int scale_dim,
const float* scale_data, const float* bias_data) {
int cnt = scale_dim >> 4;
int remain = scale_dim % 16;
for (int n = 0; n < outer_dim; n++) {
const float* din_ptr_n = din + n * scale_dim;
float* dout_ptr_n = dout + n * scale_dim;
#pragma omp parallel for
for (int i = 0; i < cnt; i++) {
int idx = i << 4;
const float* din_ptr = din_ptr_n + idx;
const float* scale_ptr = scale_data + idx;
const float* bias_ptr = bias_data + idx;
float* dout_ptr = dout_ptr_n + idx;
float32x4_t din0 = vld1q_f32(din_ptr);
float32x4_t vscale0 = vld1q_f32(scale_ptr);
float32x4_t vbias0 = vld1q_f32(bias_ptr);
float32x4_t din1 = vld1q_f32(din_ptr + 4);
float32x4_t vscale1 = vld1q_f32(scale_ptr + 4);
float32x4_t vbias1 = vld1q_f32(bias_ptr + 4);
float32x4_t din2 = vld1q_f32(din_ptr + 8);
float32x4_t vscale2 = vld1q_f32(scale_ptr + 8);
float32x4_t vbias2 = vld1q_f32(bias_ptr + 8);
float32x4_t vsum1 = vmlaq_f32(vbias0, din0, vscale0);
float32x4_t vsum2 = vmlaq_f32(vbias1, din1, vscale1);
float32x4_t din3 = vld1q_f32(din_ptr + 12);
float32x4_t vscale3 = vld1q_f32(scale_ptr + 12);
float32x4_t vbias3 = vld1q_f32(bias_ptr + 12);
vst1q_f32(dout_ptr, vsum1);
vst1q_f32(dout_ptr + 4, vsum2);
float32x4_t vsum3 = vmlaq_f32(vbias2, din2, vscale2);
float32x4_t vsum4 = vmlaq_f32(vbias3, din3, vscale3);
vst1q_f32(dout_ptr + 8, vsum3);
vst1q_f32(dout_ptr + 12, vsum4);
}
int idx = cnt << 4;
const float* din_ptr = din_ptr_n + idx;
float* dout_ptr = dout_ptr_n + idx;
const float* scale_ptr = scale_data + idx;
const float* bias_ptr = bias_data + idx;
for (int j = 0; j < remain; j++) {
*dout_ptr = *din_ptr * (*scale_ptr) + (*bias_ptr);
dout_ptr++;
din_ptr++;
scale_ptr++;
bias_ptr++;
}
}
}
} // namespace math } // namespace math
} // namespace arm } // namespace arm
} // namespace lite } // namespace lite
......
...@@ -22,6 +22,14 @@ namespace math { ...@@ -22,6 +22,14 @@ namespace math {
template <typename T> template <typename T>
void scale(const T* din, T* dout, int num, float scale, float bias); void scale(const T* din, T* dout, int num, float scale, float bias);
template <typename T>
void scale(const T* din, T* dout, int outer_dim, int scale_dim, int inner_dim,
const float* scale_data, const float* bias_data);
template <typename T>
void scale(const T* din, T* dout, int outer_dim, int scale_dim,
const float* scale_data, const float* bias_data);
} // namespace math } // namespace math
} // namespace arm } // namespace arm
} // namespace lite } // namespace lite
......
...@@ -52,10 +52,10 @@ void split_cpy<float>(const float* din, float* dout, int num) { ...@@ -52,10 +52,10 @@ void split_cpy<float>(const float* din, float* dout, int num) {
} }
template <> template <>
void split<float>(const float* din, std::vector<lite::Tensor*>* dout, void split<float>(const float* din, const std::vector<lite::Tensor*>& dout,
const int axis, const std::vector<int>& in_strides) { const int axis, const std::vector<int>& in_strides) {
int input_offset = 0; int input_offset = 0;
for (auto out : *dout) { for (auto out : dout) {
auto out_dim = out->dims(); auto out_dim = out->dims();
std::vector<int> out_strides(out_dim.size()); std::vector<int> out_strides(out_dim.size());
out_strides[out_dim.size() - 1] = out_dim[out_dim.size() - 1]; out_strides[out_dim.size() - 1] = out_dim[out_dim.size() - 1];
......
...@@ -26,7 +26,7 @@ template <typename T> ...@@ -26,7 +26,7 @@ template <typename T>
void split_cpy(const T* din, T* dout, int num); void split_cpy(const T* din, T* dout, int num);
template <typename T> template <typename T>
void split(const T* din, std::vector<lite::Tensor*>* dout, const int axis, void split(const T* din, const std::vector<lite::Tensor*>& dout, const int axis,
const std::vector<int>& in_strides); const std::vector<int>& in_strides);
} // namespace math } // namespace math
......
...@@ -54,15 +54,15 @@ void DeviceInfo::InitInternal(DeviceInfo* dev) { ...@@ -54,15 +54,15 @@ void DeviceInfo::InitInternal(DeviceInfo* dev) {
<< ", cluster ID: " << dev->cluster_ids_[dev->core_ids_[i]] << ", cluster ID: " << dev->cluster_ids_[dev->core_ids_[i]]
<< ", CPU ARCH: A" << dev->archs_[i]; << ", CPU ARCH: A" << dev->archs_[i];
} }
LOG(INFO) << "L1 DataCache size is: "; VLOG(1) << "L1 DataCache size is: ";
for (int i = 0; i < dev->compute_core_num_; ++i) { for (int i = 0; i < dev->compute_core_num_; ++i) {
LOG(INFO) << dev->L1_cache_[i] / 1024 << " KB"; VLOG(1) << dev->L1_cache_[i] / 1024 << " KB";
} }
LOG(INFO) << "L2 Cache size is: "; VLOG(1) << "L2 Cache size is: ";
for (int i = 0; i < dev->compute_core_num_; ++i) { for (int i = 0; i < dev->compute_core_num_; ++i) {
LOG(INFO) << dev->L2_cache_[i] / 1024 << " KB"; VLOG(1) << dev->L2_cache_[i] / 1024 << " KB";
} }
LOG(INFO) << "Total memory: " << dev->max_memory_ << "KB"; VLOG(1) << "Total memory: " << dev->max_memory_ << "KB";
dev->max_freq_ = max_freq[0]; dev->max_freq_ = max_freq[0];
for (int j = 1; j < dev->compute_core_num_; ++j) { for (int j = 1; j < dev->compute_core_num_; ++j) {
......
...@@ -6,10 +6,11 @@ message(STATUS "compile with lite ARM kernels") ...@@ -6,10 +6,11 @@ message(STATUS "compile with lite ARM kernels")
cc_library(fc_compute_arm SRCS fc_compute.cc DEPS ${lite_kernel_deps} math_arm) cc_library(fc_compute_arm SRCS fc_compute.cc DEPS ${lite_kernel_deps} math_arm)
cc_library(relu_compute_arm SRCS relu_compute.cc DEPS ${lite_kernel_deps}) cc_library(relu_compute_arm SRCS relu_compute.cc DEPS ${lite_kernel_deps})
cc_library(mul_compute_arm SRCS mul_compute.cc DEPS ${lite_kernel_deps} eigen3) cc_library(mul_compute_arm SRCS mul_compute.cc DEPS ${lite_kernel_deps} math_arm)
cc_library(scale_compute_arm SRCS scale_compute.cc DEPS ${lite_kernel_deps} math_arm) cc_library(scale_compute_arm SRCS scale_compute.cc DEPS ${lite_kernel_deps} math_arm)
cc_library(softmax_compute_arm SRCS softmax_compute.cc DEPS ${lite_kernel_deps} math_arm) cc_library(softmax_compute_arm SRCS softmax_compute.cc DEPS ${lite_kernel_deps} math_arm)
cc_library(conv_compute_arm SRCS conv_compute.cc DEPS ${lite_kernel_deps} math_arm) cc_library(conv_compute_arm SRCS conv_compute.cc DEPS ${lite_kernel_deps} math_arm)
cc_library(batch_norm_compute_arm SRCS batch_norm_compute.cc DEPS ${lite_kernel_deps} math_arm)
cc_library(elementwise_add_compute_arm SRCS elementwise_add_compute.cc DEPS ${lite_kernel_deps} math_arm) cc_library(elementwise_add_compute_arm SRCS elementwise_add_compute.cc DEPS ${lite_kernel_deps} math_arm)
cc_library(pool_compute_arm SRCS pool_compute.cc DEPS ${lite_kernel_deps} math_arm) cc_library(pool_compute_arm SRCS pool_compute.cc DEPS ${lite_kernel_deps} math_arm)
cc_library(split_compute_arm SRCS split_compute.cc DEPS ${lite_kernel_deps} math_arm) cc_library(split_compute_arm SRCS split_compute.cc DEPS ${lite_kernel_deps} math_arm)
...@@ -18,8 +19,10 @@ lite_cc_test(test_fc_compute_arm SRCS fc_compute_test.cc DEPS fc_compute_arm mat ...@@ -18,8 +19,10 @@ lite_cc_test(test_fc_compute_arm SRCS fc_compute_test.cc DEPS fc_compute_arm mat
lite_cc_test(test_scale_compute_arm SRCS scale_compute_test.cc DEPS scale_compute_arm) lite_cc_test(test_scale_compute_arm SRCS scale_compute_test.cc DEPS scale_compute_arm)
lite_cc_test(test_softmax_compute_arm SRCS softmax_compute_test.cc DEPS softmax_compute_arm) lite_cc_test(test_softmax_compute_arm SRCS softmax_compute_test.cc DEPS softmax_compute_arm)
lite_cc_test(test_conv_compute_arm SRCS conv_compute_test.cc DEPS conv_compute_arm) lite_cc_test(test_conv_compute_arm SRCS conv_compute_test.cc DEPS conv_compute_arm)
lite_cc_test(test_batch_norm_compute_arm SRCS batch_norm_compute_test.cc DEPS batch_norm_compute_arm)
lite_cc_test(test_elementwise_add_compute_arm SRCS elementwise_add_compute_test.cc DEPS elementwise_add_compute_arm) lite_cc_test(test_elementwise_add_compute_arm SRCS elementwise_add_compute_test.cc DEPS elementwise_add_compute_arm)
lite_cc_test(test_pool_compute_arm SRCS pool_compute_test.cc DEPS pool_compute_arm) lite_cc_test(test_pool_compute_arm SRCS pool_compute_test.cc DEPS pool_compute_arm)
lite_cc_test(test_mul_compute_arm SRCS mul_compute_test.cc DEPS mul_compute_arm)
lite_cc_test(test_split_compute_arm SRCS split_compute_test.cc DEPS split_compute_arm) lite_cc_test(test_split_compute_arm SRCS split_compute_test.cc DEPS split_compute_arm)
set(arm_kernels set(arm_kernels
...@@ -29,6 +32,7 @@ set(arm_kernels ...@@ -29,6 +32,7 @@ set(arm_kernels
scale_compute_arm scale_compute_arm
softmax_compute_arm softmax_compute_arm
conv_compute_arm conv_compute_arm
batch_norm_compute_arm
elementwise_add_compute_arm elementwise_add_compute_arm
pool_compute_arm pool_compute_arm
split_compute_arm split_compute_arm
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/kernels/arm/batch_norm_compute.h"
#include "paddle/fluid/lite/arm/math/funcs.h"
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/type_system.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
void BatchNormCompute::PrepareForRun() {
auto& param = this->Param<param_t>();
auto x_dims = param.x->dims();
bool global_stats = param.is_test || param.use_global_stats;
if (global_stats) {
int64_t channel_size = 0;
switch (param.data_layout) {
case DATALAYOUT(kNCHW):
channel_size = x_dims[1];
break;
// case DATALAYOUT(kNHWC):
// channel_size = x_dims[x_dims.size() - 1];
// break;
default:
LOG(FATAL) << "Unknown storage order: "
<< DataLayoutToStr(param.data_layout);
break;
}
new_scale.Resize({channel_size});
new_bias.Resize({channel_size});
auto* scale_data = param.scale->mutable_data<float>();
auto* bias_data = param.bias->mutable_data<float>();
auto* mean_data = param.mean->mutable_data<float>();
auto* variance_data = param.variance->mutable_data<float>();
auto* new_scale_data = new_scale.mutable_data<float>();
auto* new_bias_data = new_bias.mutable_data<float>();
for (int c = 0; c < channel_size; c++) {
float inv_scale = 1.f / (std::sqrt(variance_data[c] + param.epsilon));
new_bias_data[c] =
bias_data[c] - inv_scale * scale_data[c] * mean_data[c];
new_scale_data[c] = inv_scale * scale_data[c];
}
}
}
void BatchNormCompute::Run() {
auto& param = this->Param<param_t>();
auto x_dims = param.x->dims();
auto x_data = param.x->mutable_data<float>();
auto y_data = param.y->mutable_data<float>();
bool global_stats = param.is_test || param.use_global_stats;
if (global_stats) {
auto* new_scale_data = new_scale.mutable_data<float>();
auto* new_bias_data = new_bias.mutable_data<float>();
int64_t outer_size = 0;
int64_t channel_size = 0;
int64_t inner_size = 0;
switch (param.data_layout) {
case DATALAYOUT(kNCHW):
outer_size = x_dims[0];
channel_size = x_dims[1];
inner_size = x_dims.Slice(2, x_dims.size()).production();
lite::arm::math::scale(x_data, y_data, outer_size, channel_size,
inner_size, new_scale_data, new_bias_data);
break;
// case DATALAYOUT(kNHWC):
// outer_size = x_dims.Slice(0, x_dims.size() - 1).production();
// channel_size = x_dims[x_dims.size() - 1];
// lite::arm::math::scale(x_data, y_data, outer_size, channel_size,
// new_scale_data, new_bias_data);
// break;
default:
LOG(FATAL) << "Unknown storage order: "
<< DataLayoutToStr(param.data_layout);
break;
}
} else {
// TODO(hong19860320) calculate mean_out, variance_out, saved_mean and
// saved_variance
}
}
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(batch_norm, kARM, kFloat, kNCHW,
paddle::lite::kernels::arm::BatchNormCompute, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Scale", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Mean", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Variance", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Y", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("MeanOut", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("VarianceOut", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("SavedMean", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("SavedVariance", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
class BatchNormCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
using param_t = operators::BatchNormParam;
void PrepareForRun() override;
void Run() override;
virtual ~BatchNormCompute() = default;
private:
Tensor new_scale;
Tensor new_bias;
};
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/kernels/arm/batch_norm_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <utility>
#include <vector>
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
template <typename dtype>
void batch_norm_compute_ref(const operators::BatchNormParam& param) {
DDim x_dims = param.x->dims();
auto x_data = param.x->mutable_data<dtype>();
auto scale_data = param.scale->mutable_data<dtype>();
auto bias_data = param.bias->mutable_data<dtype>();
auto mean_data = param.mean->mutable_data<dtype>();
auto variance_data = param.variance->mutable_data<dtype>();
auto y_data = param.y->mutable_data<dtype>();
float epsilon = param.epsilon;
float momentum = param.momentum;
DataLayoutType data_layout = param.data_layout;
bool global_stats = param.is_test || param.use_global_stats;
if (global_stats) {
int64_t outer_size = 0;
int64_t channel_size = 0;
int64_t inner_size = 0;
switch (data_layout) {
case DATALAYOUT(kNCHW):
outer_size = x_dims[0];
channel_size = x_dims[1];
inner_size = x_dims.Slice(2, x_dims.size()).production();
break;
// case DATALAYOUT(kNHWC):
// outer_size = x_dims.Slice(0, x_dims.size() - 1).production();
// channel_size = x_dims[x_dims.size() - 1];
// inner_size = 1;
// break;
default:
LOG(FATAL) << "Unknown storage order: " << DataLayoutToStr(data_layout);
break;
}
auto x_ptr = x_data;
auto y_ptr = y_data;
for (int o = 0; o < outer_size; o++) {
for (int c = 0; c < channel_size; c++) {
for (int i = 0; i < inner_size; i++) {
dtype norm_x =
(*x_ptr - mean_data[c]) / std::sqrt(variance_data[c] + epsilon);
*y_ptr = norm_x * scale_data[c] + bias_data[c];
x_ptr++;
y_ptr++;
}
}
}
} else {
// TODO(hong19860320) calculate mean_out, variance_out, saved_mean and
// saved_variance
}
}
TEST(batch_norm_arm, retrive_op) {
auto batch_norm =
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>(
"batch_norm");
ASSERT_FALSE(batch_norm.empty());
ASSERT_TRUE(batch_norm.front());
}
TEST(batch_norm_arm, init) {
BatchNormCompute batch_norm;
ASSERT_EQ(batch_norm.precision(), PRECISION(kFloat));
ASSERT_EQ(batch_norm.target(), TARGET(kARM));
}
TEST(batch_norm_arm, compute) {
DeviceInfo::Init();
for (auto n : {1, 2}) {
for (auto c : {6, 32 /*, 128*/}) {
for (auto h : {9, 18 /*, 56 , 112, 224, 512*/}) {
for (auto w : {9, 18 /*, 56, 112, 224, 512*/}) {
for (auto is_test : {/*false, */ true}) {
for (auto use_global_stats : {false, true}) {
for (auto epsilon : {1e-4f, 1e-5f}) {
for (auto momentum : {0.9f, 0.99f}) {
for (auto data_layout :
{DATALAYOUT(kNCHW) /*, DATALAYOUT(kNHWC)*/}) {
Tensor x;
Tensor scale;
Tensor bias;
Tensor mean;
Tensor variance;
Tensor y;
Tensor mean_out;
Tensor variance_out;
Tensor saved_mean;
Tensor saved_variance;
Tensor y_ref;
Tensor mean_out_ref;
Tensor variance_out_ref;
Tensor saved_mean_ref;
Tensor saved_variance_ref;
// set the dims of input, output, ref output tensors
std::vector<int64_t> in_out_shape;
switch (data_layout) {
case DATALAYOUT(kNCHW):
in_out_shape = {n, c, h, w};
break;
// case DATALAYOUT(kNHWC):
// in_out_shape = {n, h, w, c};
// break;
default:
LOG(FATAL) << "Unknown storage order: "
<< DataLayoutToStr(data_layout);
break;
}
x.Resize(in_out_shape);
scale.Resize({c});
bias.Resize({c});
mean.Resize({c});
variance.Resize({c});
y.Resize(in_out_shape);
mean_out.Resize({c});
variance_out.Resize({c});
saved_mean.Resize({c});
saved_variance.Resize({c});
y_ref.Resize(in_out_shape);
mean_out_ref.Resize({c});
variance_out_ref.Resize({c});
saved_mean_ref.Resize({c});
saved_variance_ref.Resize({c});
// initialize the data of input tensors
auto* x_data = x.mutable_data<float>();
auto* scale_data = scale.mutable_data<float>();
auto* bias_data = bias.mutable_data<float>();
auto* mean_data = mean.mutable_data<float>();
auto* variance_data = variance.mutable_data<float>();
auto* y_data = y.mutable_data<float>();
for (int i = 0; i < x.dims().production(); i++) {
x_data[i] = static_cast<float>(i % 64);
}
for (int i = 0; i < scale.dims().production(); i++) {
scale_data[i] = static_cast<float>(i) * 0.01f + 0.03f;
}
for (int i = 0; i < bias.dims().production(); i++) {
bias_data[i] = static_cast<float>(i) * 0.065f + 0.1f;
}
for (int i = 0; i < mean.dims().production(); i++) {
mean_data[i] = static_cast<float>(i) * 0.0565f;
}
for (int i = 0; i < variance.dims().production(); i++) {
variance_data[i] = static_cast<float>(i) * 2.08f + 1.5f;
}
// prepare kernel params and run
BatchNormCompute batch_norm;
std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<ARMContext>();
batch_norm.SetContext(std::move(ctx));
operators::BatchNormParam param;
param.x = &x;
param.scale = &scale;
param.bias = &bias;
param.mean = &mean;
param.variance = &variance;
param.is_test = is_test;
param.use_global_stats = use_global_stats;
param.epsilon = epsilon;
param.momentum = momentum;
param.data_layout = data_layout;
param.y = &y;
param.mean_out = &mean_out;
param.variance_out = &variance_out;
param.saved_mean = &saved_mean;
param.saved_variance = &saved_variance;
batch_norm.SetParam(param);
batch_norm.Launch();
// invoking ref implementation and compare results
param.y = &y_ref;
param.mean_out = &mean_out_ref;
param.variance_out = &variance_out_ref;
param.saved_mean = &saved_mean_ref;
param.saved_variance = &saved_variance_ref;
batch_norm_compute_ref<float>(param);
auto* y_ref_data = y_ref.mutable_data<float>();
for (int i = 0; i < y.dims().production(); i++) {
EXPECT_NEAR(y_data[i], y_ref_data[i], 1e-5);
}
}
}
}
}
}
}
}
}
}
}
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(batch_norm, kARM, kFloat, kNCHW, def);
...@@ -124,7 +124,20 @@ TEST(conv_arm, init) { ...@@ -124,7 +124,20 @@ TEST(conv_arm, init) {
TEST(conv_arm, compute) { TEST(conv_arm, compute) {
DeviceInfo::Init(); DeviceInfo::Init();
#if 0 #if 1
for (auto n : {2}) {
for (auto ic : {6}) {
for (auto oc : {6}) {
for (auto ih : {9}) {
for (auto iw : {9}) {
for (auto flag_bias : {false, true}) {
for (auto flag_relu : {false, true}) {
for (auto depthwise : {false, true}) {
for (auto dilation : {1}) {
for (auto stride : {1, 2}) {
for (auto padding : {0, 1, 2}) {
for (auto ks : {1, 3, 5}) {
#else
for (auto n : {1, 2}) { for (auto n : {1, 2}) {
for (auto ic : {6, 32 /*, 128*/}) { for (auto ic : {6, 32 /*, 128*/}) {
for (auto oc : {6, 32 /*, 128*/}) { for (auto oc : {6, 32 /*, 128*/}) {
...@@ -137,19 +150,6 @@ TEST(conv_arm, compute) { ...@@ -137,19 +150,6 @@ TEST(conv_arm, compute) {
for (auto stride : {1, 2}) { for (auto stride : {1, 2}) {
for (auto padding : {0, 1, 2}) { for (auto padding : {0, 1, 2}) {
for (auto ks : {1, 3, 5}) { for (auto ks : {1, 3, 5}) {
#else
for (auto n : {1}) {
for (auto ic : {6}) {
for (auto oc : {6}) {
for (auto ih : {9}) {
for (auto iw : {9}) {
for (auto flag_bias : {false, true}) {
for (auto flag_relu : {false, true}) {
for (auto depthwise : {false, true}) {
for (auto dilation : {1}) {
for (auto stride : {1}) {
for (auto padding : {0, 1}) {
for (auto ks : {1, 3, 5}) {
#endif #endif
int group = 1; int group = 1;
if (depthwise) { // depthwise convolution ? if (depthwise) { // depthwise convolution ?
......
...@@ -22,6 +22,10 @@ namespace lite { ...@@ -22,6 +22,10 @@ namespace lite {
namespace kernels { namespace kernels {
namespace arm { namespace arm {
void FcCompute::PrepareForRun() {
// TODO(TJ): transpose weight
}
void FcCompute::Run() { void FcCompute::Run() {
auto& param = this->Param<operators::FcParam>(); auto& param = this->Param<operators::FcParam>();
auto x_dims = param.input->dims(); auto x_dims = param.input->dims();
...@@ -48,22 +52,16 @@ void FcCompute::Run() { ...@@ -48,22 +52,16 @@ void FcCompute::Run() {
&ctx); &ctx);
lite::arm::math::sgemm_prepack(packed_in, w_data, b_data, o_data, x_h, n, lite::arm::math::sgemm_prepack(packed_in, w_data, b_data, o_data, x_h, n,
x_w, false, false, false, &ctx); x_w, false, false, false, &ctx);
if (param.bias) { if (param.bias) {
CHECK_EQ(param.bias->numel(), n); CHECK_EQ(param.bias->numel(), n);
lite::arm::math::fill_bias_fc(o_data, b_data, x_h, n); lite::arm::math::fill_bias_fc(o_data, b_data, x_h, n);
} }
} else { } else {
// use sgemmv lite::arm::math::sgemv(w_data, i_data, o_data, false, n, x_w,
// sgemv((const float*)weights, (const float*)din, (float*)dout, b_data != nullptr, b_data, false);
// false, n, x_w, _param->_flag_bias, (float*)bias, false);
} }
} }
TargetType FcCompute::target() const { return TARGET(kARM); }
PrecisionType FcCompute::precision() const { return PRECISION(kFloat); }
} // namespace arm } // namespace arm
} // namespace kernels } // namespace kernels
} // namespace lite } // namespace lite
......
...@@ -25,10 +25,9 @@ class FcCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> { ...@@ -25,10 +25,9 @@ class FcCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public: public:
using param_t = operators::FcParam; using param_t = operators::FcParam;
void Run() override; void PrepareForRun() override;
TargetType target() const override; void Run() override;
PrecisionType precision() const override;
virtual ~FcCompute() = default; virtual ~FcCompute() = default;
}; };
......
...@@ -12,57 +12,57 @@ ...@@ -12,57 +12,57 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <Eigen/Core> #include "paddle/fluid/lite/kernels/arm/mul_compute.h"
#include "paddle/fluid/lite/core/kernel.h" #include "paddle/fluid/lite/arm/math/funcs.h"
#include "paddle/fluid/lite/core/op_registry.h" #include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/types.h" #include "paddle/fluid/lite/core/type_system.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
namespace kernels { namespace kernels {
namespace arm { namespace arm {
template <typename T> void MulCompute::PrepareForRun() {
void mul_compute_eigen(const T* x, int x_h, int x_w, const T* y, int y_h, // TODO(TJ): transpose x or y if necessary
int y_w, T* out) { }
using matrix_t =
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
Eigen::Map<const matrix_t> X(x, x_h, x_w); void MulCompute::Run() {
Eigen::Map<const matrix_t> Y(y, y_h, y_w); auto& param = Param<param_t>();
Eigen::Map<matrix_t> Out(out, x_h, y_w);
Out = X * Y; const auto* x_data = param.x->data<float>();
} const auto* y_data = param.y->data<float>();
auto* o_data = param.output->mutable_data<float>();
class MulCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> { int m = static_cast<int>(
public: param.x->dims().Slice(0, param.x_num_col_dims).production());
using param_t = operators::MulParam; int x_w =
static_cast<int>(param.x->dims()
.Slice(param.x_num_col_dims, param.x->dims().size())
.production());
int y_h = static_cast<int>(
param.y->dims().Slice(0, param.y_num_col_dims).production());
int n =
static_cast<int>(param.y->dims()
.Slice(param.y_num_col_dims, param.y->dims().size())
.production());
void Run() override { CHECK_EQ(x_w, y_h) << "x_w must be equal with y_h";
auto& param = Param<operators::MulParam>(); auto k = x_w;
core::dim2 x_shape( if (n == 1) {
{static_cast<int>( lite::arm::math::sgemv(x_data, y_data, o_data, false, m, k, false, nullptr,
param.x->dims().Slice(0, param.x_num_col_dims).production()), false);
static_cast<int>(
param.x->dims()
.Slice(param.x_num_col_dims, param.x->dims().size())
.production())});
core::dim2 y_shape(
{static_cast<int>(
param.y->dims().Slice(0, param.y_num_col_dims).production()),
static_cast<int>(
param.y->dims()
.Slice(param.y_num_col_dims, param.y->dims().size())
.production())});
mul_compute_eigen(param.x->data<float>(), x_shape.x, x_shape.y, // } else {
param.y->data<float>(), y_shape.x, y_shape.y, // constexpr bool is_tranposed_y = false;
param.output->mutable_data<float>()); auto& ctx = this->ctx_->template As<ARMContext>();
}
virtual ~MulCompute() = default; float* packed_x = static_cast<float*>(ctx.workspace_data<float>()) +
}; ctx.l2_cache_size() / sizeof(float);
lite::arm::math::prepackA(packed_x, x_data, k, 0, m, 0, k, false, &ctx);
lite::arm::math::sgemm_prepack(packed_x, y_data, nullptr, o_data, m, n, k,
false, false, is_tranposed_y, &ctx);
}
}
} // namespace arm } // namespace arm
} // namespace kernels } // namespace kernels
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/types.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
class MulCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
using param_t = operators::MulParam;
void PrepareForRun() override;
void Run() override;
virtual ~MulCompute() = default;
};
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/kernels/arm/mul_compute.h"
#include <gtest/gtest.h>
#include <algorithm>
#include <iostream>
#include <memory>
#include <random>
#include <utility>
#include <vector>
#include "paddle/fluid/lite/arm/math/funcs.h"
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
template <typename T>
void FillData(T* a, const int n, const T lower = static_cast<T>(-2.f),
const T upper = static_cast<T>(2.f)) {
static unsigned int seed = 100;
std::mt19937 rng(seed++);
std::uniform_real_distribution<double> uniform_dist(0, 1);
for (int i = 0; i < n; ++i) {
a[i] = static_cast<T>(uniform_dist(rng) * (upper - lower) + lower);
}
}
TEST(mul_arm, retrive_op) {
auto mul =
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>("mul");
ASSERT_FALSE(mul.empty());
ASSERT_TRUE(mul.front());
}
TEST(mul_arm, init) {
MulCompute mul;
ASSERT_EQ(mul.precision(), PRECISION(kFloat));
ASSERT_EQ(mul.target(), TARGET(kARM));
}
TEST(mul_arm, compare_test) {
using T = float;
for (int m : {1, 2, 3, 4}) {
for (int n : {1, 2, 3, 4}) {
for (int k : {1, 2, 3, 4}) {
VLOG(3) << "m: " << m << ", n: " << n << ", k: " << k;
lite::Tensor x, y, out, ref;
x.Resize({m, k});
y.Resize({k, n});
out.Resize({m, n});
ref.Resize({m, n});
auto* x_data = x.mutable_data<T>();
auto* y_data = y.mutable_data<T>();
auto* out_data = out.mutable_data<T>();
auto* ref_data = ref.mutable_data<T>();
FillData<T>(x_data, x.dims().production());
FillData<T>(y_data, y.dims().production());
FillData<T>(out_data, out.dims().production(), 0, 0);
FillData<T>(ref_data, ref.dims().production(), 0, 0);
MulCompute mul;
operators::MulParam param;
param.x = &x;
param.y = &y;
param.output = &out;
DeviceInfo::Init();
std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<ARMContext>();
mul.SetParam(param);
mul.SetContext(std::move(ctx));
mul.PrepareForRun();
mul.Run();
lite::arm::math::mul_compute_eigen(x_data, m, k, y_data, k, n,
ref_data);
for (int i = 0; i < out.dims().production(); i++) {
EXPECT_NEAR(out_data[i], ref_data[i], 1e-3);
}
}
}
}
}
TEST(mul_arm, num_col_dims) {
using T = float;
lite::Tensor x, y, out, ref;
x.Resize({2, 3, 4});
y.Resize({3, 4, 5});
out.Resize({2, 5});
ref.Resize({2, 5});
auto* x_data = x.mutable_data<T>();
auto* y_data = y.mutable_data<T>();
auto* out_data = out.mutable_data<T>();
auto* ref_data = ref.mutable_data<T>();
FillData<T>(x_data, x.dims().production());
FillData<T>(y_data, y.dims().production());
FillData<T>(out_data, out.dims().production());
FillData<T>(ref_data, out.dims().production());
MulCompute mul;
operators::MulParam param;
param.x = &x;
param.y = &y;
param.output = &out;
param.x_num_col_dims = 1;
param.y_num_col_dims = 2;
DeviceInfo::Init();
std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<ARMContext>();
mul.SetParam(param);
mul.SetContext(std::move(ctx));
mul.PrepareForRun();
mul.Run();
lite::arm::math::mul_compute_eigen(x_data, 2, 12, y_data, 12, 5, ref_data);
for (int i = 0; i < out.dims().production(); i++) {
EXPECT_NEAR(out_data[i], ref_data[i], 1e-3);
}
}
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(mul, kARM, kFloat, kNCHW, def);
...@@ -182,7 +182,7 @@ TEST(pool_arm, compute) { ...@@ -182,7 +182,7 @@ TEST(pool_arm, compute) {
for (auto stride : {2}) { for (auto stride : {2}) {
for (auto pad : {0}) { for (auto pad : {0}) {
for (auto n : {1, 3, 4, 11}) { for (auto n : {1, 3, 4, 11}) {
for (auto c : {1, 3, 11, 4, 1024}) { for (auto c : {1, 3, 11 /* ,1024 */}) { // speedup for ci
for (auto h : {3, 1, 11, 4, 1}) { for (auto h : {3, 1, 11, 4, 1}) {
for (auto w : {1, 3, 4, 12, 1}) { for (auto w : {1, 3, 4, 12, 1}) {
VLOG(3) << "n:" << n << " c:" << c << " h:" << h << " w:" << w VLOG(3) << "n:" << n << " c:" << c << " h:" << h << " w:" << w
......
...@@ -54,6 +54,15 @@ TEST(scale_arm, compute) { ...@@ -54,6 +54,15 @@ TEST(scale_arm, compute) {
lite::Tensor output; lite::Tensor output;
lite::Tensor output_ref; lite::Tensor output_ref;
#if 1 // for ci speedup
for (auto n : {1, 3}) {
for (auto c : {1, 3}) {
for (auto h : {3, 4}) {
for (auto w : {4, 3}) {
for (auto bias_after_scale : {true, false}) {
for (auto s : {-1.0f, 0.13f}) {
for (auto b : {-15.f, 0.11234f}) {
#else
for (auto n : {1, 3, 4, 11}) { for (auto n : {1, 3, 4, 11}) {
for (auto c : {1, 3, 11, 4}) { for (auto c : {1, 3, 11, 4}) {
for (auto h : {3, 1, 11, 4}) { for (auto h : {3, 1, 11, 4}) {
...@@ -61,6 +70,8 @@ TEST(scale_arm, compute) { ...@@ -61,6 +70,8 @@ TEST(scale_arm, compute) {
for (auto bias_after_scale : {true, false}) { for (auto bias_after_scale : {true, false}) {
for (auto s : {-100.25f, -1.0f, 0.13f, 3840.975f}) { for (auto s : {-100.25f, -1.0f, 0.13f, 3840.975f}) {
for (auto b : {-3075.495f, -15.f, 0.11234f, 128.15f}) { for (auto b : {-3075.495f, -15.f, 0.11234f, 128.15f}) {
#endif
x.Resize(DDim(std::vector<int64_t>({n, c, h, w}))); x.Resize(DDim(std::vector<int64_t>({n, c, h, w})));
output.Resize(DDim(std::vector<int64_t>({n, c, h, w}))); output.Resize(DDim(std::vector<int64_t>({n, c, h, w})));
output_ref.Resize(DDim(std::vector<int64_t>({n, c, h, w}))); output_ref.Resize(DDim(std::vector<int64_t>({n, c, h, w})));
......
...@@ -24,7 +24,7 @@ namespace arm { ...@@ -24,7 +24,7 @@ namespace arm {
void SplitCompute::Run() { void SplitCompute::Run() {
auto& param = Param<operators::SplitParam>(); auto& param = Param<operators::SplitParam>();
const float* din = param.x->data<float>(); const float* din = param.x->data<float>();
auto* dout = param.output; auto& dout = param.output;
auto in_dim = param.x->dims(); auto in_dim = param.x->dims();
std::vector<int> in_strides(in_dim.size()); std::vector<int> in_strides(in_dim.size());
in_strides[in_dim.size() - 1] = in_dim[in_dim.size() - 1]; in_strides[in_dim.size() - 1] = in_dim[in_dim.size() - 1];
......
...@@ -24,20 +24,10 @@ namespace kernels { ...@@ -24,20 +24,10 @@ namespace kernels {
namespace arm { namespace arm {
void splite_resize_out(const lite::Tensor* din, void splite_resize_out(const lite::Tensor* din,
std::vector<lite::Tensor*>* dout, int axis, int num, const std::vector<lite::Tensor*>& dout, int axis,
const std::vector<int>& sections) { int num, const std::vector<int>& sections) {
for (auto out : *dout) delete out;
dout->clear();
auto in_dims = din->dims(); auto in_dims = din->dims();
int outs_number; int outs_number = dout.size();
if (num > 0) {
outs_number = num;
} else {
outs_number = sections.size();
}
for (int i = 0; i < outs_number; i++) {
dout->push_back(new lite::Tensor);
}
std::vector<lite::DDimLite> outs_dims; std::vector<lite::DDimLite> outs_dims;
outs_dims.reserve(outs_number); outs_dims.reserve(outs_number);
...@@ -58,7 +48,7 @@ void splite_resize_out(const lite::Tensor* din, ...@@ -58,7 +48,7 @@ void splite_resize_out(const lite::Tensor* din,
} }
for (int j = 0; j < outs_dims.size(); ++j) { for (int j = 0; j < outs_dims.size(); ++j) {
(*dout)[j]->Resize(outs_dims[j]); dout[j]->Resize(outs_dims[j]);
} }
} }
...@@ -75,7 +65,7 @@ void split_compute_ref(const operators::SplitParam& param) { ...@@ -75,7 +65,7 @@ void split_compute_ref(const operators::SplitParam& param) {
} }
int input_offset = 0; int input_offset = 0;
for (auto out : *dout) { for (auto out : dout) {
auto out_dim = out->dims(); auto out_dim = out->dims();
std::vector<int> out_strides(out_dim.size()); std::vector<int> out_strides(out_dim.size());
out_strides[out_dim.size() - 1] = out_dim[out_dim.size() - 1]; out_strides[out_dim.size() - 1] = out_dim[out_dim.size() - 1];
...@@ -128,16 +118,31 @@ TEST(split_arm, compute) { ...@@ -128,16 +118,31 @@ TEST(split_arm, compute) {
for (int i = 0; i < x.dims().production(); i++) { for (int i = 0; i < x.dims().production(); i++) {
x_data[i] = i; x_data[i] = i;
} }
splite_resize_out(&x, &output, axis, num, sections); for (auto out : output) delete out;
splite_resize_out(&x, &output_ref, axis, num, sections); for (auto out : output_ref) delete out;
output.clear();
output_ref.clear();
int outs_number;
if (num > 0) {
outs_number = num;
} else {
outs_number = sections.size();
}
for (int i = 0; i < outs_number; i++) {
output.push_back(new lite::Tensor);
output_ref.push_back(new lite::Tensor);
}
splite_resize_out(&x, output, axis, num, sections);
splite_resize_out(&x, output_ref, axis, num, sections);
param.x = &x; param.x = &x;
param.axis = axis; param.axis = axis;
param.num = num; param.num = num;
param.sections = &sections; param.sections = sections;
param.output = &output; param.output = output;
split.SetParam(param); split.SetParam(param);
split.Run(); split.Run();
param.output = &output_ref; param.output = output_ref;
split_compute_ref<float>(param); split_compute_ref<float>(param);
for (int i = 0; i < output.size(); i++) { for (int i = 0; i < output.size(); i++) {
float* output_data = output[i]->mutable_data<float>(); float* output_data = output[i]->mutable_data<float>();
......
...@@ -8,6 +8,7 @@ cc_library(mul_op_lite SRCS mul_op.cc DEPS ${op_DEPS}) ...@@ -8,6 +8,7 @@ cc_library(mul_op_lite SRCS mul_op.cc DEPS ${op_DEPS})
cc_library(scale_op_lite SRCS scale_op.cc DEPS ${op_DEPS}) cc_library(scale_op_lite SRCS scale_op.cc DEPS ${op_DEPS})
cc_library(softmax_op_lite SRCS softmax_op.cc DEPS ${op_DEPS}) cc_library(softmax_op_lite SRCS softmax_op.cc DEPS ${op_DEPS})
cc_library(reshape_op_lite SRCS reshape_op.cc DEPS ${op_DEPS} ) cc_library(reshape_op_lite SRCS reshape_op.cc DEPS ${op_DEPS} )
cc_library(batch_norm_op_lite SRCS batch_norm_op.cc DEPS ${op_DEPS})
cc_library(feed_op_lite SRCS feed_op.cc DEPS ${op_DEPS}) cc_library(feed_op_lite SRCS feed_op.cc DEPS ${op_DEPS})
cc_library(fetch_op_lite SRCS fetch_op.cc DEPS ${op_DEPS}) cc_library(fetch_op_lite SRCS fetch_op.cc DEPS ${op_DEPS})
cc_library(io_copy_op_lite SRCS io_copy_op.cc DEPS ${op_DEPS}) cc_library(io_copy_op_lite SRCS io_copy_op.cc DEPS ${op_DEPS})
...@@ -30,6 +31,7 @@ set(ops_lite ...@@ -30,6 +31,7 @@ set(ops_lite
scale_op_lite scale_op_lite
softmax_op_lite softmax_op_lite
reshape_op_lite reshape_op_lite
batch_norm_op_lite
feed_op_lite feed_op_lite
fetch_op_lite fetch_op_lite
io_copy_op_lite io_copy_op_lite
...@@ -52,4 +54,5 @@ lite_cc_test(test_pool_op_lite SRCS pool_op_test.cc ...@@ -52,4 +54,5 @@ lite_cc_test(test_pool_op_lite SRCS pool_op_test.cc
lite_cc_test(test_scale_op_lite SRCS scale_op_test.cc DEPS scale_op_lite memory_lite) lite_cc_test(test_scale_op_lite SRCS scale_op_test.cc DEPS scale_op_lite memory_lite)
lite_cc_test(test_softmax_op_lite SRCS softmax_op_test.cc DEPS softmax_op_lite memory_lite) lite_cc_test(test_softmax_op_lite SRCS softmax_op_test.cc DEPS softmax_op_lite memory_lite)
lite_cc_test(test_reshape_op_lite SRCS reshape_op_test.cc DEPS reshape_op_lite memory_lite) lite_cc_test(test_reshape_op_lite SRCS reshape_op_test.cc DEPS reshape_op_lite memory_lite)
lite_cc_test(test_batch_norm_op_lite SRCS batch_norm_op_test.cc DEPS batch_norm_op_lite memory_lite)
lite_cc_test(test_concat_op_lite SRCS concat_op_test.cc DEPS concat_op_lite memory_lite) lite_cc_test(test_concat_op_lite SRCS concat_op_test.cc DEPS concat_op_lite memory_lite)
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/operators/batch_norm_op.h"
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
bool BatchNormOp::CheckShape() const {
CHECK_OR_FALSE(param_.x);
CHECK_OR_FALSE(param_.bias);
CHECK_OR_FALSE(param_.scale);
CHECK_OR_FALSE(param_.mean);
CHECK_OR_FALSE(param_.variance);
CHECK_OR_FALSE(param_.y);
if (!param_.is_test) {
CHECK_OR_FALSE(param_.mean_out);
CHECK_OR_FALSE(param_.variance_out);
CHECK_OR_FALSE(param_.saved_mean);
CHECK_OR_FALSE(param_.saved_variance);
}
auto x_dims = param_.x->dims();
auto scale_dims = param_.scale->dims();
auto bias_dims = param_.bias->dims();
auto mean_dims = param_.mean->dims();
auto variance_dims = param_.variance->dims();
CHECK(x_dims.size() >= 2 && x_dims.size() <= 5)
<< "Input X must have 2 to 5 dimensions.";
CHECK_EQ(scale_dims.size(), 1UL) << "Input Scale must have 1 dimensions.";
CHECK_EQ(bias_dims.size(), 1UL) << "Input Bias must have 1 dimensions.";
CHECK_EQ(mean_dims.size(), 1UL) << "Input Mean must have 1 dimensions.";
CHECK_EQ(variance_dims.size(), 1UL)
<< "Input Variance must have 1 dimensions.";
return true;
}
bool BatchNormOp::InferShape() const {
auto x_dims = param_.x->dims();
int64_t channel_size = 0;
switch (param_.data_layout) {
case DATALAYOUT(kNCHW):
channel_size = x_dims[1];
break;
// case DATALAYOUT(kNHWC):
// channel_size = x_dims[x_dims.size() - 1];
// break;
default:
LOG(FATAL) << "Unknown storage order: "
<< DataLayoutToStr(param_.data_layout);
break;
}
if (!param_.is_test) {
param_.mean_out->Resize({channel_size});
param_.variance_out->Resize({channel_size});
param_.saved_mean->Resize({channel_size});
param_.saved_variance->Resize({channel_size});
}
param_.y->Resize(x_dims);
return true;
}
bool BatchNormOp::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) {
param_.x = scope->FindVar(op_desc.Input("X").front())->GetMutable<Tensor>();
param_.bias =
scope->FindVar(op_desc.Input("Bias").front())->GetMutable<Tensor>();
param_.scale =
scope->FindVar(op_desc.Input("Scale").front())->GetMutable<Tensor>();
param_.mean =
scope->FindVar(op_desc.Input("Mean").front())->GetMutable<Tensor>();
param_.variance =
scope->FindVar(op_desc.Input("Variance").front())->GetMutable<Tensor>();
param_.y = scope->FindVar(op_desc.Output("Y").front())->GetMutable<Tensor>();
param_.is_test = op_desc.GetAttr<bool>("is_test");
param_.use_global_stats = op_desc.GetAttr<bool>("use_global_stats");
if (!param_.is_test) {
param_.mean_out =
scope->FindVar(op_desc.Output("MeanOut").front())->GetMutable<Tensor>();
param_.variance_out = scope->FindVar(op_desc.Output("VarianceOut").front())
->GetMutable<Tensor>();
param_.saved_mean = scope->FindVar(op_desc.Output("SavedMean").front())
->GetMutable<Tensor>();
param_.saved_variance =
scope->FindVar(op_desc.Output("SavedVariance").front())
->GetMutable<Tensor>();
}
param_.epsilon = op_desc.GetAttr<float>("epsilon");
param_.momentum = op_desc.GetAttr<float>("momentum");
std::string data_layout = op_desc.GetAttr<std::string>("data_layout");
CHECK_EQ(data_layout, "NCHW") << "TODO(hong19860320): Only support NCHW.";
// param_.data_layout = StringToDataLayout(data_layout);
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(batch_norm, paddle::lite::operators::BatchNormOp);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/lite/core/op_lite.h"
#include "paddle/fluid/lite/core/scope.h"
#include "paddle/fluid/lite/utils/all.h"
namespace paddle {
namespace lite {
namespace operators {
class BatchNormOp : public OpLite {
public:
BatchNormOp() {}
explicit BatchNormOp(const std::string &op_type) : OpLite(op_type) {}
bool CheckShape() const override;
bool InferShape() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "batch_norm"; }
private:
mutable BatchNormParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/operators/batch_norm_op.h"
#include <gtest/gtest.h>
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
TEST(batch_norm_op_lite, test) {
// prepare variables
Scope scope;
auto* x = scope.Var("x")->GetMutable<Tensor>();
auto* scale = scope.Var("scale")->GetMutable<Tensor>();
auto* bias = scope.Var("bias")->GetMutable<Tensor>();
auto* mean = scope.Var("mean")->GetMutable<Tensor>();
auto* variance = scope.Var("variance")->GetMutable<Tensor>();
auto* y = scope.Var("y")->GetMutable<Tensor>();
x->Resize({2, 32, 10, 20});
auto x_dims = x->dims();
const int64_t channel_size = x_dims[1]; // NCHW
scale->Resize({channel_size});
bias->Resize({channel_size});
mean->Resize({channel_size});
variance->Resize(DDim({channel_size}));
// prepare op desc
cpp::OpDesc desc;
desc.SetType("batch_norm");
desc.SetInput("X", {"x"});
desc.SetInput("Scale", {"scale"});
desc.SetInput("Bias", {"bias"});
desc.SetInput("Mean", {"mean"});
desc.SetInput("Variance", {"variance"});
desc.SetOutput("Y", {"y"});
desc.SetAttr("is_test", true);
desc.SetAttr("use_global_stats", false);
desc.SetAttr("epsilon", 1e-5f);
desc.SetAttr("momentum", 0.9f);
desc.SetAttr("data_layout", std::string("NCHW"));
BatchNormOp batch_norm("batch_norm");
batch_norm.SetValidPlaces({Place{TARGET(kHost), PRECISION(kFloat)}});
batch_norm.Attach(desc, &scope);
batch_norm.CheckShape();
batch_norm.InferShape();
// check output dims
auto y_dims = y->dims();
CHECK_EQ(y_dims.size(), x_dims.size());
for (size_t i = 0; i < y_dims.size(); i++) {
CHECK_EQ(y_dims[i], x_dims[i]);
}
}
TEST(batch_norm_op_lite, test_enable_is_test) {
// prepare variables
Scope scope;
auto* x = scope.Var("x")->GetMutable<Tensor>();
auto* scale = scope.Var("scale")->GetMutable<Tensor>();
auto* bias = scope.Var("bias")->GetMutable<Tensor>();
auto* mean = scope.Var("mean")->GetMutable<Tensor>();
auto* variance = scope.Var("variance")->GetMutable<Tensor>();
auto* y = scope.Var("y")->GetMutable<Tensor>();
auto* mean_out = scope.Var("mean_out")->GetMutable<Tensor>();
auto* variance_out = scope.Var("variance_out")->GetMutable<Tensor>();
auto* saved_mean = scope.Var("saved_mean")->GetMutable<Tensor>();
auto* saved_variance = scope.Var("saved_variance")->GetMutable<Tensor>();
x->Resize({2, 32, 10, 20});
auto x_dims = x->dims();
const int64_t channel_size = x_dims[1]; // NCHW
scale->Resize({channel_size});
bias->Resize({channel_size});
mean->Resize({channel_size});
variance->Resize({channel_size});
// prepare op desc
cpp::OpDesc desc;
desc.SetType("batch_norm");
desc.SetInput("X", {"x"});
desc.SetInput("Scale", {"scale"});
desc.SetInput("Bias", {"bias"});
desc.SetInput("Mean", {"mean"});
desc.SetInput("Variance", {"variance"});
desc.SetOutput("Y", {"y"});
desc.SetOutput("MeanOut", {"mean_out"});
desc.SetOutput("VarianceOut", {"variance_out"});
desc.SetOutput("SavedMean", {"saved_mean"});
desc.SetOutput("SavedVariance", {"saved_variance"});
desc.SetAttr("is_test", false);
desc.SetAttr("use_global_stats", false);
desc.SetAttr("epsilon", 1e-5f);
desc.SetAttr("momentum", 0.9f);
desc.SetAttr("data_layout", std::string("NCHW"));
BatchNormOp batch_norm("batch_norm");
batch_norm.SetValidPlaces({Place{TARGET(kHost), PRECISION(kFloat)}});
batch_norm.Attach(desc, &scope);
batch_norm.CheckShape();
batch_norm.InferShape();
// check output dims
auto y_dims = y->dims();
CHECK_EQ(y_dims.size(), x_dims.size());
for (size_t i = 0; i < y_dims.size(); i++) {
CHECK_EQ(y_dims[i], x_dims[i]);
}
auto mean_out_dims = mean_out->dims();
auto variance_out_dims = variance_out->dims();
auto saved_mean_dims = saved_mean->dims();
auto saved_variance_dims = saved_variance->dims();
CHECK_EQ(mean_out_dims.size(), 1UL);
CHECK_EQ(variance_out_dims.size(), 1UL);
CHECK_EQ(saved_mean_dims.size(), 1UL);
CHECK_EQ(saved_variance_dims.size(), 1UL);
CHECK_EQ(mean_out_dims[0], channel_size);
CHECK_EQ(variance_out_dims[0], channel_size);
CHECK_EQ(saved_mean_dims[0], channel_size);
CHECK_EQ(saved_variance_dims[0], channel_size);
}
} // namespace operators
} // namespace lite
} // namespace paddle
...@@ -57,6 +57,7 @@ struct FcParam { ...@@ -57,6 +57,7 @@ struct FcParam {
lite::Tensor* output{}; lite::Tensor* output{};
lite::DDim in_mat_dims; lite::DDim in_mat_dims;
int in_num_col_dims{1}; int in_num_col_dims{1};
bool weight_transposed{false};
}; };
struct ReluParam { struct ReluParam {
...@@ -145,6 +146,25 @@ struct ConvParam { ...@@ -145,6 +146,25 @@ struct ConvParam {
std::string data_format{"Anylayout"}; std::string data_format{"Anylayout"};
}; };
// For BatchNorm op
struct BatchNormParam {
lite::Tensor* x{};
lite::Tensor* bias{};
lite::Tensor* scale{};
lite::Tensor* mean{};
lite::Tensor* variance{};
lite::Tensor* y{};
lite::Tensor* mean_out{};
lite::Tensor* variance_out{};
lite::Tensor* saved_mean{};
lite::Tensor* saved_variance{};
bool is_test{true};
bool use_global_stats{false};
float epsilon;
float momentum;
DataLayoutType data_layout{DATALAYOUT(kNCHW)};
};
// For Pooling op // For Pooling op
struct PoolParam { struct PoolParam {
lite::Tensor* x{}; lite::Tensor* x{};
...@@ -177,10 +197,10 @@ struct DropoutParam { ...@@ -177,10 +197,10 @@ struct DropoutParam {
// For Split op // For Split op
struct SplitParam { struct SplitParam {
lite::Tensor* x{}; lite::Tensor* x{};
std::vector<lite::Tensor*>* output{}; std::vector<lite::Tensor*> output{};
int axis{-1}; int axis{-1};
int num{0}; int num{0};
std::vector<int>* sections; std::vector<int> sections;
}; };
/// ----------------------- element wise operators ---------------------- /// ----------------------- element wise operators ----------------------
......
...@@ -21,7 +21,7 @@ namespace operators { ...@@ -21,7 +21,7 @@ namespace operators {
bool SplitOp::CheckShape() const { bool SplitOp::CheckShape() const {
CHECK_OR_FALSE(param_.x); CHECK_OR_FALSE(param_.x);
CHECK_OR_FALSE(param_.output); CHECK_GT_OR_FALSE(param_.output.size(), 1UL);
auto x_dims = param_.x->dims(); auto x_dims = param_.x->dims();
auto x_rank = x_dims.size(); auto x_rank = x_dims.size();
CHECK_OR_FALSE(param_.axis >= -static_cast<int>(x_rank) && CHECK_OR_FALSE(param_.axis >= -static_cast<int>(x_rank) &&
...@@ -31,7 +31,7 @@ bool SplitOp::CheckShape() const { ...@@ -31,7 +31,7 @@ bool SplitOp::CheckShape() const {
bool SplitOp::InferShape() const { bool SplitOp::InferShape() const {
const auto &outs = param_.output; const auto &outs = param_.output;
auto in_dims = param_.x.dims(); auto in_dims = param_.x->dims();
int axis = param_.axis; int axis = param_.axis;
int num = param_.num; int num = param_.num;
const auto &sections = param_.sections; const auto &sections = param_.sections;
...@@ -68,7 +68,7 @@ bool SplitOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { ...@@ -68,7 +68,7 @@ bool SplitOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
param_.sections = opdesc.GetAttr<std::vector<int>>("sections"); param_.sections = opdesc.GetAttr<std::vector<int>>("sections");
param_.x = const_cast<lite::Tensor *>( param_.x = const_cast<lite::Tensor *>(
&scope->FindVar(opdesc.Input("X").front())->Get<lite::Tensor>()); &scope->FindVar(opdesc.Input("X").front())->Get<lite::Tensor>());
auto outs = op_desc.Output("Out"); auto outs = opdesc.Output("Out");
for (auto var : outs) { for (auto var : outs) {
param_.output.push_back(scope->FindVar(var)->GetMutable<lite::Tensor>()); param_.output.push_back(scope->FindVar(var)->GetMutable<lite::Tensor>());
} }
...@@ -79,4 +79,4 @@ bool SplitOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { ...@@ -79,4 +79,4 @@ bool SplitOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_LITE_OP(softmax, paddle::lite::operators::SoftmaxOp); REGISTER_LITE_OP(split, paddle::lite::operators::SplitOp);
...@@ -23,7 +23,7 @@ namespace paddle { ...@@ -23,7 +23,7 @@ namespace paddle {
namespace lite { namespace lite {
namespace operators { namespace operators {
class SoftmaxOp : public OpLite { class SplitOp : public OpLite {
public: public:
SplitOp() {} SplitOp() {}
explicit SplitOp(const std::string &op_type) : OpLite(op_type) {} explicit SplitOp(const std::string &op_type) : OpLite(op_type) {}
......
...@@ -59,11 +59,15 @@ function cmake_arm { ...@@ -59,11 +59,15 @@ function cmake_arm {
-DARM_TARGET_OS=$1 -DARM_TARGET_ARCH_ABI=$2 -DARM_TARGET_OS=$1 -DARM_TARGET_ARCH_ABI=$2
} }
function build_single {
#make $1 -j$(expr $(nproc) - 2)
make $1 -j8
}
function build { function build {
file=$1 file=$1
for _test in $(cat $file); do for _test in $(cat $file); do
#make $_test -j$(expr $(nproc) - 2) build_single $_test
make $_test -j8
done done
} }
...@@ -81,39 +85,6 @@ function test_lite { ...@@ -81,39 +85,6 @@ function test_lite {
done done
} }
port_armv8=5554
port_armv7=5556
# Run test on android
function test_lite_android {
local file=$1
local adb_abi=$2
local port=
if [[ ${adb_abi} == "armeabi-v7a" ]]; then
port=${port_armv7}
fi
if [[ ${adb_abi} == "arm64-v8a" ]]; then
port=${port_armv8}
fi
if [[ "${port}x" == "x" ]]; then
echo "Port can not be empty"
exit 1
fi
echo "file: ${file}"
# push all to adb and test
adb_work_dir="/data/local/tmp"
skip_list="test_model_parser_lite"
for _test in $(cat $file); do
[[ $skip_list =~ (^|[[:space:]])$_test($|[[:space:]]) ]] && continue || echo 'skip $_test'
testpath=$(find ./paddle/fluid -name ${_test})
adb -s emulator-${port} push ${testpath} ${adb_work_dir}
adb -s emulator-${port} shell chmod +x "${adb_work_dir}/${_test}"
adb -s emulator-${port} shell "./${adb_work_dir}/${_test}"
done
}
# Build the code and run lite server tests. This is executed in the CI system. # Build the code and run lite server tests. This is executed in the CI system.
function build_test_server { function build_test_server {
mkdir -p ./build mkdir -p ./build
...@@ -126,8 +97,34 @@ function build_test_server { ...@@ -126,8 +97,34 @@ function build_test_server {
build $LIBS_FILE build $LIBS_FILE
} }
# Build the code and run lite server tests. This is executed in the CI system. # test_arm_android <some_test_name> <adb_port_number>
function test_arm_android {
test_name=$1
port=$2
if [[ "${test_name}x" == "x" ]]; then
echo "test_name can not be empty"
exit 1
fi
if [[ "${port}x" == "x" ]]; then
echo "Port can not be empty"
exit 1
fi
echo "test name: ${test_name}"
adb_work_dir="/data/local/tmp"
skip_list="test_model_parser_lite" # add more with space
[[ $skip_list =~ (^|[[:space:]])$test_name($|[[:space:]]) ]] && continue || echo 'skip $test_name'
testpath=$(find ./paddle/fluid -name ${test_name})
adb -s emulator-${port} push ${testpath} ${adb_work_dir}
adb -s emulator-${port} shell chmod +x "${adb_work_dir}/${test_name}"
adb -s emulator-${port} shell "./${adb_work_dir}/${test_name}"
}
# Build the code and run lite arm tests. This is executed in the CI system.
function build_test_arm { function build_test_arm {
port_armv8=5554
port_armv7=5556
adb kill-server adb kill-server
adb devices | grep emulator | cut -f1 | while read line; do adb -s $line emu kill; done adb devices | grep emulator | cut -f1 | while read line; do adb -s $line emu kill; done
# start android arm64-v8a armeabi-v7a emulators first # start android arm64-v8a armeabi-v7a emulators first
...@@ -140,6 +137,7 @@ function build_test_arm { ...@@ -140,6 +137,7 @@ function build_test_arm {
for os in "android" "armlinux" ; do for os in "android" "armlinux" ; do
for abi in "arm64-v8a" "armeabi-v7a" "armeabi-v7a-hf" ; do for abi in "arm64-v8a" "armeabi-v7a" "armeabi-v7a-hf" ; do
# TODO(TJ): enable compile on v7-hf on andorid and all v7 on armlinux
if [[ ${abi} == "armeabi-v7a-hf" ]]; then if [[ ${abi} == "armeabi-v7a-hf" ]]; then
echo "armeabi-v7a-hf is not supported on both android and armlinux" echo "armeabi-v7a-hf is not supported on both android and armlinux"
continue continue
...@@ -156,17 +154,30 @@ function build_test_arm { ...@@ -156,17 +154,30 @@ function build_test_arm {
cmake_arm ${os} ${abi} cmake_arm ${os} ${abi}
build $TESTS_FILE build $TESTS_FILE
# armlinux need in another docker
# TODO(TJ): enable test with armlinux
if [[ ${os} == "android" ]]; then if [[ ${os} == "android" ]]; then
adb_abi=${abi} adb_abi=${abi}
if [[ ${adb_abi} == "armeabi-v7a-hf" ]]; then if [[ ${adb_abi} == "armeabi-v7a-hf" ]]; then
adb_abi="armeabi-v7a" adb_abi="armeabi-v7a"
fi fi
if [[ ${adb_abi} == "armeabi-v7a" ]]; then if [[ ${adb_abi} == "armeabi-v7a" ]]; then
# skip v7 tests # skip all armv7 tests
# TODO(TJ): enable test with armv7
continue continue
fi fi
test_lite_android $TESTS_FILE ${adb_abi} local port=
# armlinux need in another docker if [[ ${adb_abi} == "armeabi-v7a" ]]; then
port=${port_armv7}
fi
if [[ ${adb_abi} == "arm64-v8a" ]]; then
port=${port_armv8}
fi
echo "test file: ${TESTS_FILE}"
for _test in $(cat $TESTS_FILE); do
test_arm_android $_test $port
done
fi fi
cd - cd -
done done
...@@ -182,12 +193,13 @@ function print_usage { ...@@ -182,12 +193,13 @@ function print_usage {
echo "----------------------------------------" echo "----------------------------------------"
echo -e "cmake_x86: run cmake with X86 mode" echo -e "cmake_x86: run cmake with X86 mode"
echo -e "cmake_cuda: run cmake with CUDA mode" echo -e "cmake_cuda: run cmake with CUDA mode"
echo -e "cmake_arm: run cmake with ARM mode" echo -e "--arm_os=<os> --arm_abi=<abi> cmake_arm: run cmake with ARM mode"
echo echo
echo -e "build: compile the tests" echo -e "build: compile the tests"
echo -e "--test_name=<test_name> build_single: compile single test"
echo echo
echo -e "test_server: run server tests" echo -e "test_server: run server tests"
echo -e "test_mobile: run mobile tests" echo -e "--test_name=<test_name> --adb_port_number=<adb_port_number> test_arm_android: run arm test"
echo "----------------------------------------" echo "----------------------------------------"
echo echo
} }
...@@ -200,11 +212,31 @@ function main { ...@@ -200,11 +212,31 @@ function main {
TESTS_FILE="${i#*=}" TESTS_FILE="${i#*=}"
shift shift
;; ;;
--test_name=*)
TEST_NAME="${i#*=}"
shift
;;
--arm_os=*)
ARM_OS="${i#*=}"
shift
;;
--arm_abi=*)
ARM_ABI="${i#*=}"
shift
;;
--arm_port=*)
ARM_PORT="${i#*=}"
shift
;;
build) build)
build $TESTS_FILE build $TESTS_FILE
build $LIBS_FILE build $LIBS_FILE
shift shift
;; ;;
build_single)
build_single $TEST_NAME
shift
;;
cmake_x86) cmake_x86)
cmake_x86 cmake_x86
shift shift
...@@ -214,15 +246,15 @@ function main { ...@@ -214,15 +246,15 @@ function main {
shift shift
;; ;;
cmake_arm) cmake_arm)
cmake_arm $2 $3 cmake_arm $ARM_OS $ARM_ABI
shift shift
;; ;;
test_server) test_server)
test_lite $TESTS_FILE test_lite $TESTS_FILE
shift shift
;; ;;
test_mobile) test_arm_android)
test_lite $TESTS_FILE test_arm_android $TEST_NAME $ARM_PORT
shift shift
;; ;;
build_test_server) build_test_server)
...@@ -250,6 +282,4 @@ function main { ...@@ -250,6 +282,4 @@ function main {
done done
} }
print_usage
main $@ main $@
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册