提交 49fd9fa9 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!4372 Modify the method for getting output index of metagraph.

Merge pull request !4372 from wangshaocong/lite
......@@ -58,7 +58,7 @@ int Reduce::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
for (size_t i = 0; i < in_shape.size(); i++) {
bool reduce_axis = false;
for (int idx = 0; idx < num_axes; ++idx) {
if (static_cast<size_t>((*axes)[idx]) == i) {
if (static_cast<size_t>((*axes)[idx]) == i || static_cast<size_t>((*axes)[idx] + in_shape.size()) == i) {
reduce_axis = true;
break;
}
......
......@@ -71,7 +71,7 @@ int ReduceCPUKernel::CheckParameters() {
return RET_ERROR;
}
for (auto i = 0; i < num_axes_; i++) {
if (axes_[i] < -static_cast<int>(input_rank) || static_cast<size_t>(axes_[i]) >= input_rank) {
if (axes_[i] < -static_cast<int>(input_rank) || axes_[i] >= static_cast<int>(input_rank)) {
MS_LOG(ERROR) << "Reduce got invalid axis " << axes_[i] << ", axis should be in ["
<< -static_cast<int>(input_rank) << ", " << input_rank - 1 << "].";
return RET_ERROR;
......
......@@ -236,18 +236,31 @@ void TfliteModelParser::SetInputTensor(const std::unique_ptr<tflite::SubGraphT>
}
}
void TfliteModelParser::SetGraphTensorIndex(const mindspore::lite::TensorCache &tensorCache,
void TfliteModelParser::SetGraphTensorIndex(const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph,
const std::unique_ptr<tflite::ModelT> &tflite_model,
const mindspore::lite::TensorCache &tensorCache,
schema::MetaGraphT *subGraphDef) {
auto opGraph = OpGraphT::Build(subGraphDef);
auto graphInputs = tensorCache.GetGraphInputs();
auto graphOutputs = opGraph->GetOutputNode();
subGraphDef->inputIndex.assign(graphInputs.begin(), graphInputs.end());
for (const auto &output : graphOutputs) {
auto op = opMap[output->ID()];
for (auto outputIndex : op->outputIndex) {
subGraphDef->outputIndex.emplace_back(outputIndex);
for (auto outputIndex : tflite_subgraph->outputs) {
int i = 0;
bool found = false;
for (const auto &tfliteOp : tflite_subgraph->operators) {
int j = 0;
auto opType = GetTfliteNodeType(tfliteOp, tflite_model);
std::string opName = opType + "-" + std::to_string(i++);
for (auto opOutputIndex : tfliteOp->outputs) {
if (outputIndex == opOutputIndex) {
subGraphDef->outputIndex.emplace_back(opMap[opName]->outputIndex[j]);
found = true;
break;
}
j++;
}
if (found) {
break;
}
}
}
}
......@@ -284,7 +297,7 @@ MetaGraphT *TfliteModelParser::Parse(const std::string &modelFile, const std::st
return nullptr;
}
SetGraphTensorIndex(tensorCache, subGraph.get());
SetGraphTensorIndex(tflite_subgraph, tflite_model, tensorCache, subGraph.get());
SetAllTensors(tensorCache, subGraph.get());
return subGraph.release();
}
......
......@@ -50,7 +50,10 @@ class TfliteModelParser : public ModelParser {
void SetInputTensor(const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, TensorCache *tensor_cache);
void SetGraphTensorIndex(const mindspore::lite::TensorCache &tensorCache, schema::MetaGraphT *subGraphDef);
void SetGraphTensorIndex(const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph,
const std::unique_ptr<tflite::ModelT> &tflite_model,
const mindspore::lite::TensorCache &tensorCache,
schema::MetaGraphT *subGraphDef);
STATUS ParseOp(const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::MetaGraphT *sub_graph,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册