diff --git a/mace/core/mace.cc b/mace/core/mace.cc index 7177a42281b456ad074f413752c7ff86a7900f3f..274f824513a4d891b48b663f30986041853ca82a 100644 --- a/mace/core/mace.cc +++ b/mace/core/mace.cc @@ -237,11 +237,6 @@ MaceStatus MaceEngine::Impl::Run( << "' is not belong to model's outputs: " << MakeString(MapKeys(output_info_map_)); } - if (device_type_ == DeviceType::GPU) { - MACE_CHECK(output.second.shape().size() == 4, - "The outputs' shape must be 4-dimension with NHWC format," - " please use 1 to fill missing dimensions"); - } Tensor *output_tensor = ws_->GetTensor(MakeString("mace_output_node_", output.first)); output_tensors.push_back(output_tensor); diff --git a/mace/core/operator.cc b/mace/core/operator.cc index 4bc3cfade4fcd129b68687b0ac03b08961a11d6d..9920e5434265183f1f45a7eec924c5e92f61a482 100644 --- a/mace/core/operator.cc +++ b/mace/core/operator.cc @@ -90,6 +90,7 @@ extern void Register_Eltwise(OperatorRegistry *op_registry); extern void Register_FoldedBatchNorm(OperatorRegistry *op_registry); extern void Register_FullyConnected(OperatorRegistry *op_registry); extern void Register_Gather(OperatorRegistry *op_registry); +extern void Register_Identity(OperatorRegistry *op_registry); extern void Register_LocalResponseNorm(OperatorRegistry *op_registry); extern void Register_MatMul(OperatorRegistry *op_registry); extern void Register_Pad(OperatorRegistry *op_registry); @@ -107,6 +108,7 @@ extern void Register_Stack(OperatorRegistry *op_registry); extern void Register_StridedSlice(OperatorRegistry *op_registry); extern void Register_SpaceToBatchND(OperatorRegistry *op_registry); extern void Register_SpaceToDepth(OperatorRegistry *op_registry); +extern void Register_Squeeze(OperatorRegistry *op_registry); extern void Register_Transpose(OperatorRegistry *op_registry); extern void Register_WinogradInverseTransform(OperatorRegistry *op_registry); extern void Register_WinogradTransform(OperatorRegistry *op_registry); @@ -135,6 +137,7 @@ OperatorRegistry::OperatorRegistry() { ops::Register_FoldedBatchNorm(this); ops::Register_FullyConnected(this); ops::Register_Gather(this); + ops::Register_Identity(this); ops::Register_LocalResponseNorm(this); ops::Register_MatMul(this); ops::Register_Pad(this); @@ -152,6 +155,7 @@ OperatorRegistry::OperatorRegistry() { ops::Register_StridedSlice(this); ops::Register_SpaceToBatchND(this); ops::Register_SpaceToDepth(this); + ops::Register_Squeeze(this); ops::Register_Transpose(this); ops::Register_WinogradInverseTransform(this); ops::Register_WinogradTransform(this); diff --git a/mace/core/workspace.cc b/mace/core/workspace.cc index 322670de55b0a8edb71b5d8d37e4aed739c13923..f0efde525367447cfe19d9f15e75b221a73c8d9d 100644 --- a/mace/core/workspace.cc +++ b/mace/core/workspace.cc @@ -14,6 +14,7 @@ #include #include +#include #include #include "mace/core/arg_helper.h" @@ -22,6 +23,15 @@ namespace mace { +namespace { +bool ShouldPreallocateMemoryForOp(const OperatorDef &op) { + static const std::unordered_set reuse_buffer_ops { + "Reshape", "Identity", "Squeeze" + }; + return reuse_buffer_ops.find(op.type()) == reuse_buffer_ops.end(); +} +} // namespace + Workspace::Workspace() : host_scratch_buffer_(new ScratchBuffer( GetDeviceAllocator(DeviceType::CPU))) {} @@ -177,7 +187,7 @@ MaceStatus Workspace::CreateOutputTensorBuffer(const NetDef &net_def, ProtoArgHelper::GetOptionalArg( op, "device", static_cast(device_type)); if (op_device == device_type && !op.mem_id().empty() - && op.type() != "Reshape") { + && ShouldPreallocateMemoryForOp(op)) { auto mem_ids = op.mem_id(); int count = mem_ids.size(); for (int i = 0; i < count; ++i) { diff --git a/mace/kernels/opencl/buffer_to_image.cc b/mace/kernels/opencl/buffer_to_image.cc index 394c48ffe58e98fce5abbda6be9283772458c9d4..fab8d3260b37bdc29634e82c0a111b5afc509546 100644 --- a/mace/kernels/opencl/buffer_to_image.cc +++ b/mace/kernels/opencl/buffer_to_image.cc @@ -136,7 +136,19 @@ MaceStatus BufferToImageFunctor::operator()( b2f_kernel.setArg(idx++, static_cast(buffer->dim(3))); } else if (type == ARGUMENT) { b2f_kernel.setArg(idx++, static_cast(buffer->dim(0))); - } else { + } else if (type == IN_OUT_CHANNEL) { + if (buffer->dim_size() == 4) { // NHWC + b2f_kernel.setArg(idx++, static_cast(buffer->dim(1))); + b2f_kernel.setArg(idx++, static_cast(buffer->dim(2))); + b2f_kernel.setArg(idx++, static_cast(buffer->dim(3))); + } else if (buffer->dim_size() == 2) { // NC + b2f_kernel.setArg(idx++, static_cast(1)); + b2f_kernel.setArg(idx++, static_cast(1)); + b2f_kernel.setArg(idx++, static_cast(buffer->dim(1))); + } else { + MACE_NOT_IMPLEMENTED; + } + } else if (type == IN_OUT_WIDTH || type == IN_OUT_HEIGHT) { b2f_kernel.setArg(idx++, static_cast(buffer->dim(1))); b2f_kernel.setArg(idx++, static_cast(buffer->dim(2))); if (buffer->dim_size() < 4) { @@ -144,6 +156,10 @@ MaceStatus BufferToImageFunctor::operator()( } else { b2f_kernel.setArg(idx++, static_cast(buffer->dim(3))); } + } else { + b2f_kernel.setArg(idx++, static_cast(buffer->dim(1))); + b2f_kernel.setArg(idx++, static_cast(buffer->dim(2))); + b2f_kernel.setArg(idx++, static_cast(buffer->dim(3))); } b2f_kernel.setArg(idx++, *(image->opencl_image())); diff --git a/mace/kernels/opencl/helper.cc b/mace/kernels/opencl/helper.cc index d157db82bbb8265a205c9a0f814b2dea1c00dd2d..16a922bebf654bd37fddd1d93260d2e28dcc2495 100644 --- a/mace/kernels/opencl/helper.cc +++ b/mace/kernels/opencl/helper.cc @@ -28,10 +28,15 @@ namespace { // [(C + 3) / 4 * W, N * H] void CalInOutputImageShape(const std::vector &shape, /* NHWC */ std::vector *image_shape) { - MACE_CHECK(shape.size() == 4); + MACE_CHECK(shape.size() == 4 || shape.size() == 2); image_shape->resize(2); - (*image_shape)[0] = RoundUpDiv4(shape[3]) * shape[2]; - (*image_shape)[1] = shape[0] * shape[1]; + if (shape.size() == 4) { + (*image_shape)[0] = RoundUpDiv4(shape[3]) * shape[2]; + (*image_shape)[1] = shape[0] * shape[1]; + } else if (shape.size() == 2) { + (*image_shape)[0] = RoundUpDiv4(shape[1]); + (*image_shape)[1] = shape[0]; + } } // [Ic, H * W * (Oc + 3) / 4] diff --git a/mace/kernels/opencl/image_to_buffer.cc b/mace/kernels/opencl/image_to_buffer.cc index f293189b922142a04840052df767dc31a35efd62..dcaa1c6465801006d91b027791ac2bc2dfdc4ab0 100644 --- a/mace/kernels/opencl/image_to_buffer.cc +++ b/mace/kernels/opencl/image_to_buffer.cc @@ -123,7 +123,19 @@ MaceStatus ImageToBufferFunctor::operator()( b2f_kernel.setArg(idx++, static_cast(buffer->dim(1))); b2f_kernel.setArg(idx++, static_cast(buffer->dim(2))); b2f_kernel.setArg(idx++, static_cast(buffer->dim(3))); - } else { + } else if (type == IN_OUT_CHANNEL) { + if (buffer->dim_size() == 4) { // NHWC + b2f_kernel.setArg(idx++, static_cast(buffer->dim(1))); + b2f_kernel.setArg(idx++, static_cast(buffer->dim(2))); + b2f_kernel.setArg(idx++, static_cast(buffer->dim(3))); + } else if (buffer->dim_size() == 2) { // NC + b2f_kernel.setArg(idx++, static_cast(1)); + b2f_kernel.setArg(idx++, static_cast(1)); + b2f_kernel.setArg(idx++, static_cast(buffer->dim(1))); + } else { + MACE_NOT_IMPLEMENTED; + } + } else if (type == IN_OUT_WIDTH || type == IN_OUT_HEIGHT) { b2f_kernel.setArg(idx++, static_cast(buffer->dim(1))); b2f_kernel.setArg(idx++, static_cast(buffer->dim(2))); if (buffer->dim_size() < 4) { @@ -131,6 +143,10 @@ MaceStatus ImageToBufferFunctor::operator()( } else { b2f_kernel.setArg(idx++, static_cast(buffer->dim(3))); } + } else { + b2f_kernel.setArg(idx++, static_cast(buffer->dim(1))); + b2f_kernel.setArg(idx++, static_cast(buffer->dim(2))); + b2f_kernel.setArg(idx++, static_cast(buffer->dim(3))); } b2f_kernel.setArg(idx++, *(image->opencl_image())); diff --git a/mace/kernels/opencl/softmax.cc b/mace/kernels/opencl/softmax.cc index cfaee93a3cd8ebe441c736165b2806d63782571e..76dc2c7f810aadf9042c5a5c6ce07875fdf799f3 100644 --- a/mace/kernels/opencl/softmax.cc +++ b/mace/kernels/opencl/softmax.cc @@ -45,10 +45,26 @@ template MaceStatus SoftmaxFunctor::operator()(const Tensor *logits, Tensor *output, StatsFuture *future) { - const index_t batch = logits->dim(0); - const index_t height = logits->dim(1); - const index_t width = logits->dim(2); - const index_t channels = logits->dim(3); + index_t batch = 0; + index_t height = 0; + index_t width = 0; + index_t channels = 0; + + if (logits->dim_size() == 2) { + batch = logits->dim(0); + height = 1; + width = 1; + channels = logits->dim(1); + + } else if (logits->dim_size() == 4) { + batch = logits->dim(0); + height = logits->dim(1); + width = logits->dim(2); + channels = logits->dim(3); + } else { + MACE_NOT_IMPLEMENTED; + } + const index_t channel_blocks = RoundUpDiv4(channels); const int remain_channels = channel_blocks * 4 - channels; @@ -103,8 +119,7 @@ MaceStatus SoftmaxFunctor::operator()(const Tensor *logits, std::vector lws = LocalWS(gws, kwg_size_); std::string tuning_key = - Concat("softmax_opencl_kernel", output->dim(0), output->dim(1), - output->dim(2), output->dim(3)); + Concat("softmax_opencl_kernel", batch, height, width, channels); TuningOrRun3DKernel(kernel_, tuning_key, gws, lws, future); if (runtime->IsOutOfRangeCheckEnabled()) { diff --git a/mace/ops/identity.cc b/mace/ops/identity.cc new file mode 100644 index 0000000000000000000000000000000000000000..ed89561a231d08438f35ae1ec53ecf45c0c806b5 --- /dev/null +++ b/mace/ops/identity.cc @@ -0,0 +1,43 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mace/ops/identity.h" + +namespace mace { +namespace ops { + +void Register_Identity(OperatorRegistry *op_registry) { + MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Identity") + .Device(DeviceType::CPU) + .TypeConstraint("T") + .Build(), + IdentityOp); + +#ifdef MACE_ENABLE_OPENCL + MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Identity") + .Device(DeviceType::GPU) + .TypeConstraint("T") + .Build(), + IdentityOp); + + MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Identity") + .Device(DeviceType::GPU) + .TypeConstraint("T") + .Build(), + IdentityOp); +#endif +} + +} // namespace ops +} // namespace mace diff --git a/mace/ops/identity.h b/mace/ops/identity.h new file mode 100644 index 0000000000000000000000000000000000000000..d2aa7446e52a50b4134c4f7455c6a93bfe71a8dd --- /dev/null +++ b/mace/ops/identity.h @@ -0,0 +1,46 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MACE_OPS_IDENTITY_H_ +#define MACE_OPS_IDENTITY_H_ + +#include + +#include "mace/core/operator.h" + +namespace mace { +namespace ops { + +template +class IdentityOp : public Operator { + public: + IdentityOp(const OperatorDef &op_def, Workspace *ws) + : Operator(op_def, ws) {} + + MaceStatus Run(StatsFuture *future) override { + const Tensor *input = this->Input(INPUT); + Tensor *output = this->Output(OUTPUT); + output->ReuseTensorBuffer(*input); + return MACE_SUCCESS; + } + + private: + MACE_OP_INPUT_TAGS(INPUT); + MACE_OP_OUTPUT_TAGS(OUTPUT); +}; + +} // namespace ops +} // namespace mace + +#endif // MACE_OPS_IDENTITY_H_ diff --git a/mace/ops/identity_test.cc b/mace/ops/identity_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..26d835ce4d2260eb3f5aa95d57ab79f86523e357 --- /dev/null +++ b/mace/ops/identity_test.cc @@ -0,0 +1,62 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gmock/gmock.h" +#include "mace/core/operator.h" +#include "mace/ops/ops_test_util.h" + +namespace mace { +namespace ops { +namespace test { + +class IdentityOpTest : public OpsTestBase {}; + +namespace { +void TestIdentity(const std::vector &shape) { + // Construct graph + OpsTestNet net; + OpDefBuilder("Identity", "IdentityTest") + .Input("Input") + .Output("Output") + .Finalize(net.NewOperatorDef()); + + // Add input data + net.AddRandomInput("Input", shape); + + // Run + net.RunOp(); + + auto input = net.GetTensor("Input"); + auto output = net.GetTensor("Output"); + + EXPECT_THAT(output->shape(), ::testing::ContainerEq(shape)); + + const float *input_ptr = input->data(); + const float *output_ptr = output->data(); + const int size = output->size(); + for (int i = 0; i < size; ++i) { + ASSERT_EQ(input_ptr[i], output_ptr[i]); + } +} +} // namespace + +TEST_F(IdentityOpTest, TestIdentity) { + TestIdentity({1, 2, 3, 4}); + TestIdentity({1}); + TestIdentity({}); +} + +} // namespace test +} // namespace ops +} // namespace mace diff --git a/mace/ops/reshape.cc b/mace/ops/reshape.cc index 4390c520dd2daf7d7d67fa2e393ace58ec392b61..ff0befc25b0f846b973df5c524fe5da30e684e37 100644 --- a/mace/ops/reshape.cc +++ b/mace/ops/reshape.cc @@ -23,6 +23,20 @@ void Register_Reshape(OperatorRegistry *op_registry) { .TypeConstraint("T") .Build(), ReshapeOp); + +#ifdef MACE_ENABLE_OPENCL + MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Reshape") + .Device(DeviceType::GPU) + .TypeConstraint("T") + .Build(), + ReshapeOp); + + MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Reshape") + .Device(DeviceType::GPU) + .TypeConstraint("T") + .Build(), + ReshapeOp); +#endif } } // namespace ops diff --git a/mace/ops/reshape.h b/mace/ops/reshape.h index fe1df988b5dcf911f4bceb5fa122ea9487ec6712..90a443144bb87d32f8d99d722ef75554195772a8 100644 --- a/mace/ops/reshape.h +++ b/mace/ops/reshape.h @@ -27,26 +27,29 @@ template class ReshapeOp : public Operator { public: ReshapeOp(const OperatorDef &op_def, Workspace *ws) - : Operator(op_def, ws), - shape_(OperatorBase::GetRepeatedArgs("shape")) {} + : Operator(op_def, ws) {} MaceStatus Run(StatsFuture *future) override { const Tensor *input = this->Input(INPUT); - const index_t num_dims = shape_.size(); + const Tensor *shape = this->Input(SHAPE); + const index_t num_dims = shape->dim_size() == 0 ? 0 : shape->dim(0); + Tensor::MappingGuard shape_guard(shape); + const int32_t *shape_data = shape->data(); + int unknown_idx = -1; index_t product = 1; std::vector out_shape; for (int i = 0; i < num_dims; ++i) { - if (shape_[i] == -1) { + if (shape_data[i] == -1) { MACE_CHECK(unknown_idx == -1) << "Only one input size may be -1"; unknown_idx = i; out_shape.push_back(1); } else { - MACE_CHECK(shape_[i] >= 0) << "Shape must be non-negative: " - << shape_[i]; - out_shape.push_back(shape_[i]); - product *= shape_[i]; + MACE_CHECK(shape_data[i] >= 0) << "Shape must be non-negative: " + << shape_data[i]; + out_shape.push_back(shape_data[i]); + product *= shape_data[i]; } } @@ -65,11 +68,10 @@ class ReshapeOp : public Operator { } private: - std::vector shape_; kernels::ReshapeFunctor functor_; private: - MACE_OP_INPUT_TAGS(INPUT); + MACE_OP_INPUT_TAGS(INPUT, SHAPE); MACE_OP_OUTPUT_TAGS(OUTPUT); }; diff --git a/mace/ops/reshape_test.cc b/mace/ops/reshape_test.cc index 2b24277718a2c1c8d11fb77dd05fbf0886956c46..4e48d384162e3401e083384a0d1176a7227114a3 100644 --- a/mace/ops/reshape_test.cc +++ b/mace/ops/reshape_test.cc @@ -30,12 +30,15 @@ void TestReshape(const std::vector &org_shape, OpsTestNet net; OpDefBuilder("Reshape", "ReshapeTest") .Input("Input") + .Input("Shape") .Output("Output") - .AddIntsArg("shape", output_shape) .Finalize(net.NewOperatorDef()); // Add input data net.AddRandomInput("Input", org_shape); + net.AddInputFromArray("Shape", + {output_shape.size()}, + output_shape); // Run net.RunOp(); diff --git a/mace/ops/softmax_test.cc b/mace/ops/softmax_test.cc index 5468ca244412df914854e8c9f47ef4818d4ad7da..62f7f32f8a1ff4667dbe3660b0b516e2897d7fa5 100644 --- a/mace/ops/softmax_test.cc +++ b/mace/ops/softmax_test.cc @@ -93,17 +93,25 @@ void Complex(const std::vector &logits_shape) { // Add input data net.AddRandomInput("Input", logits_shape); - net.TransformDataFormat("Input", NHWC, "InputNCHW", NCHW); - - OpDefBuilder("Softmax", "SoftmaxTest") - .Input("InputNCHW") - .Output("OutputNCHW") - .Finalize(net.NewOperatorDef()); + if (logits_shape.size() == 4) { + net.TransformDataFormat("Input", NHWC, "InputNCHW", NCHW); + OpDefBuilder("Softmax", "SoftmaxTest") + .Input("InputNCHW") + .Output("OutputNCHW") + .Finalize(net.NewOperatorDef()); + } else { + OpDefBuilder("Softmax", "SoftmaxTest") + .Input("Input") + .Output("Output") + .Finalize(net.NewOperatorDef()); + } // Run on cpu net.RunOp(); - net.TransformDataFormat("OutputNCHW", NCHW, "Output", NHWC); + if (logits_shape.size() == 4) { + net.TransformDataFormat("OutputNCHW", NCHW, "Output", NHWC); + } Tensor expected; expected.Copy(*net.GetOutput("Output")); @@ -142,6 +150,11 @@ TEST_F(SoftmaxOpTest, OPENCLUnAligned) { Complex({5, 211, 107, 1}); } +TEST_F(SoftmaxOpTest, OPENCLAlignedRank2) { + Complex({1, 1001}); + Complex({3, 1001}); +} + } // namespace test } // namespace ops } // namespace mace diff --git a/mace/ops/squeeze.cc b/mace/ops/squeeze.cc new file mode 100644 index 0000000000000000000000000000000000000000..e917936fc949d9ccf99ba611753e97fa7a503248 --- /dev/null +++ b/mace/ops/squeeze.cc @@ -0,0 +1,43 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mace/ops/squeeze.h" + +namespace mace { +namespace ops { + +void Register_Squeeze(OperatorRegistry *op_registry) { + MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Squeeze") + .Device(DeviceType::CPU) + .TypeConstraint("T") + .Build(), + SqueezeOp); + +#ifdef MACE_ENABLE_OPENCL + MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Squeeze") + .Device(DeviceType::GPU) + .TypeConstraint("T") + .Build(), + SqueezeOp); + + MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Squeeze") + .Device(DeviceType::GPU) + .TypeConstraint("T") + .Build(), + SqueezeOp); +#endif +} + +} // namespace ops +} // namespace mace diff --git a/mace/ops/squeeze.h b/mace/ops/squeeze.h new file mode 100644 index 0000000000000000000000000000000000000000..b736955f24d76936b4c15451860353337487a444 --- /dev/null +++ b/mace/ops/squeeze.h @@ -0,0 +1,64 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MACE_OPS_SQUEEZE_H_ +#define MACE_OPS_SQUEEZE_H_ + +#include +#include + +#include "mace/core/operator.h" + +namespace mace { +namespace ops { + +template +class SqueezeOp : public Operator { + public: + SqueezeOp(const OperatorDef &op_def, Workspace *ws) + : Operator(op_def, ws), + axis_(OperatorBase::GetRepeatedArgs("axis", {})) {} + + MaceStatus Run(StatsFuture *future) override { + MACE_UNUSED(future); + + const Tensor *input = this->Input(INPUT); + Tensor *output = this->Output(OUTPUT); + + std::vector output_shape; + std::unordered_set axis_set(axis_.begin(), axis_.end()); + for (int i = 0; i < input->dim_size(); ++i) { + if (input->dim(i) > 1 + || (!axis_set.empty() && axis_set.find(i) == axis_set.end())) { + output_shape.push_back(input->dim(i)); + } + } + output->ReuseTensorBuffer(*input); + output->Reshape(output_shape); + + return MACE_SUCCESS; + } + + private: + std::vector axis_; + + private: + MACE_OP_INPUT_TAGS(INPUT); + MACE_OP_OUTPUT_TAGS(OUTPUT); +}; + +} // namespace ops +} // namespace mace + +#endif // MACE_OPS_SQUEEZE_H_ diff --git a/mace/ops/squeeze_test.cc b/mace/ops/squeeze_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..35f224c9a901dab81de1469c9218a0bb3b7debd8 --- /dev/null +++ b/mace/ops/squeeze_test.cc @@ -0,0 +1,66 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gmock/gmock.h" +#include "mace/core/operator.h" +#include "mace/ops/ops_test_util.h" + +namespace mace { +namespace ops { +namespace test { + +class SqueezeOpTest : public OpsTestBase {}; + +namespace { +void TestSqueeze(const std::vector &org_shape, + const std::vector &axis, + const std::vector &res_shape) { + // Construct graph + OpsTestNet net; + OpDefBuilder("Squeeze", "SqueezeTest") + .Input("Input") + .AddIntsArg("axis", axis) + .Output("Output") + .Finalize(net.NewOperatorDef()); + + // Add input data + net.AddRandomInput("Input", org_shape); + + // Run + net.RunOp(); + + auto input = net.GetTensor("Input"); + auto output = net.GetTensor("Output"); + + EXPECT_THAT(output->shape(), ::testing::ContainerEq(res_shape)); + + const float *input_ptr = input->data(); + const float *output_ptr = output->data(); + const int size = output->size(); + for (int i = 0; i < size; ++i) { + ASSERT_EQ(input_ptr[i], output_ptr[i]); + } +} +} // namespace + +TEST_F(SqueezeOpTest, TestSqueeze) { + TestSqueeze({1, 2, 1, 4}, {}, {2, 4}); + TestSqueeze({1, 2, 1, 4}, {1}, {1, 2, 1, 4}); + TestSqueeze({1, 2, 1, 4}, {2}, {1, 2, 4}); + TestSqueeze({1}, {}, {}); +} + +} // namespace test +} // namespace ops +} // namespace mace diff --git a/mace/python/tools/converter_tool/base_converter.py b/mace/python/tools/converter_tool/base_converter.py index 20ce14b48cd40a7352807dd44158f2e41d7fdb32..918c26bae164ac200d882a00b19f2ffafc1d8492 100644 --- a/mace/python/tools/converter_tool/base_converter.py +++ b/mace/python/tools/converter_tool/base_converter.py @@ -84,6 +84,8 @@ MaceSupportedOps = [ 'Eltwise', 'FoldedBatchNorm', 'FullyConnected', + 'Gather', + 'Identity', 'LocalResponseNorm', 'MatMul', 'Pad', @@ -95,6 +97,10 @@ MaceSupportedOps = [ 'Reshape', 'ResizeBilinear', 'Slice', + 'Shape', + 'Squeeze', + 'Stack', + 'StridedSlice', 'Softmax', 'SpaceToBatchND', 'SpaceToDepth', @@ -144,7 +150,6 @@ class MaceKeyword(object): class TransformerRule(Enum): - REMOVE_USELESS_RESHAPE_OP = 0 REMOVE_IDENTITY_OP = 1 TRANSFORM_GLOBAL_POOLING = 2 FOLD_RESHAPE = 3 @@ -212,7 +217,6 @@ class ConverterOption(object): self._winograd_enabled = False self._transformer_option = [ TransformerRule.ADD_IN_OUT_TENSOR_INFO, - TransformerRule.REMOVE_USELESS_RESHAPE_OP, TransformerRule.REMOVE_IDENTITY_OP, TransformerRule.TRANSFORM_GLOBAL_POOLING, TransformerRule.FOLD_RESHAPE, diff --git a/mace/python/tools/converter_tool/tensorflow_converter.py b/mace/python/tools/converter_tool/tensorflow_converter.py index b5f644e4ee6ae3410fbf1e998a8426ba07e1e512..8068cf5a9b1705fc55ebc28f8b51ef568d727486 100644 --- a/mace/python/tools/converter_tool/tensorflow_converter.py +++ b/mace/python/tools/converter_tool/tensorflow_converter.py @@ -41,6 +41,8 @@ tf_kernel_str = 'ksize' tf_epsilon_str = 'epsilon' tf_align_corners = 'align_corners' tf_block_size = 'block_size' +tf_squeeze_dims = 'squeeze_dims' +tf_axis = 'axis' TFSupportedOps = [ 'Conv2D', @@ -149,7 +151,8 @@ class TensorflowConverter(base_converter.ConverterInterface): TFOpType.MatMul.name: self.convert_matmul, TFOpType.Identity.name: self.convert_identity, TFOpType.Reshape.name: self.convert_reshape, - TFOpType.Shape.name: self.convert_nop, + TFOpType.Shape.name: self.convert_shape, + TFOpType.Squeeze.name: self.convert_squeeze, TFOpType.Transpose.name: self.convert_transpose, TFOpType.Softmax.name: self.convert_softmax, TFOpType.ResizeBilinear.name: self.convert_resize_bilinear, @@ -257,6 +260,16 @@ class TensorflowConverter(base_converter.ConverterInterface): tensor.data_type = data_type tensor.float_data.extend(value.flat) + # this function tries to infer tensor shape, but some dimension shape + # may be undefined due to variance of input length + @staticmethod + def infer_tensor_shape(tensor): + shape = tensor.shape.as_list() + + def normalize_func(dim): + return dim if dim else - 1 + return [normalize_func(dim) for dim in shape] + def convert_nop(self, tf_op): pass @@ -268,7 +281,7 @@ class TensorflowConverter(base_converter.ConverterInterface): op.output.extend([tf_output.name for tf_output in tf_op.outputs]) for tf_output in tf_op.outputs: output_shape = op.output_shape.add() - output_shape.dims.extend(tf_output.shape.as_list()) + output_shape.dims.extend(self.infer_tensor_shape(tf_output)) ConverterUtil.add_data_format_arg(op, DataFormat.NHWC) @@ -481,24 +494,29 @@ class TensorflowConverter(base_converter.ConverterInterface): op = self.convert_general_op(tf_op) op.type = MaceOp.MatMul.name + def convert_shape(self, tf_op): + op = self.convert_general_op(tf_op) + op.type = MaceOp.Shape.name + op.output_type.extend([mace_pb2.DT_INT32]) + def convert_reshape(self, tf_op): op = self.convert_general_op(tf_op) op.type = MaceOp.Reshape.name - del op.input[1:] - shape_arg = op.arg.add() - shape_arg.name = MaceKeyword.mace_shape_str - shape_value = [] - if tf_op.inputs[1].op.type == TFOpType.Const.name: - shape_value = list(tf_op.inputs[1].eval().astype(np.int32)) - for i in xrange(len(shape_value)): - if shape_value[i] == -1: - shape_value[i] = 1 - self._skip_tensor.add(tf_op.inputs[-1].name) - elif tf_op.inputs[1].op.type == TFOpType.Shape.name: - shape_value = list(tf_op.inputs[1].op.inputs[0].shape.as_list()) - - shape_arg.ints.extend(shape_value) + def convert_squeeze(self, tf_op): + op = self.convert_general_op(tf_op) + op.type = MaceOp.Squeeze.name + + axis_arg = op.arg.add() + axis_arg.name = MaceKeyword.mace_axis_str + try: + axis_value = tf_op.get_attr('squeeze_dims') + except ValueError: + try: + axis_value = tf_op.get_attr('axis') + except ValueError: + axis_value = [] + axis_arg.ints.extend(axis_value) def convert_transpose(self, tf_op): perm = tf_op.inputs[1].eval().astype(np.int32) diff --git a/mace/python/tools/converter_tool/transformer.py b/mace/python/tools/converter_tool/transformer.py index 8f4417d2ecf095891234e7e5b2a5541064d5cc01..8c8987bf9d758944dcf8f2656672a969dd9558ac 100644 --- a/mace/python/tools/converter_tool/transformer.py +++ b/mace/python/tools/converter_tool/transformer.py @@ -56,7 +56,6 @@ class Transformer(base_converter.ConverterInterface): # DO NOT reorder the following transformers' order self._registered_transformers_order = [ TransformerRule.ADD_IN_OUT_TENSOR_INFO, - TransformerRule.REMOVE_USELESS_RESHAPE_OP, TransformerRule.REMOVE_IDENTITY_OP, TransformerRule.TRANSFORM_GLOBAL_POOLING, TransformerRule.FOLD_RESHAPE, @@ -81,8 +80,6 @@ class Transformer(base_converter.ConverterInterface): self._registered_transformers = { TransformerRule.ADD_IN_OUT_TENSOR_INFO: self.add_in_out_tensor_info, - TransformerRule.REMOVE_USELESS_RESHAPE_OP: - self.remove_useless_reshape_op, TransformerRule.REMOVE_IDENTITY_OP: self.remove_identity_op, TransformerRule.TRANSFORM_GLOBAL_POOLING: self.transform_global_pooling, @@ -289,19 +286,6 @@ class Transformer(base_converter.ConverterInterface): return False - def remove_useless_reshape_op(self): - net = self._model - for op in net.op: - if op.type == MaceOp.Reshape.name: - shape = list(ConverterUtil.get_arg( - op, MaceKeyword.mace_shape_str).ints) - if shape == self.get_tensor_shape(op.input[0]): - print("Remove useless reshape: %s(%s)" - % (op.name, op.type)) - op.type = 'Identity' - - return False - def remove_identity_op(self): net = self._model for op in net.op: @@ -791,6 +775,26 @@ class Transformer(base_converter.ConverterInterface): "channel dimension") arg.i = 3 + elif op.type == MaceOp.Squeeze.name: + for arg in op.arg: + if arg.name == MaceKeyword.mace_axis_str: + if ConverterUtil.data_format( + op) == DataFormat.NHWC \ + and self._target_data_format == DataFormat.NCHW: # noqa + print("Transpose squeeze args: %s(%s)" + % (op.name, op.type)) + mace_check(list(arg.ints) == [1, 2], + 'only support squeeze at at [1, 2]') + arg.ints[:] = [2, 3] + elif ConverterUtil.data_format( + op) == DataFormat.NCHW \ + and self._target_data_format == DataFormat.NHWC: # noqa + print("Transpose squeeze args: %s(%s)" + % (op.name, op.type)) + mace_check(list(arg.ints) == [2, 3], + 'only support squeeze at at [2, 3]') + arg.ints[:] = [1, 2] + # transpose op output shape data_format = ConverterUtil.data_format(op) if data_format is not None \ @@ -818,16 +822,19 @@ class Transformer(base_converter.ConverterInterface): + '_' + input_node.name op = net.op.add() op.name = self.normalize_op_name(input_node.name) - op.type = MaceOp.Transpose.name op.input.extend([new_input_name]) op.output.extend([input_node.name]) output_shape = op.output_shape.add() output_shape.dims.extend(input_node.shape) - self.transpose_shape(output_shape.dims, [0, 3, 1, 2]) + if len(output_shape.dims) == 4: + op.type = MaceOp.Transpose.name + self.transpose_shape(output_shape.dims, [0, 3, 1, 2]) - dims_arg = op.arg.add() - dims_arg.name = MaceKeyword.mace_dims_str - dims_arg.ints.extend([0, 3, 1, 2]) + dims_arg = op.arg.add() + dims_arg.name = MaceKeyword.mace_dims_str + dims_arg.ints.extend([0, 3, 1, 2]) + else: + op.type = MaceOp.Identity.name ConverterUtil.add_data_format_arg(op, DataFormat.NCHW) @@ -836,19 +843,22 @@ class Transformer(base_converter.ConverterInterface): + '_' + output_node.name op = self._model.op.add() op.name = self.normalize_op_name(output_name) - op.type = MaceOp.Transpose.name op.input.extend([output_node.name]) op.output.extend([output_name]) output_shape = op.output_shape.add() output_shape.dims.extend( self._producer[output_node.name].output_shape[0].dims) - self.transpose_shape(output_shape.dims, [0, 2, 3, 1]) + if len(output_shape.dims) == 4: + op.type = MaceOp.Transpose.name + self.transpose_shape(output_shape.dims, [0, 2, 3, 1]) - dims_arg = op.arg.add() - dims_arg.name = MaceKeyword.mace_dims_str - dims_arg.ints.extend([0, 2, 3, 1]) + dims_arg = op.arg.add() + dims_arg.name = MaceKeyword.mace_dims_str + dims_arg.ints.extend([0, 2, 3, 1]) - ConverterUtil.add_data_format_arg(op, DataFormat.NHWC) + ConverterUtil.add_data_format_arg(op, DataFormat.NHWC) + else: + op.type = MaceOp.Identity.name return False @@ -1003,36 +1013,51 @@ class Transformer(base_converter.ConverterInterface): return False def fold_reshape(self): - changed = False net = self._model for op in net.op: - if op.type == MaceOp.Softmax.name or op.type == MaceOp.MatMul.name: - print("Fold reshape: %s(%s)" % (op.name, op.type)) - if self.consumer_count(op.output[0]) == 1: + if op.type == MaceOp.Softmax.name: + # see if possible to fold + # Reshape(xd->2d) + Softmax(2d) + Reshape(xd) to Softmax(xd) + should_fold = False + if op.input[0] in self._producer \ + and self._producer[op.input[0]].type \ + == MaceOp.Reshape.name \ + and len(op.output_shape[0].dims) == 2 \ + and self.consumer_count(op.output[0]) == 1: + producer = self._producer[op.input[0]] consumer = self._consumers[op.output[0]][0] - if consumer.type == MaceOp.Reshape.name: - shape = ConverterUtil.get_arg(consumer, - MaceKeyword.mace_shape_str).ints # noqa - del op.output_shape[0].dims[:] - op.output_shape[0].dims.extend(shape) - self.safe_remove_node(consumer, op) - changed = True + if (consumer.type == MaceOp.Reshape.name + and op.output_shape[0].dims[-1] + == consumer.output_shape[0].dims[-1] + and op.output_shape[0].dims[-1] != -1 + and self.get_tensor_shape(producer.input[0]) + == consumer.output_shape[0].dims): + should_fold = True + + if should_fold: + print( + "Fold reshape and softmax: %s(%s)" + % (op.name, op.type)) producer = self._producer[op.input[0]] - if producer.type == MaceOp.Reshape.name: - self.safe_remove_node(producer, - self._producer[ - producer.input[0]]) - changed = True - - if len(op.output_shape[0].dims) < 4: - shape = ([1, 1, 1, 1] + list(op.output_shape[0].dims))[-4:] - op.output_shape[0].dims[:] = shape[:] - changed = True - - if changed: - return True + consumer = self._consumers[op.output[0]][0] + op.output_shape[0].dims[:] = self.get_tensor_shape( + producer.input[0]) + + # if there is a shape op, remove it too + if (consumer.input[1] in self._producer + and self._producer[consumer.input[1]].type + == 'Shape'): + self.safe_remove_node( + self._producer[consumer.input[1]], None) + # remove consumer reshape + self.safe_remove_node(consumer, op) + # remove producer reshape + self.safe_remove_node(producer, + self._producer.get(producer.input[0], + None)) + return True return False def transform_matmul_to_fc(self): diff --git a/mace/python/tools/memory_optimizer.py b/mace/python/tools/memory_optimizer.py index c05b94be7780c628495dbb85523e657dd55aeddc..44f11c5f3dc96da0857c813d005c416650ddd56c 100644 --- a/mace/python/tools/memory_optimizer.py +++ b/mace/python/tools/memory_optimizer.py @@ -83,6 +83,11 @@ class MemoryOptimizer(object): optimized_mem_size += self.mem_size(self.mem_block[mem]) return optimized_mem_size + @staticmethod + def is_memory_reuse_op(op): + return op.type == 'Reshape' or op.type == 'Identity' \ + or op.type == 'Squeeze' + def optimize(self): for op in self.net_def.op: if not self.op_need_optimize_memory(op): @@ -96,51 +101,59 @@ class MemoryOptimizer(object): 'the number of output.') return for i in range(len(op.output)): - op_mem_block = self.get_op_mem_block(op.type, - op.output_shape[i].dims) - mem_id = -1 - if len(self.idle_mem) > 0: - best_mem_add_size = sys.maxint - best_mem_waste_size = sys.maxint - for mid in self.idle_mem: - old_mem_block = self.mem_block[mid] - new_mem_block = self.resize_mem_block( - old_mem_block, op_mem_block) - add_mem_size = self.sub_mem_block(new_mem_block, - old_mem_block) - waste_mem_size = self.sub_mem_block(new_mem_block, - op_mem_block) - - # minimize add_mem_size; if best_mem_add_size is 0, - # then minimize waste_mem_size - if (best_mem_add_size > 0 and - add_mem_size < best_mem_add_size) \ - or (best_mem_add_size == 0 and - waste_mem_size < best_mem_waste_size): - best_mem_id = mid - best_mem_add_size = add_mem_size - best_mem_waste_size = waste_mem_size - best_mem_block = new_mem_block - - # if add mem size < op mem size, then reuse it - if best_mem_add_size <= self.mem_size(op_mem_block): - self.mem_block[best_mem_id] = best_mem_block - mem_id = best_mem_id - self.idle_mem.remove(mem_id) - - if mem_id == -1: - mem_id = self.mem_id_base() + self.total_mem_count - self.total_mem_count += 1 - self.mem_block[mem_id] = op_mem_block - - op.mem_id.extend([mem_id]) - self.op_mem[op.output[i]] = mem_id + if self.is_memory_reuse_op(op): + # make these ops reuse memory of input tensor + mem_id = self.op_mem.get(op.input[0], -1) + else: + op_mem_block = self.get_op_mem_block( + op.type, + op.output_shape[i].dims) + mem_id = -1 + if len(self.idle_mem) > 0: + best_mem_add_size = sys.maxint + best_mem_waste_size = sys.maxint + for mid in self.idle_mem: + old_mem_block = self.mem_block[mid] + new_mem_block = self.resize_mem_block( + old_mem_block, op_mem_block) + add_mem_size = self.sub_mem_block(new_mem_block, + old_mem_block) + waste_mem_size = self.sub_mem_block(new_mem_block, + op_mem_block) + + # minimize add_mem_size; if best_mem_add_size is 0, + # then minimize waste_mem_size + if (best_mem_add_size > 0 and + add_mem_size < best_mem_add_size) \ + or (best_mem_add_size == 0 and + waste_mem_size < best_mem_waste_size): + best_mem_id = mid + best_mem_add_size = add_mem_size + best_mem_waste_size = waste_mem_size + best_mem_block = new_mem_block + + # if add mem size < op mem size, then reuse it + if best_mem_add_size <= self.mem_size(op_mem_block): + self.mem_block[best_mem_id] = best_mem_block + mem_id = best_mem_id + self.idle_mem.remove(mem_id) + + if mem_id == -1: + mem_id = self.mem_id_base() + self.total_mem_count + self.total_mem_count += 1 + self.mem_block[mem_id] = op_mem_block + + if mem_id != -1: + op.mem_id.extend([mem_id]) + self.op_mem[op.output[i]] = mem_id # de-ref input tensor mem - for ipt in op.input: + for idx in xrange(len(op.input)): + ipt = op.input[idx] if ipt in self.ref_counter: self.ref_counter[ipt] -= 1 - if self.ref_counter[ipt] == 0: + if self.ref_counter[ipt] == 0 and \ + (idx > 0 or not self.is_memory_reuse_op(op)): self.idle_mem.add(self.op_mem[ipt]) elif self.ref_counter[ipt] < 0: raise Exception('ref count is less than 0') @@ -170,8 +183,10 @@ class GPUMemoryOptimizer(MemoryOptimizer): mem_block[0] = output_shape[2] mem_block[1] = output_shape[0] * int((output_shape[1] + 3) / 4) else: - mem_block[0] = output_shape[2] * int((output_shape[3] + 3) / 4) - mem_block[1] = output_shape[0] * output_shape[1] + padded_output_shape = ([1, 1, 1, 1] + list(output_shape))[-4:] + mem_block[0] = padded_output_shape[2] * int( + (padded_output_shape[3] + 3) / 4) + mem_block[1] = padded_output_shape[0] * padded_output_shape[1] return mem_block def mem_size(self, memory_block): diff --git a/tools/converter.py b/tools/converter.py index 0d2b19c3509eedff23308d15f39c27cdcd1eb313..736fcec034c974271986e9c31e19c49cf4868650 100644 --- a/tools/converter.py +++ b/tools/converter.py @@ -482,6 +482,7 @@ def print_configuration(flags, configs): def download_model_files(model_file_path, model_output_dir, weight_file_path=""): + MaceLogger.info("Downloading model, please wait ...") if model_file_path.startswith("http://") or \ model_file_path.startswith("https://"): model_file = model_output_dir + "/model.pb" @@ -491,6 +492,7 @@ def download_model_files(model_file_path, weight_file_path.startswith("https://"): weight_file = model_output_dir + "/model.caffemodel" urllib.urlretrieve(weight_file_path, weight_file) + MaceLogger.info("Model downloaded successfully.") def get_model_files_path(model_file_path,