From 103f81de8fe51253bc90fe8a33ab5f1d6b039b65 Mon Sep 17 00:00:00 2001 From: liuqi Date: Thu, 14 Sep 2017 15:04:01 +0800 Subject: [PATCH] Move test and benchmark of batch_norm and conv2d3x3r1 to ops directory. --- .../kernels/benchmark/batch_norm_benchmark.cc | 76 ----- mace/kernels/neon/conv_2d_neon_3x3.cc | 34 +-- mace/kernels/test/batch_norm_neon_test.cc | 73 ----- mace/ops/batch_norm_benchmark.cc | 68 +++++ mace/ops/batch_norm_test.cc | 40 ++- mace/ops/conv_2d_benchmark.cc | 1 + mace/ops/conv_2d_test.cc | 44 +++ mace/ops/ops_test_util.h | 286 +++++++++--------- 8 files changed, 320 insertions(+), 302 deletions(-) delete mode 100644 mace/kernels/benchmark/batch_norm_benchmark.cc delete mode 100644 mace/kernels/test/batch_norm_neon_test.cc create mode 100644 mace/ops/batch_norm_benchmark.cc diff --git a/mace/kernels/benchmark/batch_norm_benchmark.cc b/mace/kernels/benchmark/batch_norm_benchmark.cc deleted file mode 100644 index 33825587..00000000 --- a/mace/kernels/benchmark/batch_norm_benchmark.cc +++ /dev/null @@ -1,76 +0,0 @@ -// -// Copyright (c) 2017 XiaoMi All rights reserved. -// - -#include "mace/core/testing/test_benchmark.h" -#include "mace/kernels/batch_norm.h" - -namespace mace { -template -static void BatchNorm(int iters, int batch, int channels, int height, int width) { - - std::random_device rd; - std::mt19937 gen(rd()); - std::normal_distribution nd(0, 1); - - index_t input_size = batch * channels * height * width; - std::vector input(input_size, 0.0); - std::vector scale(channels, 0.0); - std::vector offset(channels, 0.0); - std::vector mean(channels, 0.0); - std::vector var(channels, 0.0); - - for (int i = 0; i < input_size; ++i) { - input[i] = nd(gen); - } - for (int i = 0; i < channels; ++i) { - scale[i] = nd(gen); - offset[i] = nd(gen); - mean[i] = nd(gen); - var[i] = std::abs(nd(gen)); - } - - // declare output - std::unique_ptr output(new T[input_size]); - auto functor = kernels::BatchNormFunctor(1e-5); - - while(iters--) { - functor(input.data(), - scale.data(), - offset.data(), - mean.data(), - var.data(), - batch, - channels, - height * width, - output.get()); - } -} - -#define BM_BATCH_NORM_MACRO(N, C, H, W, TYPE, DEVICE) \ - static void BM_BATCH_NORM_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE( \ - int iters) { \ - const int64_t tot = static_cast(iters) * N * C * H * W; \ - mace::testing::ItemsProcessed(tot); \ - mace::testing::BytesProcessed(tot * (sizeof(TYPE)));\ - BatchNorm(iters, N, C, H, W); \ - } \ - BENCHMARK(BM_BATCH_NORM_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE) - -#define BM_BATCH_NORM(N, C, H, W, TYPE) \ - BM_BATCH_NORM_MACRO(N, C, H, W, TYPE, CPU); \ - BM_BATCH_NORM_MACRO(N, C, H, W, TYPE, NEON); - -BM_BATCH_NORM(1, 1, 128, 128, float); -BM_BATCH_NORM(1, 1, 512, 512, float); -BM_BATCH_NORM(1, 1, 1024, 1024, float); -BM_BATCH_NORM(16, 1, 256, 256, float); -BM_BATCH_NORM(32, 1, 256, 256, float); -BM_BATCH_NORM(64, 1, 256, 256, float); -BM_BATCH_NORM(1, 3, 128, 128, float); -BM_BATCH_NORM(1, 3, 512, 512, float); -BM_BATCH_NORM(1, 3, 1024, 1024, float); -BM_BATCH_NORM(16, 3, 256, 256, float); -BM_BATCH_NORM(32, 3, 256, 256, float); -BM_BATCH_NORM(64, 3, 256, 256, float); -} // namespace mace \ No newline at end of file diff --git a/mace/kernels/neon/conv_2d_neon_3x3.cc b/mace/kernels/neon/conv_2d_neon_3x3.cc index ee194f40..8ba5e82d 100644 --- a/mace/kernels/neon/conv_2d_neon_3x3.cc +++ b/mace/kernels/neon/conv_2d_neon_3x3.cc @@ -8,7 +8,7 @@ namespace mace { namespace kernels { -static const int REGISTER_SIZE = 4; +static const int kRegisterSize = 4; void Conv2dNeonK3x3S1(const float* input, // NCHW const index_t* input_shape, @@ -44,7 +44,7 @@ void Conv2dNeonK3x3S1(const float* input, // NCHW float32x4_t filter3 = vld1q_f32(filter_ptr+3); float32x4_t filter6 = vld1q_f32(filter_ptr+6); - const float* row[REGISTER_SIZE] = { + const float* row[kRegisterSize] = { input_ptr, input_ptr + input_width, input_ptr + 2 * input_width, input_ptr + 3 * input_width }; @@ -61,7 +61,7 @@ void Conv2dNeonK3x3S1(const float* input, // NCHW float32x4_t sum0 = vdupq_n_f32(.0f); float32x4_t sum1 = vdupq_n_f32(.0f); float32x4_t row0_ext_0 = vld1q_f32(row[0]); //0123 - float32x4_t row0_latter = vld1q_f32(row[0] + REGISTER_SIZE); //4567 + float32x4_t row0_latter = vld1q_f32(row[0] + kRegisterSize); //4567 float32x4_t row0_ext_1 = vextq_f32(row0_ext_0, row0_latter, 1); //1234 float32x4_t row0_ext_2 = vextq_f32(row0_ext_0, row0_latter, 2); //2345 @@ -70,7 +70,7 @@ void Conv2dNeonK3x3S1(const float* input, // NCHW sum0 = vfmaq_laneq_f32(sum0, row0_ext_2, filter0, 2); float32x4_t row1_ext_0 = vld1q_f32(row[1]); //0123 - float32x4_t row1_latter = vld1q_f32(row[1] + REGISTER_SIZE); //4567 + float32x4_t row1_latter = vld1q_f32(row[1] + kRegisterSize); //4567 float32x4_t row1_ext_1 = vextq_f32(row1_ext_0, row1_latter, 1); //1234 float32x4_t row1_ext_2 = vextq_f32(row1_ext_0, row1_latter, 2); //2345 @@ -79,7 +79,7 @@ void Conv2dNeonK3x3S1(const float* input, // NCHW sum0 = vfmaq_laneq_f32(sum0, row1_ext_2, filter3, 2); row0_ext_0 = vld1q_f32(row[2]); //0123 - row0_latter = vld1q_f32(row[2] + REGISTER_SIZE); //4567 + row0_latter = vld1q_f32(row[2] + kRegisterSize); //4567 row0_ext_1 = vextq_f32(row0_ext_0, row0_latter, 1); //1234 row0_ext_2 = vextq_f32(row0_ext_0, row0_latter, 2); //2345 @@ -97,7 +97,7 @@ void Conv2dNeonK3x3S1(const float* input, // NCHW sum1 = vfmaq_laneq_f32(sum1, row0_ext_2, filter3, 2); row1_ext_0 = vld1q_f32(row[3]); //0123 - row1_latter = vld1q_f32(row[3] + REGISTER_SIZE); //4567 + row1_latter = vld1q_f32(row[3] + kRegisterSize); //4567 row1_ext_1 = vextq_f32(row1_ext_0, row1_latter, 1); //1234 row1_ext_2 = vextq_f32(row1_ext_0, row1_latter, 2); //2345 @@ -112,10 +112,10 @@ void Conv2dNeonK3x3S1(const float* input, // NCHW vst1q_f32(output_ptr1, output_row0); vst1q_f32(output_ptr2, output_row1); - output_ptr1 += REGISTER_SIZE; - output_ptr2 += REGISTER_SIZE; - for(int i = 0; i < REGISTER_SIZE; ++i) { - row[i] += REGISTER_SIZE; + output_ptr1 += kRegisterSize; + output_ptr2 += kRegisterSize; + for(int i = 0; i < kRegisterSize; ++i) { + row[i] += kRegisterSize; } } for (; remain_count > 0; --remain_count) { @@ -138,13 +138,13 @@ void Conv2dNeonK3x3S1(const float* input, // NCHW ++output_ptr1; ++output_ptr2; - for(int i = 0; i < REGISTER_SIZE; ++i) { + for(int i = 0; i < kRegisterSize; ++i) { row[i] += 1; } } output_ptr1 += width; output_ptr2 += width; - for(int i = 0; i < REGISTER_SIZE; ++i) { + for(int i = 0; i < kRegisterSize; ++i) { row[i] += 2 + input_width; } } @@ -155,7 +155,7 @@ void Conv2dNeonK3x3S1(const float* input, // NCHW for(; count > 0; --count) { float32x4_t sum0 = vdupq_n_f32(.0f); float32x4_t row0_ext_0 = vld1q_f32(row[0]); //0123 - float32x4_t row0_latter = vld1q_f32(row[0] + REGISTER_SIZE); //4567 + float32x4_t row0_latter = vld1q_f32(row[0] + kRegisterSize); //4567 float32x4_t row0_ext_1 = vextq_f32(row0_ext_0, row0_latter, 1); //1234 float32x4_t row0_ext_2 = vextq_f32(row0_ext_0, row0_latter, 2); //2345 @@ -164,7 +164,7 @@ void Conv2dNeonK3x3S1(const float* input, // NCHW sum0 = vfmaq_laneq_f32(sum0, row0_ext_2, filter0, 2); float32x4_t row1_ext_0 = vld1q_f32(row[1]); //0123 - float32x4_t row1_latter = vld1q_f32(row[1] + REGISTER_SIZE); //4567 + float32x4_t row1_latter = vld1q_f32(row[1] + kRegisterSize); //4567 float32x4_t row1_ext_1 = vextq_f32(row1_ext_0, row1_latter, 1); //1234 float32x4_t row1_ext_2 = vextq_f32(row1_ext_0, row1_latter, 2); //2345 @@ -173,7 +173,7 @@ void Conv2dNeonK3x3S1(const float* input, // NCHW sum0 = vfmaq_laneq_f32(sum0, row1_ext_2, filter3, 2); row0_ext_0 = vld1q_f32(row[2]); //0123 - row0_latter = vld1q_f32(row[2] + REGISTER_SIZE); //4567 + row0_latter = vld1q_f32(row[2] + kRegisterSize); //4567 row0_ext_1 = vextq_f32(row0_ext_0, row0_latter, 1); //1234 row0_ext_2 = vextq_f32(row0_ext_0, row0_latter, 2); //2345 @@ -184,9 +184,9 @@ void Conv2dNeonK3x3S1(const float* input, // NCHW float32x4_t output_row0 = vld1q_f32(output_ptr1); output_row0 = vaddq_f32(output_row0, sum0); vst1q_f32(output_ptr1, output_row0); - output_ptr1 += REGISTER_SIZE; + output_ptr1 += kRegisterSize; for(int i = 0; i < 3; ++i) { - row[i] += REGISTER_SIZE; + row[i] += kRegisterSize; } } for (; remain_count > 0; --remain_count) { diff --git a/mace/kernels/test/batch_norm_neon_test.cc b/mace/kernels/test/batch_norm_neon_test.cc deleted file mode 100644 index b47a3360..00000000 --- a/mace/kernels/test/batch_norm_neon_test.cc +++ /dev/null @@ -1,73 +0,0 @@ -// -// Copyright (c) 2017 XiaoMi All rights reserved. -// - -#include -#include "gtest/gtest.h" -#include "mace/kernels/batch_norm.h" - -namespace mace { - -TEST(BatchNormNeonTest, Simple) { - std::random_device rd; - std::mt19937 gen(rd()); - std::normal_distribution nd(0, 1); - srand(time(NULL)); - - // generate random input - index_t batch = 1 + rand() % 128; - index_t channels = 3; - index_t height = 10 + rand() % 100; - index_t width = 10 + rand() % 100; - - index_t input_size = batch * channels * height * width; - std::vector input(input_size, 0.0); - std::vector scale(channels, 0.0); - std::vector offset(channels, 0.0); - std::vector mean(channels, 0.0); - std::vector var(channels, 0.0); - - for (int i = 0; i < input_size; ++i) { - input[i] = nd(gen); - } - for (int i = 0; i < channels; ++i) { - scale[i] = nd(gen); - offset[i] = nd(gen); - mean[i] = nd(gen); - var[i] = std::abs(nd(gen)); - } - - // declare output - std::unique_ptr output(new float[input_size]); - std::unique_ptr output_neon(new float[input_size]); - - kernels::BatchNormFunctor(1e-5)( - input.data(), - scale.data(), - offset.data(), - mean.data(), - var.data(), - batch, - channels, - height * width, - output.get() - ); - kernels::BatchNormFunctor(1e-5)( - input.data(), - scale.data(), - offset.data(), - mean.data(), - var.data(), - batch, - channels, - height * width, - output_neon.get() - ); - - for (index_t i = 0; i < input_size; ++i) { - EXPECT_FLOAT_EQ(output[i], output_neon[i]); - } - -} - -} // namespace mace \ No newline at end of file diff --git a/mace/ops/batch_norm_benchmark.cc b/mace/ops/batch_norm_benchmark.cc new file mode 100644 index 00000000..789934fb --- /dev/null +++ b/mace/ops/batch_norm_benchmark.cc @@ -0,0 +1,68 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "mace/core/operator.h" +#include "mace/core/testing/test_benchmark.h" +#include "mace/ops/ops_test_util.h" + +namespace mace { +template +static void BatchNorm(int iters, int batch, int channels, int height, int width) { + + mace::testing::StopTiming(); + + OpsTestNet net; + OpDefBuilder("BatchNorm", "BatchNormBM") + .Input("Input") + .Input("Scale") + .Input("Offset") + .Input("Mean") + .Input("Var") + .Output("Output") + .Finalize(net.operator_def()); + + // Add input data + net.AddRandomInput("Input", {batch, channels, height, width}); + net.AddRandomInput("Scale", {channels}); + net.AddRandomInput("Offset", {channels}); + net.AddRandomInput("Mean", {channels}); + net.AddRandomInput("Var", {channels}, true); + + // Warm-up + for (int i = 0; i < 5; ++i) { + net.RunOp(D); + } + + mace::testing::StartTiming(); + while(iters--) { + net.RunOp(D); + } +} + +#define BM_BATCH_NORM_MACRO(N, C, H, W, TYPE, DEVICE) \ + static void BM_BATCH_NORM_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE( \ + int iters) { \ + const int64_t tot = static_cast(iters) * N * C * H * W; \ + mace::testing::ItemsProcessed(tot); \ + mace::testing::BytesProcessed(tot * (sizeof(TYPE))); \ + BatchNorm(iters, N, C, H, W); \ + } \ + BENCHMARK(BM_BATCH_NORM_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE) + +#define BM_BATCH_NORM(N, C, H, W, TYPE) \ + BM_BATCH_NORM_MACRO(N, C, H, W, TYPE, CPU); \ + BM_BATCH_NORM_MACRO(N, C, H, W, TYPE, NEON); + +BM_BATCH_NORM(1, 1, 512, 512, float); +BM_BATCH_NORM(1, 1, 1024, 1024, float); +BM_BATCH_NORM(1, 3, 128, 128, float); +BM_BATCH_NORM(1, 3, 512, 512, float); +BM_BATCH_NORM(1, 3, 1024, 1024, float); +BM_BATCH_NORM(1, 64, 256, 256, float); +BM_BATCH_NORM(1, 64, 512, 512, float); +BM_BATCH_NORM(1, 128, 256, 256, float); +BM_BATCH_NORM(1, 128, 512, 512, float); +BM_BATCH_NORM(32, 1, 256, 256, float); +BM_BATCH_NORM(32, 3, 256, 256, float); +} // namespace mace \ No newline at end of file diff --git a/mace/ops/batch_norm_test.cc b/mace/ops/batch_norm_test.cc index ca32ec75..7d74aac4 100644 --- a/mace/ops/batch_norm_test.cc +++ b/mace/ops/batch_norm_test.cc @@ -9,7 +9,7 @@ namespace mace { class BatchNormOpTest : public OpsTestBase {}; -TEST_F(BatchNormOpTest, Simple) { +TEST_F(BatchNormOpTest, SimpleCPU) { // Construct graph auto& net = test_net(); OpDefBuilder("BatchNorm", "BatchNormTest") @@ -40,4 +40,42 @@ TEST_F(BatchNormOpTest, Simple) { ExpectTensorNear(expected, *net.GetOutput("Output"), 0.01); } +TEST_F(BatchNormOpTest, SimpleNeon) { + srand(time(NULL)); + + // generate random input + index_t batch = 1 + rand() % 10; + index_t channels = 3 + rand() % 50; + index_t height = 10 + rand() % 50; + index_t width = 10 + rand() % 50; + // Construct graph + auto& net = test_net(); + OpDefBuilder("BatchNorm", "BatchNormTest") + .Input("Input") + .Input("Scale") + .Input("Offset") + .Input("Mean") + .Input("Var") + .Output("Output") + .Finalize(net.operator_def()); + + // Add input data + net.AddRandomInput("Input", {batch, channels, height, width}); + net.AddRandomInput("Scale", {channels}); + net.AddRandomInput("Offset", {channels}); + net.AddRandomInput("Mean", {channels}); + net.AddRandomInput("Var", {channels}, true); + + // run cpu + net.RunOp(); + + // Check + Tensor expected = *net.GetOutput("Output"); + + // Run NEON + net.RunOp(DeviceType::NEON); + + ExpectTensorNear(expected, *net.GetOutput("Output"), 1e-5); +} + } diff --git a/mace/ops/conv_2d_benchmark.cc b/mace/ops/conv_2d_benchmark.cc index 1fe070e2..772dd200 100644 --- a/mace/ops/conv_2d_benchmark.cc +++ b/mace/ops/conv_2d_benchmark.cc @@ -62,5 +62,6 @@ static void Conv2d(int iters, int batch, int channels, int height, int width, BM_CONV_2D(1, 64, 32, 32, 1, 1, 1, VALID, 128, float); BM_CONV_2D(1, 64, 32, 32, 3, 3, 1, VALID, 128, float); +BM_CONV_2D(1, 64, 32, 32, 3, 3, 1, SAME, 128, float); } // namespace mace diff --git a/mace/ops/conv_2d_test.cc b/mace/ops/conv_2d_test.cc index 43228e6a..c516dd73 100644 --- a/mace/ops/conv_2d_test.cc +++ b/mace/ops/conv_2d_test.cc @@ -192,3 +192,47 @@ TEST_F(Conv2dOpTest, Conv1x1) { } // TODO we need more tests +TEST_F(Conv2dOpTest, Conv3x3R1) { + auto func = [&](Padding type) { + srand(time(NULL)); + + // generate random input + index_t batch = 1 + rand() % 5; + index_t input_channels = 3 + rand() % 50; + index_t height = 10 + rand() % 100; + index_t width = 10 + rand() % 100; + index_t output_channels = 3 + rand() % 50; + // Construct graph + auto& net = test_net(); + OpDefBuilder("Conv2d", "Conv2dTest") + .Input("Input") + .Input("Filter") + .Input("Bias") + .Output("Output") + .Finalize(net.operator_def()); + + // Add args + net.AddIntsArg("strides", {1, 1}); + net.AddIntArg("padding", type); + net.AddIntsArg("dilations", {1, 1}); + + // Add input data + net.AddRandomInput("Input", {batch, input_channels, height, width}); + net.AddRandomInput("Filter", {output_channels, input_channels, 3, 3}); + net.AddRandomInput("Bias", {output_channels}); + // run cpu + net.RunOp(); + + // Check + Tensor expected = *net.GetOutput("Output"); + + // Run NEON + net.RunOp(DeviceType::NEON); + + ExpectTensorNear(expected, *net.GetOutput("Output"), 1e-5); + + }; + + func(VALID); + func(SAME); +} diff --git a/mace/ops/ops_test_util.h b/mace/ops/ops_test_util.h index 991eee57..8d218058 100644 --- a/mace/ops/ops_test_util.h +++ b/mace/ops/ops_test_util.h @@ -15,152 +15,167 @@ namespace mace { class OpDefBuilder { - public: - OpDefBuilder(const char* type, const char* name) { - op_def_.set_type(type); - op_def_.set_name(name); - } - OpDefBuilder& Input(const char* input_name) { - op_def_.add_input(input_name); - return *this; - } - OpDefBuilder& Output(const char* output_name) { - op_def_.add_output(output_name); - return *this; - } - void Finalize(OperatorDef* op_def) const { - MACE_CHECK(op_def != nullptr, "input should not be null."); - *op_def = op_def_; - } - OperatorDef op_def_; + public: + OpDefBuilder(const char *type, const char *name) { + op_def_.set_type(type); + op_def_.set_name(name); + } + + OpDefBuilder &Input(const char *input_name) { + op_def_.add_input(input_name); + return *this; + } + + OpDefBuilder &Output(const char *output_name) { + op_def_.add_output(output_name); + return *this; + } + + void Finalize(OperatorDef *op_def) const { + MACE_CHECK(op_def != nullptr, "input should not be null."); + *op_def = op_def_; + } + + OperatorDef op_def_; }; class OpsTestNet { - public: - OpsTestNet() {} - - template - void AddInputFromArray(const char* name, - const std::vector& shape, - const std::vector& data) { - Tensor* input = ws_.CreateTensor(name, cpu_allocator(), DataTypeToEnum::v()); - input->Resize(shape); - T* input_data = input->mutable_data(); - MACE_CHECK(input->size() == data.size()); - memcpy(input_data, data.data(), data.size() * sizeof(T)); - } + public: + OpsTestNet() {} + + template + void AddInputFromArray(const char *name, + const std::vector &shape, + const std::vector &data) { + Tensor *input = ws_.CreateTensor(name, cpu_allocator(), DataTypeToEnum::v()); + input->Resize(shape); + T *input_data = input->mutable_data(); + MACE_CHECK(input->size() == data.size()); + memcpy(input_data, data.data(), data.size() * sizeof(T)); + } - template - void AddRandomInput(const char* name, const std::vector& shape) { - Tensor* input = ws_.CreateTensor(name, cpu_allocator(), DataTypeToEnum::v()); - input->Resize(shape); - float* input_data = input->mutable_data(); + template + void AddRepeatedInput(const char *name, + const std::vector &shape, + const T data) { + Tensor *input = ws_.CreateTensor(name, cpu_allocator(), DataTypeToEnum::v()); + input->Resize(shape); + T *input_data = input->mutable_data(); + MACE_CHECK(input->size() == data.size()); + std::fill(input_data, input_data + input->size(), data); + } - std::random_device rd; - std::mt19937 gen(rd()); - std::normal_distribution nd(0, 1); + template + void AddRandomInput(const char *name, const std::vector &shape, bool positive = false) { + Tensor *input = ws_.CreateTensor(name, cpu_allocator(), DataTypeToEnum::v()); + input->Resize(shape); + float *input_data = input->mutable_data(); - std::generate(input_data, input_data + input->size(), - [&gen, &nd]{ return nd(gen); }); - } + std::random_device rd; + std::mt19937 gen(rd()); + std::normal_distribution nd(0, 1); - void AddIntArg(const char* name, const int value) { - auto arg = op_def_.add_arg(); - arg->set_name(name); - arg->set_i(value); - } + std::generate(input_data, input_data + input->size(), + [&gen, &nd, positive] { return positive ? std::abs(nd(gen)) : nd(gen); }); + } - void AddFloatArg(const char* name, const float value) { - auto arg = op_def_.add_arg(); - arg->set_name(name); - arg->set_f(value); - } + void AddIntArg(const char *name, const int value) { + auto arg = op_def_.add_arg(); + arg->set_name(name); + arg->set_i(value); + } - void AddStringArg(const char* name, const char* value) { - auto arg = op_def_.add_arg(); - arg->set_name(name); - arg->set_s(value); - } + void AddFloatArg(const char *name, const float value) { + auto arg = op_def_.add_arg(); + arg->set_name(name); + arg->set_f(value); + } - void AddIntsArg(const char* name, const std::vector& values) { - auto arg = op_def_.add_arg(); - arg->set_name(name); - for (auto value : values) { - arg->add_ints(value); - } + void AddStringArg(const char *name, const char *value) { + auto arg = op_def_.add_arg(); + arg->set_name(name); + arg->set_s(value); + } + + void AddIntsArg(const char *name, const std::vector &values) { + auto arg = op_def_.add_arg(); + arg->set_name(name); + for (auto value : values) { + arg->add_ints(value); } + } - void AddFloatsArg(const char* name, const std::vector& values) { - auto arg = op_def_.add_arg(); - arg->set_name(name); - for (auto value : values) { - arg->add_floats(value); - } + void AddFloatsArg(const char *name, const std::vector &values) { + auto arg = op_def_.add_arg(); + arg->set_name(name); + for (auto value : values) { + arg->add_floats(value); } + } - void AddStringsArg(const char* name, const std::vector& values) { - auto arg = op_def_.add_arg(); - arg->set_name(name); - for (auto value : values) { - arg->add_strings(value); - } + void AddStringsArg(const char *name, const std::vector &values) { + auto arg = op_def_.add_arg(); + arg->set_name(name); + for (auto value : values) { + arg->add_strings(value); } + } - OperatorDef* operator_def() { return &op_def_; } + OperatorDef *operator_def() { return &op_def_; } - Workspace* ws() { return &ws_; } + Workspace *ws() { return &ws_; } - bool RunOp(DeviceType device) { - if (!net_) { - NetDef net_def; - net_def.add_op()->CopyFrom(op_def_); - VLOG(3) << net_def.DebugString(); - net_ = CreateNet(net_def, &ws_, device); - } - return net_->Run(); + bool RunOp(DeviceType device) { + if (!net_) { + NetDef net_def; + net_def.add_op()->CopyFrom(op_def_); + VLOG(3) << net_def.DebugString(); + net_ = CreateNet(net_def, &ws_, device); } + return net_->Run(); + } - bool RunOp() { - return RunOp(DeviceType::CPU); - } + bool RunOp() { + return RunOp(DeviceType::CPU); + } - Tensor* GetOutput(const char* output_name) { - return ws_.GetTensor(output_name); - } + Tensor *GetOutput(const char *output_name) { + return ws_.GetTensor(output_name); + } - public: - Workspace ws_; - OperatorDef op_def_; - std::unique_ptr net_; + public: + Workspace ws_; + OperatorDef op_def_; + std::unique_ptr net_; }; class OpsTestBase : public ::testing::Test { - public: - OpsTestNet& test_net() { return test_net_; }; - - protected: - virtual void TearDown() { - auto ws = test_net_.ws(); - auto tensor_names = ws->Tensors(); - for (auto& name : tensor_names) { - ws->RemoveTensor(name); - } + public: + OpsTestNet &test_net() { return test_net_; }; + + protected: + virtual void TearDown() { + auto ws = test_net_.ws(); + auto tensor_names = ws->Tensors(); + for (auto &name : tensor_names) { + ws->RemoveTensor(name); } + } - private: - OpsTestNet test_net_; + private: + OpsTestNet test_net_; }; -template -Tensor CreateTensor(const std::vector& shape, const std::vector& data) { +template +Tensor CreateTensor(const std::vector &shape, const std::vector &data) { Tensor res(cpu_allocator(), DataTypeToEnum::v()); res.Resize(shape); - float* input_data = res.mutable_data(); + float *input_data = res.mutable_data(); memcpy(input_data, data.data(), data.size() * sizeof(T)); return res; } -inline bool IsSameSize(const Tensor& x, const Tensor& y) { +inline bool IsSameSize(const Tensor &x, const Tensor &y) { if (x.dim_size() != y.dim_size()) return false; for (int d = 0; d < x.dim_size(); ++d) { if (x.dim(d) != y.dim(d)) return false; @@ -168,58 +183,59 @@ inline bool IsSameSize(const Tensor& x, const Tensor& y) { return true; } -inline std::string ShapeToString(const Tensor& x) { +inline std::string ShapeToString(const Tensor &x) { std::stringstream stream; for (int i = 0; i < x.dim_size(); i++) { - if (i > 0) stream<<","; + if (i > 0) stream << ","; int64_t dim = x.dim(i); if (dim < 0) { - stream<<"?"; + stream << "?"; } else { - stream< +template struct is_floating_point_type { static const bool value = std::is_same::value || std::is_same::value; }; -template -inline void ExpectEqual(const T& a, const T& b) { +template +inline void ExpectEqual(const T &a, const T &b) { EXPECT_EQ(a, b); } -template <> -inline void ExpectEqual(const float& a, const float& b) { +template<> +inline void ExpectEqual(const float &a, const float &b) { EXPECT_FLOAT_EQ(a, b); } -template <> -inline void ExpectEqual(const double& a, const double& b) { +template<> +inline void ExpectEqual(const double &a, const double &b) { EXPECT_DOUBLE_EQ(a, b); } -inline void AssertSameTypeDims(const Tensor& x, const Tensor& y) { +inline void AssertSameTypeDims(const Tensor &x, const Tensor &y) { ASSERT_EQ(x.dtype(), y.dtype()); ASSERT_TRUE(IsSameSize(x, y)) - << "x.shape [" << ShapeToString(x) << "] vs " - << "y.shape [ " << ShapeToString(y) << "]"; + << "x.shape [" << ShapeToString(x) << "] vs " + << "y.shape [ " << ShapeToString(y) << "]"; } -template ::value> +template::value> struct Expector; + // Partial specialization for float and double. -template +template struct Expector { - static void Equal(const T& a, const T& b) { ExpectEqual(a, b); } + static void Equal(const T &a, const T &b) { ExpectEqual(a, b); } - static void Equal(const Tensor& x, const Tensor& y) { + static void Equal(const Tensor &x, const Tensor &y) { ASSERT_EQ(x.dtype(), DataTypeToEnum::v()); AssertSameTypeDims(x, y); auto a = x.data(); @@ -229,22 +245,22 @@ struct Expector { } } - static void Near(const Tensor& x, const Tensor& y, const double abs_err) { + static void Near(const Tensor &x, const Tensor &y, const double abs_err) { ASSERT_EQ(x.dtype(), DataTypeToEnum::v()); AssertSameTypeDims(x, y); auto a = x.data(); auto b = y.data(); for (int i = 0; i < x.size(); ++i) { EXPECT_NEAR(a[i], b[i], abs_err) - << "a = " << a << " b = " << b << " index = " << i; + << "a = " << a << " b = " << b << " index = " << i; } } }; -template -void ExpectTensorNear(const Tensor& x, const Tensor& y, const double abs_err) { +template +void ExpectTensorNear(const Tensor &x, const Tensor &y, const double abs_err) { static_assert(is_floating_point_type::value, "T is not a floating point type"); - Expector::Near(x, y ,abs_err); + Expector::Near(x, y, abs_err); } } // namespace mace -- GitLab