未验证 提交 3244a9de 编写于 作者: 光明和真理's avatar 光明和真理 提交者: GitHub

[MLU]transpose convbpf output to HWCN for better performance (#44552)

上级 f9cd526b
......@@ -436,6 +436,8 @@ class MLUDepthwiseConvGradOpKernel : public framework::OpKernel<T> {
Tensor output_grad_tensor(output_grad->type());
const std::vector<int> perm_to_nhwc = {0, 2, 3, 1};
const std::vector<int> perm_to_nchw = {0, 3, 1, 2};
const std::vector<int> perm_hwcm_to_mchw = {3, 2, 0, 1};
const std::vector<int> perm_mchw_to_hwcm = {2, 3, 1, 0};
if (channel_last) {
input_tensor.ShareDataWith(*input);
output_grad_tensor.ShareDataWith(*output_grad);
......@@ -462,10 +464,12 @@ class MLUDepthwiseConvGradOpKernel : public framework::OpKernel<T> {
auto filter_grad_dims = filter_grad->dims();
Tensor temp_filter_grad(filter_grad->type());
temp_filter_grad.mutable_data<T>({filter_grad_dims[0],
filter_grad_dims[2],
filter_grad_dims[3],
filter_grad_dims[1]},
// Details about setting diff_w hwcn for better performance, see the CNNL
// documentation.
temp_filter_grad.mutable_data<T>({filter_grad_dims[perm_mchw_to_hwcm[0]],
filter_grad_dims[perm_mchw_to_hwcm[1]],
filter_grad_dims[perm_mchw_to_hwcm[2]],
filter_grad_dims[perm_mchw_to_hwcm[3]]},
ctx.GetPlace());
cnnlDataType_t tensor_dtype = ToCnnlDataType<T>();
......@@ -474,7 +478,7 @@ class MLUDepthwiseConvGradOpKernel : public framework::OpKernel<T> {
MLUCnnlTensorDesc out_grad_desc(
output_grad_tensor, data_layout, tensor_dtype);
MLUCnnlTensorDesc temp_filter_grad_desc(
temp_filter_grad, data_layout, tensor_dtype);
temp_filter_grad, CNNL_LAYOUT_HWCN, tensor_dtype);
MLUCnnlConvolutionDesc conv_desc(in_dims_size,
paddings.data(),
......@@ -492,9 +496,9 @@ class MLUDepthwiseConvGradOpKernel : public framework::OpKernel<T> {
temp_filter_grad_desc.get(),
GetBasePtr(&temp_filter_grad));
// transpose filter_grad from MHWC to MCHW
// transpose filter_grad from HWCM to MCHW
TransposeFromMLUTensor<T>(ctx,
perm_to_nchw,
perm_hwcm_to_mchw,
&temp_filter_grad,
filter_grad,
false /*need_reshape_or_alloc*/);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册