diff --git a/mindspore/lite/schema/ops.fbs b/mindspore/lite/schema/ops.fbs index 697727862203687921f74b3b507a828083108fb4..ebb82b7f2e7e7671730d09c138a29c5b150ebd04 100644 --- a/mindspore/lite/schema/ops.fbs +++ b/mindspore/lite/schema/ops.fbs @@ -457,6 +457,7 @@ table Min { table Slice { format: Format = 0; + axes: [int]; begin: [int]; size: [int]; } diff --git a/mindspore/lite/src/ops/gather.cc b/mindspore/lite/src/ops/gather.cc index 837d4a5840429e9a74262c4daebe0e47eca710e0..a899bbb0a61b52e0f05644761aac53e6e30b1154 100644 --- a/mindspore/lite/src/ops/gather.cc +++ b/mindspore/lite/src/ops/gather.cc @@ -65,10 +65,6 @@ int Gather::InferShape(std::vector inputs_, std::vectorshape(); int indices_rank = indices_shape.size(); - if (indices_rank < batch_dims + 1) { - MS_LOG(ERROR) << "input[1]'s rank is less than batchDim + 1"; - return RET_ERROR; - } if (batch_dims != 0) { MS_LOG(ERROR) << "batchDims " << batch_dims << " != 0, which is not support"; return RET_ERROR; diff --git a/mindspore/lite/src/ops/shape.cc b/mindspore/lite/src/ops/shape.cc index 2680cf03e3a43e41714b4238090cbd6ea49225a4..a6b8098b2c54d49e3dc0b3dea85eb31a6ac79bed 100644 --- a/mindspore/lite/src/ops/shape.cc +++ b/mindspore/lite/src/ops/shape.cc @@ -38,6 +38,7 @@ int Shape::InferShape(std::vector inputs_, std::vectorset_data_type(kNumberTypeInt32); + out_tensor->SetFormat(schema::Format_NHWC); if (!GetInferFlag()) { return RET_OK; } diff --git a/mindspore/lite/src/ops/slice.cc b/mindspore/lite/src/ops/slice.cc index 38135a7d593a659b87015b179fba7dd9ffe219bb..3f90bcd147946a2705ceeee01609a279a3b9d2f7 100644 --- a/mindspore/lite/src/ops/slice.cc +++ b/mindspore/lite/src/ops/slice.cc @@ -29,6 +29,7 @@ constexpr int kSliceOutputNum = 1; int Slice::GetFormat() const { return this->primitive_->value.AsSlice()->format; } std::vector Slice::GetBegin() const { return this->primitive_->value.AsSlice()->begin; } std::vector Slice::GetSize() const { return this->primitive_->value.AsSlice()->size; } +std::vector Slice::GetAxes() const { return this->primitive_->value.AsSlice()->axes; } void Slice::SetFormat(int format) { this->primitive_->value.AsSlice()->format = (schema::Format)format; } void Slice::SetBegin(const std::vector &begin) { this->primitive_->value.AsSlice()->begin = begin; } @@ -45,9 +46,14 @@ std::vector Slice::GetSize() const { auto fb_vector = this->primitive_->value_as_Slice()->size(); return std::vector(fb_vector->begin(), fb_vector->end()); } - +std::vector Slice::GetAxes() const { + auto fb_vector = this->primitive_->value_as_Slice()->axes(); + return std::vector(fb_vector->begin(), fb_vector->end()); +} #endif +std::vector Slice::GetPostProcessBegin() const { return this->begin; } +std::vector Slice::GetPostProcessSize() const { return this->size; } int Slice::InferShape(std::vector inputs, std::vector outputs) { MS_ASSERT(this->primitive_ != nullptr); if (inputs.size() != kSliceInputNum || outputs.size() != kSliceOutputNum) { @@ -61,30 +67,37 @@ int Slice::InferShape(std::vector inputs, std::vector
  • shape(); - std::vector slice_begin(GetBegin().begin(), GetBegin().end()); - std::vector slice_size(GetSize().begin(), GetSize().end()); + std::vector slice_begin(GetBegin()); + std::vector slice_size(GetSize()); + std::vector slice_axes(GetAxes()); std::vector output_shape(input_shape.size()); + begin.assign(input_shape.size(), 0); + size.assign(input_shape.size(), -1); + for (size_t i = 0; i < slice_axes.size(); ++i) { + begin[slice_axes[i]] = slice_begin[i]; + size[slice_axes[i]] = slice_size[i]; + } for (size_t i = 0; i < input_shape.size(); ++i) { - if (slice_size[i] < 0 && slice_size[i] != -1) { - MS_LOG(ERROR) << "Invalid size input!size[" << i << "]=" << slice_size[i]; + if (size[i] < 0 && size[i] != -1) { + MS_LOG(ERROR) << "Invalid size input!size[" << i << "]=" << size[i]; return RET_PARAM_INVALID; } - if (slice_begin[i] < 0) { - MS_LOG(ERROR) << "Invalid begin input " << slice_begin[i] << " which should be >= 0"; + if (begin[i] < 0) { + MS_LOG(ERROR) << "Invalid begin input " << begin[i] << " which should be >= 0"; return RET_PARAM_INVALID; } - if (input_shape[i] <= slice_begin[i]) { - MS_LOG(ERROR) << "Invalid begin input!begin[" << i << "]=" << slice_begin[i] + if (input_shape[i] <= begin[i]) { + MS_LOG(ERROR) << "Invalid begin input!begin[" << i << "]=" << begin[i] << " which should be <= " << input_shape[i]; return RET_PARAM_INVALID; } - if (slice_size[i] > (input_shape[i] - slice_begin[i])) { - MS_LOG(ERROR) << "Invalid size input " << slice_size[i] - << " which should be <= " << input_shape[i] - slice_begin[i]; + if (size[i] > (input_shape[i] - begin[i])) { + MS_LOG(ERROR) << "Invalid size input " << size[i] + << " which should be <= " << input_shape[i] - begin[i]; return RET_PARAM_INVALID; } - output_shape[i] = slice_size[i] < 0 ? input_shape[i] - slice_begin[i] : slice_size[i]; + output_shape[i] = size[i] < 0 ? input_shape[i] - begin[i] : size[i]; } outputs[0]->set_shape(output_shape); diff --git a/mindspore/lite/src/ops/slice.h b/mindspore/lite/src/ops/slice.h index 589135338e4d5dad9c9e339e28a824b749e9f890..1bf67ae798875213de7474d47481be79f2479fa4 100644 --- a/mindspore/lite/src/ops/slice.h +++ b/mindspore/lite/src/ops/slice.h @@ -41,6 +41,14 @@ class Slice : public PrimitiveC { int GetFormat() const; std::vector GetBegin() const; std::vector GetSize() const; + std::vector GetAxes() const; + // due to difference between tflite and onnx, when inferring shape, construct new parameters of begin and size. + // when running graph, we need to obtain new begins and sizes using the two function as below. + std::vector GetPostProcessBegin() const; + std::vector GetPostProcessSize() const; + protected: + std::vector begin = {0}; + std::vector size = {-1}; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/populate_parameter.cc b/mindspore/lite/src/populate_parameter.cc index ab1ab98a6350bd276bd72c8ac33f66b10a29a321..31c43fe55ce7ef0f65a493f2fe7c6de1275d6448 100644 --- a/mindspore/lite/src/populate_parameter.cc +++ b/mindspore/lite/src/populate_parameter.cc @@ -1010,8 +1010,8 @@ OpParameter *PopulateSliceParameter(const mindspore::lite::PrimitiveC *primitive memset(slice_param, 0, sizeof(SliceParameter)); auto param = reinterpret_cast(const_cast(primitive)); slice_param->op_parameter_.type_ = primitive->Type(); - auto param_begin = param->GetBegin(); - auto param_size = param->GetSize(); + auto param_begin = param->GetPostProcessBegin(); + auto param_size = param->GetPostProcessSize(); if (param_begin.size() != param_size.size()) { free(slice_param); return nullptr; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/slice.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/slice.cc index dccbd7a40aecabe78d8a674cd597ac1e8f8b8ca9..7e5fb0b28f7582c830df31a9be98ec678014154e 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/slice.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/slice.cc @@ -20,6 +20,7 @@ #include "nnacl/fp32/slice.h" #include "include/errorcode.h" #include "src/runtime/runtime_api.h" +#include "src/ops/slice.h" using mindspore::lite::KernelRegistrar; using mindspore::lite::RET_ERROR; @@ -40,7 +41,15 @@ int SliceLaunch(void *cdata, int task_id) { } // namespace int SliceCPUKernel::ReSize() { - auto *param = reinterpret_cast(op_parameter_); + auto primitive_slice = reinterpret_cast(primitive_); + auto begin = primitive_slice->GetPostProcessBegin(); + auto size = primitive_slice->GetPostProcessSize(); + auto param = reinterpret_cast(op_parameter_); + param->param_length_ = in_tensors_[0]->shape().size(); + for (int i = 0; i < param->param_length_; ++i) { + param->begin_[i] = begin[i]; + param->size_[i] = size[i]; + } auto input_shape = in_tensors_[0]->shape(); if (static_cast(input_shape.size()) != param->param_length_) { MS_LOG(ERROR) << "Input begin's lenth " << param->param_length_ << "is not equal to input shape size " diff --git a/mindspore/lite/tools/converter/converter_flags.cc b/mindspore/lite/tools/converter/converter_flags.cc index 4059ef12a4d80f7c2dced701a4bcb99a49561d3c..b0f2ec605f10a8d5201f8973704060049705ee6d 100644 --- a/mindspore/lite/tools/converter/converter_flags.cc +++ b/mindspore/lite/tools/converter/converter_flags.cc @@ -24,8 +24,8 @@ namespace lite { namespace converter { Flags::Flags() { AddFlag(&Flags::fmkIn, "fmk", "Input model framework type. TFLITE | CAFFE | MS", ""); - AddFlag(&Flags::modelFile, "modelFile", "Input model file path. TFLITE: *.tflite | CAFFE: *.prototxt | MS: *.mindir", - ""); + AddFlag(&Flags::modelFile, "modelFile", + "Input model file path. TFLITE: *.tflite | CAFFE: *.prototxt | MS: *.mindir | ONNX: *.onnx", ""); AddFlag(&Flags::outputFile, "outputFile", "Output model file path. Will add .ms automatically", ""); AddFlag(&Flags::weightFile, "weightFile", "Input model weight file path. Needed when fmk is CAFFE. CAFFE: *.caffemodel", ""); @@ -41,6 +41,10 @@ Flags::Flags() { } int Flags::Init(int argc, const char **argv) { + if (argc == 1) { + std::cout << this->Usage() << std::endl; + return RET_SUCCESS_EXIT; + } Option err = this->ParseFlags(argc, argv); if (err.IsSome()) { diff --git a/mindspore/lite/tools/converter/legacy_optimizer/fusion/mul_add_fusion_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/fusion/mul_add_fusion_pass.cc index fdc6f530d6e47810843db60bcffad4e9949d5e1b..a2193866b594780d9819bcd761ba6cfdf84d9fed 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/fusion/mul_add_fusion_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/fusion/mul_add_fusion_pass.cc @@ -121,7 +121,8 @@ STATUS MulAddFusionPass::AddNewScaleNode(MetaGraphT *graph, const std::unique_pt return RET_ERROR; } // NHWC - scaleParam->axis = -1; + int shape_size = graph->allTensors.at(addBiasIndex)->dims.size(); + scaleParam->axis = 0 - shape_size; mulNode->primitive->value.value = scaleParam.release(); mulNode->inputIndex.push_back(addBiasIndex); if (addNode->primitive->value.AsAdd()->activationType != ActivationType_NO_ACTIVATION) { diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_slice_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_slice_parser.cc index 91f7fd2e4aeb9105c3085131fcb7c2dd00cc0281..d89c5dcb54385d08d9a7a0507a1cd560f8e045f5 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_slice_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_slice_parser.cc @@ -38,22 +38,38 @@ STATUS OnnxSliceParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::No return RET_NULL_PTR; } + std::vector axes; + std::vector starts; + std::vector ends; for (const auto &onnx_node_attr : onnx_node.attribute()) { const auto &attribute_name = onnx_node_attr.name(); if (attribute_name == "starts") { - const int size = onnx_node_attr.ints_size(); - MS_LOG(INFO) << "SLICE starts size " << size; - for (int i = 0; i < size; ++i) { - attr->begin.emplace_back(static_cast(onnx_node_attr.ints(i))); + const int num = onnx_node_attr.ints_size(); + starts.clear(); + for (int i = 0; i < num; ++i) { + starts.push_back(static_cast(onnx_node_attr.ints()[i])); + } + } else if (attribute_name == "axes") { + const int num = onnx_node_attr.ints_size(); + axes.clear(); + for (int i = 0; i < num; ++i) { + axes.push_back(static_cast(onnx_node_attr.ints()[i])); } } else if (attribute_name == "ends") { - const int size = onnx_node_attr.ints_size(); - for (int i = 0; i < size; ++i) { - attr->size.emplace_back(static_cast(onnx_node_attr.ints(i))); + const int num = onnx_node_attr.ints_size(); + ends.clear(); + for (int i = 0; i < num; ++i) { + ends.push_back(static_cast(onnx_node_attr.ints()[i])); } } } - + std::vector sizes(starts.size(), -1); + for (size_t i = 0; i < starts.size(); ++i) { + sizes[i] = (ends[i] < 0 ? ends[i] : ends[i] - starts[i]); + } + attr->axes = axes; + attr->begin = starts; + attr->size = sizes; op->primitive->value.type = schema::PrimitiveType_Slice; op->primitive->value.value = attr.release(); return RET_OK; diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_slice_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_slice_parser.cc index 12e1afaea6296692c737abc31bef2da151f26ff8..571f6ad4d136fe4d0c0a3fbe24cd327a2d9f01bb 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_slice_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_slice_parser.cc @@ -55,7 +55,12 @@ STATUS TfliteSliceParser::Parse(const std::unique_ptr &tflite MS_LOG(ERROR) << "get slice -> size failed"; return RET_ERROR; } - + std::vector axes; + axes.clear(); + for (size_t i = 0; i < attr->begin.size(); ++i) { + axes.push_back(i); + } + attr->axes = axes; op->primitive->value.type = schema::PrimitiveType_Slice; op->primitive->value.value = attr.release(); diff --git a/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc b/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc index 0915b2afcd5970dee88c21129bb2436c27a00e63..55537d41b82a8484be8ba294febac7bf115dcc1f 100644 --- a/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc @@ -72,8 +72,7 @@ const std::vector GetCNodeInputTensors(const CNodePtr &CNode) { } const ParameterPtr CreateNewParamter(const FuncGraphPtr &func_graph, Tensor *tensor) { auto parameter = func_graph->add_parameter(); - std::vector shape; - std::copy(tensor->shape().begin(), tensor->shape().end(), std::back_inserter(shape)); + std::vector shape(tensor->shape()); auto type_id = static_cast(tensor->data_type()); auto type_ptr = TypeIdToType(type_id); auto abstract_tensor = std::make_shared(type_ptr, shape); @@ -160,6 +159,15 @@ const AnfNodePtr ConstFoldPass::Process(const FuncGraphPtr &func_graph, const An MS_LOG(ERROR) << "lite_primitive is nullptr"; return nullptr; } + // here, input_tensor's format need to be transposed nhwc according to fmkType, + // but for the time being, we only transpose the tensor with 0/1/2/3D. + // Others should be added in future. + for (size_t j = 0; j < input_tensors.size(); ++j) { + input_tensors[j]->SetFormat(schema::Format_NHWC); + if (input_tensors[j]->shape().size() == 4) { + MS_LOG(WARNING) << "init input_tensor format to nhwc"; + } + } lite_primitive->InferShape(input_tensors, output_tensors); auto lite_kernel = GetLiteKernel(input_tensors, output_tensors, lite_primitive.get()); if (lite_kernel == nullptr) {