// Copyright 2018 Xiaomi, Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "mace/ops/ops_test_util.h" namespace mace { namespace ops { namespace test { class PadTest : public OpsTestBase {}; namespace { template void Simple() { // Construct graph OpsTestNet net; // Add input data net.AddRepeatedInput("Input", {1, 2, 3, 1}, 2); if (D == DeviceType::GPU) { OpDefBuilder("Pad", "PadTest") .Input("Input") .Output("Output") .AddIntsArg("paddings", {0, 0, 1, 2, 1, 2, 0, 0}) .AddFloatArg("constant_value", 1.0) .AddIntArg("data_format", DataFormat::NHWC) .Finalize(net.NewOperatorDef()); // Run net.RunOp(D); } else { net.TransformDataFormat("Input", NHWC, "TInput", NCHW); OpDefBuilder("Pad", "PadTest") .Input("TInput") .Output("TOutput") .AddIntsArg("paddings", {0, 0, 1, 2, 1, 2, 0, 0}) .AddFloatArg("constant_value", 1.0) .AddIntArg("data_format", DataFormat::NHWC) .Finalize(net.NewOperatorDef()); // Run net.RunOp(); net.TransformDataFormat("TOutput", NCHW, "Output", NHWC); } auto output = net.GetTensor("Output"); auto expected = net.CreateTensor( {1, 5, 6, 1}, { 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 2, 2, 2, 1.0, 1.0, 1.0, 2, 2, 2, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, }); ExpectTensorNear(*expected, *output, 1e-5); } } // namespace TEST_F(PadTest, SimpleCPU) { Simple(); } TEST_F(PadTest, SimpleGPU) { Simple(); } TEST_F(PadTest, ComplexCPU) { // Construct graph OpsTestNet net; // Add input data net.AddRepeatedInput("Input", {1, 1, 1, 2}, 2); net.TransformDataFormat("Input", NHWC, "TInput", NCHW); OpDefBuilder("Pad", "PadTest") .Input("TInput") .Output("TOutput") .AddIntsArg("paddings", {0, 0, 1, 1, 1, 1, 1, 1}) .AddFloatArg("constant_value", 1.0) .AddIntArg("data_format", DataFormat::NHWC) .Finalize(net.NewOperatorDef()); // Run net.RunOp(); net.TransformDataFormat("TOutput", NCHW, "Output", NHWC); auto output = net.GetTensor("Output"); auto expected = net.CreateTensor( {1, 3, 3, 4}, { 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, }); ExpectTensorNear(*expected, *output, 1e-5); } namespace { template void Complex(const std::vector &input_shape, const std::vector &paddings) { // Construct graph OpsTestNet net; // Add input data net.AddRandomInput("Input", input_shape); net.TransformDataFormat("Input", NHWC, "TInput", NCHW); OpDefBuilder("Pad", "PadTest") .Input("TInput") .Output("TOutput") .AddIntsArg("paddings", paddings) .AddFloatArg("constant_value", 1.0) .AddIntArg("data_format", DataFormat::NHWC) .Finalize(net.NewOperatorDef()); // Run net.RunOp(); net.TransformDataFormat("TOutput", NCHW, "Output", NHWC); auto expected = net.CreateTensor(); expected->Copy(*net.GetOutput("Output")); OpDefBuilder("Pad", "PadTest") .Input("Input") .Output("Output") .AddIntsArg("paddings", paddings) .AddFloatArg("constant_value", 1.0) .AddIntArg("data_format", DataFormat::NHWC) .Finalize(net.NewOperatorDef()); // Run net.RunOp(DeviceType::GPU); auto output = net.GetTensor("Output"); if (DataTypeToEnum::value == DT_HALF) { ExpectTensorNear(*expected, *output, 1e-2, 1e-2); } else { ExpectTensorNear(*expected, *output, 1e-5); } } } // namespace TEST_F(PadTest, ComplexFloat) { Complex({1, 32, 32, 4}, {0, 0, 2, 2, 1, 1, 0, 0}); Complex({1, 31, 37, 16}, {0, 0, 2, 0, 1, 0, 0, 0}); Complex({1, 128, 128, 32}, {0, 0, 0, 1, 0, 2, 0, 0}); } TEST_F(PadTest, ComplexHalf) { Complex({1, 32, 32, 4}, {0, 0, 2, 2, 1, 1, 0, 0}); Complex({1, 31, 37, 16}, {0, 0, 2, 0, 1, 0, 0, 0}); Complex({1, 128, 128, 32}, {0, 0, 0, 1, 0, 2, 0, 0}); } } // namespace test } // namespace ops } // namespace mace