未验证 提交 7b9b7303 编写于 作者: W wangchaochaohu 提交者: GitHub

Conv refine (#20644) (#20671)

* add condition judgement for performance improvement test=develop

* add condition judgement for performance improvement test=develop

* refine code style test=develop
上级 90d05bbd
......@@ -540,6 +540,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
workspace_size);
}
if (!is_sys_pad) {
std::vector<int> starts(transformed_input_channel.dims().size(), 0);
std::vector<int> axes(transformed_input_channel.dims().size(), 0);
......@@ -558,6 +559,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
ctx, &transformed_input_grad, &transformed_input_grad_channel,
starts, axes);
}
}
if (channel_last) {
TransToChannelLast<paddle::platform::CUDADeviceContext, T>(
......@@ -982,6 +984,7 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
workspace_size);
}
if (!is_sys_pad) {
// reverse padded input
std::vector<int> starts(X->dims().size(), 0);
std::vector<int> axes(X->dims().size(), 0);
......@@ -997,6 +1000,7 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
Slice_2<paddle::platform::CUDADeviceContext, T, 5>(
ctx, &transformed_dX, &transformed_dX_channel, starts, axes);
}
}
if (channel_last) {
TransToChannelLast<paddle::platform::CUDADeviceContext, T>(
ctx, &transformed_dX_channel, dX);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册