提交 2ae4214f 编写于 作者: K kai00

fushion mem check fixed

上级 33a562de
...@@ -90,23 +90,29 @@ STATUS BatchNormConvertScalePass::DoFusion(MetaGraphT *graph, const std::string ...@@ -90,23 +90,29 @@ STATUS BatchNormConvertScalePass::DoFusion(MetaGraphT *graph, const std::string
return RET_OK; return RET_OK;
} }
auto bnPath = matchedPath.at(bnOpName); auto bnPath = matchedPath.at(bnOpName);
status = GetTransParam(graph, bnPath);
if (status != RET_OK) {
MS_LOG(ERROR) << "GetTransParam failed: " << status;
return status;
}
status = GenNewScaleTensor(graph, bnPath); status = GenNewScaleTensor(graph, bnPath);
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "GenNewScaleTensor failed: " << status; MS_LOG(ERROR) << "GenNewScaleTensor failed: " << status;
delete[] transScale;
delete[] transBias;
transScale = nullptr;
transBias = nullptr;
return status; return status;
} }
status = ConvertBNToScale(graph, bnPath); status = ConvertBNToScale(graph, bnPath);
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "GenNewScaleTensor failed: " << status; MS_LOG(ERROR) << "GenNewScaleTensor failed: " << status;
delete[] transScale;
delete[] transBias;
transScale = nullptr;
transBias = nullptr;
return status; return status;
} }
delete[] transScale;
delete[] transBias;
transScale = nullptr;
transBias = nullptr;
return RET_OK; return RET_OK;
} }
STATUS BatchNormConvertScalePass::ConvertBNToScale(MetaGraphT *graph, const std::shared_ptr<Path> &bnPath) { STATUS BatchNormConvertScalePass::ConvertBNToScale(MetaGraphT *graph, const std::shared_ptr<Path> &bnPath) {
...@@ -245,6 +251,10 @@ STATUS BatchNormConvertScalePass::GetTransParam(MetaGraphT *graph, const std::sh ...@@ -245,6 +251,10 @@ STATUS BatchNormConvertScalePass::GetTransParam(MetaGraphT *graph, const std::sh
// cal transScale, tf : scale/sqrt(variance + eps); caffe : 1/sqrt(variance + eps) // cal transScale, tf : scale/sqrt(variance + eps); caffe : 1/sqrt(variance + eps)
if (memcpy_s(transScale, bnChannel * sizeof(float), varianceData, bnChannel * sizeof(float)) != 0) { if (memcpy_s(transScale, bnChannel * sizeof(float), varianceData, bnChannel * sizeof(float)) != 0) {
MS_LOG(ERROR) << "memcpy_s transScale error"; MS_LOG(ERROR) << "memcpy_s transScale error";
delete[] transScale;
delete[] transBias;
transScale = nullptr;
transBias = nullptr;
return RET_ERROR; return RET_ERROR;
} }
// 1/sqrt(variance + eps) // 1/sqrt(variance + eps)
...@@ -370,14 +380,5 @@ STATUS BatchNormConvertScalePass::GetBnEpsilon(MetaGraphT *graph) { ...@@ -370,14 +380,5 @@ STATUS BatchNormConvertScalePass::GetBnEpsilon(MetaGraphT *graph) {
} }
return RET_OK; return RET_OK;
} }
BatchNormConvertScalePass::~BatchNormConvertScalePass() {
if (this->transScale != nullptr) {
delete (this->transScale);
}
if (this->transBias != nullptr) {
delete (this->transBias);
}
}
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore
...@@ -36,7 +36,7 @@ class BatchNormConvertScalePass : public FusionPass { ...@@ -36,7 +36,7 @@ class BatchNormConvertScalePass : public FusionPass {
public: public:
BatchNormConvertScalePass() = default; BatchNormConvertScalePass() = default;
~BatchNormConvertScalePass() override; ~BatchNormConvertScalePass() = default;
STATUS DefinePattern() override; STATUS DefinePattern() override;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册