// Copyright 2018 The MACE Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "gmock/gmock.h" #include "mace/ops/ops_test_util.h" namespace mace { namespace ops { namespace test { class ExtractPoolingTest : public OpsTestBase {}; namespace { template void TestExtractPooling(const std::vector &input_shape, const std::vector &input_value, const int modulus, const int num_log_count, const int include_variance, const std::vector &input_time_range, const std::vector &input_indexes, const std::vector &forward_indexes, const std::vector &counts, const std::vector &output_indexes, const std::vector &output_time_range, const std::vector &output_shape, const std::vector &output_value) { // Construct graph OpsTestNet net; net.AddInputFromArray("Input", input_shape, input_value); OpDefBuilder("ExtractPooling", "ExtractPoolingTest") .Input("Input") .AddIntArg("modulus", modulus) .AddIntArg("include_variance", include_variance) .AddIntArg("num_log_count", num_log_count) .AddIntsArg("input_indexes", input_indexes) .AddIntsArg("output_indexes", output_indexes) .AddIntsArg("forward_indexes", forward_indexes) .AddFloatsArg("counts", counts) .AddIntsArg("input_time_range", input_time_range) .AddIntsArg("output_time_range", output_time_range) .Output("Output") .Finalize(net.NewOperatorDef()); // Run net.RunOp(); // Check auto expected = net.CreateTensor(output_shape, output_value); ExpectTensorNear(*expected, *net.GetOutput("Output"), 1e-5); } } // namespace TEST_F(ExtractPoolingTest, SimpleCPU) { TestExtractPooling( {3, 20, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60}, 9, 0, 0, {-2, 17}, {0, 3, 6, 9, 12, 15}, {0, 6, 2, 6}, {6, 4}, {0, 9}, {0, 17}, {3, 18, 3}, {29.5, 30.5, 31.5, 29.5, 30.5, 31.5, 29.5, 30.5, 31.5, 29.5, 30.5, 31.5, 29.5, 30.5, 31.5, 29.5, 30.5, 31.5, 29.5, 30.5, 31.5, 29.5, 30.5, 31.5, 29.5, 30.5, 31.5, 38.5, 39.5, 40.5, 38.5, 39.5, 40.5, 38.5, 39.5, 40.5, 38.5, 39.5, 40.5, 38.5, 39.5, 40.5, 38.5, 39.5, 40.5, 38.5, 39.5, 40.5, 38.5, 39.5, 40.5, 38.5, 39.5, 40.5, 29.5, 30.5, 31.5, 29.5, 30.5, 31.5, 29.5, 30.5, 31.5, 29.5, 30.5, 31.5, 29.5, 30.5, 31.5, 29.5, 30.5, 31.5, 29.5, 30.5, 31.5, 29.5, 30.5, 31.5, 29.5, 30.5, 31.5, 38.5, 39.5, 40.5, 38.5, 39.5, 40.5, 38.5, 39.5, 40.5, 38.5, 39.5, 40.5, 38.5, 39.5, 40.5, 38.5, 39.5, 40.5, 38.5, 39.5, 40.5, 38.5, 39.5, 40.5, 38.5, 39.5, 40.5, 29.5, 30.5, 31.5, 29.5, 30.5, 31.5, 29.5, 30.5, 31.5, 29.5, 30.5, 31.5, 29.5, 30.5, 31.5, 29.5, 30.5, 31.5, 29.5, 30.5, 31.5, 29.5, 30.5, 31.5, 29.5, 30.5, 31.5, 38.5, 39.5, 40.5, 38.5, 39.5, 40.5, 38.5, 39.5, 40.5, 38.5, 39.5, 40.5, 38.5, 39.5, 40.5, 38.5, 39.5, 40.5, 38.5, 39.5, 40.5, 38.5, 39.5, 40.5, 38.5, 39.5, 40.5}); } TEST_F(ExtractPoolingTest, SimpleCPUWithVariance) { TestExtractPooling( {3, 20, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60}, 9, 1, 1, {-2, 17}, {0, 3, 6, 9, 12, 15}, {0, 6, 2, 6}, {6, 4}, {0, 9}, {0, 17}, {3, 18, 7}, {1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704, 1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704, 1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704, 1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704, 1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704, 1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704, 1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704, 1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704, 1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704, 1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623, 1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623, 1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623, 1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623, 1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623, 1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623, 1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623, 1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623, 1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623, 1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704, 1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704, 1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704, 1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704, 1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704, 1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704, 1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704, 1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704, 1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704, 1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623, 1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623, 1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623, 1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623, 1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623, 1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623, 1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623, 1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623, 1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623, 1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704, 1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704, 1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704, 1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704, 1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704, 1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704, 1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704, 1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704, 1.79176, 29.5, 30.5, 31.5, 15.3704, 15.3704, 15.3704, 1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623, 1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623, 1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623, 1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623, 1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623, 1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623, 1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623, 1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623, 1.386294, 38.5, 39.5, 40.5, 10.0623, 10.0623, 10.0623}); } } // namespace test } // namespace ops } // namespace mace