提交 9638139e 编写于 作者: K kai00

fixed format trans

上级 1351d5f4
......@@ -1168,6 +1168,7 @@ int AnfImporterFromProtobuf::Import() {
const onnx::GraphProto &graphBuild = onnx_model_->graph();
if (!BuildFuncGraph(dstGraph, graphBuild)) {
MS_LOG(ERROR) << "Build funcgraph failed!";
func_graph_ = nullptr;
return RET_ERROR;
}
func_graph_ = dstGraph;
......
......@@ -96,7 +96,7 @@ static const std::vector<schema::PrimitiveType> nhwcOpList = {
schema::PrimitiveType_Conv2D, schema::PrimitiveType_DeConv2D,
schema::PrimitiveType_DepthwiseConv2D, schema::PrimitiveType_DeDepthwiseConv2D,
schema::PrimitiveType_Pooling, schema::PrimitiveType_Resize,
schema::PrimitiveType_BatchNorm};
schema::PrimitiveType_BatchNorm, schema::PrimitiveType_FusedBatchNorm};
static const std::vector<schema::PrimitiveType> fp32FullOpList = {
schema::PrimitiveType_Concat, schema::PrimitiveType_Add,
......
......@@ -234,7 +234,7 @@ static STATUS TransFilterData(schema::TensorT *tensor, kTransFilterType type, in
buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
} else if (type == kCKHW2KHWC) {
p2Buff =
buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterK) + (w * filterC) + (c));
buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c));
} else {
p2Buff =
buf.get() + ((h * filterW * filterK * filterC) + (w * filterK * filterC) + (k * filterC) + (c));
......
......@@ -351,7 +351,7 @@ int WeightFormatPass::NonQuantDataFormatTrans(GraphNode *graphNode) {
// todo(00445839): consider varible weight condition
}
} else if (opType == schema::PrimitiveType_DepthwiseConv2D) { // weight should be CKHW
if (graphNode->subGraph->fmkType == converter::FmkType_MS) {
if (fmkType == converter::FmkType_MS) {
weightTensor->format = schema::Format_CKHW;
}
if (weightTensor->format == schema::Format_CKHW) { // from caffe or onnx or ms
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册