From 020455ced80b33c6c47b50b45a188a8d30f84654 Mon Sep 17 00:00:00 2001 From: luxuhui Date: Tue, 18 Feb 2020 22:26:03 +0800 Subject: [PATCH] fix bug on reshape op & tar command & merge_duplicate_nodes issue:593 Signed-off-by: Luxuhui --- mace/ops/reduce.cc | 2 ++ mace/ops/reshape.cc | 20 +++++++++++-------- .../python/transform/tensorflow_converter.py | 1 - tools/python/transform/transformer.py | 3 ++- tools/sh_commands.py | 10 +--------- 5 files changed, 17 insertions(+), 19 deletions(-) diff --git a/mace/ops/reduce.cc b/mace/ops/reduce.cc index cd78320c..7c34db3e 100644 --- a/mace/ops/reduce.cc +++ b/mace/ops/reduce.cc @@ -1035,6 +1035,8 @@ class ReduceOp : public ReduceOpBase { void RegisterReduce(OpRegistryBase *op_registry) { MACE_REGISTER_OP(op_registry, "Reduce", ReduceOp, DeviceType::CPU, float); + MACE_REGISTER_OP(op_registry, "Reduce", ReduceOp, + DeviceType::CPU, int); #ifdef MACE_ENABLE_QUANTIZE MACE_REGISTER_OP(op_registry, "Reduce", ReduceOp, DeviceType::CPU, uint8_t); diff --git a/mace/ops/reshape.cc b/mace/ops/reshape.cc index 8e469b26..94561720 100644 --- a/mace/ops/reshape.cc +++ b/mace/ops/reshape.cc @@ -74,7 +74,7 @@ template class ReshapeOp : public Operation { public: explicit ReshapeOp(OpConstructContext *context) - : Operation(context), + : Operation(context), dim_(Operation::GetRepeatedArgs("dim")), has_df_(Operation::GetOptionalArg("has_data_format", 0)) {} MaceStatus Run(OpContext *context) override { @@ -85,18 +85,21 @@ class ReshapeOp : public Operation { const int32_t *shape_data = shape->data(); const index_t num_dims = shape->dim_size() == 0 ? 0 : shape->dim(0); std::vector out_shape; - MACE_RETURN_IF_ERROR( - GetOutputShape(input, shape_data, num_dims, &out_shape)); // NHWC -> NCHW - if (has_df_ && D == DeviceType::CPU && out_shape.size() == 4 && - shape->is_weight()) { + std::vector trans_shape_data(shape_data, + shape_data + shape->size()); + if (has_df_ && D == DeviceType::CPU && shape->dim_size() == 4 && + out_shape.size() == 4 && dim_.size() == 4) { std::vector dst_dims = {0, 3, 1, 2}; - std::vector trans_shape = - TransposeShape(out_shape, dst_dims); - out_shape = trans_shape; + std::vector tmp_shape = + TransposeShape(trans_shape_data, dst_dims); + trans_shape_data = tmp_shape; } + MACE_RETURN_IF_ERROR( + GetOutputShape(input, trans_shape_data.data(), num_dims, &out_shape)); + Tensor *output = this->Output(OUTPUT); output->ReuseTensorBuffer(*input); output->Reshape(out_shape); @@ -105,6 +108,7 @@ class ReshapeOp : public Operation { } private: + std::vector dim_; bool has_df_; private: diff --git a/tools/python/transform/tensorflow_converter.py b/tools/python/transform/tensorflow_converter.py index ed516bca..73a62dd7 100644 --- a/tools/python/transform/tensorflow_converter.py +++ b/tools/python/transform/tensorflow_converter.py @@ -142,7 +142,6 @@ TFTransformGraphOptions = [ 'fold_old_batch_norms', 'remove_control_dependencies', 'strip_unused_nodes', - 'merge_duplicate_nodes', 'sort_by_execution_order' ] diff --git a/tools/python/transform/transformer.py b/tools/python/transform/transformer.py index 69411e41..9b26d487 100644 --- a/tools/python/transform/transformer.py +++ b/tools/python/transform/transformer.py @@ -1395,7 +1395,8 @@ class Transformer(base_converter.ConverterInterface): if op.type == MaceOp.Reshape: input_op = self._producer[op.input[0]] out_dims_len = len(op.output_shape[0].dims) - if len(input_op.output_shape[0].dims) != 4 \ + if len(input_op.output_shape) != 1 or \ + len(input_op.output_shape[0].dims) != 4 \ or (out_dims_len != 4 and out_dims_len != 2): print("In this model, reshape is not transposable op.") return False diff --git a/tools/sh_commands.py b/tools/sh_commands.py index 95d93c5d..5c1c8569 100644 --- a/tools/sh_commands.py +++ b/tools/sh_commands.py @@ -774,15 +774,7 @@ def packaging_lib(libmace_output_dir, project_name): six.print_("Start packaging '%s' libs into %s" % (project_name, tar_package_path)) which_sys = platform.system() - if which_sys == "Linux": - sh.tar( - "cvzf", - "%s" % tar_package_path, - glob.glob("%s/*" % project_dir), - "--exclude", - "%s/_tmp" % project_dir, - _fg=True) - elif which_sys == "Darwin": + if which_sys == "Linux" or which_sys == "Darwin": sh.tar( "--exclude", "%s/_tmp" % project_dir, -- GitLab