diff --git a/src/operators/kernel/central-arm-func/conv_bn_relu_arm_func.h b/src/operators/kernel/central-arm-func/conv_bn_relu_arm_func.h index 6e8aec99e5f595381efa98e7fb04501c13ddf4de..7eeb7f76670aa5c5a39544484ac92e611ff9066a 100644 --- a/src/operators/kernel/central-arm-func/conv_bn_relu_arm_func.h +++ b/src/operators/kernel/central-arm-func/conv_bn_relu_arm_func.h @@ -32,6 +32,7 @@ void ConvBNReluBasic(const FusionConvBNReluParam ¶m) { Tensor new_scale = *param.NewScale(); Tensor *output = param.Output(); + output->mutable_data(); int groups = param.Groups(); std::vector strides = param.Strides(); diff --git a/src/operators/kernel/central-arm-func/dwconv_bn_relu_arm_func.h b/src/operators/kernel/central-arm-func/dwconv_bn_relu_arm_func.h index cef297daad3c83253105ccf2c44d195e01d074ae..e0299d00ae09de62c133676449f0148a49beae5e 100644 --- a/src/operators/kernel/central-arm-func/dwconv_bn_relu_arm_func.h +++ b/src/operators/kernel/central-arm-func/dwconv_bn_relu_arm_func.h @@ -32,6 +32,7 @@ void DWConvBNReluBasic(const FusionDWConvBNReluParam ¶m) { Tensor new_scale = *param.NewScale(); Tensor *output = param.Output(); + output->mutable_data(); int groups = param.Groups(); std::vector strides = param.Strides(); diff --git a/src/operators/math/depthwise_conv3x3.cpp b/src/operators/math/depthwise_conv3x3.cpp index ab47126329d5f5c9b8607250dff086a31466fcec..8220e20429ef3b26acb1f0f130ecd41f2954a3c2 100644 --- a/src/operators/math/depthwise_conv3x3.cpp +++ b/src/operators/math/depthwise_conv3x3.cpp @@ -564,7 +564,7 @@ void DepthwiseConvAddBNRelu3x3s1p1(const framework::Tensor *input, #if __ARM_NEON const float *input_data = input->data(); const float *filter_data = filter->data(); - float *output_data = output->data(); + float *output_data = output->mutable_data(); const float *newscale_data = new_scale->data(); const float *newbias_data = new_bias->data(); @@ -1309,7 +1309,7 @@ void DepthwiseConv3x3s2p1v2(const framework::Tensor *input, #if __ARM_NEON const float *input_data = input->data(); const float *filter_data = filter->data(); - float *output_data = output->data(); + float *output_data = output->mutable_data(); const float *bias_data; if (if_bias) { bias_data = bias->data(); @@ -1729,7 +1729,7 @@ void DepthwiseConvAddBNRelu3x3s2p1v2(const framework::Tensor *input, const float *input_data = input->data(); const float *filter_data = filter->data(); - float *output_data = output->data(); + float *output_data = output->mutable_data(); const float *newscale_data = new_scale->data(); const float *newbias_data = new_bias->data(); @@ -1978,6 +1978,7 @@ void DepthwiseConv3x3s2p0(const framework::Tensor *input, const int output_width = static_cast(output->dims()[3]); const int inhxw = input_height * input_width; const int outhxw = output_height * output_width; + output->mutable_data(); float32x4_t zero = vdupq_n_f32(0.0); for (int b = 0; b < batch_size; b++) { diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 23634f33fe01dbfbc994f48a522c30c966fc7087..8b52faf184bf79211b39ce46ae21e0668d1dafc2 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -337,8 +337,8 @@ if (NOT FOUND_MATCH) target_link_libraries(test-genet paddle-mobile) # gen test - ADD_EXECUTABLE(test-sigmoid operators/test_sigmoid_op.cpp test_include.h) - target_link_libraries(test-sigmoid paddle-mobile) + ADD_EXECUTABLE(test-sigmoid-op operators/test_sigmoid_op.cpp test_include.h) + target_link_libraries(test-sigmoid-op paddle-mobile) # gen test ADD_EXECUTABLE(test-depthwise-conv-op operators/test_depthwise_conv_op.cpp test_helper.h test_include.h executor_for_test.h) @@ -408,14 +408,14 @@ if (NOT FOUND_MATCH) ADD_EXECUTABLE(test-ocr net/test_ocr.cpp test_helper.h test_include.h) target_link_libraries(test-ocr paddle-mobile) - ADD_EXECUTABLE(test-sequence-expand operators/test_sequence_expand_op.cpp test_helper.h test_include.h) - target_link_libraries(test-sequence-expand paddle-mobile) + ADD_EXECUTABLE(test-sequence-expand-op operators/test_sequence_expand_op.cpp test_helper.h test_include.h) + target_link_libraries(test-sequence-expand-op paddle-mobile) - ADD_EXECUTABLE(test-sequence-pool operators/test_sequence_pool_op.cpp test_helper.h test_include.h) - target_link_libraries(test-sequence-pool paddle-mobile) + ADD_EXECUTABLE(test-sequence-pool-op operators/test_sequence_pool_op.cpp test_helper.h test_include.h) + target_link_libraries(test-sequence-pool-op paddle-mobile) - ADD_EXECUTABLE(test-sequence-softmax operators/test_sequence_softmax_op.cpp test_helper.h test_include.h) - target_link_libraries(test-sequence-softmax paddle-mobile) + ADD_EXECUTABLE(test-sequence-softmax-op operators/test_sequence_softmax_op.cpp test_helper.h test_include.h) + target_link_libraries(test-sequence-softmax-op paddle-mobile) # gen test ADD_EXECUTABLE(test-vgg16ssd net/test_vgg16ssd.cpp test_helper.h test_include.h) @@ -437,4 +437,9 @@ if (NOT FOUND_MATCH) ADD_EXECUTABLE(test-logical-xor-op operators/test_logical_xor_op.cpp test_helper.h test_include.h) target_link_libraries(test-logical-xor-op paddle-mobile) + ADD_EXECUTABLE(test-conv-bn-relu-op operators/test_conv_bn_relu_op.cpp test_helper.h test_include.h) + target_link_libraries(test-conv-bn-relu-op paddle-mobile) + + ADD_EXECUTABLE(test-dwconv-bn-relu-op operators/test_dwconv_bn_relu_op.cpp test_helper.h test_include.h) + target_link_libraries(test-dwconv-bn-relu-op paddle-mobile) endif () diff --git a/test/operators/test_conv_bn_relu_op.cpp b/test/operators/test_conv_bn_relu_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6a09d838e0a30486569448726c255b1a6ba7f617 --- /dev/null +++ b/test/operators/test_conv_bn_relu_op.cpp @@ -0,0 +1,172 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "../test_helper.h" +#include "../test_include.h" +#include "operators/fusion_conv_bn_relu_op.h" + +namespace paddle_mobile { + +// Reference convolution from Caffe for checking results. +// accumulate through explicit loops over input, output, and filters. +template +int TestConvBnReluOp(int in_channels, int in_height, int in_width, + int out_channels, int groups, std::string opname) { + int kernel_h = Kernel; + int kernel_w = Kernel; + int pad_h = Pad; + int pad_w = Pad; + int stride_h = Stride; + int stride_w = Stride; + int dilation_h = 1; + int dilation_w = 1; + + int batch_size = 1; + int input_c = in_channels; + int input_h = in_height; + int input_w = in_width; + int output_c = out_channels; + framework::DDim input_shape = + framework::make_ddim({batch_size, input_c, input_h, input_w}); + framework::DDim filter_shape = + framework::make_ddim({output_c, input_c / groups, kernel_h, kernel_w}); + framework::DDim shape = framework::make_ddim({output_c}); + + VariableNameMap inputs; + VariableNameMap outputs; + auto scope = std::make_shared(); + inputs["Input"] = std::vector({"input"}); + inputs["Filter"] = std::vector({"filter"}); + outputs["Out"] = std::vector({"output"}); + inputs["Mean"] = std::vector({"input_mean"}); + inputs["Variance"] = std::vector({"input_variance"}); + inputs["Scale"] = std::vector({"input_scale"}); + inputs["Bias"] = std::vector({"input_bias"}); + auto input_var = scope.get()->Var("input"); + auto input = input_var->template GetMutable(); + SetupTensor(input, input_shape, -20.0, 20.0); + + auto filter_var = scope.get()->Var("filter"); + auto filter = filter_var->template GetMutable(); + SetupTensor(filter, filter_shape, -20, 20); + + auto input_mean_var = scope.get()->Var("input_mean"); + auto input_mean = input_mean_var->template GetMutable(); + SetupTensor(input_mean, shape, -10.0, 10.0); + auto vari_var = scope.get()->Var("input_variance"); + auto vari = vari_var->template GetMutable(); + SetupTensor(vari, shape, -10.0, 10.0); + auto scale_var = scope.get()->Var("input_scale"); + auto scale = scale_var->template GetMutable(); + SetupTensor(scale, shape, -10.0, 10.0); + auto input_bias_var = scope.get()->Var("input_bias"); + auto input_bias = input_bias_var->template GetMutable(); + SetupTensor(input_bias, shape, -10.0, 10.0); + + auto output_var = scope.get()->Var("output"); + framework::AttributeMap attrs; + attrs["strides"].Set>(std::vector({stride_h, stride_w})); + attrs["paddings"].Set>(std::vector({pad_h, pad_w})); + attrs["dilations"].Set>( + std::vector({dilation_h, dilation_w})); + attrs["groups"].Set(groups); + attrs["epsilon"].Set(1e-6); + attrs["momentum"].Set(0.f); + auto *op = new operators::FusionConvBNReluOp( + "fusion_conv_bn_relu", inputs, outputs, attrs, scope); + op->InferShape(); + op->Init(); + for (int i = 0; i < 10; ++i) { + op->Run(); + } + auto time1 = time(); + for (int i = 0; i < 10; ++i) { + op->Run(); + } + auto time2 = time(); + std::ofstream out_file("./out_conv.txt", std::ios::app); + out_file << opname << " cost :" << time_diff(time1, time2) / 10.0 << "ms" + << std::endl; + out_file.close(); + + delete op; + return 0; +} + +} // namespace paddle_mobile + +int main(int argc, char *argv[]) { + // kernel = 3, pad = 1, stride = 2 + paddle_mobile::TestConvBnReluOp(3, 48, 48, 16, 1, + "conv_bn_relu"); + // kernel = 1, pad = 0, stride = 1 + paddle_mobile::TestConvBnReluOp(16, 24, 24, 8, 1, + "depthwise_seperable"); + // kernel = 1, pad = 0, stride = 1 + paddle_mobile::TestConvBnReluOp(8, 24, 24, 24, 1, + "MBConv_3x3_conv1"); + // kernel = 1, pad = 0, stride = 1 + paddle_mobile::TestConvBnReluOp(24, 24, 24, 8, 1, + "MBConv_3x3_pw1"); + // kernel = 1, pad = 0, stride = 1 + paddle_mobile::TestConvBnReluOp(8, 24, 24, 24, 1, + "MBConv_3x3_conv2"); + // kernel = 1, pad = 0, stride = 1 + paddle_mobile::TestConvBnReluOp(24, 24, 24, 8, 1, + "MBConv_3x3_pw2"); + // kernel = 1, pad = 0, stride = 1 + paddle_mobile::TestConvBnReluOp(8, 24, 24, 24, 1, + "MBConv_3x3_conv3"); + // kernel = 1, pad = 0, stride = 1 + paddle_mobile::TestConvBnReluOp(24, 12, 12, 16, 1, + "MBConv_3x3_pw3"); + // kernel = 1, pad = 0, stride = 1 + paddle_mobile::TestConvBnReluOp( + 16, 12, 12, 48, 1, "MBConv_5x5_stage1_conv1"); + // kernel = 1, pad = 0, stride = 1 + paddle_mobile::TestConvBnReluOp( + 48, 12, 12, 16, 1, "MBConv_5x5_stage1_pw1"); + // kernel = 1, pad = 0, stride = 1 + paddle_mobile::TestConvBnReluOp( + 16, 12, 12, 48, 1, "MBConv_5x5_stage1_conv2"); + // kernel = 1, pad = 0, stride = 1 + paddle_mobile::TestConvBnReluOp( + 48, 12, 12, 16, 1, "MBConv_5x5_stage1_pw2"); + // kernel = 1, pad = 0, stride = 1 + paddle_mobile::TestConvBnReluOp( + 16, 12, 12, 48, 1, "MBConv_5x5_stage1_conv3"); + // kernel = 1, pad = 0, stride = 1 + paddle_mobile::TestConvBnReluOp( + 48, 6, 6, 32, 1, "MBConv_5x5_stage1_pw3"); + // kernel = 1, pad = 0, stride = 1 + paddle_mobile::TestConvBnReluOp( + 32, 6, 6, 192, 1, "MBConv_5x5_stage2_conv1"); + // kernel = 1, pad = 0, stride = 1 + paddle_mobile::TestConvBnReluOp( + 192, 6, 6, 32, 1, "MBConv_5x5_stage2_pw1"); + // kernel = 1, pad = 0, stride = 1 + paddle_mobile::TestConvBnReluOp( + 32, 6, 6, 192, 1, "MBConv_5x5_stage2_conv2"); + // kernel = 1, pad = 0, stride = 1 + paddle_mobile::TestConvBnReluOp( + 192, 6, 6, 32, 1, "MBConv_5x5_stage2_pw2"); + // kernel = 1, pad = 0, stride = 1 + paddle_mobile::TestConvBnReluOp( + 32, 6, 6, 192, 1, "MBConv_5x5_stage2_conv3"); + // kernel = 1, pad = 0, stride = 1 + paddle_mobile::TestConvBnReluOp( + 192, 6, 6, 64, 1, "MBConv_5x5_stage2_pw3"); + + return 0; +} diff --git a/test/operators/test_conv_op.cpp b/test/operators/test_conv_op.cpp index c596c1def4006853532395f151c6e9c47cf8e3e8..3a949daefeb89df1c72702f1207a0d0f0e652f93 100644 --- a/test/operators/test_conv_op.cpp +++ b/test/operators/test_conv_op.cpp @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include #include "../test_helper.h" #include "../test_include.h" #include "operators/conv_op.h" @@ -209,10 +210,10 @@ int TestConvOp(int in_channels, int in_height, int in_width, int out_channels, // PADDLE_MOBILE_ENFORCE(std::abs(gap / (output_data[i] + 1e-5)) < 1e-3, // "output[%d] = %d, output_cmp[%d] = %d", i, // output_data[i], i, output_cmp_data[i]); - if (std::abs(gap / (output_data[i] + 1e-5)) > 1e-3) { - LOG(kLOG_INFO) << "output_data[" << i << "] = " << output_data[i] - << ", output_cmp_data[" << i - << "] = " << output_cmp_data[i]; + if (gap > 1e-2 && std::abs(gap / (output_data[i] + 1e-5)) > 1e-3) { + std::cerr << "output_data[" << i << "] = " << output_data[i] + << ", output_cmp_data[" << i << "] = " << output_cmp_data[i] + << std::endl; exit(1); } } @@ -222,94 +223,131 @@ int TestConvOp(int in_channels, int in_height, int in_width, int out_channels, } // namespace paddle_mobile -int main(int argc, char *argv[]) { - if (argc < 5) { - LOG(paddle_mobile::kLOG_INFO) - << "Usage:\n" - << " ./test-int8-conv-op in_channels in_height in_width out_channels " - "[groups]\n" - << " params:\n" - << " -in_channels: int, input image's channels\n" - << " -in_height: int, input image's height\n" - << " -in_width: int, input image's width\n" - << " -out_channels: int, conv output channels\n"; - return 1; - } - int in_channels = atoi(argv[1]); - int in_height = atoi(argv[2]); - int in_width = atoi(argv[3]); - int out_channels = atoi(argv[4]); - int groups = 1; - if (argc == 6) { - groups = atoi(argv[5]); - } +int TestAll(const int in_channels, const int in_height, const int in_width, + const int out_channels, const int groups) { + std::cerr << "in_channels=" << in_channels << ", in_height=" << in_height + << ", in_width=" << in_width << ", out_channels=" << out_channels + << ", groups=" << groups << std::endl; + // // kernel = 3, pad = 0, stride = 1 + // std::cerr << "float, kernel=3, pad=0, stride=1" << std::endl; + // paddle_mobile::TestConvOp( + // in_channels, in_height, in_width, out_channels, groups); + // // kernel = 3, pad = 1, stride = 1 + // std::cerr << "float, kernel=3, pad=1, stride=1" << std::endl; + // paddle_mobile::TestConvOp( + // in_channels, in_height, in_width, out_channels, groups); + // // kernel = 3, pad = 2, stride = 1 + // std::cerr << "float, kernel=3, pad=2, stride=1" << std::endl; + // paddle_mobile::TestConvOp( + // in_channels, in_height, in_width, out_channels, groups); + // // kernel = 3, pad = 5, stride = 1 + // std::cerr << "float, kernel=3, pad=5, stride=1" << std::endl; + // paddle_mobile::TestConvOp( + // in_channels, in_height, in_width, out_channels, groups); + // + // // kernel = 3, pad = 0, stride = 2 + // std::cerr << "float, kernel=3, pad=0, stride=2" << std::endl; + // paddle_mobile::TestConvOp( + // in_channels, in_height, in_width, out_channels, groups); + // // kernel = 3, pad = 1, stride = 2 + // std::cerr << "float, kernel=3, pad=1, stride=2" << std::endl; + // paddle_mobile::TestConvOp( + // in_channels, in_height, in_width, out_channels, groups); + // // kernel = 3, pad = 2, stride = 2 + // std::cerr << "float, kernel=3, pad=2, stride=2" << std::endl; + // paddle_mobile::TestConvOp( + // in_channels, in_height, in_width, out_channels, groups); + // // kernel = 3, pad = 5, stride = 2 + // std::cerr << "float, kernel=3, pad=5, stride=2" << std::endl; + // paddle_mobile::TestConvOp( + // in_channels, in_height, in_width, out_channels, groups); + +#ifndef __aarch64__ // kernel = 3, pad = 0, stride = 1 - LOG(paddle_mobile::kLOG_INFO) << "float, kernel=3, pad=0, stride=1"; + std::cerr << "int8, kernel=3, pad=0, stride=1" << std::endl; paddle_mobile::TestConvOp( in_channels, in_height, in_width, out_channels, groups); // kernel = 3, pad = 1, stride = 1 - LOG(paddle_mobile::kLOG_INFO) << "float, kernel=3, pad=1, stride=1"; + std::cerr << "int8, kernel=3, pad=1, stride=1" << std::endl; paddle_mobile::TestConvOp( in_channels, in_height, in_width, out_channels, groups); // kernel = 3, pad = 2, stride = 1 - LOG(paddle_mobile::kLOG_INFO) << "float, kernel=3, pad=2, stride=1"; + std::cerr << "int8, kernel=3, pad=2, stride=1" << std::endl; paddle_mobile::TestConvOp( in_channels, in_height, in_width, out_channels, groups); // kernel = 3, pad = 5, stride = 1 - LOG(paddle_mobile::kLOG_INFO) << "float, kernel=3, pad=5, stride=1"; + std::cerr << "int8, kernel=3, pad=5, stride=1" << std::endl; paddle_mobile::TestConvOp( in_channels, in_height, in_width, out_channels, groups); // kernel = 3, pad = 0, stride = 2 - LOG(paddle_mobile::kLOG_INFO) << "float, kernel=3, pad=0, stride=2"; + std::cerr << "int8, kernel=3, pad=0, stride=2" << std::endl; paddle_mobile::TestConvOp( in_channels, in_height, in_width, out_channels, groups); // kernel = 3, pad = 1, stride = 2 - LOG(paddle_mobile::kLOG_INFO) << "float, kernel=3, pad=1, stride=2"; + std::cerr << "int8, kernel=3, pad=1, stride=2" << std::endl; paddle_mobile::TestConvOp( in_channels, in_height, in_width, out_channels, groups); // kernel = 3, pad = 2, stride = 2 - LOG(paddle_mobile::kLOG_INFO) << "float, kernel=3, pad=2, stride=2"; + std::cerr << "int8, kernel=3, pad=2, stride=2" << std::endl; paddle_mobile::TestConvOp( in_channels, in_height, in_width, out_channels, groups); // kernel = 3, pad = 5, stride = 2 - LOG(paddle_mobile::kLOG_INFO) << "float, kernel=3, pad=5, stride=2"; + std::cerr << "int8, kernel=3, pad=5, stride=2" << std::endl; paddle_mobile::TestConvOp( in_channels, in_height, in_width, out_channels, groups); +#endif // __aarch64__ // kernel = 5, pad = 0, stride = 1 - LOG(paddle_mobile::kLOG_INFO) << "float, kernel=5, pad=0, stride=1"; + std::cerr << "float, kernel=5, pad=0, stride=1" << std::endl; paddle_mobile::TestConvOp( in_channels, in_height, in_width, out_channels, groups); // kernel = 5, pad = 1, stride = 1 - LOG(paddle_mobile::kLOG_INFO) << "float, kernel=5, pad=1, stride=1"; + std::cerr << "float, kernel=5, pad=1, stride=1" << std::endl; paddle_mobile::TestConvOp( in_channels, in_height, in_width, out_channels, groups); // kernel = 5, pad = 2, stride = 1 - LOG(paddle_mobile::kLOG_INFO) << "float, kernel=5, pad=2, stride=1"; + std::cerr << "float, kernel=5, pad=2, stride=1" << std::endl; paddle_mobile::TestConvOp( in_channels, in_height, in_width, out_channels, groups); // kernel = 5, pad = 5, stride = 1 - LOG(paddle_mobile::kLOG_INFO) << "float, kernel=5, pad=5, stride=1"; + std::cerr << "float, kernel=5, pad=5, stride=1" << std::endl; paddle_mobile::TestConvOp( in_channels, in_height, in_width, out_channels, groups); +#ifndef __aarch64__ // kernel = 5, pad = 0, stride = 1 - LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=5, pad=0, stride=1"; + std::cerr << "int8, kernel=5, pad=0, stride=1" << std::endl; paddle_mobile::TestConvOp( in_channels, in_height, in_width, out_channels, groups); // kernel = 5, pad = 1, stride = 1 - LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=5, pad=1, stride=1"; + std::cerr << "int8, kernel=5, pad=1, stride=1" << std::endl; paddle_mobile::TestConvOp( in_channels, in_height, in_width, out_channels, groups); // kernel = 5, pad = 2, stride = 1 - LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=5, pad=2, stride=1"; + std::cerr << "int8, kernel=5, pad=2, stride=1" << std::endl; paddle_mobile::TestConvOp( in_channels, in_height, in_width, out_channels, groups); // kernel = 5, pad = 5, stride = 1 - LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=5, pad=5, stride=1"; + std::cerr << "int8, kernel=5, pad=5, stride=1" << std::endl; paddle_mobile::TestConvOp( in_channels, in_height, in_width, out_channels, groups); +#endif // __aarch64__ + + return 0; +} + +int main() { + TestAll(1, 5, 5, 1, 1); + TestAll(1, 5, 5, 10, 1); + TestAll(10, 5, 5, 10, 10); + + TestAll(5, 33, 33, 5, 1); + TestAll(5, 33, 33, 13, 1); + TestAll(13, 33, 33, 13, 13); + TestAll(5, 33, 13, 5, 1); + TestAll(5, 33, 13, 13, 1); + TestAll(13, 33, 13, 13, 13); return 0; } diff --git a/test/operators/test_dwconv_bn_relu_op.cpp b/test/operators/test_dwconv_bn_relu_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7fcf10d903e571ac7b0f5fb0a4b1214bf55327d1 --- /dev/null +++ b/test/operators/test_dwconv_bn_relu_op.cpp @@ -0,0 +1,145 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "../test_helper.h" +#include "../test_include.h" +#include "operators/fusion_dwconv_bn_relu_op.h" + +namespace paddle_mobile { + +template +int TestDWConvAddBnReluOp(int in_channels, int in_height, int in_width, + int out_channels, int groups, std::string opname) { + int kernel_h = Kernel; + int kernel_w = Kernel; + int pad_h = Pad; + int pad_w = Pad; + int stride_h = Stride; + int stride_w = Stride; + int dilation_h = 1; + int dilation_w = 1; + + int batch_size = 1; + int input_c = in_channels; + int input_h = in_height; + int input_w = in_width; + int output_c = out_channels; + framework::DDim input_shape = + framework::make_ddim({batch_size, input_c, input_h, input_w}); + framework::DDim filter_shape = + framework::make_ddim({output_c, input_c / groups, kernel_h, kernel_w}); + framework::DDim shape = framework::make_ddim({output_c}); + + VariableNameMap inputs; + VariableNameMap outputs; + auto scope = std::make_shared(); + inputs["Input"] = std::vector({"input"}); + inputs["Filter"] = std::vector({"filter"}); + inputs["Mean"] = std::vector({"mean"}); + inputs["Variance"] = std::vector({"variance"}); + inputs["Scale"] = std::vector({"scale"}); + inputs["Bias"] = std::vector({"bias"}); + outputs["Out"] = std::vector({"output"}); + + auto input_var = scope.get()->Var("input"); + auto input = input_var->template GetMutable(); + SetupTensor(input, input_shape, -20.0, 20.0); + + auto filter_var = scope.get()->Var("filter"); + auto filter = filter_var->template GetMutable(); + SetupTensor(filter, filter_shape, -20, 20); + + auto mean_var = scope.get()->Var("mean"); + auto mean = mean_var->template GetMutable(); + SetupTensor(mean, shape, -10.0, 10.0); + + auto vari_var = scope.get()->Var("variance"); + auto vari = vari_var->template GetMutable(); + SetupTensor(vari, shape, -10.0, 10.0); + + auto scale_var = scope.get()->Var("scale"); + auto scale = scale_var->template GetMutable(); + SetupTensor(scale, shape, -10.0, 10.0); + + auto bias_var = scope.get()->Var("bias"); + auto bias = bias_var->template GetMutable(); + SetupTensor(bias, shape, -10.0, 10.0); + + auto output_var = scope.get()->Var("output"); + framework::AttributeMap attrs; + attrs["strides"].Set>(std::vector({stride_h, stride_w})); + attrs["paddings"].Set>(std::vector({pad_h, pad_w})); + attrs["dilations"].Set>( + std::vector({dilation_h, dilation_w})); + attrs["groups"].Set(groups); + attrs["epsilon"].Set(1e-6); + attrs["momentum"].Set(0.f); + + auto *op = new operators::FusionDWConvBNReluOp( + "fusion_dwconv_bn_relu", inputs, outputs, attrs, scope); + op->InferShape(); + op->Init(); + for (int i = 0; i < 10; ++i) { + op->Run(); + } + auto time1 = time(); + for (int i = 0; i < 10; ++i) { + op->Run(); + } + auto time2 = time(); + std::ofstream out_file("./out_dwconv.txt", std::ios::app); + out_file << opname << " cost :" << time_diff(time1, time2) / 10.0 << "ms" + << std::endl; + out_file.close(); + + delete op; + return 0; +} + +} // namespace paddle_mobile + +int main(int argc, char *argv[]) { + // kernel = 3, pad = 1, stride = 1 + paddle_mobile::TestDWConvAddBnReluOp( + 16, 24, 24, 16, 16, "depthwise_seperable"); + // kernel = 3, pad = 1, stride = 1 + paddle_mobile::TestDWConvAddBnReluOp( + 24, 24, 24, 24, 24, "MBConv_3x3_dw1"); + // kernel = 3, pad = 1, stride = 1 + paddle_mobile::TestDWConvAddBnReluOp( + 24, 24, 24, 24, 24, "MBConv_3x3_dw2"); + // kernel = 3, pad = 1, stride = 2 + paddle_mobile::TestDWConvAddBnReluOp( + 24, 24, 24, 24, 24, "MBConv_3x3_dw3"); + // kernel = 5, pad = 2, stride = 1 + paddle_mobile::TestDWConvAddBnReluOp( + 48, 12, 12, 48, 48, "MBConv_5x5_stage1_dw1"); + // kernel = 5, pad = 2, stride = 1 + paddle_mobile::TestDWConvAddBnReluOp( + 48, 12, 12, 48, 48, "MBConv_5x5_stage1_dw2"); + // kernel = 5, pad = 2, stride = 2 + paddle_mobile::TestDWConvAddBnReluOp( + 48, 12, 12, 48, 48, "MBConv_5x5_stage1_dw3"); + // kernel = 5, pad = 2, stride = 1 + paddle_mobile::TestDWConvAddBnReluOp( + 192, 6, 6, 192, 192, "MBConv_5x5_stage2_dw1"); + // kernel = 5, pad = 2, stride = 1 + paddle_mobile::TestDWConvAddBnReluOp( + 192, 6, 6, 192, 192, "MBConv_5x5_stage2_dw2"); + // kernel = 5, pad = 2, stride = 1 + paddle_mobile::TestDWConvAddBnReluOp( + 192, 6, 6, 192, 192, "MBConv_5x5_stage2_dw3"); + + return 0; +} diff --git a/test/operators/test_gru_op.cpp b/test/operators/test_gru_op.cpp index f2ce833661bfd1b3d751a7ac2d54cfb70114a6c6..b11ec4f5f77aca2c4997153863e70b1a6b209c32 100644 --- a/test/operators/test_gru_op.cpp +++ b/test/operators/test_gru_op.cpp @@ -12,18 +12,89 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "../test_helper.h" #include "../test_include.h" #include "operators/gru_op.h" -int main() { - paddle_mobile::framework::Loader loader; - auto program = loader.Load(g_nlp); - PADDLE_MOBILE_ENFORCE(program.originProgram != nullptr, - "program file read fail"); +namespace paddle_mobile { - Executor4Test> - executor(program, "gru"); +template +int TestGruOp(int in_channels, int out_channels, std::string opname) { + int input_c = in_channels; + int output_c = out_channels; + paddle_mobile::framework::LoD lod{{0, input_c}}; + int batch_size = lod.size(); + framework::DDim input_shape = framework::make_ddim({input_c, output_c * 3}); + framework::DDim weight_shape = framework::make_ddim({output_c, output_c * 3}); + framework::DDim h0_shape = framework::make_ddim({batch_size, output_c}); + framework::DDim bias_shape = framework::make_ddim({batch_size, output_c * 3}); + VariableNameMap inputs; + VariableNameMap outputs; + auto scope = std::make_shared(); + inputs["Input"] = std::vector({"input"}); + inputs["Weight"] = std::vector({"weight"}); + inputs["H0"] = std::vector({"h0"}); + inputs["Bias"] = std::vector({"bias"}); + + outputs["BatchGate"] = std::vector({"output_batch_gate"}); + outputs["BatchResetHiddenPrev"] = + std::vector({"output_batch_reset_hidden_prev"}); + outputs["BatchHidden"] = std::vector({"output_batch_hidden"}); + outputs["Hidden"] = std::vector({"output_hidden"}); + + auto input_var = scope.get()->Var("input"); + auto input = input_var->template GetMutable(); + SetupTensor(input, input_shape, -127, 127); + input->set_lod(lod); + + auto weight_var = scope.get()->Var("weight"); + auto weight = weight_var->template GetMutable(); + SetupTensor(weight, weight_shape, -127, 127); + + auto h0_var = scope.get()->Var("h0"); + auto h0 = h0_var->template GetMutable(); + SetupTensor(h0, h0_shape, -127, 127); + + auto bias_var = scope.get()->Var("bias"); + auto bias = bias_var->template GetMutable(); + SetupTensor(bias, bias_shape, -127, 127); + + auto batch_gate_var = scope.get()->Var("output_batch_gate"); + auto batch_reset_hidden_prev_var = + scope.get()->Var("output_batch_reset_hidden_prev"); + auto batch_hidden_var = scope.get()->Var("output_batch_hidden"); + auto hidden_var = scope.get()->Var("output_hidden"); + + framework::AttributeMap attrs; + attrs["activation"].SetString(std::string("relu")); + attrs["gate_activation"].SetString(std::string("sigmoid")); + attrs["is_reverse"].Set(false); + + auto *op = + new operators::GruOp("gru", inputs, outputs, attrs, scope); + op->InferShape(); + op->Init(); + for (int i = 0; i < 10; ++i) { + op->Run(); + } + auto time1 = time(); + for (int i = 0; i < 10; ++i) { + op->Run(); + } + auto time2 = time(); + std::ofstream out_file("./out_gru.txt", std::ios::app); + out_file << opname << " cost :" << time_diff(time1, time2) / 10.0 << "ms" + << std::endl; + out_file.close(); + + delete op; + return 0; +} + +} // namespace paddle_mobile + +int main(int argc, char *argv[]) { + paddle_mobile::TestGruOp(384, 120, "gru_forward"); return 0; } diff --git a/test/operators/test_log_op.cpp b/test/operators/test_log_op.cpp index 8d675f06decc902365c32d797b432923933656f7..2f29e8711bb8de0e576a9a1485d96a448ec3d3c0 100644 --- a/test/operators/test_log_op.cpp +++ b/test/operators/test_log_op.cpp @@ -76,6 +76,5 @@ int main() { paddle_mobile::TestLogOp({1, 1, 2, 3}); paddle_mobile::TestLogOp({1, 3, 11, 22}); paddle_mobile::TestLogOp({1, 32, 112, 112}); - std::cout << "test log op pass." << std::endl; return 0; } diff --git a/test/operators/test_quantize_op.cpp b/test/operators/test_quantize_op.cpp index 50c0e7bd05da7f7a5ee1fd6912be0eff2f6e2958..f3b8fd151c83d115b003b226549ba351188808da 100644 --- a/test/operators/test_quantize_op.cpp +++ b/test/operators/test_quantize_op.cpp @@ -92,18 +92,10 @@ static float find_abs_max(const Tensor *input) { return max_abs; } -int TestQuqntizeOp(int argc, char *argv[]) { - if (argc < 5) { - std::cout << "Usage: ./test-quantize-op batch_size channel height width" - << std::endl; - return 1; - } - int batch_size = atoi(argv[1]); - int channel = atoi(argv[2]); - int height = atoi(argv[3]); - int width = atoi(argv[4]); - std::cout << "batch_size: " << batch_size << ", channel: " << channel - << ", height: " << height << ", width: " << width << std::endl; +int TestQuqntizeOp(const int batch_size, const int channel, const int height, + const int width) { + DLOG << "batch_size: " << batch_size << ", channel: " << channel + << ", height: " << height << ", width: " << width; framework::DDim dim = framework::make_ddim({batch_size, channel, height, width}); @@ -140,9 +132,7 @@ int TestQuqntizeOp(int argc, char *argv[]) { framework::Tensor output_cmp; output_cmp.Resize(output->dims()); float scale = 127 / output_scale_cmp; - // quantize(input, scale, &output_cmp); - // quantize(input, scale, &output_cmp); - quantize(input, scale, &output_cmp); + quantize(input, scale, &output_cmp); int8_t *output_cmp_data = output_cmp.data(); for (int i = 0; i < output->numel(); ++i) { PADDLE_MOBILE_ENFORCE(output_data[i] == output_cmp_data[i], @@ -157,5 +147,7 @@ int TestQuqntizeOp(int argc, char *argv[]) { } // namespace paddle_mobile int main(int argc, char *argv[]) { - return paddle_mobile::TestQuqntizeOp(argc, argv); + TestQuqntizeOp(1, 10, 10, 5); + TestQuqntizeOp(1, 111, 111, 5); + TestQuqntizeOp(5, 111, 111, 5); } diff --git a/test/operators/test_sequence_pool_op.cpp b/test/operators/test_sequence_pool_op.cpp index a8518d630a6008c7cd1fa99d2b0df1d27ebfba32..3b377aa437b8a37041e3f30d299214e19c48ff4e 100644 --- a/test/operators/test_sequence_pool_op.cpp +++ b/test/operators/test_sequence_pool_op.cpp @@ -59,7 +59,7 @@ int TestSequencePoolOp(const framework::LoDTensor &input_x, int main(int argc, char *argv[]) { framework::LoDTensor input_x, output; // case 1 - std::cerr << "running max case 1" << std::endl; + DLOG << "running max case 1"; { std::vector data{1, 2, 3, 4}; input_x.Resize(framework::make_ddim({4, 1})); @@ -71,14 +71,14 @@ int main(int argc, char *argv[]) { std::vector expect_data{2, 4}; for (int i = 0; i < 2; ++i) { if (output.data()[i] != expect_data[i]) { - std::cerr << "output[" << i << "]: " << output.data()[i] - << " != expect[" << i << "]: " << expect_data[i] << std::endl; + DLOG << "output[" << i << "]: " << output.data()[i] + << " != expect[" << i << "]: " << expect_data[i]; return 1; } } } // case 2 - std::cerr << "running max case 2" << std::endl; + DLOG << "running max case 2"; { std::vector data{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}; input_x.Resize(framework::make_ddim({data.size(), 1})); @@ -90,13 +90,13 @@ int main(int argc, char *argv[]) { std::vector expect_data{3, 10}; for (int i = 0; i < 2; ++i) { if (output.data()[i] != expect_data[i]) { - std::cerr << "output[" << i << "]: " << output.data()[i] - << " != expect[" << i << "]: " << expect_data[i] << std::endl; + DLOG << "output[" << i << "]: " << output.data()[i] + << " != expect[" << i << "]: " << expect_data[i]; return 1; } } } - std::cerr << "running max case 3" << std::endl; + DLOG << "running max case 3"; // case 3 { std::vector data{1, 2, 3, 4, 5, 6, 7, 8}; @@ -109,14 +109,14 @@ int main(int argc, char *argv[]) { std::vector expect_data{3, 4, 7, 8}; for (int i = 0; i < 4; ++i) { if (output.data()[i] != expect_data[i]) { - std::cerr << "output[" << i << "]: " << output.data()[i] - << " != expect[" << i << "]: " << expect_data[i] << std::endl; + DLOG << "output[" << i << "]: " << output.data()[i] + << " != expect[" << i << "]: " << expect_data[i]; return 1; } } } // case 4 - std::cerr << "running max case 4" << std::endl; + DLOG << "running max case 4"; { std::vector data{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20}; @@ -129,14 +129,14 @@ int main(int argc, char *argv[]) { std::vector expect_data{6, 7, 8, 9, 10, 16, 17, 18, 19, 20}; for (int i = 0; i < 10; ++i) { if (output.data()[i] != expect_data[i]) { - std::cerr << "output[" << i << "]: " << output.data()[i] - << " != expect[" << i << "]: " << expect_data[i] << std::endl; + DLOG << "output[" << i << "]: " << output.data()[i] + << " != expect[" << i << "]: " << expect_data[i]; return 1; } } } // case 1 - std::cerr << "running sum case 1" << std::endl; + DLOG << "running sum case 1"; { std::vector data{1, 2, 3, 4}; input_x.Resize(framework::make_ddim({4, 1})); @@ -148,14 +148,14 @@ int main(int argc, char *argv[]) { std::vector expect_data{3, 7}; for (int i = 0; i < 2; ++i) { if (output.data()[i] != expect_data[i]) { - std::cerr << "output[" << i << "]: " << output.data()[i] - << " != expect[" << i << "]: " << expect_data[i] << std::endl; + DLOG << "output[" << i << "]: " << output.data()[i] + << " != expect[" << i << "]: " << expect_data[i]; return 1; } } } // case 2 - std::cerr << "running sum case 2" << std::endl; + DLOG << "running sum case 2"; { std::vector data{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}; input_x.Resize(framework::make_ddim({data.size(), 1})); @@ -167,14 +167,14 @@ int main(int argc, char *argv[]) { std::vector expect_data{6, 49}; for (int i = 0; i < 2; ++i) { if (output.data()[i] != expect_data[i]) { - std::cerr << "output[" << i << "]: " << output.data()[i] - << " != expect[" << i << "]: " << expect_data[i] << std::endl; + DLOG << "output[" << i << "]: " << output.data()[i] + << " != expect[" << i << "]: " << expect_data[i]; return 1; } } } // case 3 - std::cerr << "running sum case 3" << std::endl; + DLOG << "running sum case 3"; { std::vector data{1, 2, 3, 4, 5, 6, 7, 8}; input_x.Resize(framework::make_ddim({4, 2})); @@ -186,14 +186,14 @@ int main(int argc, char *argv[]) { std::vector expect_data{4, 6, 12, 14}; for (int i = 0; i < 4; ++i) { if (output.data()[i] != expect_data[i]) { - std::cerr << "output[" << i << "]: " << output.data()[i] - << " != expect[" << i << "]: " << expect_data[i] << std::endl; + DLOG << "output[" << i << "]: " << output.data()[i] + << " != expect[" << i << "]: " << expect_data[i]; return 1; } } } // case 4 - std::cerr << "running sum case 4" << std::endl; + DLOG << "running sum case 4"; { std::vector data{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20}; @@ -206,14 +206,14 @@ int main(int argc, char *argv[]) { std::vector expect_data{7, 9, 11, 13, 15, 27, 29, 31, 33, 35}; for (int i = 0; i < 10; ++i) { if (output.data()[i] != expect_data[i]) { - std::cerr << "output[" << i << "]: " << output.data()[i] - << " != expect[" << i << "]: " << expect_data[i] << std::endl; + DLOG << "output[" << i << "]: " << output.data()[i] + << " != expect[" << i << "]: " << expect_data[i]; return 1; } } } // case 1 - std::cerr << "running first case 1" << std::endl; + DLOG << "running first case 1"; { std::vector data{1, 2, 3, 4}; input_x.Resize(framework::make_ddim({4, 1})); @@ -225,14 +225,14 @@ int main(int argc, char *argv[]) { std::vector expect_data{1, 3}; for (int i = 0; i < 2; ++i) { if (output.data()[i] != expect_data[i]) { - std::cerr << "output[" << i << "]: " << output.data()[i] - << " != expect[" << i << "]: " << expect_data[i] << std::endl; + DLOG << "output[" << i << "]: " << output.data()[i] + << " != expect[" << i << "]: " << expect_data[i]; return 1; } } } // case 2 - std::cerr << "running first case 2" << std::endl; + DLOG << "running first case 2"; { std::vector data{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}; input_x.Resize(framework::make_ddim({data.size(), 1})); @@ -244,14 +244,14 @@ int main(int argc, char *argv[]) { std::vector expect_data{1, 4}; for (int i = 0; i < 2; ++i) { if (output.data()[i] != expect_data[i]) { - std::cerr << "output[" << i << "]: " << output.data()[i] - << " != expect[" << i << "]: " << expect_data[i] << std::endl; + DLOG << "output[" << i << "]: " << output.data()[i] + << " != expect[" << i << "]: " << expect_data[i]; return 1; } } } // case 3 - std::cerr << "running first case 3" << std::endl; + DLOG << "running first case 3"; { std::vector data{1, 2, 3, 4, 5, 6, 7, 8}; input_x.Resize(framework::make_ddim({4, 2})); @@ -263,14 +263,14 @@ int main(int argc, char *argv[]) { std::vector expect_data{1, 2, 5, 6}; for (int i = 0; i < 4; ++i) { if (output.data()[i] != expect_data[i]) { - std::cerr << "output[" << i << "]: " << output.data()[i] - << " != expect[" << i << "]: " << expect_data[i] << std::endl; + DLOG << "output[" << i << "]: " << output.data()[i] + << " != expect[" << i << "]: " << expect_data[i]; return 1; } } } // case 4 - std::cerr << "running first case 4" << std::endl; + DLOG << "running first case 4"; { std::vector data{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20}; @@ -283,8 +283,8 @@ int main(int argc, char *argv[]) { std::vector expect_data{1, 2, 3, 4, 5, 11, 12, 13, 14, 15}; for (int i = 0; i < 10; ++i) { if (output.data()[i] != expect_data[i]) { - std::cerr << "output[" << i << "]: " << output.data()[i] - << " != expect[" << i << "]: " << expect_data[i] << std::endl; + DLOG << "output[" << i << "]: " << output.data()[i] + << " != expect[" << i << "]: " << expect_data[i]; return 1; } } diff --git a/test/operators/test_sigmoid_op.cpp b/test/operators/test_sigmoid_op.cpp index 40f6461a2cfdfb67b135a5a3a22c29bf19750189..260dd62781ad18b46e78db3cfaccf1fe27797175 100644 --- a/test/operators/test_sigmoid_op.cpp +++ b/test/operators/test_sigmoid_op.cpp @@ -76,6 +76,5 @@ int main() { paddle_mobile::TestSigmoidOp({1, 1, 2, 3}); paddle_mobile::TestSigmoidOp({1, 3, 11, 22}); paddle_mobile::TestSigmoidOp({1, 32, 112, 112}); - std::cout << "test sigmoid op pass." << std::endl; return 0; } diff --git a/test/operators/test_tanh_op.cpp b/test/operators/test_tanh_op.cpp index b8006931075d742724d18c3af3627f780a7bf454..d013b0eedfbe3bdc773e263aad594c89212ad6ce 100644 --- a/test/operators/test_tanh_op.cpp +++ b/test/operators/test_tanh_op.cpp @@ -58,7 +58,7 @@ int TestTanhOp(const std::vector input_shape) { const float *output_data = output->data(); for (int i = 0; i < output->numel(); ++i) { float gap = output_data[i] - output_cmp_data[i]; - if (std::abs(gap / (output_data[i] + 1e-5)) > 1e-3) { + if (gap > 1e-5 && std::abs(gap / (output_data[i] + 1e-5)) > 1e-3) { LOG(kLOG_INFO) << "output_data[" << i << "] = " << output_data[i] << ", output_cmp_data[" << i << "] = " << output_cmp_data[i]; diff --git a/tools/ci_build.sh b/tools/ci_build.sh index f21bcdc67e6f73f61e0b33558672ac61fdf0fb22..d725afe4595b8e88578ec6c2f0f3c78bc0807a1b 100755 --- a/tools/ci_build.sh +++ b/tools/ci_build.sh @@ -15,6 +15,7 @@ # limitations under the License. set -e +source ./ci_run_test.sh function print_usage() { echo "\n${RED}Usage${NONE}: @@ -231,6 +232,11 @@ function build_linux_fpga() { docker build -t paddle-mobile:dev - < Dockerfile fi docker run --rm -v `pwd`:/workspace paddle-mobile:dev bash /workspace/tools/docker_build_fpga.sh + cd - +} + +function run_android_test() { + ExecuteAndroidTests $1 } function main() { @@ -239,9 +245,11 @@ function main() { case $CMD in android_armv7) build_android_armv7 + run_android_test armeabi-v7a ;; android_armv8) build_android_armv8 + run_android_test arm64-v8a ;; ios) build_ios diff --git a/tools/ci_run_test.sh b/tools/ci_run_test.sh new file mode 100644 index 0000000000000000000000000000000000000000..6470a97b15a4497cf933ff0a22befa34383dd890 --- /dev/null +++ b/tools/ci_run_test.sh @@ -0,0 +1,43 @@ +#!/usr/bin/env bash + +operators= + +function AddTest() { + operators="${operators} $1" +} + +function ExecuteAndroidTests() { + platform=$1 + devices=`adb devices | grep -v devices | grep device | awk -F ' ' '{print $1}'` + for device in ${devices}; do + adb -s ${device} shell rm -rf /data/local/tmp/* + adb -s ${device} push ../build/${platform}/build/libpaddle-mobile.so /data/local/tmp/ + for op in ${operators}; do + adb -s ${device} push ../test/build/test-${op}-op /data/local/tmp/ + adb -s ${device} shell "cd /data/local/tmp/; LD_LIBRARY_PATH=. ./test-${op}-op" + echo "${BLUE}run test ${op} pass${NONE}" + done + done +} + +AddTest batchnorm +AddTest cast +AddTest conv +AddTest dequantize +#AddTest elementwiseadd +AddTest log +AddTest logical-and +AddTest logical-not +AddTest logical-or +AddTest logical-xor +AddTest pool +AddTest quantize +AddTest relu +AddTest relu6 +AddTest sequence-expand +AddTest sequence-pool +AddTest sequence-softmax +AddTest sigmoid +AddTest softmax +AddTest tanh +AddTest topk diff --git a/tools/docker_build_fpga.sh b/tools/docker_build_fpga.sh index 0927c328dd41b87f77adf19d514703e7bcafbce8..31a28b1532909079b70c1bb1ea63cede8d2c1668 100644 --- a/tools/docker_build_fpga.sh +++ b/tools/docker_build_fpga.sh @@ -1,3 +1,5 @@ +#!/usr/bin/env bash + apt-get update apt-get install -y gcc g++ cmake