diff --git a/mace/core/operator.cc b/mace/core/operator.cc index 28b3bc4636135aedd8e009396f699847a3f24a9c..11ab742e7534c101545fbd8bfd29c70dfb5d3d4f 100644 --- a/mace/core/operator.cc +++ b/mace/core/operator.cc @@ -289,7 +289,7 @@ void OpRegistryBase::GetInOutMemoryTypes( const std::string &op_type, OpConditionContext *context) const { MACE_CHECK(registry_.count(op_type) != 0, - op_type, " operation is not registered."); + op_type, " operation is not registered. op_type=", op_type); return registry_.at(op_type)->memory_type_setter(context); } diff --git a/mace/ops/registry/ops_registry.cc b/mace/ops/registry/ops_registry.cc index 1af424f1b3eaf742c29788746c20bc0eb2d5de4d..eafa78ceb876549fff28cd2eb48df719ff3a17e9 100644 --- a/mace/ops/registry/ops_registry.cc +++ b/mace/ops/registry/ops_registry.cc @@ -64,6 +64,7 @@ extern void RegisterResizeBilinear(OpRegistryBase *op_registry); extern void RegisterResizeNearestNeighbor(OpRegistryBase *op_registry); extern void RegisterReverse(OpRegistryBase *op_registry); extern void RegisterScalarMath(OpRegistryBase *op_registry); +extern void RegisterSelect(OpRegistryBase *op_registry); extern void RegisterShape(OpRegistryBase *op_registry); extern void RegisterSlice(OpRegistryBase *op_registry); extern void RegisterSoftmax(OpRegistryBase *op_registry); @@ -143,6 +144,7 @@ OpRegistry::OpRegistry() : OpRegistryBase() { ops::RegisterResizeNearestNeighbor(this); ops::RegisterReverse(this); ops::RegisterScalarMath(this); + ops::RegisterSelect(this); ops::RegisterShape(this); ops::RegisterSlice(this); ops::RegisterSoftmax(this); diff --git a/mace/ops/select.cc b/mace/ops/select.cc new file mode 100644 index 0000000000000000000000000000000000000000..4d094e651eea8e0113786ee078d4a3c04c8660e0 --- /dev/null +++ b/mace/ops/select.cc @@ -0,0 +1,213 @@ +// Copyright 2019 The MACE Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mace/core/operator.h" +#include "mace/core/tensor.h" + +namespace mace { +namespace ops { + +template +class SelectOp; + +template<> +class SelectOp : public Operation { + public: + explicit SelectOp(OpConstructContext *context) + : Operation(context) {} + + MaceStatus Run(OpContext *context) override { + if (this->InputSize() == 1) { + return RunWithNoData(context); + } else { + return RunWithData(context); + } + } + + MaceStatus RunWithNoData(OpContext *context) { + const Tensor *condition = this->Input(CONDITION); + Tensor *output = this->Output(OUTPUT); + const index_t condition_rank = condition->dim_size(); + MACE_RETURN_IF_ERROR(output->Resize({condition->size(), condition_rank})); + float *output_data = output->mutable_data(); + const bool *condition_data = condition->data(); + + index_t i = 0; + if (condition_rank == 1) { + const index_t channel = condition->dim(0); + for (index_t c = 0; c < channel; ++c) { + if (condition_data[c]) { + output_data[i++] = c; + } + } + } else if (condition_rank == 2) { + const index_t width = condition->dim(0); + const index_t channel = condition->dim(1); + for (index_t w = 0; w < width; ++w) { + index_t w_base = w * channel; + for (index_t c = 0; c < channel; ++c) { + if (condition_data[w_base + c]) { + output_data[i++] = w; + output_data[i++] = c; + } + } + } + } else if (condition_rank == 3) { + const index_t height = condition->dim(0); + const index_t width = condition->dim(1); + const index_t channel = condition->dim(2); + for (index_t h = 0; h < height; ++h) { + index_t h_base = h * width; + for (index_t w = 0; w < width; ++w) { + index_t w_base = (w + h_base) * channel; + for (index_t c = 0; c < channel; ++c) { + if (condition_data[w_base + c]) { + output_data[i++] = h; + output_data[i++] = w; + output_data[i++] = c; + } + } + } + } + } else if (condition_rank == 4) { + const index_t batch = condition->dim(0); + const index_t height = condition->dim(1); + const index_t width = condition->dim(2); + const index_t channel = condition->dim(3); + for (index_t b = 0; b < batch; ++b) { + index_t b_base = b * height; + for (index_t h = 0; h < height; ++h) { + index_t h_base = (b_base + h) * width; + for (index_t w = 0; w < width; ++w) { + index_t w_base = (w + h_base) * channel; + for (index_t c = 0; c < channel; ++c) { + if (condition_data[w_base + c]) { + output_data[i++] = b; + output_data[i++] = h; + output_data[i++] = w; + output_data[i++] = c; + } + } + } + } + } + } else { + const index_t condition_size = condition->size(); + const index_t condition_rank = condition->dim_size(); + auto div_buffer = context->device()->scratch_buffer(); + div_buffer->Rewind(); + MACE_RETURN_IF_ERROR(div_buffer->GrowSize( + condition_rank * sizeof(index_t))); + index_t *div_ptr = div_buffer->mutable_data(); + div_ptr[condition_rank - 1] = 1; + for (index_t dim = condition_rank - 1; dim > 0; --dim) { + div_ptr[dim - 1] = div_ptr[dim] * condition->dim(dim); + } + for (index_t c = 0; c < condition_size; ++c) { + if (condition_data[c]) { + auto remainder = c; + for (index_t dim = 0; dim < condition_rank; ++dim) { + output_data[i++] = remainder / div_ptr[dim]; + remainder = remainder % div_ptr[dim]; + } + } + } + } + + MACE_RETURN_IF_ERROR(output->Resize({i / condition_rank, condition_rank})); + return MaceStatus::MACE_SUCCESS; + } + + bool CheckDataValid(const Tensor *condition, + const Tensor *x, const Tensor *y) { + const index_t x_rank = x->dim_size(); + const index_t y_rank = y->dim_size(); + const index_t condition_rank = condition->dim_size(); + MACE_CHECK(condition_rank <= x_rank && x_rank == y_rank); + + for (index_t i = 0; i < condition_rank; ++i) { + MACE_CHECK(condition->dim(i) == x->dim(i), + "dimensions are not equal: ", + MakeString(condition->shape()), + " vs. ", + MakeString(x->shape())); + } + + for (index_t i = 0; i < x_rank; ++i) { + MACE_CHECK(y->dim(i) == x->dim(i), "dimensions are not equal: ", + MakeString(y->shape()), " vs. ", MakeString(x->shape())); + } + + return true; + } + + MaceStatus RunWithData(OpContext *context) { + const Tensor *condition = this->Input(CONDITION); + const Tensor *x = this->Input(X); + const Tensor *y = this->Input(Y); + MACE_ASSERT(CheckDataValid(condition, x, y)); + + Tensor *output = this->Output(OUTPUT); + MACE_RETURN_IF_ERROR(output->Resize(x->shape())); + float *output_data = output->mutable_data(); + const bool *condition_data = condition->data(); + const float *x_data = x->data(); + const float *y_data = y->data(); + + const index_t condition_size = condition->size(); + const index_t x_size = x->size(); + utils::ThreadPool + &thread_pool = context->device()->cpu_runtime()->thread_pool(); + if (condition_size == x_size) { + thread_pool.Compute1D([=](index_t start, index_t end, index_t step) { + for (index_t k = start; k < end; k += step) { + // LOG(INFO) << "condition_data[" << k << "] = " << condition_data[k]; + output_data[k] = condition_data[k] ? x_data[k] : y_data[k]; + } + }, 0, x_size, 1); + } else if (x_size > condition_size) { // broadcast + const auto block_size = x_size / condition_size; + MACE_ASSERT( + block_size > 1 && x_size % condition_size == 0, + "x_size should be a multiple of condition_size and greater than 1"); + const auto raw_block_size = block_size * sizeof(float); + thread_pool.Compute1D([=](index_t start, index_t end, index_t step) { + for (index_t k = start; k < end; k += step) { + auto offset = block_size * k; + if (condition_data[k]) { + memcpy(output_data + offset, x_data + offset, raw_block_size); + } else { + memcpy(output_data + offset, y_data + offset, raw_block_size); + } + } + }, 0, condition_size, 1); + } else { + MACE_CHECK(false, "x_size should be bigger than condition_size"); + } + + return MaceStatus::MACE_SUCCESS; + } + + private: + MACE_OP_INPUT_TAGS(CONDITION, X, Y); + MACE_OP_OUTPUT_TAGS(OUTPUT); +}; + +void RegisterSelect(OpRegistryBase *op_registry) { + MACE_REGISTER_OP(op_registry, "Select", SelectOp, + DeviceType::CPU, float); +} + +} // namespace ops +} // namespace mace diff --git a/test/ccunit/mace/ops/select_test.cc b/test/ccunit/mace/ops/select_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..f7f20be42fa2969a6873b39defc455f6bfda2f8b --- /dev/null +++ b/test/ccunit/mace/ops/select_test.cc @@ -0,0 +1,151 @@ +// Copyright 2019 The MACE Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mace/ops/ops_test_util.h" + +namespace mace { +namespace ops { +namespace test { + +class SelectOpTest : public OpsTestBase {}; + +namespace { +template +void TestSelect(const std::vector &input_shape, + const std::vector &input, + const std::vector &x_shape, + const std::vector &x, + const std::vector &y_shape, + const std::vector &y, + const std::vector &output_shape, + const std::vector &output) { + // Construct graph + OpsTestNet net; + OpDefBuilder builder("Select", "SelectTest"); + builder.Input("Input"); + if (x.size() > 0) { + builder.Input("X").Input("Y"); + } + builder.Output("Output").Finalize(net.NewOperatorDef()); + + net.AddInputFromArray(MakeString("Input"), input_shape, input); + if (x.size() > 0) { + net.AddInputFromArray(MakeString("X"), x_shape, x); + net.AddInputFromArray(MakeString("Y"), y_shape, y); + } + + // Run + net.RunOp(); + + net.AddInputFromArray("ExpectedOutput", output_shape, output); + ExpectTensorNear(*net.GetOutput("ExpectedOutput"), + *net.GetOutput("Output")); +} +} // namespace + + +TEST_F(SelectOpTest, SimpleTestWithData) { + TestSelect( + {2, 3}, + {true, false, false, false, true, true}, + {2, 3}, + {3.0, 2.0, 3.0, 4.0, 5.0, 6.0}, + {2, 3}, + {3.0, -1.0, -2.0, -3.0, 8.0, 9.0}, + {2, 3}, + {3.0, -1.0, -2.0, -3.0, 5.0, 6.0}); +} + +TEST_F(SelectOpTest, SimpleTestWithDataBroadcast) { + TestSelect( + {2}, + {true, false}, + {2, 3}, + {3.0, 2.0, 3.0, 4.0, 5.0, 6.0}, + {2, 3}, + {3.0, -1.0, -2.0, -3.0, 8.0, 9.0}, + {2, 3}, + {3, 2, 3, -3, 8, 9}); +} + +TEST_F(SelectOpTest, SimpleTestWithNoDataBroadcast1) { + TestSelect( + {2}, + {true, false}, + {}, {}, {}, {}, + {1, 1}, + {0}); +} + +TEST_F(SelectOpTest, SimpleTestWithNoDataBroadcast2) { + TestSelect( + {2, 3}, + {true, false, false, false, true, true}, + {}, {}, {}, {}, + {3, 2}, + {0, 0, 1, 1, 1, 2}); +} + +TEST_F(SelectOpTest, SimpleTestWithNoDataBroadcast3) { + TestSelect( + {2, 2, 3}, + {true, false, false, false, true, true, + true, false, false, false, true, true}, + {}, {}, {}, {}, + {6, 3}, + {0, 0, 0, 0, 1, 1, 0, 1, 2, + 1, 0, 0, 1, 1, 1, 1, 1, 2}); +} + +TEST_F(SelectOpTest, SimpleTestWithNoDataBroadcast4) { + TestSelect( + {2, 2, 2, 3}, + {true, false, false, false, true, true, + true, false, false, false, true, true, + true, false, false, false, true, true, + true, false, false, false, true, true}, + {}, {}, {}, {}, + {12, 4}, + {0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 2, + 0, 1, 0, 0, 0, 1, 1, 1, 0, 1, 1, 2, + 1, 0, 0, 0, 1, 0, 1, 1, 1, 0, 1, 2, + 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 2}); +} + +TEST_F(SelectOpTest, SimpleTestWithNoDataBroadcast5) { + TestSelect( + {2, 2, 2, 2, 3}, + {true, false, false, false, true, true, + true, false, false, false, true, true, + true, false, false, false, true, true, + true, false, false, false, true, true, + true, false, false, false, true, true, + true, false, false, false, true, true, + true, false, false, false, true, true, + true, false, false, false, true, true}, + {}, {}, {}, {}, + {24, 5}, + {0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 2, + 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 2, + 0, 1, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1, 2, + 0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 2, + 1, 0, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 2, + 1, 0, 1, 0, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 2, + 1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 2, + 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2}); +} + +} // namespace test +} // namespace ops +} // namespace mace diff --git a/tools/python/transform/base_converter.py b/tools/python/transform/base_converter.py index e5b7bb20e2b879592c6d73d30e01e083ff355df7..6db141cedc91e4141984bf2cf2fddbac6524af14 100644 --- a/tools/python/transform/base_converter.py +++ b/tools/python/transform/base_converter.py @@ -136,6 +136,7 @@ MaceSupportedOps = [ 'ResizeNearestNeighbor', 'Reverse', 'ScalarMath', + 'Select', 'Slice', 'Splice', 'Split', @@ -280,7 +281,7 @@ class MaceKeyword(object): class TransformerRule(Enum): - REMOVE_IDENTITY_OP = 1 + REMOVE_USELESS_OP = 1 TRANSFORM_GLOBAL_POOLING = 2 FOLD_RESHAPE = 3 TRANSFORM_MATMUL_TO_FC = 4 @@ -526,9 +527,9 @@ class ConverterOption(object): else: self._transformer_option = [ # Model structure related transformation - TransformerRule.REMOVE_IDENTITY_OP, + TransformerRule.REMOVE_USELESS_OP, TransformerRule.TRANSFORM_FAKE_QUANTIZE, - TransformerRule.REMOVE_IDENTITY_OP, + TransformerRule.REMOVE_USELESS_OP, TransformerRule.TRANSFORM_GLOBAL_POOLING, TransformerRule.TRANSFORM_LSTMCELL_ZEROSTATE, TransformerRule.TRANSFORM_BASIC_LSTMCELL, diff --git a/tools/python/transform/tensorflow_converter.py b/tools/python/transform/tensorflow_converter.py index e91cab8ae8919115ab825d8081b58637bba58ac5..73a62dd73cce88da5cb77b8b6d824c232af99da2 100644 --- a/tools/python/transform/tensorflow_converter.py +++ b/tools/python/transform/tensorflow_converter.py @@ -57,6 +57,7 @@ TFSupportedOps = [ 'ArgMax', 'AvgPool', 'BatchMatMul', + 'BatchMatMulV2', 'BatchToSpaceND', 'BiasAdd', 'Cast', @@ -105,6 +106,7 @@ TFSupportedOps = [ 'ResizeNearestNeighbor', 'ReverseV2', 'Rsqrt', + 'Select', 'Shape', 'Sigmoid', 'Sign', @@ -134,7 +136,7 @@ TFSupportedOps = [six.b(op) for op in TFSupportedOps] TFTransformGraphOptions = [ 'strip_unused_nodes', - 'remove_nodes(op=Identity, op=CheckNumerics)', + 'remove_nodes(op=Identity, op=CheckNumerics, op=StopGradient)', 'fold_constants(ignore_errors=true)', 'fold_batch_norms', 'fold_old_batch_norms', @@ -211,6 +213,7 @@ class TensorflowConverter(base_converter.ConverterInterface): TFOpType.ArgMax.name: self.convert_argmax, TFOpType.AvgPool.name: self.convert_pooling, TFOpType.BatchMatMul.name: self.convert_matmul, + TFOpType.BatchMatMulV2.name: self.convert_matmul, TFOpType.BatchToSpaceND.name: self.convert_space_batch, TFOpType.BiasAdd.name: self.convert_biasadd, TFOpType.Cast.name: self.convert_cast, @@ -263,6 +266,7 @@ class TensorflowConverter(base_converter.ConverterInterface): TFOpType.ResizeBilinear.name: self.convert_resize_bilinear, TFOpType.ResizeNearestNeighbor.name: self.convert_resize_nearest_neighbor, # noqa TFOpType.ReverseV2.name: self.convert_reverse, + TFOpType.Select.name: self.convert_select, TFOpType.Shape.name: self.convert_shape, TFOpType.Sigmoid.name: self.convert_activation, TFOpType.Sign.name: self.convert_elementwise, @@ -993,6 +997,10 @@ class TensorflowConverter(base_converter.ConverterInterface): op = self.convert_general_op(tf_op) op.type = MaceOp.Reverse.name + def convert_select(self, tf_op): + op = self.convert_general_op(tf_op) + op.type = MaceOp.Select.name + def convert_stack(self, tf_op): op = self.convert_general_op(tf_op) op.type = MaceOp.Stack.name diff --git a/tools/python/transform/transformer.py b/tools/python/transform/transformer.py index bbf14b4f56617a475d2eebc5395f7ef04f0f9dee..1c69a07e9864d379309a6f65a740be2986fb03a3 100644 --- a/tools/python/transform/transformer.py +++ b/tools/python/transform/transformer.py @@ -48,7 +48,7 @@ class Transformer(base_converter.ConverterInterface): self._registered_transformers = { TransformerRule.TRANSFORM_FAKE_QUANTIZE: self.transform_fake_quantize, - TransformerRule.REMOVE_IDENTITY_OP: self.remove_identity_op, + TransformerRule.REMOVE_USELESS_OP: self.remove_useless_op, TransformerRule.TRANSFORM_GLOBAL_POOLING: self.transform_global_pooling, TransformerRule.TRANSFORM_LSTMCELL_ZEROSTATE: @@ -347,15 +347,15 @@ class Transformer(base_converter.ConverterInterface): return False - def remove_identity_op(self): + def remove_useless_op(self): net = self._model for op in net.op: if op.type == 'Identity': - print("Remove identity: %s(%s)" % (op.name, op.type)) + print("Remove useless op: %s(%s)" % (op.name, op.type)) self.safe_remove_node(op, self._producer.get(op.input[0], None)) return True - if op.type == 'Reshape' and \ + elif op.type == 'Reshape' and \ op.output_shape[0].dims == \ self.get_tensor_shape(op.input[0]): print("Remove useless reshape: %s(%s)" % (op.name, op.type))