From 021e1e25591e85c7b176a5d07baacafce31b1188 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=8F=B6=E5=89=91=E6=AD=A6?= Date: Fri, 18 Oct 2019 19:32:56 +0800 Subject: [PATCH] support tf.tile fix type error fix type warning add TileOp to MaceTransposableDataFormatOps and reduce memcpy in TileOp --- mace/ops/registry/ops_registry.cc | 2 + mace/ops/tile.cc | 127 +++++++++++++++++ test/ccbenchmark/mace/ops/tile_benchmark.cc | 76 +++++++++++ test/ccunit/mace/ops/tile_test.cc | 128 ++++++++++++++++++ tools/python/transform/base_converter.py | 4 +- .../python/transform/tensorflow_converter.py | 6 + 6 files changed, 342 insertions(+), 1 deletion(-) create mode 100644 mace/ops/tile.cc create mode 100644 test/ccbenchmark/mace/ops/tile_benchmark.cc create mode 100644 test/ccunit/mace/ops/tile_test.cc diff --git a/mace/ops/registry/ops_registry.cc b/mace/ops/registry/ops_registry.cc index 536fc296..b7b7f5f6 100644 --- a/mace/ops/registry/ops_registry.cc +++ b/mace/ops/registry/ops_registry.cc @@ -74,6 +74,7 @@ extern void RegisterStack(OpRegistryBase *op_registry); extern void RegisterStridedSlice(OpRegistryBase *op_registry); extern void RegisterSumGroup(OpRegistryBase *op_registry); extern void RegisterTargetRMSNorm(OpRegistryBase *op_registry); +extern void RegisterTile(OpRegistryBase *op_registry); extern void RegisterTranspose(OpRegistryBase *op_registry); extern void RegisterUnstack(OpRegistryBase *op_registry); extern void RegisterUnsqueeze(OpRegistryBase *op_registry); @@ -148,6 +149,7 @@ OpRegistry::OpRegistry() : OpRegistryBase() { ops::RegisterSqueeze(this); ops::RegisterSumGroup(this); ops::RegisterTargetRMSNorm(this); + ops::RegisterTile(this); ops::RegisterTranspose(this); ops::RegisterUnstack(this); ops::RegisterUnsqueeze(this); diff --git a/mace/ops/tile.cc b/mace/ops/tile.cc new file mode 100644 index 00000000..36d0bfe9 --- /dev/null +++ b/mace/ops/tile.cc @@ -0,0 +1,127 @@ +// Copyright 2018 The MACE Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "mace/core/operator.h" +#include "mace/utils/memory.h" + +namespace mace { +namespace ops { + +template +class TileOp : public Operation { + public: + explicit TileOp(OpConstructContext *context) + : Operation(context), + has_data_format_(Operation::GetOptionalArg("has_data_format", 0)) { + } + + MaceStatus Run(OpContext *context) override { + MACE_UNUSED(context); + const Tensor *input = this->Input(0); + const Tensor *multiples = this->Input(1); + const index_t input_dims = input->dim_size(); + + Tensor::MappingGuard input_guard(input); + Tensor::MappingGuard multiples_guard(multiples); + const T *input_data = input->data(); + const int32_t *multiples_data = multiples->data(); + + MACE_CHECK(multiples->dim_size() == 1, "multiples must be 1-dimensional. ", + multiples->dim_size()); + MACE_CHECK(input_dims == multiples->size(), + "multiples length must be the same as the dim_size of input", + input_dims, " vs. ", multiples->size()); + + std::vector multiples_vec(multiples_data, + multiples_data + multiples->size()); + if (has_data_format_ && input_dims) { + int32_t h = multiples_vec[1]; + int32_t w = multiples_vec[2]; + int32_t c = multiples_vec[3]; + + multiples_vec[1] = c; + multiples_vec[2] = h; + multiples_vec[3] = w; + } + + Tensor *output = this->Output(0); + std::vector output_shape = {}; + for (index_t i = 0; i < input_dims; ++i) { + output_shape.push_back(input->dim(i) * multiples_vec[i]); + } + MACE_RETURN_IF_ERROR(output->Resize(output_shape)); + Tensor::MappingGuard output_guard(output); + + T *output_data = output->mutable_data(); + + const index_t output_byte_size = PadAlignSize(output->size() * sizeof(T)); + ScratchBuffer *scratch = context->device()->scratch_buffer(); + scratch->Rewind(); + scratch->GrowSize(output_byte_size); + + Tensor fake_input(scratch->Scratch(output_byte_size), + DataTypeToEnum::value); + T *fake_input_data = fake_input.mutable_data(); + std::memcpy(fake_input_data, input_data, input->size() * sizeof(T)); + + index_t inner_dim = 1; + index_t outer_dim = input->size(); + index_t acc_multiples = 1; + const index_t total_multiples = + std::accumulate(multiples_vec.begin(), multiples_vec.end(), 1, + std::multiplies()); + for (int64_t i = input_dims - 1; ; --i) { + inner_dim *= input->dim(i); + outer_dim /= input->dim(i); + for (int64_t o = 0; o < outer_dim; ++o) { + for (int64_t m = 0; m < multiples_vec[i]; ++m) { + std::memcpy(output_data + (o * multiples_vec[i] + m) * inner_dim, + fake_input_data + o * inner_dim, inner_dim * sizeof(T)); + } + } + acc_multiples *= multiples_vec[i]; + if (acc_multiples == total_multiples) { + break; + } + std::memcpy(fake_input_data, output_data, + input->size() * acc_multiples * sizeof(T)); + inner_dim *= multiples_vec[i]; + } + + return MaceStatus::MACE_SUCCESS; + } + + private: + int has_data_format_; +}; + +void RegisterTile(OpRegistryBase *op_registry) { + MACE_REGISTER_OP(op_registry, "Tile", TileOp, DeviceType::CPU, float); + MACE_REGISTER_OP_CONDITION( + op_registry, OpConditionBuilder("Tile").SetDevicePlacerFunc( + [](OpConditionContext *context) -> std::set { + auto op = context->operator_def(); + if (op->output_shape_size() != op->output_size()) { + return {DeviceType::CPU}; + } + return {DeviceType::CPU}; + })); +} + +} // namespace ops +} // namespace mace diff --git a/test/ccbenchmark/mace/ops/tile_benchmark.cc b/test/ccbenchmark/mace/ops/tile_benchmark.cc new file mode 100644 index 00000000..468c243a --- /dev/null +++ b/test/ccbenchmark/mace/ops/tile_benchmark.cc @@ -0,0 +1,76 @@ +// Copyright 2018 The MACE Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mace/benchmark_utils/test_benchmark.h" +#include "mace/ops/ops_test_util.h" + +namespace mace { +namespace ops { +namespace test { + +namespace { +template +void BMTileHelper(int iters, const std::vector &input_shape) { + mace::testing::StopTiming(); + // Construct graph + OpsTestNet net; + net.AddRandomInput("Input", input_shape); + std::vector multiples = {}; + for (size_t i = 0; i < input_shape.size(); ++i) { + multiples.push_back(2); + } + net.AddInputFromArray( + "Multiples", {static_cast(multiples.size())}, multiples); + + OpDefBuilder("Tile", "TileBM") + .Input("Input") + .Input("Multiples") + .Output("Output") + .AddIntArg("T", static_cast(DataTypeToEnum::value)) + .Finalize(net.NewOperatorDef()); + + // Warm-up + for (int i = 0; i < 5; ++i) { + net.RunOp(D); + net.Sync(); + } + + mace::testing::StartTiming(); + while (iters--) { + net.RunOp(D); + net.Sync(); + } +} +} // namespace + +#define MACE_BM_TILE_MACRO(N, H, W, C, TYPE, DEVICE) \ + static void MACE_BM_TILE_##N##_##H##_##W##_##C##_##TYPE##_##DEVICE( \ + int iters) { \ + const int64_t tot = static_cast(iters) * N * H * W * C; \ + mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \ + BMTileHelper(iters, {N, H, W, C}); \ + } \ + MACE_BENCHMARK(MACE_BM_TILE_##N##_##H##_##W##_##C##_##TYPE##_##DEVICE) + +#define MACE_BM_TILE(N, H, W, C) MACE_BM_TILE_MACRO(N, H, W, C, float, CPU); + +MACE_BM_TILE(1, 32, 32, 5); +MACE_BM_TILE(1, 32, 32, 7); +MACE_BM_TILE(1, 32, 32, 3); +MACE_BM_TILE(1, 128, 128, 9); +MACE_BM_TILE(1, 128, 128, 7); + +} // namespace test +} // namespace ops +} // namespace mace diff --git a/test/ccunit/mace/ops/tile_test.cc b/test/ccunit/mace/ops/tile_test.cc new file mode 100644 index 00000000..db6a534c --- /dev/null +++ b/test/ccunit/mace/ops/tile_test.cc @@ -0,0 +1,128 @@ +// Copyright 2018 The MACE Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mace/ops/ops_test_util.h" + +namespace mace { +namespace ops { +namespace test { + +class TileOpTest : public OpsTestBase {}; + +namespace { + +void TestTile(const std::vector &input_shape, + const std::vector &input, + const std::vector &multiples, + const std::vector &output_shape, + const std::vector &output) { + OpsTestNet net; + net.AddInputFromArray("Input", input_shape, input); + net.AddInputFromArray( + "Multiples", {static_cast(multiples.size())}, multiples); + + OpDefBuilder("Tile", "TileOpTest") + .Input("Input") + .Input("Multiples") + .Output("Output") + .Finalize(net.NewOperatorDef()); + + net.RunOp(); + + net.AddInputFromArray("ExpectedOutput", output_shape, output); + ExpectTensorNear(*net.GetOutput("ExpectedOutput"), + *net.GetOutput("Output")); +} + +void TestTileWithDataFormat(const std::vector &input_shape, + const std::vector &input, + const std::vector &multiples, + const std::vector &output_shape, + const std::vector &output) { + OpsTestNet net; + net.AddInputFromArray("Input", input_shape, input); + net.AddInputFromArray( + "Multiples", {static_cast(multiples.size())}, multiples); + + net.TransformDataFormat( + "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW); + + OpDefBuilder("Tile", "TileOpTest") + .Input("InputNCHW") + .Input("Multiples") + .Output("OutputNCHW") + .AddIntArg("has_data_format", 1) + .Finalize(net.NewOperatorDef()); + + net.RunOp(); + + net.TransformDataFormat( + "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC); + net.AddInputFromArray("ExpectedOutput", output_shape, output); + ExpectTensorNear(*net.GetOutput("ExpectedOutput"), + *net.GetOutput("Output")); +} +} // namespace + +TEST_F(TileOpTest, SimpleTest) { + TestTile({2, 3}, {0, 1, 2, 3, 4, 5}, {2, 3}, {4, 9}, + {0, 1, 2, 0, 1, 2, 0, 1, 2, 3, 4, 5, 3, 4, 5, 3, 4, 5, + 0, 1, 2, 0, 1, 2, 0, 1, 2, 3, 4, 5, 3, 4, 5, 3, 4, 5}); + TestTile({2, 2, 3}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, {1, 1, 2}, + {2, 2, 6}, {0, 1, 2, 0, 1, 2, 3, 4, 5, 3, 4, 5, + 6, 7, 8, 6, 7, 8, 9, 10, 11, 9, 10, 11}); + TestTile({2, 2, 3}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, {2, 1, 2}, + {4, 2, 6}, {0, 1, 2, 0, 1, 2, 3, 4, 5, 3, 4, 5, 6, 7, 8, 6, + 7, 8, 9, 10, 11, 9, 10, 11, 0, 1, 2, 0, 1, 2, 3, 4, + 5, 3, 4, 5, 6, 7, 8, 6, 7, 8, 9, 10, 11, 9, 10, 11}); + TestTile({2, 2, 2, 3}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, + {1, 1, 1, 2}, {2, 2, 2, 6}, + {0, 1, 2, 0, 1, 2, 3, 4, 5, 3, 4, 5, 6, 7, 8, 6, + 7, 8, 9, 10, 11, 9, 10, 11, 12, 13, 14, 12, 13, 14, 15, 16, + 17, 15, 16, 17, 18, 19, 20, 18, 19, 20, 21, 22, 23, 21, 22, 23}); + TestTile({2, 2, 2, 3}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, + {1, 2, 2, 1}, {2, 4, 4, 3}, + {0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 0, 1, + 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, + 22, 23, 18, 19, 20, 21, 22, 23, 12, 13, 14, 15, 16, 17, 12, 13, + 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 18, 19, 20, 21, 22, 23}); +} + +TEST_F(TileOpTest, TestTileWithDataFormat) { + TestTileWithDataFormat( + {2, 2, 2, 3}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, + {1, 1, 1, 2}, {2, 2, 2, 6}, + {0, 1, 2, 0, 1, 2, 3, 4, 5, 3, 4, 5, 6, 7, 8, 6, + 7, 8, 9, 10, 11, 9, 10, 11, 12, 13, 14, 12, 13, 14, 15, 16, + 17, 15, 16, 17, 18, 19, 20, 18, 19, 20, 21, 22, 23, 21, 22, 23}); + TestTileWithDataFormat( + {2, 2, 2, 3}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, + {1, 2, 2, 1}, {2, 4, 4, 3}, + {0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 0, 1, + 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, + 22, 23, 18, 19, 20, 21, 22, 23, 12, 13, 14, 15, 16, 17, 12, 13, + 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 18, 19, 20, 21, 22, 23}); +} + +} // namespace test +} // namespace ops +} // namespace mace diff --git a/tools/python/transform/base_converter.py b/tools/python/transform/base_converter.py index b0a2e4ab..fe93d048 100644 --- a/tools/python/transform/base_converter.py +++ b/tools/python/transform/base_converter.py @@ -153,6 +153,7 @@ MaceSupportedOps = [ 'TargetRMSNorm', 'Transpose', 'Cumsum', + 'Tile', ] MaceOp = Enum('MaceOp', [(op, op) for op in MaceSupportedOps], type=str) @@ -185,7 +186,8 @@ MaceTransposableDataFormatOps = [MaceOp.Activation, MaceOp.Softmax, MaceOp.Split, MaceOp.Squeeze, - MaceOp.SqrDiffMean] + MaceOp.SqrDiffMean, + MaceOp.Tile] class MaceKeyword(object): diff --git a/tools/python/transform/tensorflow_converter.py b/tools/python/transform/tensorflow_converter.py index 1ad82f21..69bf0538 100644 --- a/tools/python/transform/tensorflow_converter.py +++ b/tools/python/transform/tensorflow_converter.py @@ -126,6 +126,7 @@ TFSupportedOps = [ 'MirrorPad', 'Cumsum', 'OneHot', + 'Tile', ] TFOpType = Enum('TFOpType', [(op, op) for op in TFSupportedOps], type=str) @@ -276,6 +277,7 @@ class TensorflowConverter(base_converter.ConverterInterface): TFOpType.Cumsum.name: self.convert_cumsum, TFOpType.OneHot.name: self.convert_one_hot, TFOpType.Sum.name: self.convert_reduce, + TFOpType.Tile.name: self.convert_tile, } self._option = option self._mace_net_def = mace_pb2.NetDef() @@ -1050,6 +1052,10 @@ class TensorflowConverter(base_converter.ConverterInterface): del op.input[0] self._skip_tensor.add(tf_op.inputs[0].name) + def convert_tile(self, tf_op): + op = self.convert_general_op(tf_op) + op.type = MaceOp.Tile.name + def convert_fake_quantize(self, tf_op): op = self.convert_general_op(tf_op) min_arg = op.arg.add() -- GitLab