From db4e94e379b408ed084cb635bdd8cbd6341b96d6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E5=AF=85?= Date: Fri, 22 Jun 2018 14:28:07 +0800 Subject: [PATCH] Support nlp model and ops --- .../create_a_model_deployment.rst | 2 + mace/core/mace.cc | 3 - mace/core/net.cc | 4 +- mace/core/operator.cc | 2 + mace/kernels/bias_add.h | 67 ++++-- mace/kernels/eltwise.h | 209 ++++++++++++++++-- mace/kernels/opencl/bias_add.cc | 3 + mace/ops/bias_add.h | 6 +- mace/ops/bias_add_benchmark.cc | 1 + mace/ops/bias_add_test.cc | 3 + mace/ops/cast.cc | 34 +++ mace/ops/cast.h | 62 ++++++ mace/ops/cast_test.cc | 66 ++++++ mace/ops/concat.cc | 6 +- mace/ops/eltwise_test.cc | 59 +++++ mace/ops/reshape.cc | 6 + mace/ops/stack.cc | 5 + mace/ops/strided_slice.cc | 5 + mace/python/tools/converter.py | 10 +- .../tools/converter_tool/base_converter.py | 55 +++-- .../tools/converter_tool/caffe_converter.py | 4 + .../converter_tool/tensorflow_converter.py | 38 +++- .../tools/converter_tool/transformer.py | 84 +++---- mace/python/tools/memory_optimizer.py | 31 ++- mace/utils/string_util.h | 8 +- tools/converter.py | 14 +- tools/generate_data.py | 26 ++- 27 files changed, 685 insertions(+), 128 deletions(-) create mode 100644 mace/ops/cast.cc create mode 100644 mace/ops/cast.h create mode 100644 mace/ops/cast_test.cc diff --git a/docs/getting_started/create_a_model_deployment.rst b/docs/getting_started/create_a_model_deployment.rst index 4bcb3416..184d1101 100644 --- a/docs/getting_started/create_a_model_deployment.rst +++ b/docs/getting_started/create_a_model_deployment.rst @@ -65,6 +65,8 @@ Configurations - The shapes of the input tensors, in NHWC order. * - output_shapes - The shapes of the output tensors, in NHWC order. + * - input_ranges + - The numerical range of the input tensors, default [-1, 1]. It is only for test. * - validation_inputs_data - [optional] Specify Numpy validation inputs. When not provided, [-1, 1] random values will be used. * - runtime diff --git a/mace/core/mace.cc b/mace/core/mace.cc index 274f8245..000c85b7 100644 --- a/mace/core/mace.cc +++ b/mace/core/mace.cc @@ -217,9 +217,6 @@ MaceStatus MaceEngine::Impl::Run( << "' is not belong to model's inputs: " << MakeString(MapKeys(input_info_map_)); } - MACE_CHECK(input.second.shape().size() == 4, - "The Inputs' shape must be 4-dimension with NHWC format," - " please use 1 to fill missing dimensions"); Tensor *input_tensor = ws_->GetTensor(MakeString("mace_input_node_", input.first)); MACE_RETURN_IF_ERROR(input_tensor->Resize(input.second.shape())); diff --git a/mace/core/net.cc b/mace/core/net.cc index 346ca354..2f570319 100644 --- a/mace/core/net.cc +++ b/mace/core/net.cc @@ -63,7 +63,9 @@ MaceStatus SerialNet::Run(RunMetadata *run_metadata) { for (auto iter = operators_.begin(); iter != operators_.end(); ++iter) { auto &op = *iter; MACE_LATENCY_LOGGER(2, "Running operator ", op->debug_def().name(), "(", - op->debug_def().type(), ")"); + op->debug_def().type(), "), mem_id: ", + MakeListString(op->debug_def().mem_id().data(), + op->debug_def().mem_id().size())); bool future_wait = (device_type_ == DeviceType::GPU && (run_metadata != nullptr || std::distance(iter, operators_.end()) == 1)); diff --git a/mace/core/operator.cc b/mace/core/operator.cc index cde9baa9..3e3d2ba9 100644 --- a/mace/core/operator.cc +++ b/mace/core/operator.cc @@ -79,6 +79,7 @@ extern void Register_AddN(OperatorRegistry *op_registry); extern void Register_BatchNorm(OperatorRegistry *op_registry); extern void Register_BatchToSpaceND(OperatorRegistry *op_registry); extern void Register_BiasAdd(OperatorRegistry *op_registry); +extern void Register_Cast(OperatorRegistry *op_registry); extern void Register_ChannelShuffle(OperatorRegistry *op_registry); extern void Register_Concat(OperatorRegistry *op_registry); extern void Register_Conv2D(OperatorRegistry *op_registry); @@ -127,6 +128,7 @@ OperatorRegistry::OperatorRegistry() { ops::Register_BatchNorm(this); ops::Register_BatchToSpaceND(this); ops::Register_BiasAdd(this); + ops::Register_Cast(this); ops::Register_ChannelShuffle(this); ops::Register_Concat(this); ops::Register_Conv2D(this); diff --git a/mace/kernels/bias_add.h b/mace/kernels/bias_add.h index cf09c8a5..1cd8421c 100644 --- a/mace/kernels/bias_add.h +++ b/mace/kernels/bias_add.h @@ -15,6 +15,7 @@ #ifndef MACE_KERNELS_BIAS_ADD_H_ #define MACE_KERNELS_BIAS_ADD_H_ +#include #include #include @@ -29,20 +30,27 @@ namespace mace { namespace kernels { -template +struct BiasAddFunctorBase { + explicit BiasAddFunctorBase(const DataFormat data_format) { + data_format_ = data_format; + } + + DataFormat data_format_; +}; + +template struct BiasAddFunctor; -template<> -struct BiasAddFunctor { +template <> +struct BiasAddFunctor : BiasAddFunctorBase { + explicit BiasAddFunctor(const DataFormat data_format) + : BiasAddFunctorBase(data_format) {} + MaceStatus operator()(const Tensor *input, - const Tensor *bias, - Tensor *output, - StatsFuture *future) { + const Tensor *bias, + Tensor *output, + StatsFuture *future) { MACE_UNUSED(future); - const index_t batch = input->dim(0); - const index_t channels = input->dim(1); - const index_t height = input->dim(2); - const index_t width = input->dim(3); Tensor::MappingGuard input_mapper(input); Tensor::MappingGuard bias_mapper(bias); @@ -52,12 +60,31 @@ struct BiasAddFunctor { const float *bias_ptr = bias->data(); float *output_ptr = output->mutable_data(); + if (input->dim_size() == 4 && data_format_ == NCHW) { + const index_t batch = input->dim(0); + const index_t channels = input->dim(1); + const index_t height_width = input->dim(2) * input->dim(3); + #pragma omp parallel for collapse(2) - for (index_t n = 0; n < batch; ++n) { - for (index_t c = 0; c < channels; ++c) { - for (index_t hw = 0; hw < height * width; ++hw) { - index_t pos = (n * channels + c) * height * width + hw; + for (index_t n = 0; n < batch; ++n) { + for (index_t c = 0; c < channels; ++c) { + for (index_t hw = 0; hw < height_width; ++hw) { + index_t pos = (n * channels + c) * height_width + hw; + output_ptr[pos] = input_ptr[pos] + bias_ptr[c]; + } + } + } + } else { + const std::vector &shape = input->shape(); + const index_t fused_batch = std::accumulate( + shape.begin(), shape.end() - 1, 1, std::multiplies()); + const index_t channels = *shape.rbegin(); +#pragma omp parallel for + for (index_t n = 0; n < fused_batch; ++n) { + index_t pos = n * channels; + for (index_t c = 0; c < channels; ++c) { output_ptr[pos] = input_ptr[pos] + bias_ptr[c]; + ++pos; } } } @@ -67,12 +94,14 @@ struct BiasAddFunctor { }; #ifdef MACE_ENABLE_OPENCL -template -struct BiasAddFunctor { +template +struct BiasAddFunctor : BiasAddFunctorBase { + explicit BiasAddFunctor(const DataFormat data_format) + : BiasAddFunctorBase(data_format) {} MaceStatus operator()(const Tensor *input, - const Tensor *bias, - Tensor *output, - StatsFuture *future); + const Tensor *bias, + Tensor *output, + StatsFuture *future); cl::Kernel kernel_; uint32_t kwg_size_; std::unique_ptr kernel_error_; diff --git a/mace/kernels/eltwise.h b/mace/kernels/eltwise.h index bfcfa9ce..05695fd8 100644 --- a/mace/kernels/eltwise.h +++ b/mace/kernels/eltwise.h @@ -19,6 +19,7 @@ #include #include #include +#include #include "mace/core/future.h" #include "mace/core/tensor.h" @@ -44,6 +45,157 @@ enum EltwiseType { NONE = 10, }; +inline index_t GetIndex(const std::vector &shape, + const std::vector &index) { + index_t idx = 0; + for (size_t i = 0; i < shape.size(); ++i) { + if (shape[i] > 1) { + idx = idx * shape[i] + index[i]; + } + } + return idx; +} + +inline void IncreaseIndex(const std::vector &shape, + std::vector *index) { + for (index_t i = static_cast(shape.size()) - 1; i >= 0; --i) { + ++(*index)[i]; + if ((*index)[i] >= shape[i]) { + (*index)[i] -= shape[i]; + } else { + break; + } + } +} + +inline void TensorGeneralBroadcastEltwise(const EltwiseType type, + const float *input0, + const float *input1, + const std::vector &coeff, + const bool swapped, + const std::vector + &input0_shape, + const std::vector + &input1_shape, + const std::vector + &output_shape, + float *output) { + const index_t output_size = std::accumulate(output_shape.begin(), + output_shape.end(), + 1, + std::multiplies()); + std::vector out_index(output_shape.size(), 0); + switch (type) { + case SUM: + if (coeff.empty()) { + for (index_t i = 0; i < output_size; ++i) { + const index_t idx0 = GetIndex(input0_shape, out_index); + const index_t idx1 = GetIndex(input1_shape, out_index); + output[i] = input0[idx0] + input1[idx1]; + IncreaseIndex(output_shape, &out_index); + } + } else { + std::vector coeff_copy = coeff; + if (swapped) { + std::swap(coeff_copy[0], coeff_copy[1]); + } + for (index_t i = 0; i < output_size; ++i) { + const index_t idx0 = GetIndex(input0_shape, out_index); + const index_t idx1 = GetIndex(input1_shape, out_index); + output[i] = + input0[idx0] * coeff_copy[0] + input1[idx1] * coeff_copy[1]; + IncreaseIndex(output_shape, &out_index); + } + } + break; + case SUB: + if (!swapped) { + for (index_t i = 0; i < output_size; ++i) { + const index_t idx0 = GetIndex(input0_shape, out_index); + const index_t idx1 = GetIndex(input1_shape, out_index); + output[i] = input0[idx0] - input1[idx1]; + IncreaseIndex(output_shape, &out_index); + } + } else { + for (index_t i = 0; i < output_size; ++i) { + const index_t idx0 = GetIndex(input0_shape, out_index); + const index_t idx1 = GetIndex(input1_shape, out_index); + output[i] = input1[idx1] - input0[idx0]; + IncreaseIndex(output_shape, &out_index); + } + } + break; + case PROD: + for (index_t i = 0; i < output_size; ++i) { + const index_t idx0 = GetIndex(input0_shape, out_index); + const index_t idx1 = GetIndex(input1_shape, out_index); + output[i] = input0[idx0] * input1[idx1]; + IncreaseIndex(output_shape, &out_index); + } + break; + case DIV: + if (!swapped) { + for (index_t i = 0; i < output_size; ++i) { + const index_t idx0 = GetIndex(input0_shape, out_index); + const index_t idx1 = GetIndex(input1_shape, out_index); + output[i] = input0[idx0] / input1[idx1]; + IncreaseIndex(output_shape, &out_index); + } + } else { + for (index_t i = 0; i < output_size; ++i) { + const index_t idx0 = GetIndex(input0_shape, out_index); + const index_t idx1 = GetIndex(input1_shape, out_index); + output[i] = input1[idx1] / input0[idx0]; + IncreaseIndex(output_shape, &out_index); + } + } + break; + case MIN: + for (index_t i = 0; i < output_size; ++i) { + const index_t idx0 = GetIndex(input0_shape, out_index); + const index_t idx1 = GetIndex(input1_shape, out_index); + output[i] = std::min(input1[idx1], input0[idx0]); + IncreaseIndex(output_shape, &out_index); + } + break; + case MAX: + for (index_t i = 0; i < output_size; ++i) { + const index_t idx0 = GetIndex(input0_shape, out_index); + const index_t idx1 = GetIndex(input1_shape, out_index); + output[i] = std::max(input1[idx1], input0[idx0]); + IncreaseIndex(output_shape, &out_index); + } + break; + case SQR_DIFF: + for (index_t i = 0; i < output_size; ++i) { + const index_t idx0 = GetIndex(input0_shape, out_index); + const index_t idx1 = GetIndex(input1_shape, out_index); + output[i] = std::pow(input1[idx1] - input0[idx0], 2.f); + IncreaseIndex(output_shape, &out_index); + } + break; + case POW: + if (!swapped) { + for (index_t i = 0; i < output_size; ++i) { + const index_t idx0 = GetIndex(input0_shape, out_index); + const index_t idx1 = GetIndex(input1_shape, out_index); + output[i] = std::pow(input0[idx0], input1[idx1]); + IncreaseIndex(output_shape, &out_index); + } + } else { + for (index_t i = 0; i < output_size; ++i) { + const index_t idx0 = GetIndex(input0_shape, out_index); + const index_t idx1 = GetIndex(input1_shape, out_index); + output[i] = std::pow(input1[idx1], input0[idx0]); + IncreaseIndex(output_shape, &out_index); + } + } + break; + default: + LOG(FATAL) << "Eltwise op not support type " << type; + } +} + inline void TensorBroadcastEltwise(const EltwiseType type, const float *input0, const float *input1, @@ -662,40 +814,71 @@ struct EltwiseFunctor : EltwiseFunctorBase { && input1->dim(0) == input0->dim(1))), "only support broadcast channel dimension"); } else { - if (rank_diff > 0 && rank_diff < input0->dim_size()) { - for (uint32_t i = 0; i < input1->dim_size(); ++i) { - MACE_CHECK(input0->dim(rank_diff + i) == input1->dim(i), - "Element-Wise op only support tail dimensions broadcast"); - } + for (uint32_t i = 0; i < input1->dim_size(); ++i) { + MACE_CHECK(input0->dim(rank_diff + i) == 1 + || input1->dim(i) == 1 + || input0->dim(rank_diff + i) == input1->dim(i), + "Element-Wise op only support tail dimensions broadcast"); } } - index_t common_size = input1->size(); - index_t diff_size = input0->size() / common_size; - - MACE_RETURN_IF_ERROR(output->ResizeLike(input0)); - Tensor::MappingGuard input0_guard(input0); Tensor::MappingGuard input1_guard(input1); - Tensor::MappingGuard output_guard(output); const float *input0_ptr = input0->data(); const float *input1_ptr = input1->data(); - float *output_ptr = output->mutable_data(); if (data_format_ == NCHW && input1->dim_size() > 0 && input1->size() < input0->size()) { + MACE_RETURN_IF_ERROR(output->ResizeLike(input0)); + Tensor::MappingGuard output_guard(output); + float *output_ptr = output->mutable_data(); TensorEltwisePerChannel( type_, input0_ptr, input1_ptr, coeff_, input0->dim(0), input1->dim_size() == 1 ? 1 : input1->dim(0), input0->dim(1), input0->dim(2) * input0->dim(3), swapped, output_ptr); } else { - if (input1->size() == input0->size()) { + const std::vector &input0_shape = input0->shape(); + std::vector input1_shape(rank_diff, 1); + input1_shape.insert(input1_shape.end(), + input1->shape().begin(), + input1->shape().end()); + + std::vector output_shape(input0->dim_size(), 0); + for (unsigned int i = 0; i < input0_shape.size(); ++i) { + output_shape[i] = std::max(input0_shape[i], input1_shape[i]); + } + MACE_RETURN_IF_ERROR(output->Resize(output_shape)); + Tensor::MappingGuard output_guard(output); + float *output_ptr = output->mutable_data(); + + bool need_general_broadcast = false; + for (uint32_t i = 0; i < input1->dim_size(); ++i) { + if ((input0->dim(rank_diff + i) == 1 && input1->dim(i) > 1) + || (input0->dim(rank_diff + i) > 1 && input1->dim(i) == 1)) { + need_general_broadcast = true; + break; + } + } + + if (need_general_broadcast) { + TensorGeneralBroadcastEltwise(type_, + input0_ptr, + input1_ptr, + coeff_, + swapped, + input0_shape, + input1_shape, + output_shape, + output_ptr); + } else if (input1->size() == input0->size()) { TensorEltwise(type_, input0_ptr, input1_ptr, coeff_, input0->size(), swapped, output_ptr); } else if (input1->size() < input0->size()) { if (input1->size() > 1) { + index_t common_size = input1->size(); + index_t diff_size = input0->size() / common_size; TensorBroadcastEltwise(type_, input0_ptr, input1_ptr, coeff_, diff_size, common_size, swapped, output_ptr); } else { diff --git a/mace/kernels/opencl/bias_add.cc b/mace/kernels/opencl/bias_add.cc index 63fd1033..d1e58bd3 100644 --- a/mace/kernels/opencl/bias_add.cc +++ b/mace/kernels/opencl/bias_add.cc @@ -26,6 +26,9 @@ MaceStatus BiasAddFunctor::operator()(const Tensor *input, const Tensor *bias, Tensor *output, StatsFuture *future) { + MACE_CHECK(input->dim_size() == 4 && data_format_ == NHWC, + "gpu only support biasadd for 4-dimensional NHWC format tensor"); + const index_t batch = input->dim(0); const index_t height = input->dim(1); const index_t width = input->dim(2); diff --git a/mace/ops/bias_add.h b/mace/ops/bias_add.h index cc9c4bd9..901c1e74 100644 --- a/mace/ops/bias_add.h +++ b/mace/ops/bias_add.h @@ -25,14 +25,14 @@ template class BiasAddOp : public Operator { public: BiasAddOp(const OperatorDef &operator_def, Workspace *ws) - : Operator(operator_def, ws), functor_() {} + : Operator(operator_def, ws), + functor_(static_cast(OperatorBase::GetOptionalArg( + "data_format", NHWC))) {} MaceStatus Run(StatsFuture *future) override { const Tensor *input = this->Input(INPUT); const Tensor *bias = this->Input(BIAS); - MACE_CHECK(input->dim_size() == 4, "input must be 4-dimensional. ", - input->dim_size()); MACE_CHECK(bias->dim_size() == 1, "bias must be 1-dimensional. ", bias->dim_size()); diff --git a/mace/ops/bias_add_benchmark.cc b/mace/ops/bias_add_benchmark.cc index c0f6ad72..ca8500ed 100644 --- a/mace/ops/bias_add_benchmark.cc +++ b/mace/ops/bias_add_benchmark.cc @@ -42,6 +42,7 @@ void BiasAdd(int iters, int batch, int channels, int height, int width) { OpDefBuilder("BiasAdd", "BiasAddBM") .Input("Input") .Input("Bias") + .AddIntArg("data_format", NCHW) .Output("Output") .Finalize(net.NewOperatorDef()); } else if (D == DeviceType::GPU) { diff --git a/mace/ops/bias_add_test.cc b/mace/ops/bias_add_test.cc index c030b8f0..c4158454 100644 --- a/mace/ops/bias_add_test.cc +++ b/mace/ops/bias_add_test.cc @@ -37,6 +37,7 @@ void BiasAddSimple() { OpDefBuilder("BiasAdd", "BiasAddTest") .Input("InputNCHW") .Input("Bias") + .AddIntArg("data_format", NCHW) .Output("OutputNCHW") .Finalize(net.NewOperatorDef()); // Run @@ -99,6 +100,7 @@ TEST_F(BiasAddOpTest, SimpleRandomOPENCL) { OpDefBuilder("BiasAdd", "BiasAddTest") .Input("InputNCHW") .Input("Bias") + .AddIntArg("data_format", NCHW) .Output("OutputNCHW") .Finalize(net.NewOperatorDef()); @@ -155,6 +157,7 @@ TEST_F(BiasAddOpTest, ComplexRandomOPENCL) { OpDefBuilder("BiasAdd", "BiasAddTest") .Input("InputNCHW") .Input("Bias") + .AddIntArg("data_format", NCHW) .Output("OutputNCHW") .Finalize(net.NewOperatorDef()); diff --git a/mace/ops/cast.cc b/mace/ops/cast.cc new file mode 100644 index 00000000..556a79f8 --- /dev/null +++ b/mace/ops/cast.cc @@ -0,0 +1,34 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mace/ops/cast.h" + +namespace mace { +namespace ops { + +void Register_Cast(OperatorRegistry *op_registry) { + MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Cast") + .Device(DeviceType::CPU) + .TypeConstraint("T") + .Build(), + CastOp); + MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Cast") + .Device(DeviceType::CPU) + .TypeConstraint("T") + .Build(), + CastOp); +} + +} // namespace ops +} // namespace mace diff --git a/mace/ops/cast.h b/mace/ops/cast.h new file mode 100644 index 00000000..a8b283d0 --- /dev/null +++ b/mace/ops/cast.h @@ -0,0 +1,62 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MACE_OPS_CAST_H_ +#define MACE_OPS_CAST_H_ + +#include + +#include "mace/core/operator.h" + +namespace mace { +namespace ops { + +template +class CastOp : public Operator { + public: + CastOp(const OperatorDef &op_def, Workspace *ws) + : Operator(op_def, ws) {} + + MaceStatus Run(StatsFuture *future) override { + MACE_UNUSED(future); + const Tensor *input = this->Input(INPUT); + Tensor *output = this->Output(OUTPUT); + MACE_RETURN_IF_ERROR(output->ResizeLike(input)) + + Tensor::MappingGuard input_guard(input); + Tensor::MappingGuard output_guard(output); + auto src_dtype = input->dtype(); + + auto output_data = output->mutable_data
(); + +#define MACE_CAST_COPY \ + auto input_data = input->data(); \ + for (index_t i = 0; i < output->size(); ++i) { \ + output_data[i] = static_cast
(input_data[i]); \ + } + + MACE_RUN_WITH_TYPE_ENUM(src_dtype, MACE_CAST_COPY); + + return MACE_SUCCESS; + } + + private: + MACE_OP_INPUT_TAGS(INPUT); + MACE_OP_OUTPUT_TAGS(OUTPUT); +}; + +} // namespace ops +} // namespace mace + +#endif // MACE_OPS_CAST_H_ diff --git a/mace/ops/cast_test.cc b/mace/ops/cast_test.cc new file mode 100644 index 00000000..e483f429 --- /dev/null +++ b/mace/ops/cast_test.cc @@ -0,0 +1,66 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gmock/gmock.h" +#include "mace/core/operator.h" +#include "mace/ops/ops_test_util.h" + +namespace mace { +namespace ops { +namespace test { + +class CastOpTest : public OpsTestBase {}; + +namespace { +template +void TestCast(const std::vector &input_shape, + const std::vector &input) { + // Construct graph + OpsTestNet net; + OpDefBuilder("Cast", "CastTest") + .Input("Input") + .Output("Output") + .AddIntArg("T", DataTypeToEnum::v()) + .Finalize(net.NewOperatorDef()); + + // Add input data + net.AddInputFromArray("Input", input_shape, input); + + // Run + net.RunOp(); + + auto input_tensor = net.GetTensor("Input"); + auto output_tensor = net.GetTensor("Output"); + + EXPECT_THAT(output_tensor->shape(), ::testing::ContainerEq(input_shape)); + + const int size = output_tensor->size(); + for (int i = 0; i < size; ++i) { + Expector::Near( + *input_tensor, *output_tensor, 1e-5, 1.f); + } +} +} // namespace + +TEST_F(CastOpTest, TestCastFromFloatToInt32) { + TestCast({1, 2, 3}, {1.1, 2.2, 3.3, 4.4, 5.5, 6.6}); +} + +TEST_F(CastOpTest, TestCastFromInt32ToFloat) { + TestCast({1, 2, 3}, {1, 2, 3, 4, 5, 6}); +} + +} // namespace test +} // namespace ops +} // namespace mace diff --git a/mace/ops/concat.cc b/mace/ops/concat.cc index 0275d497..bf82f796 100644 --- a/mace/ops/concat.cc +++ b/mace/ops/concat.cc @@ -23,7 +23,11 @@ void Register_Concat(OperatorRegistry *op_registry) { .TypeConstraint("T") .Build(), ConcatOp); - + MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Concat") + .Device(DeviceType::CPU) + .TypeConstraint("T") + .Build(), + ConcatOp); #ifdef MACE_ENABLE_OPENCL MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Concat") .Device(DeviceType::GPU) diff --git a/mace/ops/eltwise_test.cc b/mace/ops/eltwise_test.cc index 37666b33..84ec9d06 100644 --- a/mace/ops/eltwise_test.cc +++ b/mace/ops/eltwise_test.cc @@ -135,6 +135,41 @@ void SimpleTensorEltwise(const kernels::EltwiseType type, ExpectTensorNear(*expected, *net.GetOutput("Output"), 1e-5); } + +template +void TensorGeneralBroadcastEltwise(const kernels::EltwiseType type, + const std::vector &shape0, + const std::vector &input0, + const std::vector &shape1, + const std::vector &input1, + const std::vector &output_shape, + const std::vector &output, + const std::vector &coeff = {}) { + // Construct graph + OpsTestNet net; + + // Add input data + net.AddInputFromArray("Input0", shape0, input0); + net.AddInputFromArray("Input1", shape1, input1); + + if (D == DeviceType::CPU) { + auto op_builder = OpDefBuilder("Eltwise", "EltwiseTest") + .Input("Input0") + .Input("Input1") + .AddIntArg("type", static_cast(type)) + .AddFloatsArg("coeff", coeff) + .Output("Output"); + op_builder.Finalize(net.NewOperatorDef()); + + // Run + net.RunOp(D); + } else { + MACE_NOT_IMPLEMENTED; + } + + auto expected = CreateTensor(output_shape, output); + ExpectTensorNear(*expected, *net.GetOutput("Output"), 1e-5); +} } // namespace TEST_F(EltwiseOpTest, CPUSimpleTensorScalar) { @@ -559,6 +594,30 @@ TEST_F(EltwiseOpTest, RandomTensorTensorHalf) { {3, 31, 37, 17}); } +TEST_F(EltwiseOpTest, TensorGeneralBroadcast) { + TensorGeneralBroadcastEltwise( + kernels::EltwiseType::SUM, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, {1, 1, 2, 1}, + {1, 2}, {1, 1, 2, 3}, {2, 3, 4, 6, 7, 8}); + TensorGeneralBroadcastEltwise( + kernels::EltwiseType::SUB, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, {1, 1, 2, 1}, + {1, 2}, {1, 1, 2, 3}, {0, 1, 2, 2, 3, 4}); + TensorGeneralBroadcastEltwise( + kernels::EltwiseType::PROD, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, + {1, 1, 2, 1}, {1, 2}, {1, 1, 2, 3}, {1, 2, 3, 8, 10, 12}); + TensorGeneralBroadcastEltwise( + kernels::EltwiseType::DIV, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, {1, 1, 2, 1}, + {1, 2}, {1, 1, 2, 3}, {1, 2, 3, 2, 2.5, 3}); + TensorGeneralBroadcastEltwise( + kernels::EltwiseType::MIN, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, {1, 1, 2, 1}, + {1, 2}, {1, 1, 2, 3}, {1, 1, 1, 2, 2, 2}); + TensorGeneralBroadcastEltwise( + kernels::EltwiseType::MAX, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, {1, 1, 2, 1}, + {1, 2}, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}); + TensorGeneralBroadcastEltwise( + kernels::EltwiseType::SQR_DIFF, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, + {1, 1, 2, 1}, {1, 2}, {1, 1, 2, 3}, {0, 1, 4, 4, 9, 16}); +} + } // namespace test } // namespace ops } // namespace mace diff --git a/mace/ops/reshape.cc b/mace/ops/reshape.cc index ff0befc2..aefc6337 100644 --- a/mace/ops/reshape.cc +++ b/mace/ops/reshape.cc @@ -23,6 +23,12 @@ void Register_Reshape(OperatorRegistry *op_registry) { .TypeConstraint("T") .Build(), ReshapeOp); + MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Reshape") + .Device(DeviceType::CPU) + .TypeConstraint("T") + .Build(), + ReshapeOp); + #ifdef MACE_ENABLE_OPENCL MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Reshape") diff --git a/mace/ops/stack.cc b/mace/ops/stack.cc index f951460a..992ee408 100644 --- a/mace/ops/stack.cc +++ b/mace/ops/stack.cc @@ -23,6 +23,11 @@ void Register_Stack(OperatorRegistry *op_registry) { .TypeConstraint("T") .Build(), StackOp); + MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Stack") + .Device(DeviceType::CPU) + .TypeConstraint("T") + .Build(), + StackOp); } } // namespace ops diff --git a/mace/ops/strided_slice.cc b/mace/ops/strided_slice.cc index 674c766f..84cf7883 100644 --- a/mace/ops/strided_slice.cc +++ b/mace/ops/strided_slice.cc @@ -23,6 +23,11 @@ void Register_StridedSlice(OperatorRegistry *op_registry) { .TypeConstraint("T") .Build(), StridedSliceOp); + MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("StridedSlice") + .Device(DeviceType::CPU) + .TypeConstraint("T") + .Build(), + StridedSliceOp); } } // namespace ops diff --git a/mace/python/tools/converter.py b/mace/python/tools/converter.py index 86392e5d..55ca3c6f 100644 --- a/mace/python/tools/converter.py +++ b/mace/python/tools/converter.py @@ -108,7 +108,10 @@ def main(unused_args): print("%s does not support dsp runtime yet." % FLAGS.platform) sys.exit(-1) else: - option = cvt.ConverterOption() + if FLAGS.transformers: + option = cvt.ConverterOption(FLAGS.transformers.split(',')) + else: + option = cvt.ConverterOption() option.winograd_enabled = bool(FLAGS.winograd) input_node_names = FLAGS.input_node.split(',') @@ -285,6 +288,11 @@ def parse_args(): type=str, default="fp16_fp32", help="fp16_fp32/fp32_fp32") + parser.add_argument( + "--transformers", + type=str, + default="", + help="model transformers") return parser.parse_known_args() diff --git a/mace/python/tools/converter_tool/base_converter.py b/mace/python/tools/converter_tool/base_converter.py index b5ef56b9..74e6a080 100644 --- a/mace/python/tools/converter_tool/base_converter.py +++ b/mace/python/tools/converter_tool/base_converter.py @@ -74,6 +74,7 @@ MaceSupportedOps = [ 'BatchNorm', 'BatchToSpaceND', 'BiasAdd', + 'Cast', 'ChannelShuffle', 'Concat', 'Conv2D', @@ -177,9 +178,10 @@ class TransformerRule(Enum): TRANSPOSE_DATA_FORMAT = 15 TRANSFORM_GLOBAL_CONV_TO_FC = 16 TRANSFORM_BUFFER_IMAGE = 17 - ADD_DEVICE_AND_DATA_TYPE = 18 + ADD_DEVICE = 18 SORT_BY_EXECUTION = 19 ADD_IN_OUT_TENSOR_INFO = 20 + ADD_MACE_INPUT_AND_OUTPUT_NODES = 21 class ConverterInterface(object): @@ -219,34 +221,39 @@ class NodeInfo(object): class ConverterOption(object): """A class for specifying options passed to converter tool""" - def __init__(self): + def __init__(self, transformers=None): self._input_nodes = {} self._output_nodes = {} self._data_type = mace_pb2.DT_FLOAT self._device = DeviceType.CPU.value self._winograd_enabled = False - self._transformer_option = [ - TransformerRule.REMOVE_IDENTITY_OP, - TransformerRule.TRANSFORM_GLOBAL_POOLING, - TransformerRule.FOLD_RESHAPE, - TransformerRule.TRANSFORM_MATMUL_TO_FC, - TransformerRule.FOLD_BATCHNORM, - TransformerRule.FOLD_CONV_AND_BN, - TransformerRule.FOLD_DEPTHWISE_CONV_AND_BN, - TransformerRule.TRANSFORM_GPU_WINOGRAD, - TransformerRule.TRANSFORM_ADD_TO_BIASADD, - TransformerRule.FOLD_BIASADD, - TransformerRule.FLATTEN_ATROUS_CONV, - TransformerRule.FOLD_ACTIVATION, - TransformerRule.TRANSPOSE_FILTERS, - TransformerRule.TRANSPOSE_DATA_FORMAT, - TransformerRule.ADD_IN_OUT_TENSOR_INFO, - TransformerRule.TRANSFORM_GLOBAL_CONV_TO_FC, - TransformerRule.RESHAPE_FC_WEIGHT, - TransformerRule.TRANSFORM_BUFFER_IMAGE, - TransformerRule.ADD_DEVICE_AND_DATA_TYPE, - TransformerRule.SORT_BY_EXECUTION, - ] + if transformers: + self._transformer_option = [TransformerRule[transformer] + for transformer in transformers] + else: + self._transformer_option = [ + TransformerRule.REMOVE_IDENTITY_OP, + TransformerRule.TRANSFORM_GLOBAL_POOLING, + TransformerRule.FOLD_RESHAPE, + TransformerRule.TRANSFORM_MATMUL_TO_FC, + TransformerRule.FOLD_BATCHNORM, + TransformerRule.FOLD_CONV_AND_BN, + TransformerRule.FOLD_DEPTHWISE_CONV_AND_BN, + TransformerRule.TRANSFORM_GPU_WINOGRAD, + TransformerRule.TRANSFORM_ADD_TO_BIASADD, + TransformerRule.FOLD_BIASADD, + TransformerRule.FLATTEN_ATROUS_CONV, + TransformerRule.FOLD_ACTIVATION, + TransformerRule.TRANSPOSE_FILTERS, + TransformerRule.TRANSPOSE_DATA_FORMAT, + TransformerRule.ADD_IN_OUT_TENSOR_INFO, + TransformerRule.TRANSFORM_GLOBAL_CONV_TO_FC, + TransformerRule.RESHAPE_FC_WEIGHT, + TransformerRule.TRANSFORM_BUFFER_IMAGE, + TransformerRule.ADD_DEVICE, + TransformerRule.ADD_MACE_INPUT_AND_OUTPUT_NODES, + TransformerRule.SORT_BY_EXECUTION, + ] @property def input_nodes(self): diff --git a/mace/python/tools/converter_tool/caffe_converter.py b/mace/python/tools/converter_tool/caffe_converter.py index 9e7c4d31..b39b18b2 100644 --- a/mace/python/tools/converter_tool/caffe_converter.py +++ b/mace/python/tools/converter_tool/caffe_converter.py @@ -341,6 +341,10 @@ class CaffeConverter(base_converter.ConverterInterface): op.input.extend(caffe_op.layer.bottom) op.output.extend(caffe_op.layer.top) + data_type_arg = op.arg.add() + data_type_arg.name = 'T' + data_type_arg.i = self._option.data_type + ConverterUtil.add_data_format_arg(op, DataFormat.NCHW) return op diff --git a/mace/python/tools/converter_tool/tensorflow_converter.py b/mace/python/tools/converter_tool/tensorflow_converter.py index 63d046bd..99c077c9 100644 --- a/mace/python/tools/converter_tool/tensorflow_converter.py +++ b/mace/python/tools/converter_tool/tensorflow_converter.py @@ -92,6 +92,7 @@ TFSupportedOps = [ 'Slice', 'Stack', 'Pack', + 'Cast', ] TFOpType = Enum('TFOpType', [(op, op) for op in TFSupportedOps], type=str) @@ -181,7 +182,8 @@ class TensorflowConverter(base_converter.ConverterInterface): TFOpType.StridedSlice.name: self.convert_stridedslice, TFOpType.Slice.name: self.convert_slice, TFOpType.Pack.name: self.convert_stack, - TFOpType.Stack.name: self.convert_stack + TFOpType.Stack.name: self.convert_stack, + TFOpType.Cast.name: self.convert_cast } self._option = option self._mace_net_def = mace_pb2.NetDef() @@ -300,6 +302,19 @@ class TensorflowConverter(base_converter.ConverterInterface): output_shape = op.output_shape.add() output_shape.dims.extend(self.infer_tensor_shape(tf_output)) + data_type_arg = op.arg.add() + data_type_arg.name = 'T' + try: + dtype = tf_op.get_attr('T') + if dtype == tf.int32: + data_type_arg.i = mace_pb2.DT_INT32 + elif dtype == tf.float32: + data_type_arg.i = self._option.data_type + else: + mace_check(False, "data type %s not supported" % dtype) + except ValueError: + data_type_arg.i = self._option.data_type + ConverterUtil.add_data_format_arg(op, DataFormat.NHWC) return op @@ -367,7 +382,7 @@ class TensorflowConverter(base_converter.ConverterInterface): value_arg.f = tf_op.inputs[0].eval().astype(np.float32) self._skip_tensor.add(tf_op.inputs[0].name) del op.input[0] - elif len(tf_op.inputs[1].shape) == 0: + elif len(tf_op.inputs) > 1 and len(tf_op.inputs[1].shape) == 0: value_arg = op.arg.add() value_arg.name = MaceKeyword.mace_value_str value_arg.f = tf_op.inputs[1].eval().astype(np.float32) @@ -655,6 +670,9 @@ class TensorflowConverter(base_converter.ConverterInterface): def convert_slice(self, tf_op): op = self.convert_general_op(tf_op) op.type = MaceOp.StridedSlice.name + arg = op.arg.add() + arg.name = 'slice' + arg.i = 1 def convert_stack(self, tf_op): op = self.convert_general_op(tf_op) @@ -666,3 +684,19 @@ class TensorflowConverter(base_converter.ConverterInterface): axis_arg.i = tf_op.get_attr(MaceKeyword.mace_axis_str) except ValueError: axis_arg.i = 0 + + def convert_cast(self, tf_op): + op = self.convert_general_op(tf_op) + op.type = MaceOp.Cast.name + + data_type_arg = ConverterUtil.get_arg(op, 'T') + try: + dtype = tf_op.get_attr('DstT') + if dtype == tf.int32: + data_type_arg.i = mace_pb2.DT_INT32 + elif dtype == tf.float32: + data_type_arg.i = self._option.data_type + else: + mace_check(False, "data type %s not supported" % dtype) + except ValueError: + data_type_arg.i = self._option.data_type diff --git a/mace/python/tools/converter_tool/transformer.py b/mace/python/tools/converter_tool/transformer.py index 4a9e3fbe..9fc8a346 100644 --- a/mace/python/tools/converter_tool/transformer.py +++ b/mace/python/tools/converter_tool/transformer.py @@ -53,30 +53,6 @@ class Transformer(base_converter.ConverterInterface): """ def __init__(self, option, model): - # DO NOT reorder the following transformers' order - self._registered_transformers_order = [ - TransformerRule.REMOVE_IDENTITY_OP, - TransformerRule.TRANSFORM_GLOBAL_POOLING, - TransformerRule.FOLD_RESHAPE, - TransformerRule.TRANSFORM_MATMUL_TO_FC, - TransformerRule.FOLD_BATCHNORM, - TransformerRule.FOLD_CONV_AND_BN, - TransformerRule.FOLD_DEPTHWISE_CONV_AND_BN, - TransformerRule.TRANSFORM_GPU_WINOGRAD, - TransformerRule.TRANSFORM_ADD_TO_BIASADD, - TransformerRule.FOLD_BIASADD, - TransformerRule.FOLD_ACTIVATION, - TransformerRule.FLATTEN_ATROUS_CONV, - TransformerRule.FOLD_ACTIVATION, - TransformerRule.TRANSPOSE_FILTERS, - TransformerRule.TRANSPOSE_DATA_FORMAT, - TransformerRule.ADD_IN_OUT_TENSOR_INFO, - TransformerRule.TRANSFORM_GLOBAL_CONV_TO_FC, - TransformerRule.RESHAPE_FC_WEIGHT, - TransformerRule.TRANSFORM_BUFFER_IMAGE, - TransformerRule.ADD_DEVICE_AND_DATA_TYPE, - TransformerRule.SORT_BY_EXECUTION, - ] self._registered_transformers = { TransformerRule.REMOVE_IDENTITY_OP: self.remove_identity_op, TransformerRule.TRANSFORM_GLOBAL_POOLING: @@ -105,8 +81,10 @@ class Transformer(base_converter.ConverterInterface): TransformerRule.RESHAPE_FC_WEIGHT: self.reshape_fc_weight, TransformerRule.TRANSFORM_BUFFER_IMAGE: self.transform_buffer_image, - TransformerRule.ADD_DEVICE_AND_DATA_TYPE: - self.add_device_and_data_type, + TransformerRule.ADD_DEVICE: + self.add_device, + TransformerRule.ADD_MACE_INPUT_AND_OUTPUT_NODES: + self.add_mace_input_and_output_nodes, TransformerRule.SORT_BY_EXECUTION: self.sort_by_execution, } @@ -119,18 +97,18 @@ class Transformer(base_converter.ConverterInterface): self._consumers = {} self._producer = {} self._target_data_format = DataFormat.NHWC + self._input_output_added = False if self._option.device == DeviceType.CPU.value: self._target_data_format = DataFormat.NCHW def run(self): - for key in self._registered_transformers_order: - if key in self._option.transformer_option: - transformer = self._registered_transformers[key] - while True: - self.construct_ops_and_consumers() - changed = transformer() - if not changed: + for key in self._option.transformer_option: + transformer = self._registered_transformers[key] + while True: + self.construct_ops_and_consumers() + changed = transformer() + if not changed: break return self._model @@ -900,6 +878,8 @@ class Transformer(base_converter.ConverterInterface): else: op.type = MaceOp.Identity.name + self._input_output_added = True + return False def transpose_filters(self): @@ -1060,6 +1040,8 @@ class Transformer(base_converter.ConverterInterface): ConverterUtil.add_data_format_arg(op_def, DataFormat.NHWC) + self._input_output_added = True + return False def fold_reshape(self): @@ -1164,16 +1146,13 @@ class Transformer(base_converter.ConverterInterface): in_channels * filter_width * filter_height][:] - def add_device_and_data_type(self): + def add_device(self): # TODO(liuqi) add device definition in OperatorDef net = self._model for op in net.op: arg = op.arg.add() arg.name = MaceKeyword.mace_device arg.i = self._option.device - data_type_arg = op.arg.add() - data_type_arg.name = 'T' - data_type_arg.i = self._option.data_type return False @@ -1188,6 +1167,37 @@ class Transformer(base_converter.ConverterInterface): self.sort_dfs(producer_op, visited, sorted_nodes) sorted_nodes.append(op) + def add_mace_input_and_output_nodes(self): + if self._input_output_added: + return + + print("add mace input and output nodes") + + for input_node in self._option.input_nodes.values(): + new_input_name = MaceKeyword.mace_input_node_name \ + + '_' + input_node.name + op_def = self._model.op.add() + op_def.name = self.normalize_op_name(input_node.name) + op_def.type = MaceOp.Identity.name + op_def.input.extend([new_input_name]) + op_def.output.extend([input_node.name]) + output_shape = op_def.output_shape.add() + output_shape.dims.extend(input_node.shape) + + ConverterUtil.add_data_format_arg(op_def, DataFormat.NHWC) + + for output_node in self._option.output_nodes.values(): + output_name = MaceKeyword.mace_output_node_name \ + + '_' + output_node.name + op_def = self._model.op.add() + op_def.name = self.normalize_op_name(output_name) + op_def.type = MaceOp.Identity.name + op_def.input.extend([output_node.name]) + op_def.output.extend([output_name]) + output_shape = op_def.output_shape.add() + output_shape.dims.extend( + self._producer[output_node.name].output_shape[0].dims) + def sort_by_execution(self): print("Sort by execution") net = self._model diff --git a/mace/python/tools/memory_optimizer.py b/mace/python/tools/memory_optimizer.py index d2325035..6a82f8ad 100644 --- a/mace/python/tools/memory_optimizer.py +++ b/mace/python/tools/memory_optimizer.py @@ -24,7 +24,8 @@ class MemoryOptimizer(object): self.op_mem = {} # op_name->mem_id self.mem_block = {} # mem_id->[size] or mem_id->[x, y] self.total_mem_count = 0 - self.ref_counter = {} + self.input_ref_counter = {} + self.mem_ref_counter = {} consumers = {} for op in net_def.op: @@ -41,9 +42,10 @@ class MemoryOptimizer(object): for output in op.output: tensor_name = output if tensor_name in consumers: - self.ref_counter[tensor_name] = len(consumers[tensor_name]) + self.input_ref_counter[tensor_name] = \ + len(consumers[tensor_name]) else: - self.ref_counter[tensor_name] = 0 + self.input_ref_counter[tensor_name] = 0 def op_need_optimize_memory(self, op): return True @@ -93,8 +95,8 @@ class MemoryOptimizer(object): if not self.op_need_optimize_memory(op): continue if not op.output_shape: - print('WARNING: There is no output shape information to ' - 'do memory optimization.') + print("WARNING: There is no output shape information to " + "do memory optimization. %s (%s)" % (op.name, op.type)) return if len(op.output_shape) != len(op.output): print('WARNING: the number of output shape is not equal to ' @@ -146,16 +148,23 @@ class MemoryOptimizer(object): if mem_id != -1: op.mem_id.extend([mem_id]) self.op_mem[op.output[i]] = mem_id + if mem_id not in self.mem_ref_counter: + self.mem_ref_counter[mem_id] = 1 + else: + self.mem_ref_counter[mem_id] += 1 # de-ref input tensor mem for idx in xrange(len(op.input)): ipt = op.input[idx] - if ipt in self.ref_counter: - self.ref_counter[ipt] -= 1 - if self.ref_counter[ipt] == 0 and \ - (idx > 0 or not self.is_memory_reuse_op(op)): - self.idle_mem.add(self.op_mem[ipt]) - elif self.ref_counter[ipt] < 0: + if ipt in self.input_ref_counter: + self.input_ref_counter[ipt] -= 1 + if self.input_ref_counter[ipt] == 0 \ + and ipt in self.op_mem: + mem_id = self.op_mem[ipt] + self.mem_ref_counter[mem_id] -= 1 + if self.mem_ref_counter[mem_id] == 0: + self.idle_mem.add(self.op_mem[ipt]) + elif self.input_ref_counter[ipt] < 0: raise Exception('ref count is less than 0') self.add_net_mem_blocks() diff --git a/mace/utils/string_util.h b/mace/utils/string_util.h index 576dd8c5..d1f8cb4c 100644 --- a/mace/utils/string_util.h +++ b/mace/utils/string_util.h @@ -54,10 +54,9 @@ std::string MakeString(const Args &... args) { } template -std::string MakeString(const std::vector &args) { +std::string MakeListString(const T *args, size_t size) { std::stringstream ss; ss << "["; - const size_t size = args.size(); for (size_t i = 0; i < size; ++i) { ss << args[i]; if (i < size - 1) { @@ -68,6 +67,11 @@ std::string MakeString(const std::vector &args) { return ss.str(); } +template +std::string MakeString(const std::vector &args) { + return MakeListString(args.data(), args.size()); +} + // Specializations for already-a-string types. template <> inline std::string MakeString(const std::string &str) { diff --git a/tools/converter.py b/tools/converter.py index db2e74ab..19837129 100644 --- a/tools/converter.py +++ b/tools/converter.py @@ -136,6 +136,7 @@ class YAMLKeyword(object): subgraphs = 'subgraphs' input_tensors = 'input_tensors' input_shapes = 'input_shapes' + input_ranges = 'input_ranges' output_tensors = 'output_tensors' output_shapes = 'output_shapes' runtime = 'runtime' @@ -145,6 +146,7 @@ class YAMLKeyword(object): obfuscate = 'obfuscate' winograd = 'winograd' validation_inputs_data = 'validation_inputs_data' + transformers = 'transformers' # keep it private for now class ModuleName(object): @@ -640,7 +642,8 @@ def convert_model(configs): model_config[YAMLKeyword.winograd], model_config[YAMLKeyword.obfuscate], configs[YAMLKeyword.build_type], - data_type) + data_type, + ",".join(model_config.get(YAMLKeyword.transformers, []))) if configs[YAMLKeyword.build_type] == BuildType.proto: sh.mv("-f", @@ -732,7 +735,8 @@ def build_specific_lib(target_abi, target_soc, serial_num, model_output_dir, subgraphs[0][YAMLKeyword.input_tensors], subgraphs[0][YAMLKeyword.input_shapes], - subgraphs[0][YAMLKeyword.validation_inputs_data]) + subgraphs[0][YAMLKeyword.validation_inputs_data], + input_ranges=subgraphs[0].get(YAMLKeyword.input_ranges, None)) device_type = parse_device_type(RuntimeType.gpu) sh_commands.tuning_run( @@ -975,7 +979,8 @@ def run_specific_target(flags, configs, target_abi, model_output_dir, subgraphs[0][YAMLKeyword.input_tensors], subgraphs[0][YAMLKeyword.input_shapes], - subgraphs[0][YAMLKeyword.validation_inputs_data]) + subgraphs[0][YAMLKeyword.validation_inputs_data], + input_ranges=subgraphs[0].get(YAMLKeyword.input_ranges, None)) runtime_list = [] if target_abi == ABIType.host: runtime_list.extend([RuntimeType.cpu]) @@ -1123,7 +1128,8 @@ def bm_specific_target(flags, configs, target_abi, target_soc, serial_num): model_output_dir, subgraphs[0][YAMLKeyword.input_tensors], subgraphs[0][YAMLKeyword.input_shapes], - subgraphs[0][YAMLKeyword.validation_inputs_data]) + subgraphs[0][YAMLKeyword.validation_inputs_data], + input_ranges=subgraphs[0].get(YAMLKeyword.input_ranges, None)) runtime_list = [] if target_abi == ABIType.host: runtime_list.extend([RuntimeType.cpu]) diff --git a/tools/generate_data.py b/tools/generate_data.py index 747f0c23..d62297cc 100644 --- a/tools/generate_data.py +++ b/tools/generate_data.py @@ -23,25 +23,34 @@ import common # python generate_data.py \ # --input_node input_node \ # --input_shape 1,64,64,3 \ -# --input_file input_file -# +# --input_file input_file \ +# --input_ranges -1,1 -def generate_data(name, shape, input_file): +def generate_data(name, shape, input_file, tensor_range): np.random.seed() - data = np.random.random(shape) * 2 - 1 + data = np.random.random(shape) * (tensor_range[1] - tensor_range[0]) \ + + tensor_range[0] input_file_name = common.formatted_file_name(input_file, name) print 'Generate input file: ', input_file_name data.astype(np.float32).tofile(input_file_name) -def generate_input_data(input_file, input_node, input_shape): +def generate_input_data(input_file, input_node, input_shape, input_ranges): input_names = [name for name in input_node.split(',')] input_shapes = [shape for shape in input_shape.split(':')] + if input_ranges: + input_ranges = [r for r in input_ranges.split(':')] + else: + input_ranges = None assert len(input_names) == len(input_shapes) for i in range(len(input_names)): shape = [int(x) for x in input_shapes[i].split(',')] - generate_data(input_names[i], shape, input_file) + if input_ranges: + input_range = [float(x) for x in input_ranges[i].split(',')] + else: + input_range = [-1, 1] + generate_data(input_names[i], shape, input_file, input_range) print "Generate input file done." @@ -55,10 +64,13 @@ def parse_args(): "--input_node", type=str, default="input_node", help="input node") parser.add_argument( "--input_shape", type=str, default="1,64,64,3", help="input shape.") + parser.add_argument( + "--input_ranges", type=str, default="-1,1", help="input range.") return parser.parse_known_args() if __name__ == '__main__': FLAGS, unparsed = parse_args() - generate_input_data(FLAGS.input_file, FLAGS.input_node, FLAGS.input_shape) + generate_input_data(FLAGS.input_file, FLAGS.input_node, FLAGS.input_shape, + FLAGS.input_ranges) -- GitLab