提交 bed056aa 编写于 作者: Y yeyunpeng

Fix DeDepthwiseConv2D problem

上级 0ae5eeb3
......@@ -79,7 +79,7 @@ int WeightFormatPass::ShapeFormatTrans(GraphNode *graphNode) {
switch (node->quantType) {
case QuantType_QUANT_NONE: {
if (opType == schema::PrimitiveType_Conv2D || opType == schema::PrimitiveType_DepthwiseConv2D ||
opType == schema::PrimitiveType_DeConv2D) {
opType == schema::PrimitiveType_DeConv2D || opType == schema::PrimitiveType_DeDepthwiseConv2D) {
weightTensor->format = schema::Format_KCHW;
} else {
MS_LOG(ERROR) << "Invalid opType: " << schema::EnumNamePrimitiveType(opType)
......@@ -240,11 +240,11 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) {
}
} else if (weightTensor->format == schema::Format_KHWC) { // from onnx
return RET_OK;
// if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) {
// status = TransFilterFormat<int8_t>(weightTensor.get(), kKHWC2HWCK);
// } else {
// status = TransFilterFormat<float>(weightTensor.get(), kKHWC2HWCK);
// }
// if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) {
// status = TransFilterFormat<int8_t>(weightTensor.get(), kKHWC2HWCK);
// } else {
// status = TransFilterFormat<float>(weightTensor.get(), kKHWC2HWCK);
// }
} else if (weightTensor->format == schema::Format_HWCK) { // from tf
return 0;
} else {
......@@ -275,7 +275,7 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) {
} else if (weightTensor->format == schema::Format_HWCK) { // from tf
return 0;
} else if (weightTensor->format == schema::Format_CHWK) { // from onnx
if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) {
if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) {
status = TransFilterFormat<int8_t>(weightTensor.get(), kCHWK2KHWC);
} else {
status = TransFilterFormat<float>(weightTensor.get(), kCHWK2HWCK);
......@@ -383,9 +383,11 @@ int WeightFormatPass::NonQuantDataFormatTrans(GraphNode *graphNode) {
MS_LOG(WARNING) << "TransFilter HWKCToKCHW failed, node : " << node->name.c_str();
// todo(00445839): consider varible weight condition
}
} else if (opType == schema::PrimitiveType_DeDepthwiseConv2D) { // weight should be CKHW
if (weightTensor->format == schema::Format_CKHW) { // from caffe
} else if (opType == schema::PrimitiveType_DeDepthwiseConv2D) { // weight should be KHWC
if (weightTensor->format == schema::Format_KHWC) {
return 0;
} else if (weightTensor->format == schema::Format_KCHW) { // from caffe
status = TransFilterFormat<float>(weightTensor.get(), kKCHW2KHWC);
} else if (weightTensor->format == schema::Format_HWKC) { // from tf or onnx
status = TransFilterFormat<float>(weightTensor.get(), kHWKC2CKHW);
} else {
......@@ -393,7 +395,7 @@ int WeightFormatPass::NonQuantDataFormatTrans(GraphNode *graphNode) {
return -1;
}
if (status == 0) {
node->primitive->value.AsDepthwiseConv2D()->format = schema::Format_NHWC;
node->primitive->value.AsDeDepthwiseConv2D()->format = schema::Format_NHWC;
weightTensor->format = schema::Format_CKHW;
} else {
MS_LOG(WARNING) << "TransFilter HWKCToCKHW failed, node : " << node->name.c_str();
......
......@@ -46,14 +46,13 @@ void CaffeDeconvolutionParser::ParseGroupDeconvolution(schema::CNodeT *op, schem
deDepthwiseConv2DParam->hasBias = attr->hasBias;
deDepthwiseConv2DParam->activationType = attr->activationType;
delete attr;
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_DeDepthwiseConv2D;
op->primitive->value.value = deDepthwiseConv2DParam.release();
}
STATUS CaffeDeconvolutionParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight,
schema::CNodeT *op, std::vector<schema::TensorT *> *weightVec) {
op->name = proto.name();
schema::DeConv2DT *attr = new schema::DeConv2DT();
auto *attr = new schema::DeConv2DT();
attr->format = schema::Format_NCHW;
const caffe::ConvolutionParameter convParam = proto.convolution_param();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册