未验证 提交 7783d3bd 编写于 作者: W wangchaochaohu 提交者: GitHub

Conv refine (#20644)

* add condition judgement for performance improvement test=develop

* add condition judgement for performance improvement test=develop

* refine code style test=develop
上级 57b656f9
...@@ -540,23 +540,25 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> { ...@@ -540,23 +540,25 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
workspace_size); workspace_size);
} }
std::vector<int> starts(transformed_input_channel.dims().size(), 0); if (!is_sys_pad) {
std::vector<int> axes(transformed_input_channel.dims().size(), 0); std::vector<int> starts(transformed_input_channel.dims().size(), 0);
std::vector<int> axes(transformed_input_channel.dims().size(), 0);
for (size_t i = 0; i < transformed_input_channel.dims().size(); ++i) { for (size_t i = 0; i < transformed_input_channel.dims().size(); ++i) {
starts[i] = input_pad[2 * i]; starts[i] = input_pad[2 * i];
axes[i] = i; axes[i] = i;
} }
transformed_input_grad_channel.mutable_data(ctx.GetPlace()); transformed_input_grad_channel.mutable_data(ctx.GetPlace());
if (transformed_input_channel.dims().size() == 4) { if (transformed_input_channel.dims().size() == 4) {
Slice_2<paddle::platform::CUDADeviceContext, T, 4>( Slice_2<paddle::platform::CUDADeviceContext, T, 4>(
ctx, &transformed_input_grad, &transformed_input_grad_channel, ctx, &transformed_input_grad, &transformed_input_grad_channel,
starts, axes); starts, axes);
} else { } else {
Slice_2<paddle::platform::CUDADeviceContext, T, 5>( Slice_2<paddle::platform::CUDADeviceContext, T, 5>(
ctx, &transformed_input_grad, &transformed_input_grad_channel, ctx, &transformed_input_grad, &transformed_input_grad_channel,
starts, axes); starts, axes);
}
} }
if (channel_last) { if (channel_last) {
...@@ -982,20 +984,22 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> { ...@@ -982,20 +984,22 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
workspace_size); workspace_size);
} }
// reverse padded input if (!is_sys_pad) {
std::vector<int> starts(X->dims().size(), 0); // reverse padded input
std::vector<int> axes(X->dims().size(), 0); std::vector<int> starts(X->dims().size(), 0);
std::vector<int> axes(X->dims().size(), 0);
for (size_t i = 0; i < X->dims().size(); ++i) { for (size_t i = 0; i < X->dims().size(); ++i) {
starts[i] = input_pad[2 * i]; starts[i] = input_pad[2 * i];
axes[i] = i; axes[i] = i;
} }
if (X->dims().size() == 4) { if (X->dims().size() == 4) {
Slice_2<paddle::platform::CUDADeviceContext, T, 4>( Slice_2<paddle::platform::CUDADeviceContext, T, 4>(
ctx, &transformed_dX, &transformed_dX_channel, starts, axes); ctx, &transformed_dX, &transformed_dX_channel, starts, axes);
} else { } else {
Slice_2<paddle::platform::CUDADeviceContext, T, 5>( Slice_2<paddle::platform::CUDADeviceContext, T, 5>(
ctx, &transformed_dX, &transformed_dX_channel, starts, axes); ctx, &transformed_dX, &transformed_dX_channel, starts, axes);
}
} }
if (channel_last) { if (channel_last) {
TransToChannelLast<paddle::platform::CUDADeviceContext, T>( TransToChannelLast<paddle::platform::CUDADeviceContext, T>(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册