提交 3d26fe4d 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!5220 onnx net and fix bug

Merge pull request !5220 from 徐安越/master
...@@ -457,6 +457,7 @@ table Min { ...@@ -457,6 +457,7 @@ table Min {
table Slice { table Slice {
format: Format = 0; format: Format = 0;
axes: [int];
begin: [int]; begin: [int];
size: [int]; size: [int];
} }
......
...@@ -65,10 +65,6 @@ int Gather::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor ...@@ -65,10 +65,6 @@ int Gather::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
} }
auto indices_shape = indices->shape(); auto indices_shape = indices->shape();
int indices_rank = indices_shape.size(); int indices_rank = indices_shape.size();
if (indices_rank < batch_dims + 1) {
MS_LOG(ERROR) << "input[1]'s rank is less than batchDim + 1";
return RET_ERROR;
}
if (batch_dims != 0) { if (batch_dims != 0) {
MS_LOG(ERROR) << "batchDims " << batch_dims << " != 0, which is not support"; MS_LOG(ERROR) << "batchDims " << batch_dims << " != 0, which is not support";
return RET_ERROR; return RET_ERROR;
......
...@@ -38,6 +38,7 @@ int Shape::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor: ...@@ -38,6 +38,7 @@ int Shape::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor:
auto in_tensor = inputs_.front(); auto in_tensor = inputs_.front();
auto out_tensor = outputs_.front(); auto out_tensor = outputs_.front();
out_tensor->set_data_type(kNumberTypeInt32); out_tensor->set_data_type(kNumberTypeInt32);
out_tensor->SetFormat(schema::Format_NHWC);
if (!GetInferFlag()) { if (!GetInferFlag()) {
return RET_OK; return RET_OK;
} }
......
...@@ -29,6 +29,7 @@ constexpr int kSliceOutputNum = 1; ...@@ -29,6 +29,7 @@ constexpr int kSliceOutputNum = 1;
int Slice::GetFormat() const { return this->primitive_->value.AsSlice()->format; } int Slice::GetFormat() const { return this->primitive_->value.AsSlice()->format; }
std::vector<int> Slice::GetBegin() const { return this->primitive_->value.AsSlice()->begin; } std::vector<int> Slice::GetBegin() const { return this->primitive_->value.AsSlice()->begin; }
std::vector<int> Slice::GetSize() const { return this->primitive_->value.AsSlice()->size; } std::vector<int> Slice::GetSize() const { return this->primitive_->value.AsSlice()->size; }
std::vector<int> Slice::GetAxes() const { return this->primitive_->value.AsSlice()->axes; }
void Slice::SetFormat(int format) { this->primitive_->value.AsSlice()->format = (schema::Format)format; } void Slice::SetFormat(int format) { this->primitive_->value.AsSlice()->format = (schema::Format)format; }
void Slice::SetBegin(const std::vector<int> &begin) { this->primitive_->value.AsSlice()->begin = begin; } void Slice::SetBegin(const std::vector<int> &begin) { this->primitive_->value.AsSlice()->begin = begin; }
...@@ -45,9 +46,14 @@ std::vector<int> Slice::GetSize() const { ...@@ -45,9 +46,14 @@ std::vector<int> Slice::GetSize() const {
auto fb_vector = this->primitive_->value_as_Slice()->size(); auto fb_vector = this->primitive_->value_as_Slice()->size();
return std::vector<int>(fb_vector->begin(), fb_vector->end()); return std::vector<int>(fb_vector->begin(), fb_vector->end());
} }
std::vector<int> Slice::GetAxes() const {
auto fb_vector = this->primitive_->value_as_Slice()->axes();
return std::vector<int>(fb_vector->begin(), fb_vector->end());
}
#endif #endif
std::vector<int> Slice::GetPostProcessBegin() const { return this->begin; }
std::vector<int> Slice::GetPostProcessSize() const { return this->size; }
int Slice::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::vector<lite::tensor::Tensor *> outputs) { int Slice::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::vector<lite::tensor::Tensor *> outputs) {
MS_ASSERT(this->primitive_ != nullptr); MS_ASSERT(this->primitive_ != nullptr);
if (inputs.size() != kSliceInputNum || outputs.size() != kSliceOutputNum) { if (inputs.size() != kSliceInputNum || outputs.size() != kSliceOutputNum) {
...@@ -61,30 +67,37 @@ int Slice::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::vector<li ...@@ -61,30 +67,37 @@ int Slice::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::vector<li
return RET_OK; return RET_OK;
} }
auto input_shape = input->shape(); auto input_shape = input->shape();
std::vector<int32_t> slice_begin(GetBegin().begin(), GetBegin().end()); std::vector<int32_t> slice_begin(GetBegin());
std::vector<int32_t> slice_size(GetSize().begin(), GetSize().end()); std::vector<int32_t> slice_size(GetSize());
std::vector<int32_t> slice_axes(GetAxes());
std::vector<int32_t> output_shape(input_shape.size()); std::vector<int32_t> output_shape(input_shape.size());
begin.assign(input_shape.size(), 0);
size.assign(input_shape.size(), -1);
for (size_t i = 0; i < slice_axes.size(); ++i) {
begin[slice_axes[i]] = slice_begin[i];
size[slice_axes[i]] = slice_size[i];
}
for (size_t i = 0; i < input_shape.size(); ++i) { for (size_t i = 0; i < input_shape.size(); ++i) {
if (slice_size[i] < 0 && slice_size[i] != -1) { if (size[i] < 0 && size[i] != -1) {
MS_LOG(ERROR) << "Invalid size input!size[" << i << "]=" << slice_size[i]; MS_LOG(ERROR) << "Invalid size input!size[" << i << "]=" << size[i];
return RET_PARAM_INVALID; return RET_PARAM_INVALID;
} }
if (slice_begin[i] < 0) { if (begin[i] < 0) {
MS_LOG(ERROR) << "Invalid begin input " << slice_begin[i] << " which should be >= 0"; MS_LOG(ERROR) << "Invalid begin input " << begin[i] << " which should be >= 0";
return RET_PARAM_INVALID; return RET_PARAM_INVALID;
} }
if (input_shape[i] <= slice_begin[i]) { if (input_shape[i] <= begin[i]) {
MS_LOG(ERROR) << "Invalid begin input!begin[" << i << "]=" << slice_begin[i] MS_LOG(ERROR) << "Invalid begin input!begin[" << i << "]=" << begin[i]
<< " which should be <= " << input_shape[i]; << " which should be <= " << input_shape[i];
return RET_PARAM_INVALID; return RET_PARAM_INVALID;
} }
if (slice_size[i] > (input_shape[i] - slice_begin[i])) { if (size[i] > (input_shape[i] - begin[i])) {
MS_LOG(ERROR) << "Invalid size input " << slice_size[i] MS_LOG(ERROR) << "Invalid size input " << size[i]
<< " which should be <= " << input_shape[i] - slice_begin[i]; << " which should be <= " << input_shape[i] - begin[i];
return RET_PARAM_INVALID; return RET_PARAM_INVALID;
} }
output_shape[i] = slice_size[i] < 0 ? input_shape[i] - slice_begin[i] : slice_size[i]; output_shape[i] = size[i] < 0 ? input_shape[i] - begin[i] : size[i];
} }
outputs[0]->set_shape(output_shape); outputs[0]->set_shape(output_shape);
......
...@@ -41,6 +41,14 @@ class Slice : public PrimitiveC { ...@@ -41,6 +41,14 @@ class Slice : public PrimitiveC {
int GetFormat() const; int GetFormat() const;
std::vector<int> GetBegin() const; std::vector<int> GetBegin() const;
std::vector<int> GetSize() const; std::vector<int> GetSize() const;
std::vector<int> GetAxes() const;
// due to difference between tflite and onnx, when inferring shape, construct new parameters of begin and size.
// when running graph, we need to obtain new begins and sizes using the two function as below.
std::vector<int> GetPostProcessBegin() const;
std::vector<int> GetPostProcessSize() const;
protected:
std::vector<int> begin = {0};
std::vector<int> size = {-1};
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore
......
...@@ -1010,8 +1010,8 @@ OpParameter *PopulateSliceParameter(const mindspore::lite::PrimitiveC *primitive ...@@ -1010,8 +1010,8 @@ OpParameter *PopulateSliceParameter(const mindspore::lite::PrimitiveC *primitive
memset(slice_param, 0, sizeof(SliceParameter)); memset(slice_param, 0, sizeof(SliceParameter));
auto param = reinterpret_cast<mindspore::lite::Slice *>(const_cast<mindspore::lite::PrimitiveC *>(primitive)); auto param = reinterpret_cast<mindspore::lite::Slice *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
slice_param->op_parameter_.type_ = primitive->Type(); slice_param->op_parameter_.type_ = primitive->Type();
auto param_begin = param->GetBegin(); auto param_begin = param->GetPostProcessBegin();
auto param_size = param->GetSize(); auto param_size = param->GetPostProcessSize();
if (param_begin.size() != param_size.size()) { if (param_begin.size() != param_size.size()) {
free(slice_param); free(slice_param);
return nullptr; return nullptr;
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include "nnacl/fp32/slice.h" #include "nnacl/fp32/slice.h"
#include "include/errorcode.h" #include "include/errorcode.h"
#include "src/runtime/runtime_api.h" #include "src/runtime/runtime_api.h"
#include "src/ops/slice.h"
using mindspore::lite::KernelRegistrar; using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR; using mindspore::lite::RET_ERROR;
...@@ -40,7 +41,15 @@ int SliceLaunch(void *cdata, int task_id) { ...@@ -40,7 +41,15 @@ int SliceLaunch(void *cdata, int task_id) {
} // namespace } // namespace
int SliceCPUKernel::ReSize() { int SliceCPUKernel::ReSize() {
auto *param = reinterpret_cast<SliceParameter *>(op_parameter_); auto primitive_slice = reinterpret_cast<const mindspore::lite::Slice *>(primitive_);
auto begin = primitive_slice->GetPostProcessBegin();
auto size = primitive_slice->GetPostProcessSize();
auto param = reinterpret_cast<SliceParameter *>(op_parameter_);
param->param_length_ = in_tensors_[0]->shape().size();
for (int i = 0; i < param->param_length_; ++i) {
param->begin_[i] = begin[i];
param->size_[i] = size[i];
}
auto input_shape = in_tensors_[0]->shape(); auto input_shape = in_tensors_[0]->shape();
if (static_cast<int>(input_shape.size()) != param->param_length_) { if (static_cast<int>(input_shape.size()) != param->param_length_) {
MS_LOG(ERROR) << "Input begin's lenth " << param->param_length_ << "is not equal to input shape size " MS_LOG(ERROR) << "Input begin's lenth " << param->param_length_ << "is not equal to input shape size "
......
...@@ -24,8 +24,8 @@ namespace lite { ...@@ -24,8 +24,8 @@ namespace lite {
namespace converter { namespace converter {
Flags::Flags() { Flags::Flags() {
AddFlag(&Flags::fmkIn, "fmk", "Input model framework type. TFLITE | CAFFE | MS", ""); AddFlag(&Flags::fmkIn, "fmk", "Input model framework type. TFLITE | CAFFE | MS", "");
AddFlag(&Flags::modelFile, "modelFile", "Input model file path. TFLITE: *.tflite | CAFFE: *.prototxt | MS: *.mindir", AddFlag(&Flags::modelFile, "modelFile",
""); "Input model file path. TFLITE: *.tflite | CAFFE: *.prototxt | MS: *.mindir | ONNX: *.onnx", "");
AddFlag(&Flags::outputFile, "outputFile", "Output model file path. Will add .ms automatically", ""); AddFlag(&Flags::outputFile, "outputFile", "Output model file path. Will add .ms automatically", "");
AddFlag(&Flags::weightFile, "weightFile", AddFlag(&Flags::weightFile, "weightFile",
"Input model weight file path. Needed when fmk is CAFFE. CAFFE: *.caffemodel", ""); "Input model weight file path. Needed when fmk is CAFFE. CAFFE: *.caffemodel", "");
...@@ -41,6 +41,10 @@ Flags::Flags() { ...@@ -41,6 +41,10 @@ Flags::Flags() {
} }
int Flags::Init(int argc, const char **argv) { int Flags::Init(int argc, const char **argv) {
if (argc == 1) {
std::cout << this->Usage() << std::endl;
return RET_SUCCESS_EXIT;
}
Option<std::string> err = this->ParseFlags(argc, argv); Option<std::string> err = this->ParseFlags(argc, argv);
if (err.IsSome()) { if (err.IsSome()) {
......
...@@ -121,7 +121,8 @@ STATUS MulAddFusionPass::AddNewScaleNode(MetaGraphT *graph, const std::unique_pt ...@@ -121,7 +121,8 @@ STATUS MulAddFusionPass::AddNewScaleNode(MetaGraphT *graph, const std::unique_pt
return RET_ERROR; return RET_ERROR;
} }
// NHWC // NHWC
scaleParam->axis = -1; int shape_size = graph->allTensors.at(addBiasIndex)->dims.size();
scaleParam->axis = 0 - shape_size;
mulNode->primitive->value.value = scaleParam.release(); mulNode->primitive->value.value = scaleParam.release();
mulNode->inputIndex.push_back(addBiasIndex); mulNode->inputIndex.push_back(addBiasIndex);
if (addNode->primitive->value.AsAdd()->activationType != ActivationType_NO_ACTIVATION) { if (addNode->primitive->value.AsAdd()->activationType != ActivationType_NO_ACTIVATION) {
......
...@@ -38,22 +38,38 @@ STATUS OnnxSliceParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::No ...@@ -38,22 +38,38 @@ STATUS OnnxSliceParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::No
return RET_NULL_PTR; return RET_NULL_PTR;
} }
std::vector<int> axes;
std::vector<int> starts;
std::vector<int> ends;
for (const auto &onnx_node_attr : onnx_node.attribute()) { for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name(); const auto &attribute_name = onnx_node_attr.name();
if (attribute_name == "starts") { if (attribute_name == "starts") {
const int size = onnx_node_attr.ints_size(); const int num = onnx_node_attr.ints_size();
MS_LOG(INFO) << "SLICE starts size " << size; starts.clear();
for (int i = 0; i < size; ++i) { for (int i = 0; i < num; ++i) {
attr->begin.emplace_back(static_cast<int32_t>(onnx_node_attr.ints(i))); starts.push_back(static_cast<int>(onnx_node_attr.ints()[i]));
}
} else if (attribute_name == "axes") {
const int num = onnx_node_attr.ints_size();
axes.clear();
for (int i = 0; i < num; ++i) {
axes.push_back(static_cast<int>(onnx_node_attr.ints()[i]));
} }
} else if (attribute_name == "ends") { } else if (attribute_name == "ends") {
const int size = onnx_node_attr.ints_size(); const int num = onnx_node_attr.ints_size();
for (int i = 0; i < size; ++i) { ends.clear();
attr->size.emplace_back(static_cast<int32_t>(onnx_node_attr.ints(i))); for (int i = 0; i < num; ++i) {
ends.push_back(static_cast<int>(onnx_node_attr.ints()[i]));
} }
} }
} }
std::vector<int> sizes(starts.size(), -1);
for (size_t i = 0; i < starts.size(); ++i) {
sizes[i] = (ends[i] < 0 ? ends[i] : ends[i] - starts[i]);
}
attr->axes = axes;
attr->begin = starts;
attr->size = sizes;
op->primitive->value.type = schema::PrimitiveType_Slice; op->primitive->value.type = schema::PrimitiveType_Slice;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();
return RET_OK; return RET_OK;
......
...@@ -55,7 +55,12 @@ STATUS TfliteSliceParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite ...@@ -55,7 +55,12 @@ STATUS TfliteSliceParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite
MS_LOG(ERROR) << "get slice -> size failed"; MS_LOG(ERROR) << "get slice -> size failed";
return RET_ERROR; return RET_ERROR;
} }
std::vector<int> axes;
axes.clear();
for (size_t i = 0; i < attr->begin.size(); ++i) {
axes.push_back(i);
}
attr->axes = axes;
op->primitive->value.type = schema::PrimitiveType_Slice; op->primitive->value.type = schema::PrimitiveType_Slice;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();
......
...@@ -72,8 +72,7 @@ const std::vector<Tensor *> GetCNodeInputTensors(const CNodePtr &CNode) { ...@@ -72,8 +72,7 @@ const std::vector<Tensor *> GetCNodeInputTensors(const CNodePtr &CNode) {
} }
const ParameterPtr CreateNewParamter(const FuncGraphPtr &func_graph, Tensor *tensor) { const ParameterPtr CreateNewParamter(const FuncGraphPtr &func_graph, Tensor *tensor) {
auto parameter = func_graph->add_parameter(); auto parameter = func_graph->add_parameter();
std::vector<int> shape; std::vector<int> shape(tensor->shape());
std::copy(tensor->shape().begin(), tensor->shape().end(), std::back_inserter(shape));
auto type_id = static_cast<TypeId>(tensor->data_type()); auto type_id = static_cast<TypeId>(tensor->data_type());
auto type_ptr = TypeIdToType(type_id); auto type_ptr = TypeIdToType(type_id);
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape); auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape);
...@@ -160,6 +159,15 @@ const AnfNodePtr ConstFoldPass::Process(const FuncGraphPtr &func_graph, const An ...@@ -160,6 +159,15 @@ const AnfNodePtr ConstFoldPass::Process(const FuncGraphPtr &func_graph, const An
MS_LOG(ERROR) << "lite_primitive is nullptr"; MS_LOG(ERROR) << "lite_primitive is nullptr";
return nullptr; return nullptr;
} }
// here, input_tensor's format need to be transposed nhwc according to fmkType,
// but for the time being, we only transpose the tensor with 0/1/2/3D.
// Others should be added in future.
for (size_t j = 0; j < input_tensors.size(); ++j) {
input_tensors[j]->SetFormat(schema::Format_NHWC);
if (input_tensors[j]->shape().size() == 4) {
MS_LOG(WARNING) << "init input_tensor format to nhwc";
}
}
lite_primitive->InferShape(input_tensors, output_tensors); lite_primitive->InferShape(input_tensors, output_tensors);
auto lite_kernel = GetLiteKernel(input_tensors, output_tensors, lite_primitive.get()); auto lite_kernel = GetLiteKernel(input_tensors, output_tensors, lite_primitive.get());
if (lite_kernel == nullptr) { if (lite_kernel == nullptr) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册