From 7fbe73616a09d0283438a4843c0ad2e480ee47ba Mon Sep 17 00:00:00 2001 From: Liangliang He Date: Wed, 13 Sep 2017 21:06:01 +0800 Subject: [PATCH] Refactor ops test util --- mace/kernels/conv_pool_2d_util.cc | 22 +++++---- mace/kernels/conv_pool_2d_util.h | 4 +- mace/ops/batch_norm_test.cc | 17 +++---- mace/ops/conv_2d.h | 8 ++-- mace/ops/conv_2d_benchmark.cc | 7 ++- mace/ops/conv_2d_test.cc | 76 ++++++++++++++++--------------- mace/ops/ops_test_util.h | 47 +++++++++++++------ mace/ops/pooling.h | 8 ++-- mace/ops/pooling_test.cc | 76 ++++++++++++++++--------------- mace/ops/resize_bilinear_test.cc | 24 +++++----- 10 files changed, 162 insertions(+), 127 deletions(-) diff --git a/mace/kernels/conv_pool_2d_util.cc b/mace/kernels/conv_pool_2d_util.cc index 7e9e272f..fe6f4e68 100644 --- a/mace/kernels/conv_pool_2d_util.cc +++ b/mace/kernels/conv_pool_2d_util.cc @@ -12,20 +12,23 @@ void CalcPaddingAndOutputSize(const index_t* input_shape, // NCHW const int* dilations, const int* strides, Padding padding, - std::vector* output_shape, - std::vector* padding_size) { + index_t* output_shape, + int* padding_size) { MACE_CHECK(dilations[0] > 0 && dilations[1] > 0, "Invalid dilations, must >= 1"); MACE_CHECK((dilations[0] == 1 || strides[0] == 1) && (dilations[1] == 1 || strides[1] == 1), "If dilations > 1, strides should be 1"); + MACE_CHECK_NOTNULL(output_shape); + MACE_CHECK_NOTNULL(padding_size); /* * Convlution/pooling arithmetic: * o = (i + 2 * p - k - (k - 1) * (d - 1)) / s + 1 * For details, see https://arxiv.org/pdf/1603.07285.pdf or * http://deeplearning.net/software/theano/tutorial/conv_arithmetic.html */ - *padding_size = {0, 0}; + padding_size[0] = 0; + padding_size[1] = 0; index_t output_height, output_width; index_t kernel_height = filter_shape[2]; @@ -57,16 +60,15 @@ void CalcPaddingAndOutputSize(const index_t* input_shape, // NCHW // utilize the more centered features. We need to benchmark // based on the model accuracy. - (*padding_size)[0] = (output_height - 1) * strides[0] + + padding_size[0] = (output_height - 1) * strides[0] + k_extent_height - input_shape[2]; - (*padding_size)[1] = (output_width - 1) * strides[1] + + padding_size[1] = (output_width - 1) * strides[1] + k_extent_width - input_shape[3]; - *output_shape = std::vector(4); // NCHW - (*output_shape)[0] = input_shape[0]; - (*output_shape)[1] = output_channels; - (*output_shape)[2] = output_height; - (*output_shape)[3] = output_width; + output_shape[0] = input_shape[0]; + output_shape[1] = output_channels; + output_shape[2] = output_height; + output_shape[3] = output_width; } } // namespace kernels diff --git a/mace/kernels/conv_pool_2d_util.h b/mace/kernels/conv_pool_2d_util.h index c1c5154c..3cca8a79 100644 --- a/mace/kernels/conv_pool_2d_util.h +++ b/mace/kernels/conv_pool_2d_util.h @@ -22,8 +22,8 @@ void CalcPaddingAndOutputSize(const index_t* input_shape, // NCHW const int* dilations, const int* strides, Padding padding, - std::vector* output_shape, - std::vector* padding_size); + index_t* output_shape, + int* padding_size); } // namespace kernels } // namespace mace diff --git a/mace/ops/batch_norm_test.cc b/mace/ops/batch_norm_test.cc index ef89fcee..21ee8c56 100644 --- a/mace/ops/batch_norm_test.cc +++ b/mace/ops/batch_norm_test.cc @@ -11,6 +11,7 @@ class BatchNormOpTest : public OpsTestBase {}; TEST_F(BatchNormOpTest, Simple) { // Construct graph + auto net = test_net(); OpDefBuilder("BatchNorm", "BatchNormTest") .Input("Input") .Input("Scale") @@ -18,25 +19,25 @@ TEST_F(BatchNormOpTest, Simple) { .Input("Mean") .Input("Var") .Output("Output") - .Finalize(operator_def()); + .Finalize(net->operator_def()); // Add input data - AddInputFromArray("Input", {1, 1, 6, 2}, + net->AddInputFromArray("Input", {1, 1, 6, 2}, {5, 5, 7, 7, 9, 9, 11, 11, 13, 13, 15, 15}); - AddInputFromArray("Scale", {1}, {4.0f}); - AddInputFromArray("Offset", {1}, {2.0}); - AddInputFromArray("Mean", {1}, {10}); - AddInputFromArray("Var", {1}, {11.67f}); + net->AddInputFromArray("Scale", {1}, {4.0f}); + net->AddInputFromArray("Offset", {1}, {2.0}); + net->AddInputFromArray("Mean", {1}, {10}); + net->AddInputFromArray("Var", {1}, {11.67f}); // Run - RunOp(); + net->RunOp(); // Check Tensor expected = CreateTensor({1, 1, 6, 2}, {-3.86, -3.86, -1.51, -1.51, 0.83, 0.83, 3.17, 3.17, 5.51, 5.51, 7.86, 7.86}); - ExpectTensorNear(expected, *GetOutput("Output"), 0.01); + ExpectTensorNear(expected, *net->GetOutput("Output"), 0.01); } } diff --git a/mace/ops/conv_2d.h b/mace/ops/conv_2d.h index 14b9c8ee..6ae1e06f 100644 --- a/mace/ops/conv_2d.h +++ b/mace/ops/conv_2d.h @@ -25,15 +25,15 @@ class Conv2dOp : public ConvPool2dOpBase { const Tensor* bias = this->Input(BIAS); Tensor* output = this->Output(OUTPUT); - std::vector output_shape; - std::vector paddings; + std::vector output_shape(4); + std::vector paddings(2); kernels::CalcPaddingAndOutputSize(input->shape().data(), filter->shape().data(), this->dilations_.data(), this->strides_.data(), this->padding_, - &output_shape, - &paddings); + output_shape.data(), + paddings.data()); output->Resize(output_shape); auto conv2d = kernels::Conv2dFunctor(this->strides_.data(), diff --git a/mace/ops/conv_2d_benchmark.cc b/mace/ops/conv_2d_benchmark.cc index 6c347c27..49c80b67 100644 --- a/mace/ops/conv_2d_benchmark.cc +++ b/mace/ops/conv_2d_benchmark.cc @@ -2,10 +2,14 @@ // Copyright (c) 2017 XiaoMi All rights reserved. // +#include + #include "mace/core/testing/test_benchmark.h" -#include "mace/ops/conv_2d.h" +#include "mace/kernels/conv_2d.h" +#include "mace/kernels/conv_pool_2d_util.h" namespace mace { +namespace kernels { template static void Conv2d(int iters, int batch, int channels, int height, int width, @@ -34,4 +38,5 @@ static void Conv2d(int iters, int batch, int channels, int height, int width, BM_CONV_2D(1, 64, 32, 32, 1, 1, 1, VALID, 128, float); +} // namespace kernels } // namespace mace diff --git a/mace/ops/conv_2d_test.cc b/mace/ops/conv_2d_test.cc index 797075f2..1aec07f8 100644 --- a/mace/ops/conv_2d_test.cc +++ b/mace/ops/conv_2d_test.cc @@ -12,71 +12,73 @@ class Conv2dOpTest : public OpsTestBase {}; TEST_F(Conv2dOpTest, Simple_VALID) { // Construct graph + auto net = test_net(); OpDefBuilder("Conv2d", "Conv2dTest") .Input("Input") .Input("Filter") .Input("Bias") .Output("Output") - .Finalize(operator_def()); + .Finalize(net->operator_def()); // Add args - AddIntsArg("strides", {1, 1}); - AddIntArg("padding", Padding::VALID); - AddIntsArg("dilations", {1, 1}); + net->AddIntsArg("strides", {1, 1}); + net->AddIntArg("padding", Padding::VALID); + net->AddIntsArg("dilations", {1, 1}); // Add input data - AddInputFromArray("Input", {1, 2, 3, 3}, + net->AddInputFromArray("Input", {1, 2, 3, 3}, {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}); - AddInputFromArray("Filter", {1, 2, 3, 3}, + net->AddInputFromArray("Filter", {1, 2, 3, 3}, {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}); - AddInputFromArray("Bias", {1}, {0.1f}); + net->AddInputFromArray("Bias", {1}, {0.1f}); // Run - RunOp(); + net->RunOp(); // Check Tensor expected = CreateTensor({1, 1, 1, 1}, {18.1f}); - ExpectTensorNear(expected, *GetOutput("Output"), 0.001); + ExpectTensorNear(expected, *net->GetOutput("Output"), 0.001); } TEST_F(Conv2dOpTest, Simple_SAME) { // Construct graph + auto net = test_net(); OpDefBuilder("Conv2d", "Conv2dTest") .Input("Input") .Input("Filter") .Input("Bias") .Output("Output") - .Finalize(operator_def()); + .Finalize(net->operator_def()); // Add args - AddIntsArg("strides", {1, 1}); - AddIntArg("padding", Padding::SAME); - AddIntsArg("dilations", {1, 1}); + net->AddIntsArg("strides", {1, 1}); + net->AddIntArg("padding", Padding::SAME); + net->AddIntsArg("dilations", {1, 1}); // Add input data - AddInputFromArray("Input", {1, 2, 3, 3}, + net->AddInputFromArray("Input", {1, 2, 3, 3}, {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}); - AddInputFromArray("Filter", {1, 2, 3, 3}, + net->AddInputFromArray("Filter", {1, 2, 3, 3}, {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}); - AddInputFromArray("Bias", {1}, {0.1f}); + net->AddInputFromArray("Bias", {1}, {0.1f}); // Run - RunOp(); + net->RunOp(); // Check Tensor expected = CreateTensor({1, 1, 3, 3}, @@ -84,25 +86,26 @@ TEST_F(Conv2dOpTest, Simple_SAME) { 12.1f, 18.1f, 12.1f, 8.1f, 12.1f, 8.1f}); - ExpectTensorNear(expected, *GetOutput("Output"), 0.001); + ExpectTensorNear(expected, *net->GetOutput("Output"), 0.001); } TEST_F(Conv2dOpTest, Combined) { // Construct graph + auto net = test_net(); OpDefBuilder("Conv2d", "Conv2dTest") .Input("Input") .Input("Filter") .Input("Bias") .Output("Output") - .Finalize(operator_def()); + .Finalize(net->operator_def()); // Add args - AddIntsArg("strides", {2, 2}); - AddIntArg("padding", Padding::SAME); - AddIntsArg("dilations", {1, 1}); + net->AddIntsArg("strides", {2, 2}); + net->AddIntArg("padding", Padding::SAME); + net->AddIntsArg("dilations", {1, 1}); // Add input data - AddInputFromArray("Input", {1, 2, 5, 5}, + net->AddInputFromArray("Input", {1, 2, 5, 5}, {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, @@ -113,15 +116,15 @@ TEST_F(Conv2dOpTest, Combined) { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}); - AddInputFromArray("Filter", {2, 2, 3, 3}, + net->AddInputFromArray("Filter", {2, 2, 3, 3}, {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f}); - AddInputFromArray("Bias", {2}, {0.1f, 0.2f}); + net->AddInputFromArray("Bias", {2}, {0.1f, 0.2f}); // Run - RunOp(); + net->RunOp(); // Check Tensor expected = CreateTensor({1, 2, 3, 3}, @@ -133,25 +136,26 @@ TEST_F(Conv2dOpTest, Combined) { 4.2f, 6.2f, 4.2f}); - ExpectTensorNear(expected, *GetOutput("Output"), 0.001); + ExpectTensorNear(expected, *net->GetOutput("Output"), 0.001); } TEST_F(Conv2dOpTest, Conv1x1) { // Construct graph + auto net = test_net(); OpDefBuilder("Conv2d", "Conv2dTest") .Input("Input") .Input("Filter") .Input("Bias") .Output("Output") - .Finalize(operator_def()); + .Finalize(net->operator_def()); // Add args - AddIntsArg("strides", {1, 1}); - AddIntArg("padding", Padding::VALID); - AddIntsArg("dilations", {1, 1}); + net->AddIntsArg("strides", {1, 1}); + net->AddIntArg("padding", Padding::VALID); + net->AddIntsArg("dilations", {1, 1}); // Add input data - AddInputFromArray("Input", {1, 5, 3, 10}, + net->AddInputFromArray("Input", {1, 5, 3, 10}, {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, @@ -167,13 +171,13 @@ TEST_F(Conv2dOpTest, Conv1x1) { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}); - AddInputFromArray("Filter", {2, 5, 1, 1}, + net->AddInputFromArray("Filter", {2, 5, 1, 1}, {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f}); - AddInputFromArray("Bias", {2}, {0.1f, 0.2f}); + net->AddInputFromArray("Bias", {2}, {0.1f, 0.2f}); // Run - RunOp(DeviceType::NEON); + net->RunOp(); // Check Tensor expected = CreateTensor({1, 2, 3, 10}, @@ -184,7 +188,7 @@ TEST_F(Conv2dOpTest, Conv1x1) { 10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f}); - ExpectTensorNear(expected, *GetOutput("Output"), 0.001); + ExpectTensorNear(expected, *net->GetOutput("Output"), 0.001); } // TODO we need more tests diff --git a/mace/ops/ops_test_util.h b/mace/ops/ops_test_util.h index 9dd44c25..14ec485c 100644 --- a/mace/ops/ops_test_util.h +++ b/mace/ops/ops_test_util.h @@ -35,17 +35,12 @@ class OpDefBuilder { OperatorDef op_def_; }; -class OpsTestBase : public ::testing::Test { - protected: - virtual void TearDown() { - auto tensor_names = ws_.Tensors(); - for (auto& name : tensor_names) { - ws_.RemoveTensor(name); - } - } +class OpsTestNet { public: template - void AddInputFromArray(const char* name, const std::vector& shape, const std::vector& data) { + void AddInputFromArray(const char* name, + const std::vector& shape, + const std::vector& data) { Tensor* input = ws_.CreateTensor(name, cpu_allocator(), DataTypeToEnum::v()); input->Resize(shape); float* input_data = input->mutable_data(); @@ -97,12 +92,16 @@ class OpsTestBase : public ::testing::Test { OperatorDef* operator_def() { return &op_def_; } + Workspace* ws() { return &ws_; } + bool RunOp(DeviceType device) { - NetDef net_def; - net_def.add_op()->CopyFrom(op_def_); - VLOG(0) << net_def.DebugString(); - auto net = CreateNet(net_def, &ws_, device); - return net->Run(); + if (!net_) { + NetDef net_def; + net_def.add_op()->CopyFrom(op_def_); + VLOG(0) << net_def.DebugString(); + net_ = CreateNet(net_def, &ws_, device); + } + return net_->Run(); } bool RunOp() { @@ -113,9 +112,27 @@ class OpsTestBase : public ::testing::Test { return ws_.GetTensor(output_name); } - private: + public: Workspace ws_; OperatorDef op_def_; + std::unique_ptr net_; +}; + +class OpsTestBase : public ::testing::Test { + public: + OpsTestNet* test_net() { return &test_net_; }; + + protected: + virtual void TearDown() { + auto ws = test_net_.ws(); + auto tensor_names = ws->Tensors(); + for (auto& name : tensor_names) { + ws->RemoveTensor(name); + } + } + + private: + OpsTestNet test_net_; }; template diff --git a/mace/ops/pooling.h b/mace/ops/pooling.h index 0c36b546..4d0001df 100644 --- a/mace/ops/pooling.h +++ b/mace/ops/pooling.h @@ -26,8 +26,8 @@ public: Tensor* output = this->Output(OUTPUT); std::vector in_shape = input->shape(); - std::vector output_shape; - std::vector paddings; + std::vector output_shape(4); + std::vector paddings(2); std::vector filter_shape = std::vector(4); filter_shape[0] = in_shape[1]; filter_shape[1] = in_shape[0]; @@ -38,8 +38,8 @@ public: this->dilations_.data(), this->strides_.data(), this->padding_, - &output_shape, - &paddings); + output_shape.data(), + paddings.data()); output->Resize(output_shape); auto pooling_func = kernels::PoolingFunctor(pooling_type_, diff --git a/mace/ops/pooling_test.cc b/mace/ops/pooling_test.cc index f56bff61..03831c7d 100644 --- a/mace/ops/pooling_test.cc +++ b/mace/ops/pooling_test.cc @@ -15,20 +15,21 @@ class PoolingOpTest : public OpsTestBase {}; TEST_F(PoolingOpTest, MAX_VALID) { // Construct graph + auto net = test_net(); OpDefBuilder("Pooling", "PoolingTest") .Input("Input") .Output("Output") - .Finalize(operator_def()); + .Finalize(net->operator_def()); // Add args - AddIntsArg("kernels", {2, 2}); - AddIntsArg("strides", {2, 2}); - AddIntArg("padding", Padding::VALID); - AddIntsArg("dilations", {1, 1}); - AddIntArg("pooling_type", PoolingType::MAX); + net->AddIntsArg("kernels", {2, 2}); + net->AddIntsArg("strides", {2, 2}); + net->AddIntArg("padding", Padding::VALID); + net->AddIntsArg("dilations", {1, 1}); + net->AddIntArg("pooling_type", PoolingType::MAX); // Add input data - AddInputFromArray("Input", {1, 2, 4, 4}, + net->AddInputFromArray("Input", {1, 2, 4, 4}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, @@ -39,32 +40,33 @@ TEST_F(PoolingOpTest, MAX_VALID) { 28, 29, 30, 31}); // Run - RunOp(); + net->RunOp(); // Check Tensor expected = CreateTensor({1, 2, 2, 2}, {5, 7, 13, 15, 21, 23, 29, 31}); - ExpectTensorNear(expected, *GetOutput("Output"), 0.001); + ExpectTensorNear(expected, *net->GetOutput("Output"), 0.001); } TEST_F(PoolingOpTest, AVG_VALID) { // Construct graph + auto net = test_net(); OpDefBuilder("Pooling", "PoolingTest") .Input("Input") .Output("Output") - .Finalize(operator_def()); + .Finalize(net->operator_def()); // Add args - AddIntsArg("kernels", {2, 2}); - AddIntsArg("strides", {2, 2}); - AddIntArg("padding", Padding::VALID); - AddIntsArg("dilations", {1, 1}); - AddIntArg("pooling_type", PoolingType::AVG); + net->AddIntsArg("kernels", {2, 2}); + net->AddIntsArg("strides", {2, 2}); + net->AddIntArg("padding", Padding::VALID); + net->AddIntsArg("dilations", {1, 1}); + net->AddIntArg("pooling_type", PoolingType::AVG); // Add input data - AddInputFromArray("Input", {1, 2, 4, 4}, + net->AddInputFromArray("Input", {1, 2, 4, 4}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, @@ -75,72 +77,74 @@ TEST_F(PoolingOpTest, AVG_VALID) { 28, 29, 30, 31}); // Run - RunOp(); + net->RunOp(); // Check Tensor expected = CreateTensor({1, 2, 2, 2}, {2.5, 4.5, 10.5, 12.5, 18.5, 20.5, 26.5, 28.5}); - ExpectTensorNear(expected, *GetOutput("Output"), 0.001); + ExpectTensorNear(expected, *net->GetOutput("Output"), 0.001); } TEST_F(PoolingOpTest, MAX_SAME) { // Construct graph + auto net = test_net(); OpDefBuilder("Pooling", "PoolingTest") .Input("Input") .Output("Output") - .Finalize(operator_def()); + .Finalize(net->operator_def()); // Add args - AddIntsArg("kernels", {2, 2}); - AddIntsArg("strides", {2, 2}); - AddIntArg("padding", Padding::SAME); - AddIntsArg("dilations", {1, 1}); - AddIntArg("pooling_type", PoolingType::MAX); + net->AddIntsArg("kernels", {2, 2}); + net->AddIntsArg("strides", {2, 2}); + net->AddIntArg("padding", Padding::SAME); + net->AddIntsArg("dilations", {1, 1}); + net->AddIntArg("pooling_type", PoolingType::MAX); // Add input data - AddInputFromArray("Input", {1, 1, 3, 3}, + net->AddInputFromArray("Input", {1, 1, 3, 3}, {0, 1, 2, 3, 4, 5, 6, 7, 8}); // Run - RunOp(); + net->RunOp(); // Check Tensor expected = CreateTensor({1, 1, 2, 2}, {4, 5, 7, 8}); - ExpectTensorNear(expected, *GetOutput("Output"), 0.001); + ExpectTensorNear(expected, *net->GetOutput("Output"), 0.001); } TEST_F(PoolingOpTest, MAX_VALID_DILATION) { // Construct graph + auto net = test_net(); OpDefBuilder("Pooling", "PoolingTest") .Input("Input") .Output("Output") - .Finalize(operator_def()); + .Finalize(net->operator_def()); // Add args - AddIntsArg("kernels", {2, 2}); - AddIntsArg("strides", {1, 1}); - AddIntArg("padding", Padding::VALID); - AddIntsArg("dilations", {2, 2}); - AddIntArg("pooling_type", PoolingType::MAX); + net->AddIntsArg("kernels", {2, 2}); + net->AddIntsArg("strides", {1, 1}); + net->AddIntArg("padding", Padding::VALID); + net->AddIntsArg("dilations", {2, 2}); + net->AddIntArg("pooling_type", PoolingType::MAX); // Add input data - AddInputFromArray("Input", {1, 1, 4, 4}, + net->AddInputFromArray("Input", {1, 1, 4, 4}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}); // Run - RunOp(); + net->RunOp(); // Check Tensor expected = CreateTensor({1, 1, 2, 2}, {10, 11, 14, 15}); - ExpectTensorNear(expected, *GetOutput("Output"), 0.001); + ExpectTensorNear(expected, *net->GetOutput("Output"), 0.001); } diff --git a/mace/ops/resize_bilinear_test.cc b/mace/ops/resize_bilinear_test.cc index a2567699..d6ee33b6 100644 --- a/mace/ops/resize_bilinear_test.cc +++ b/mace/ops/resize_bilinear_test.cc @@ -13,49 +13,51 @@ class ResizeBilinearTest : public OpsTestBase {}; TEST_F(ResizeBilinearTest, ResizeBilinearWOAlignCorners) { testing::internal::LogToStderr(); // Construct graph + auto net = test_net(); OpDefBuilder("ResizeBilinear", "ResizeBilinearTest") .Input("Input") .Input("OutSize") .Output("Output") - .Finalize(operator_def()); + .Finalize(net->operator_def()); // Add input data vector input(24); std::iota(begin(input), end(input), 0); - AddInputFromArray("Input", {1, 3, 2, 4}, input); - AddInputFromArray("OutSize", {2}, {1, 2}); + net->AddInputFromArray("Input", {1, 3, 2, 4}, input); + net->AddInputFromArray("OutSize", {2}, {1, 2}); // Run - RunOp(); + net->RunOp(); // Check Tensor expected = CreateTensor({1, 3, 1, 2}, {0, 2, 8, 10, 16, 18}); - ExpectTensorNear(expected, *GetOutput("Output"), 0.001); + ExpectTensorNear(expected, *net->GetOutput("Output"), 0.001); } TEST_F(ResizeBilinearTest, ResizeBilinearWAlignCorners) { testing::internal::LogToStderr(); // Construct graph + auto net = test_net(); OpDefBuilder("ResizeBilinear", "ResizeBilinearTest") .Input("Input") .Input("OutSize") .Output("Output") - .Finalize(operator_def()); + .Finalize(net->operator_def()); - AddIntArg("align_corners", 1); + net->AddIntArg("align_corners", 1); // Add input data vector input(24); std::iota(begin(input), end(input), 0); - AddInputFromArray("Input", {1, 3, 2, 4}, input); - AddInputFromArray("OutSize", {2}, {1, 2}); + net->AddInputFromArray("Input", {1, 3, 2, 4}, input); + net->AddInputFromArray("OutSize", {2}, {1, 2}); // Run - RunOp(); + net->RunOp(); // Check Tensor expected = CreateTensor({1, 3, 1, 2}, {0, 3, 8, 11, 16, 19}); - ExpectTensorNear(expected, *GetOutput("Output"), 0.001); + ExpectTensorNear(expected, *net->GetOutput("Output"), 0.001); } -- GitLab