“ca8b7a2760be58670ad9388fddc99a1f4bb0d88a”上不存在“projects/yangchaoy259189888”
提交 020455ce 编写于 作者: L luxuhui

fix bug on reshape op & tar command & merge_duplicate_nodes

issue:593
Signed-off-by: NLuxuhui <luxuhui@xiaomi.com>
上级 be149a74
......@@ -1035,6 +1035,8 @@ class ReduceOp<DeviceType::GPU, float> : public ReduceOpBase {
void RegisterReduce(OpRegistryBase *op_registry) {
MACE_REGISTER_OP(op_registry, "Reduce", ReduceOp,
DeviceType::CPU, float);
MACE_REGISTER_OP(op_registry, "Reduce", ReduceOp,
DeviceType::CPU, int);
#ifdef MACE_ENABLE_QUANTIZE
MACE_REGISTER_OP(op_registry, "Reduce", ReduceOp,
DeviceType::CPU, uint8_t);
......
......@@ -74,7 +74,7 @@ template <DeviceType D, class T>
class ReshapeOp : public Operation {
public:
explicit ReshapeOp(OpConstructContext *context)
: Operation(context),
: Operation(context), dim_(Operation::GetRepeatedArgs<int>("dim")),
has_df_(Operation::GetOptionalArg<int>("has_data_format", 0)) {}
MaceStatus Run(OpContext *context) override {
......@@ -85,18 +85,21 @@ class ReshapeOp : public Operation {
const int32_t *shape_data = shape->data<int32_t>();
const index_t num_dims = shape->dim_size() == 0 ? 0 : shape->dim(0);
std::vector<index_t> out_shape;
MACE_RETURN_IF_ERROR(
GetOutputShape(input, shape_data, num_dims, &out_shape));
// NHWC -> NCHW
if (has_df_ && D == DeviceType::CPU && out_shape.size() == 4 &&
shape->is_weight()) {
std::vector<int32_t> trans_shape_data(shape_data,
shape_data + shape->size());
if (has_df_ && D == DeviceType::CPU && shape->dim_size() == 4 &&
out_shape.size() == 4 && dim_.size() == 4) {
std::vector<int> dst_dims = {0, 3, 1, 2};
std::vector<index_t> trans_shape =
TransposeShape<index_t, index_t>(out_shape, dst_dims);
out_shape = trans_shape;
std::vector<int32_t> tmp_shape =
TransposeShape<int32_t , int32_t>(trans_shape_data, dst_dims);
trans_shape_data = tmp_shape;
}
MACE_RETURN_IF_ERROR(
GetOutputShape(input, trans_shape_data.data(), num_dims, &out_shape));
Tensor *output = this->Output(OUTPUT);
output->ReuseTensorBuffer(*input);
output->Reshape(out_shape);
......@@ -105,6 +108,7 @@ class ReshapeOp : public Operation {
}
private:
std::vector<int> dim_;
bool has_df_;
private:
......
......@@ -142,7 +142,6 @@ TFTransformGraphOptions = [
'fold_old_batch_norms',
'remove_control_dependencies',
'strip_unused_nodes',
'merge_duplicate_nodes',
'sort_by_execution_order'
]
......
......@@ -1395,7 +1395,8 @@ class Transformer(base_converter.ConverterInterface):
if op.type == MaceOp.Reshape:
input_op = self._producer[op.input[0]]
out_dims_len = len(op.output_shape[0].dims)
if len(input_op.output_shape[0].dims) != 4 \
if len(input_op.output_shape) != 1 or \
len(input_op.output_shape[0].dims) != 4 \
or (out_dims_len != 4 and out_dims_len != 2):
print("In this model, reshape is not transposable op.")
return False
......
......@@ -774,15 +774,7 @@ def packaging_lib(libmace_output_dir, project_name):
six.print_("Start packaging '%s' libs into %s" % (project_name,
tar_package_path))
which_sys = platform.system()
if which_sys == "Linux":
sh.tar(
"cvzf",
"%s" % tar_package_path,
glob.glob("%s/*" % project_dir),
"--exclude",
"%s/_tmp" % project_dir,
_fg=True)
elif which_sys == "Darwin":
if which_sys == "Linux" or which_sys == "Darwin":
sh.tar(
"--exclude",
"%s/_tmp" % project_dir,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册