未验证 提交 3244a9de 编写于 作者: 光明和真理's avatar 光明和真理 提交者: GitHub

[MLU]transpose convbpf output to HWCN for better performance (#44552)

上级 f9cd526b
...@@ -436,6 +436,8 @@ class MLUDepthwiseConvGradOpKernel : public framework::OpKernel<T> { ...@@ -436,6 +436,8 @@ class MLUDepthwiseConvGradOpKernel : public framework::OpKernel<T> {
Tensor output_grad_tensor(output_grad->type()); Tensor output_grad_tensor(output_grad->type());
const std::vector<int> perm_to_nhwc = {0, 2, 3, 1}; const std::vector<int> perm_to_nhwc = {0, 2, 3, 1};
const std::vector<int> perm_to_nchw = {0, 3, 1, 2}; const std::vector<int> perm_to_nchw = {0, 3, 1, 2};
const std::vector<int> perm_hwcm_to_mchw = {3, 2, 0, 1};
const std::vector<int> perm_mchw_to_hwcm = {2, 3, 1, 0};
if (channel_last) { if (channel_last) {
input_tensor.ShareDataWith(*input); input_tensor.ShareDataWith(*input);
output_grad_tensor.ShareDataWith(*output_grad); output_grad_tensor.ShareDataWith(*output_grad);
...@@ -462,10 +464,12 @@ class MLUDepthwiseConvGradOpKernel : public framework::OpKernel<T> { ...@@ -462,10 +464,12 @@ class MLUDepthwiseConvGradOpKernel : public framework::OpKernel<T> {
auto filter_grad_dims = filter_grad->dims(); auto filter_grad_dims = filter_grad->dims();
Tensor temp_filter_grad(filter_grad->type()); Tensor temp_filter_grad(filter_grad->type());
temp_filter_grad.mutable_data<T>({filter_grad_dims[0], // Details about setting diff_w hwcn for better performance, see the CNNL
filter_grad_dims[2], // documentation.
filter_grad_dims[3], temp_filter_grad.mutable_data<T>({filter_grad_dims[perm_mchw_to_hwcm[0]],
filter_grad_dims[1]}, filter_grad_dims[perm_mchw_to_hwcm[1]],
filter_grad_dims[perm_mchw_to_hwcm[2]],
filter_grad_dims[perm_mchw_to_hwcm[3]]},
ctx.GetPlace()); ctx.GetPlace());
cnnlDataType_t tensor_dtype = ToCnnlDataType<T>(); cnnlDataType_t tensor_dtype = ToCnnlDataType<T>();
...@@ -474,7 +478,7 @@ class MLUDepthwiseConvGradOpKernel : public framework::OpKernel<T> { ...@@ -474,7 +478,7 @@ class MLUDepthwiseConvGradOpKernel : public framework::OpKernel<T> {
MLUCnnlTensorDesc out_grad_desc( MLUCnnlTensorDesc out_grad_desc(
output_grad_tensor, data_layout, tensor_dtype); output_grad_tensor, data_layout, tensor_dtype);
MLUCnnlTensorDesc temp_filter_grad_desc( MLUCnnlTensorDesc temp_filter_grad_desc(
temp_filter_grad, data_layout, tensor_dtype); temp_filter_grad, CNNL_LAYOUT_HWCN, tensor_dtype);
MLUCnnlConvolutionDesc conv_desc(in_dims_size, MLUCnnlConvolutionDesc conv_desc(in_dims_size,
paddings.data(), paddings.data(),
...@@ -492,9 +496,9 @@ class MLUDepthwiseConvGradOpKernel : public framework::OpKernel<T> { ...@@ -492,9 +496,9 @@ class MLUDepthwiseConvGradOpKernel : public framework::OpKernel<T> {
temp_filter_grad_desc.get(), temp_filter_grad_desc.get(),
GetBasePtr(&temp_filter_grad)); GetBasePtr(&temp_filter_grad));
// transpose filter_grad from MHWC to MCHW // transpose filter_grad from HWCM to MCHW
TransposeFromMLUTensor<T>(ctx, TransposeFromMLUTensor<T>(ctx,
perm_to_nchw, perm_hwcm_to_mchw,
&temp_filter_grad, &temp_filter_grad,
filter_grad, filter_grad,
false /*need_reshape_or_alloc*/); false /*need_reshape_or_alloc*/);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册