提交 bed056aa 编写于 作者: Y yeyunpeng

Fix DeDepthwiseConv2D problem

上级 0ae5eeb3
...@@ -79,7 +79,7 @@ int WeightFormatPass::ShapeFormatTrans(GraphNode *graphNode) { ...@@ -79,7 +79,7 @@ int WeightFormatPass::ShapeFormatTrans(GraphNode *graphNode) {
switch (node->quantType) { switch (node->quantType) {
case QuantType_QUANT_NONE: { case QuantType_QUANT_NONE: {
if (opType == schema::PrimitiveType_Conv2D || opType == schema::PrimitiveType_DepthwiseConv2D || if (opType == schema::PrimitiveType_Conv2D || opType == schema::PrimitiveType_DepthwiseConv2D ||
opType == schema::PrimitiveType_DeConv2D) { opType == schema::PrimitiveType_DeConv2D || opType == schema::PrimitiveType_DeDepthwiseConv2D) {
weightTensor->format = schema::Format_KCHW; weightTensor->format = schema::Format_KCHW;
} else { } else {
MS_LOG(ERROR) << "Invalid opType: " << schema::EnumNamePrimitiveType(opType) MS_LOG(ERROR) << "Invalid opType: " << schema::EnumNamePrimitiveType(opType)
...@@ -240,11 +240,11 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) { ...@@ -240,11 +240,11 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) {
} }
} else if (weightTensor->format == schema::Format_KHWC) { // from onnx } else if (weightTensor->format == schema::Format_KHWC) { // from onnx
return RET_OK; return RET_OK;
// if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) { // if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) {
// status = TransFilterFormat<int8_t>(weightTensor.get(), kKHWC2HWCK); // status = TransFilterFormat<int8_t>(weightTensor.get(), kKHWC2HWCK);
// } else { // } else {
// status = TransFilterFormat<float>(weightTensor.get(), kKHWC2HWCK); // status = TransFilterFormat<float>(weightTensor.get(), kKHWC2HWCK);
// } // }
} else if (weightTensor->format == schema::Format_HWCK) { // from tf } else if (weightTensor->format == schema::Format_HWCK) { // from tf
return 0; return 0;
} else { } else {
...@@ -275,7 +275,7 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) { ...@@ -275,7 +275,7 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) {
} else if (weightTensor->format == schema::Format_HWCK) { // from tf } else if (weightTensor->format == schema::Format_HWCK) { // from tf
return 0; return 0;
} else if (weightTensor->format == schema::Format_CHWK) { // from onnx } else if (weightTensor->format == schema::Format_CHWK) { // from onnx
if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) { if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) {
status = TransFilterFormat<int8_t>(weightTensor.get(), kCHWK2KHWC); status = TransFilterFormat<int8_t>(weightTensor.get(), kCHWK2KHWC);
} else { } else {
status = TransFilterFormat<float>(weightTensor.get(), kCHWK2HWCK); status = TransFilterFormat<float>(weightTensor.get(), kCHWK2HWCK);
...@@ -383,9 +383,11 @@ int WeightFormatPass::NonQuantDataFormatTrans(GraphNode *graphNode) { ...@@ -383,9 +383,11 @@ int WeightFormatPass::NonQuantDataFormatTrans(GraphNode *graphNode) {
MS_LOG(WARNING) << "TransFilter HWKCToKCHW failed, node : " << node->name.c_str(); MS_LOG(WARNING) << "TransFilter HWKCToKCHW failed, node : " << node->name.c_str();
// todo(00445839): consider varible weight condition // todo(00445839): consider varible weight condition
} }
} else if (opType == schema::PrimitiveType_DeDepthwiseConv2D) { // weight should be CKHW } else if (opType == schema::PrimitiveType_DeDepthwiseConv2D) { // weight should be KHWC
if (weightTensor->format == schema::Format_CKHW) { // from caffe if (weightTensor->format == schema::Format_KHWC) {
return 0; return 0;
} else if (weightTensor->format == schema::Format_KCHW) { // from caffe
status = TransFilterFormat<float>(weightTensor.get(), kKCHW2KHWC);
} else if (weightTensor->format == schema::Format_HWKC) { // from tf or onnx } else if (weightTensor->format == schema::Format_HWKC) { // from tf or onnx
status = TransFilterFormat<float>(weightTensor.get(), kHWKC2CKHW); status = TransFilterFormat<float>(weightTensor.get(), kHWKC2CKHW);
} else { } else {
...@@ -393,7 +395,7 @@ int WeightFormatPass::NonQuantDataFormatTrans(GraphNode *graphNode) { ...@@ -393,7 +395,7 @@ int WeightFormatPass::NonQuantDataFormatTrans(GraphNode *graphNode) {
return -1; return -1;
} }
if (status == 0) { if (status == 0) {
node->primitive->value.AsDepthwiseConv2D()->format = schema::Format_NHWC; node->primitive->value.AsDeDepthwiseConv2D()->format = schema::Format_NHWC;
weightTensor->format = schema::Format_CKHW; weightTensor->format = schema::Format_CKHW;
} else { } else {
MS_LOG(WARNING) << "TransFilter HWKCToCKHW failed, node : " << node->name.c_str(); MS_LOG(WARNING) << "TransFilter HWKCToCKHW failed, node : " << node->name.c_str();
......
...@@ -46,14 +46,13 @@ void CaffeDeconvolutionParser::ParseGroupDeconvolution(schema::CNodeT *op, schem ...@@ -46,14 +46,13 @@ void CaffeDeconvolutionParser::ParseGroupDeconvolution(schema::CNodeT *op, schem
deDepthwiseConv2DParam->hasBias = attr->hasBias; deDepthwiseConv2DParam->hasBias = attr->hasBias;
deDepthwiseConv2DParam->activationType = attr->activationType; deDepthwiseConv2DParam->activationType = attr->activationType;
delete attr; delete attr;
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_DeDepthwiseConv2D; op->primitive->value.type = schema::PrimitiveType_DeDepthwiseConv2D;
op->primitive->value.value = deDepthwiseConv2DParam.release(); op->primitive->value.value = deDepthwiseConv2DParam.release();
} }
STATUS CaffeDeconvolutionParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, STATUS CaffeDeconvolutionParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight,
schema::CNodeT *op, std::vector<schema::TensorT *> *weightVec) { schema::CNodeT *op, std::vector<schema::TensorT *> *weightVec) {
op->name = proto.name(); op->name = proto.name();
schema::DeConv2DT *attr = new schema::DeConv2DT(); auto *attr = new schema::DeConv2DT();
attr->format = schema::Format_NCHW; attr->format = schema::Format_NCHW;
const caffe::ConvolutionParameter convParam = proto.convolution_param(); const caffe::ConvolutionParameter convParam = proto.convolution_param();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册