diff --git a/mace/ops/reshape.cc b/mace/ops/reshape.cc index 33850faf57fe2332419c5157a4fc89205125ddfa..e578e69aec0f314e854c70fabd411c61bc3f81bc 100644 --- a/mace/ops/reshape.cc +++ b/mace/ops/reshape.cc @@ -23,16 +23,12 @@ template class ReshapeOp : public Operation { public: explicit ReshapeOp(OpConstructContext *context) - : Operation(context) {} + : Operation(context), + has_df_(Operation::GetOptionalArg("has_data_format", 0)) {} MaceStatus Run(OpContext *context) override { MACE_UNUSED(context); const Tensor *input = this->Input(INPUT); - const std::vector &input_shape = input->shape(); - int axis = Operation::GetOptionalArg("reshape_axis", 0); - int num_axes = Operation::GetOptionalArg("num_axes", -1); - MACE_CHECK(axis == 0 && num_axes == -1, - "Only support axis = 0 and num_axes = -1"); const Tensor *shape = this->Input(SHAPE); const index_t num_dims = shape->dim_size() == 0 ? 0 : shape->dim(0); Tensor::MappingGuard shape_guard(shape); @@ -40,20 +36,16 @@ class ReshapeOp : public Operation { int unknown_idx = -1; index_t product = 1; - std::vector out_shape; + std::vector out_shape(num_dims); index_t n = 0; for (int i = 0; i < num_dims; ++i) { if (shape_data[i] == -1) { MACE_CHECK(unknown_idx == -1, "Only one input size may be -1"); unknown_idx = i; - out_shape.push_back(1); - } else if (shape_data[i] == 0) { - MACE_CHECK(shape_data[i] == 0, "Shape should be 0"); - out_shape.push_back(input_shape[i]); - product *= input_shape[i]; + out_shape[i] = 1; } else { - MACE_CHECK(shape_data[i] > 0, "Shape must be non-negative: ", + MACE_CHECK(shape_data[i] >= 0, "Shape must be non-negative: ", shape_data[i]); if (shape_data[i] == 0) { MACE_CHECK(i < input->dim_size(), @@ -62,7 +54,7 @@ class ReshapeOp : public Operation { } else { n = shape_data[i]; } - out_shape.push_back(n); + out_shape[i] = n; product *= n; } } @@ -77,14 +69,13 @@ class ReshapeOp : public Operation { } Tensor *output = this->Output(OUTPUT); // NHWC -> NCHW - auto has_df = Operation::GetOptionalArg( - "has_data_format", 0); - if (has_df && D == DeviceType::CPU + + if (has_df_ && D == DeviceType::CPU && out_shape.size() == 4 && shape->is_weight()) { std::vector dst_dims = {0, 3, 1, 2}; - std::vector out_shape_gpu = TransposeShape( + std::vector trans_shape = TransposeShape( out_shape, dst_dims); - out_shape = out_shape_gpu; + out_shape = trans_shape; } output->ReuseTensorBuffer(*input); @@ -93,6 +84,9 @@ class ReshapeOp : public Operation { return MaceStatus::MACE_SUCCESS; } + private: + bool has_df_; + private: MACE_OP_INPUT_TAGS(INPUT, SHAPE); MACE_OP_OUTPUT_TAGS(OUTPUT);