提交 4d35336b 编写于 作者: Z zhaoying 提交者: jackzhang235

(feature): add group_conv and depthwise_conv

上级 0a68279c
...@@ -13,7 +13,9 @@ ...@@ -13,7 +13,9 @@
// limitations under the License. // limitations under the License.
#include "lite/operators/conv_op.h" #include "lite/operators/conv_op.h"
#include <algorithm> #include <algorithm>
#include "lite/kernels/mlu/bridges/graph.h" #include "lite/kernels/mlu/bridges/graph.h"
#include "lite/kernels/mlu/bridges/utility.h" #include "lite/kernels/mlu/bridges/utility.h"
#include "lite/kernels/npu/bridges/registry.h" #include "lite/kernels/npu/bridges/registry.h"
...@@ -43,6 +45,8 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -43,6 +45,8 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) {
const auto output_shape = output->dims().Vectorize(); const auto output_shape = output->dims().Vectorize();
const auto bs = input_dims[0]; const auto bs = input_dims[0];
const auto oc = filter_dims[0]; const auto oc = filter_dims[0];
const auto groups = op_info->GetAttr<int>("groups");
CHECK_EQ(input_dims.size(), 4); CHECK_EQ(input_dims.size(), 4);
CHECK_EQ(filter_dims.size(), 4); CHECK_EQ(filter_dims.size(), 4);
const auto strides = op_info->GetAttr<std::vector<int>>("strides"); const auto strides = op_info->GetAttr<std::vector<int>>("strides");
...@@ -70,16 +74,55 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -70,16 +74,55 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) {
padding_algorithm, padding_algorithm,
input_dims, input_dims,
filter_dims); filter_dims);
bool is_group_mode = false;
if (groups > 1) {
is_group_mode = true;
}
const auto output_tensor = graph->AddNode( bool is_depthwise_mode = false;
output_var_name, output_shape, CNML_TENSOR, CNML_NCHW, graph->FPType()); if (filter_dims[0] == groups && filter_dims[1] == 1 && dilations[0] == 1 &&
dilations[1] == 1) { // depthwise filter shape = {1, ic ,kh ,kw}
is_depthwise_mode = true;
is_group_mode = false;
}
// Create filter node // ================ DEBUG =======================
const auto filter_tensor = graph->AddNode(filter_var_name,
filter_dims.Vectorize(), VLOG(4) << "conv2d op input_var_name : " << input_var_name << std::endl;
CNML_FILTER, VLOG(4) << "conv2d op : filter_var_name " << filter_var_name << std::endl;
CNML_NCHW, VLOG(4) << "conv2d op : output_var_name " << output_var_name << std::endl;
VLOG(4) << "conv2d op : groups " << groups << std::endl;
VLOG(4) << "conv2d op : is_depthwise_mode " << is_depthwise_mode<< std::endl;
VLOG(4) << "conv2d op : is_group_mode " << is_group_mode << std::endl;
// ================ DEBUG EDN =======================
const auto output_shape_nhwc = DimNCHW2NHWC(output_shape);
const auto output_tensor = graph->AddNode(output_var_name,
output_shape,
CNML_TENSOR,
CNML_NHWC,
graph->FPType()); graph->FPType());
scope->FindVar(output_var_name)
->GetMutable<::paddle::lite::Tensor>()
->Resize(output_shape_nhwc);
std::vector<int64_t> cnml_filter_shape = {
filter_dims[0], filter_dims[1], filter_dims[2], filter_dims[3]};
if (is_depthwise_mode) {
/*paddle filter shape is {oc , ic / groups == 1, kh, kw} while
cnml depthwise conv filter expect shape {oc / groups == 1 , ic , kh, kw}
so we should shape filter shape
*/
cnml_filter_shape = {
filter_dims[1], filter_dims[0], filter_dims[2], filter_dims[3]};
}
// Create filter node
std::shared_ptr<MLUTensor> filter_tensor = graph->AddNode(filter_var_name,
cnml_filter_shape,
CNML_FILTER,
CNML_NCHW,
graph->FPType());
const auto weight_scale = const auto weight_scale =
op_info->GetAttr<std::vector<float>>("weight_scale"); op_info->GetAttr<std::vector<float>>("weight_scale");
...@@ -89,15 +132,15 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -89,15 +132,15 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) {
dequant(filter_dequant.data(), dequant(filter_dequant.data(),
filter->mutable_data<int8_t>(), filter->mutable_data<int8_t>(),
1, 1,
filter_dims[0], cnml_filter_shape[0],
filter_dims[1] * filter_dims[2] * filter_dims[3], cnml_filter_shape[1] * cnml_filter_shape[2] * cnml_filter_shape[3],
weight_scale); weight_scale);
transpose(filter_dequant.data(), transpose(filter_dequant.data(),
filter->mutable_data<float>(), filter->mutable_data<float>(),
{static_cast<int>(filter_dims[0]), {static_cast<int>(cnml_filter_shape[0]),
static_cast<int>(filter_dims[1]), static_cast<int>(cnml_filter_shape[1]),
static_cast<int>(filter_dims[2]), static_cast<int>(cnml_filter_shape[2]),
static_cast<int>(filter_dims[3])}, static_cast<int>(cnml_filter_shape[3])},
{0, 2, 3, 1}); {0, 2, 3, 1});
filter->set_precision(PrecisionType::kFloat); filter->set_precision(PrecisionType::kFloat);
} else if (filter->precision() != PrecisionType::kFloat) { } else if (filter->precision() != PrecisionType::kFloat) {
...@@ -188,6 +231,39 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -188,6 +231,39 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) {
bias_tensor ? bias_tensor->mlu_tensor() : nullptr, bias_tensor ? bias_tensor->mlu_tensor() : nullptr,
std_tensor->mlu_tensor())); std_tensor->mlu_tensor()));
CNML_CALL(cnmlDestroyConvFirstOpParam(&conv_param)); CNML_CALL(cnmlDestroyConvFirstOpParam(&conv_param));
} else if (is_depthwise_mode) {
cnmlConvDepthwiseOpParam_t conv_depthwise_param;
cnmlCreateConvDepthwiseOpParam_V2(&conv_depthwise_param,
strides[0],
strides[1],
paddings[0] * 2,
paddings[2] * 2);
CNML_CALL(cnmlCreateConvDepthwiseOpForward(
&conv_op,
conv_depthwise_param,
graph->GetNode(input_var_name)->mlu_tensor(),
output_tensor->mlu_tensor(),
filter_tensor->mlu_tensor(),
bias_tensor ? bias_tensor->mlu_tensor() : nullptr));
CNML_CALL(cnmlDestroyConvDepthwiseOpParam(&conv_depthwise_param));
} else if (is_group_mode) {
cnmlConvOpParam_t conv_param;
CNML_CALL(cnmlCreateConvOpParam(&conv_param,
strides[0],
strides[1],
dilations[0],
dilations[1],
paddings[0] * 2,
paddings[2] * 2));
CNML_CALL(cnmlCreateConvGroupOpForward(
&conv_op,
conv_param,
graph->GetNode(input_var_name)->mlu_tensor(),
output_tensor->mlu_tensor(),
filter_tensor->mlu_tensor(),
bias_tensor ? bias_tensor->mlu_tensor() : nullptr,
groups));
CNML_CALL(cnmlDestroyConvOpParam(&conv_param));
} else { } else {
cnmlConvOpParam_t conv_param; cnmlConvOpParam_t conv_param;
CNML_CALL(cnmlCreateConvOpParam(&conv_param, CNML_CALL(cnmlCreateConvOpParam(&conv_param,
...@@ -207,12 +283,14 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -207,12 +283,14 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CNML_CALL(cnmlDestroyConvOpParam(&conv_param)); CNML_CALL(cnmlDestroyConvOpParam(&conv_param));
} }
graph->SetComputingDataType( if (!is_depthwise_mode) {
conv_op, graph->GetNode(input_var_name)->mlu_tensor(), 1 / input_scale); graph->SetComputingDataType(
graph->SetComputingDataType( conv_op, graph->GetNode(input_var_name)->mlu_tensor(), 1 / input_scale);
conv_op, graph->SetComputingDataType(
filter_tensor->mlu_tensor(), conv_op,
1 / *min_element(weight_scale.begin(), weight_scale.end())); filter_tensor->mlu_tensor(),
1 / *min_element(weight_scale.begin(), weight_scale.end()));
}
CNML_CALL(cnmlSetOperationComputingLayout(conv_op, CNML_NHWC)); CNML_CALL(cnmlSetOperationComputingLayout(conv_op, CNML_NHWC));
if (HasInputArg(op_info, scope, "Bias")) { if (HasInputArg(op_info, scope, "Bias")) {
auto* bias = scope->FindVar(bias_var_name)->GetMutable<Tensor>(); auto* bias = scope->FindVar(bias_var_name)->GetMutable<Tensor>();
......
...@@ -13,8 +13,11 @@ ...@@ -13,8 +13,11 @@
// limitations under the License. // limitations under the License.
#include "lite/operators/conv_op.h" #include "lite/operators/conv_op.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <random> #include <random>
#include "lite/core/op_lite.h" #include "lite/core/op_lite.h"
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
#include "lite/kernels/mlu/bridges/test_helper.h" #include "lite/kernels/mlu/bridges/test_helper.h"
...@@ -331,6 +334,10 @@ TEST(MLUBridges, conv) { ...@@ -331,6 +334,10 @@ TEST(MLUBridges, conv) {
#endif #endif
} }
TEST(MLUBridges, depthwise_conv2d) {
test_conv(1, 8, 8, 14, 14, false, false, false, true, 1, 1, 2, 3);
}
} // namespace mlu } // namespace mlu
} // namespace subgraph } // namespace subgraph
} // namespace lite } // namespace lite
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册