diff --git a/lite/kernels/mlu/bridges/conv_op.cc b/lite/kernels/mlu/bridges/conv_op.cc index 6a7ef408eb7432950d5a0985dd6e174236e937e0..2db9cfbd785ad727ecfda7461ab11b2a7f8d738a 100644 --- a/lite/kernels/mlu/bridges/conv_op.cc +++ b/lite/kernels/mlu/bridges/conv_op.cc @@ -13,7 +13,9 @@ // limitations under the License. #include "lite/operators/conv_op.h" + #include + #include "lite/kernels/mlu/bridges/graph.h" #include "lite/kernels/mlu/bridges/utility.h" #include "lite/kernels/npu/bridges/registry.h" @@ -43,6 +45,8 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) { const auto output_shape = output->dims().Vectorize(); const auto bs = input_dims[0]; const auto oc = filter_dims[0]; + const auto groups = op_info->GetAttr("groups"); + CHECK_EQ(input_dims.size(), 4); CHECK_EQ(filter_dims.size(), 4); const auto strides = op_info->GetAttr>("strides"); @@ -70,16 +74,55 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) { padding_algorithm, input_dims, filter_dims); + bool is_group_mode = false; + if (groups > 1) { + is_group_mode = true; + } - const auto output_tensor = graph->AddNode( - output_var_name, output_shape, CNML_TENSOR, CNML_NCHW, graph->FPType()); + bool is_depthwise_mode = false; + if (filter_dims[0] == groups && filter_dims[1] == 1 && dilations[0] == 1 && + dilations[1] == 1) { // depthwise filter shape = {1, ic ,kh ,kw} + is_depthwise_mode = true; + is_group_mode = false; + } - // Create filter node - const auto filter_tensor = graph->AddNode(filter_var_name, - filter_dims.Vectorize(), - CNML_FILTER, - CNML_NCHW, + // ================ DEBUG ======================= + + VLOG(4) << "conv2d op input_var_name : " << input_var_name << std::endl; + VLOG(4) << "conv2d op : filter_var_name " << filter_var_name << std::endl; + VLOG(4) << "conv2d op : output_var_name " << output_var_name << std::endl; + VLOG(4) << "conv2d op : groups " << groups << std::endl; + VLOG(4) << "conv2d op : is_depthwise_mode " << is_depthwise_mode<< std::endl; + VLOG(4) << "conv2d op : is_group_mode " << is_group_mode << std::endl; + + // ================ DEBUG EDN ======================= + const auto output_shape_nhwc = DimNCHW2NHWC(output_shape); + const auto output_tensor = graph->AddNode(output_var_name, + output_shape, + CNML_TENSOR, + CNML_NHWC, graph->FPType()); + scope->FindVar(output_var_name) + ->GetMutable<::paddle::lite::Tensor>() + ->Resize(output_shape_nhwc); + + std::vector cnml_filter_shape = { + filter_dims[0], filter_dims[1], filter_dims[2], filter_dims[3]}; + if (is_depthwise_mode) { + /*paddle filter shape is {oc , ic / groups == 1, kh, kw} while + cnml depthwise conv filter expect shape {oc / groups == 1 , ic , kh, kw} + so we should shape filter shape + */ + cnml_filter_shape = { + filter_dims[1], filter_dims[0], filter_dims[2], filter_dims[3]}; + } + + // Create filter node + std::shared_ptr filter_tensor = graph->AddNode(filter_var_name, + cnml_filter_shape, + CNML_FILTER, + CNML_NCHW, + graph->FPType()); const auto weight_scale = op_info->GetAttr>("weight_scale"); @@ -89,15 +132,15 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) { dequant(filter_dequant.data(), filter->mutable_data(), 1, - filter_dims[0], - filter_dims[1] * filter_dims[2] * filter_dims[3], + cnml_filter_shape[0], + cnml_filter_shape[1] * cnml_filter_shape[2] * cnml_filter_shape[3], weight_scale); transpose(filter_dequant.data(), filter->mutable_data(), - {static_cast(filter_dims[0]), - static_cast(filter_dims[1]), - static_cast(filter_dims[2]), - static_cast(filter_dims[3])}, + {static_cast(cnml_filter_shape[0]), + static_cast(cnml_filter_shape[1]), + static_cast(cnml_filter_shape[2]), + static_cast(cnml_filter_shape[3])}, {0, 2, 3, 1}); filter->set_precision(PrecisionType::kFloat); } else if (filter->precision() != PrecisionType::kFloat) { @@ -188,6 +231,39 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) { bias_tensor ? bias_tensor->mlu_tensor() : nullptr, std_tensor->mlu_tensor())); CNML_CALL(cnmlDestroyConvFirstOpParam(&conv_param)); + } else if (is_depthwise_mode) { + cnmlConvDepthwiseOpParam_t conv_depthwise_param; + cnmlCreateConvDepthwiseOpParam_V2(&conv_depthwise_param, + strides[0], + strides[1], + paddings[0] * 2, + paddings[2] * 2); + CNML_CALL(cnmlCreateConvDepthwiseOpForward( + &conv_op, + conv_depthwise_param, + graph->GetNode(input_var_name)->mlu_tensor(), + output_tensor->mlu_tensor(), + filter_tensor->mlu_tensor(), + bias_tensor ? bias_tensor->mlu_tensor() : nullptr)); + CNML_CALL(cnmlDestroyConvDepthwiseOpParam(&conv_depthwise_param)); + } else if (is_group_mode) { + cnmlConvOpParam_t conv_param; + CNML_CALL(cnmlCreateConvOpParam(&conv_param, + strides[0], + strides[1], + dilations[0], + dilations[1], + paddings[0] * 2, + paddings[2] * 2)); + CNML_CALL(cnmlCreateConvGroupOpForward( + &conv_op, + conv_param, + graph->GetNode(input_var_name)->mlu_tensor(), + output_tensor->mlu_tensor(), + filter_tensor->mlu_tensor(), + bias_tensor ? bias_tensor->mlu_tensor() : nullptr, + groups)); + CNML_CALL(cnmlDestroyConvOpParam(&conv_param)); } else { cnmlConvOpParam_t conv_param; CNML_CALL(cnmlCreateConvOpParam(&conv_param, @@ -207,12 +283,14 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) { CNML_CALL(cnmlDestroyConvOpParam(&conv_param)); } - graph->SetComputingDataType( - conv_op, graph->GetNode(input_var_name)->mlu_tensor(), 1 / input_scale); - graph->SetComputingDataType( - conv_op, - filter_tensor->mlu_tensor(), - 1 / *min_element(weight_scale.begin(), weight_scale.end())); + if (!is_depthwise_mode) { + graph->SetComputingDataType( + conv_op, graph->GetNode(input_var_name)->mlu_tensor(), 1 / input_scale); + graph->SetComputingDataType( + conv_op, + filter_tensor->mlu_tensor(), + 1 / *min_element(weight_scale.begin(), weight_scale.end())); + } CNML_CALL(cnmlSetOperationComputingLayout(conv_op, CNML_NHWC)); if (HasInputArg(op_info, scope, "Bias")) { auto* bias = scope->FindVar(bias_var_name)->GetMutable(); diff --git a/lite/kernels/mlu/bridges/conv_op_test.cc b/lite/kernels/mlu/bridges/conv_op_test.cc index e34dd7c2a85dbda62596b6e82d820fc437bfd194..3506651d87d2f13ba8b1fefeba50329291dc3742 100644 --- a/lite/kernels/mlu/bridges/conv_op_test.cc +++ b/lite/kernels/mlu/bridges/conv_op_test.cc @@ -13,8 +13,11 @@ // limitations under the License. #include "lite/operators/conv_op.h" + #include + #include + #include "lite/core/op_lite.h" #include "lite/core/op_registry.h" #include "lite/kernels/mlu/bridges/test_helper.h" @@ -331,6 +334,10 @@ TEST(MLUBridges, conv) { #endif } +TEST(MLUBridges, depthwise_conv2d) { + test_conv(1, 8, 8, 14, 14, false, false, false, true, 1, 1, 2, 3); +} + } // namespace mlu } // namespace subgraph } // namespace lite