提交 020455ce 编写于 作者: L luxuhui

fix bug on reshape op & tar command & merge_duplicate_nodes

issue:593
Signed-off-by: NLuxuhui <luxuhui@xiaomi.com>
上级 be149a74
...@@ -1035,6 +1035,8 @@ class ReduceOp<DeviceType::GPU, float> : public ReduceOpBase { ...@@ -1035,6 +1035,8 @@ class ReduceOp<DeviceType::GPU, float> : public ReduceOpBase {
void RegisterReduce(OpRegistryBase *op_registry) { void RegisterReduce(OpRegistryBase *op_registry) {
MACE_REGISTER_OP(op_registry, "Reduce", ReduceOp, MACE_REGISTER_OP(op_registry, "Reduce", ReduceOp,
DeviceType::CPU, float); DeviceType::CPU, float);
MACE_REGISTER_OP(op_registry, "Reduce", ReduceOp,
DeviceType::CPU, int);
#ifdef MACE_ENABLE_QUANTIZE #ifdef MACE_ENABLE_QUANTIZE
MACE_REGISTER_OP(op_registry, "Reduce", ReduceOp, MACE_REGISTER_OP(op_registry, "Reduce", ReduceOp,
DeviceType::CPU, uint8_t); DeviceType::CPU, uint8_t);
......
...@@ -74,7 +74,7 @@ template <DeviceType D, class T> ...@@ -74,7 +74,7 @@ template <DeviceType D, class T>
class ReshapeOp : public Operation { class ReshapeOp : public Operation {
public: public:
explicit ReshapeOp(OpConstructContext *context) explicit ReshapeOp(OpConstructContext *context)
: Operation(context), : Operation(context), dim_(Operation::GetRepeatedArgs<int>("dim")),
has_df_(Operation::GetOptionalArg<int>("has_data_format", 0)) {} has_df_(Operation::GetOptionalArg<int>("has_data_format", 0)) {}
MaceStatus Run(OpContext *context) override { MaceStatus Run(OpContext *context) override {
...@@ -85,18 +85,21 @@ class ReshapeOp : public Operation { ...@@ -85,18 +85,21 @@ class ReshapeOp : public Operation {
const int32_t *shape_data = shape->data<int32_t>(); const int32_t *shape_data = shape->data<int32_t>();
const index_t num_dims = shape->dim_size() == 0 ? 0 : shape->dim(0); const index_t num_dims = shape->dim_size() == 0 ? 0 : shape->dim(0);
std::vector<index_t> out_shape; std::vector<index_t> out_shape;
MACE_RETURN_IF_ERROR(
GetOutputShape(input, shape_data, num_dims, &out_shape));
// NHWC -> NCHW // NHWC -> NCHW
if (has_df_ && D == DeviceType::CPU && out_shape.size() == 4 && std::vector<int32_t> trans_shape_data(shape_data,
shape->is_weight()) { shape_data + shape->size());
if (has_df_ && D == DeviceType::CPU && shape->dim_size() == 4 &&
out_shape.size() == 4 && dim_.size() == 4) {
std::vector<int> dst_dims = {0, 3, 1, 2}; std::vector<int> dst_dims = {0, 3, 1, 2};
std::vector<index_t> trans_shape = std::vector<int32_t> tmp_shape =
TransposeShape<index_t, index_t>(out_shape, dst_dims); TransposeShape<int32_t , int32_t>(trans_shape_data, dst_dims);
out_shape = trans_shape; trans_shape_data = tmp_shape;
} }
MACE_RETURN_IF_ERROR(
GetOutputShape(input, trans_shape_data.data(), num_dims, &out_shape));
Tensor *output = this->Output(OUTPUT); Tensor *output = this->Output(OUTPUT);
output->ReuseTensorBuffer(*input); output->ReuseTensorBuffer(*input);
output->Reshape(out_shape); output->Reshape(out_shape);
...@@ -105,6 +108,7 @@ class ReshapeOp : public Operation { ...@@ -105,6 +108,7 @@ class ReshapeOp : public Operation {
} }
private: private:
std::vector<int> dim_;
bool has_df_; bool has_df_;
private: private:
......
...@@ -142,7 +142,6 @@ TFTransformGraphOptions = [ ...@@ -142,7 +142,6 @@ TFTransformGraphOptions = [
'fold_old_batch_norms', 'fold_old_batch_norms',
'remove_control_dependencies', 'remove_control_dependencies',
'strip_unused_nodes', 'strip_unused_nodes',
'merge_duplicate_nodes',
'sort_by_execution_order' 'sort_by_execution_order'
] ]
......
...@@ -1395,7 +1395,8 @@ class Transformer(base_converter.ConverterInterface): ...@@ -1395,7 +1395,8 @@ class Transformer(base_converter.ConverterInterface):
if op.type == MaceOp.Reshape: if op.type == MaceOp.Reshape:
input_op = self._producer[op.input[0]] input_op = self._producer[op.input[0]]
out_dims_len = len(op.output_shape[0].dims) out_dims_len = len(op.output_shape[0].dims)
if len(input_op.output_shape[0].dims) != 4 \ if len(input_op.output_shape) != 1 or \
len(input_op.output_shape[0].dims) != 4 \
or (out_dims_len != 4 and out_dims_len != 2): or (out_dims_len != 4 and out_dims_len != 2):
print("In this model, reshape is not transposable op.") print("In this model, reshape is not transposable op.")
return False return False
......
...@@ -774,15 +774,7 @@ def packaging_lib(libmace_output_dir, project_name): ...@@ -774,15 +774,7 @@ def packaging_lib(libmace_output_dir, project_name):
six.print_("Start packaging '%s' libs into %s" % (project_name, six.print_("Start packaging '%s' libs into %s" % (project_name,
tar_package_path)) tar_package_path))
which_sys = platform.system() which_sys = platform.system()
if which_sys == "Linux": if which_sys == "Linux" or which_sys == "Darwin":
sh.tar(
"cvzf",
"%s" % tar_package_path,
glob.glob("%s/*" % project_dir),
"--exclude",
"%s/_tmp" % project_dir,
_fg=True)
elif which_sys == "Darwin":
sh.tar( sh.tar(
"--exclude", "--exclude",
"%s/_tmp" % project_dir, "%s/_tmp" % project_dir,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册