提交 2ae2c3ce 编写于 作者: C cjh9368

solve aware quantizer memory problem

上级 18c6ac99
......@@ -293,13 +293,13 @@ STATUS AwareQuantizer::GenerateQuantParam() {
MS_ASSERT(graph->inputIndex.size() == 1);
// set graphInputNode input
for (auto graphInputIndex : graph->inputIndex) {
auto status = mInputArray->SetInputArrayQP(graph.get(), graphInputIndex);
auto status = mInputArray->SetInputArrayQP(graph, graphInputIndex);
if (status != RET_OK) {
MS_LOG(ERROR) << "SetInputArrayQP failed";
return status;
}
}
auto status = GenerateDefaultQuantParam(graph.get());
auto status = GenerateDefaultQuantParam(graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "GenerateDefaultQuantParam failed";
return status;
......@@ -319,7 +319,7 @@ STATUS AwareQuantizer::GenerateQuantParam() {
<< ", type: " << GetCNodeTTypeName(*node).c_str() << " set node to QuantNone and skip";
node->quantType = static_cast<schema::QuantType>(QuantType_QUANT_NONE);
} else {
status = quantParamCalcer->Calc(graph.get(), *node);
status = quantParamCalcer->Calc(graph, *node);
if (status != RET_OK) {
MS_LOG(ERROR) << "quantParamCalcer failed: " << status << " node: " << node->name.c_str();
node->quantType = schema::QuantType_QUANT_NONE;
......@@ -349,27 +349,27 @@ STATUS AwareQuantizer::DoQuantize() {
return RET_ERROR;
}
// quant weight
status = QuantConvWeight(graph.get(), node.get());
status = QuantConvWeight(graph, node.get());
if (status != RET_OK) {
MS_LOG(ERROR) << "QuantConvWeight failed!";
return RET_ERROR;
}
// quant bias
if (inputIndexes.size() == 3) {
status = QuantConvBias(graph.get(), node.get());
status = QuantConvBias(graph, node.get());
if (status != RET_OK) {
MS_LOG(ERROR) << "QuantConvBias failed!";
return RET_ERROR;
}
}
} else if (GetCNodeTType(*node) == schema::PrimitiveType_DetectionPostProcess) {
status = QuantDetectionPostProcessConstTensor(graph.get(), node.get());
status = QuantDetectionPostProcessConstTensor(graph, node.get());
if (status != RET_OK) {
MS_LOG(ERROR) << "QuantDetectionPostProcessConstTensor failed!";
return RET_ERROR;
}
} else if (GetCNodeTType(*node) == schema::PrimitiveType_Add) {
status = QuantAddConstTensor(graph.get(), node.get());
status = QuantAddConstTensor(graph, node.get());
if (status != RET_OK) {
MS_LOG(ERROR) << "QuantAddConstTensor failed!";
return RET_ERROR;
......
......@@ -73,7 +73,7 @@ class FbQuantizer {
virtual STATUS DoQuantize() = 0;
protected:
std::shared_ptr<schema::MetaGraphT> graph = nullptr;
schema::MetaGraphT *graph = nullptr;
};
} // namespace mindspore::lite::quant
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册