未验证 提交 23845893 编写于 作者: L lvmengsi 提交者: GitHub

Fix conv_grad_grad (#20469)

* fix_conv_grad_grad

* fix_bug, test=develop
上级 82992033
...@@ -609,6 +609,8 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> { ...@@ -609,6 +609,8 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
auto dX = ctx.Output<Tensor>("DInput"); auto dX = ctx.Output<Tensor>("DInput");
if (ddO) { if (ddO) {
ddO->mutable_data<T>(ctx.GetPlace()); ddO->mutable_data<T>(ctx.GetPlace());
math::SetConstant<platform::CUDADeviceContext, T> set_zero;
set_zero(dev_ctx, ddO, static_cast<T>(0));
} }
if (dW) { if (dW) {
dW->mutable_data<T>(ctx.GetPlace()); dW->mutable_data<T>(ctx.GetPlace());
...@@ -646,7 +648,7 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> { ...@@ -646,7 +648,7 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
// transform Tensors to channel first----------- // transform Tensors to channel first-----------
Tensor transformed_X_channel(X->type()); Tensor transformed_X_channel(X->type());
Tensor transformed_dO_channel(dO->type()); Tensor transformed_dO_channel(dO->type());
Tensor transformed_ddX_channel(ddX->type()); Tensor transformed_ddX_channel(X->type());
Tensor transformed_ddO_channel(dO->type()); Tensor transformed_ddO_channel(dO->type());
Tensor transformed_dX_channel(X->type()); Tensor transformed_dX_channel(X->type());
...@@ -662,10 +664,12 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> { ...@@ -662,10 +664,12 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
TransToChannelFirst<platform::CUDADeviceContext, T>( TransToChannelFirst<platform::CUDADeviceContext, T>(
ctx, dO, &transformed_dO_channel); ctx, dO, &transformed_dO_channel);
ResizeToChannelFirst<platform::CUDADeviceContext, T>( if (ddX) {
ctx, ddX, &transformed_ddX_channel); ResizeToChannelFirst<platform::CUDADeviceContext, T>(
TransToChannelFirst<platform::CUDADeviceContext, T>( ctx, ddX, &transformed_ddX_channel);
ctx, ddX, &transformed_ddX_channel); TransToChannelFirst<platform::CUDADeviceContext, T>(
ctx, ddX, &transformed_ddX_channel);
}
if (ddO) { if (ddO) {
ResizeToChannelFirst<platform::CUDADeviceContext, T>( ResizeToChannelFirst<platform::CUDADeviceContext, T>(
...@@ -680,7 +684,9 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> { ...@@ -680,7 +684,9 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
} else { } else {
transformed_X_channel = *X; transformed_X_channel = *X;
transformed_dO_channel = *dO; transformed_dO_channel = *dO;
transformed_ddX_channel = *ddX; if (ddX) {
transformed_ddX_channel = *ddX;
}
if (ddO) { if (ddO) {
transformed_ddO_channel.ShareDataWith(*ddO); transformed_ddO_channel.ShareDataWith(*ddO);
} }
...@@ -729,15 +735,15 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> { ...@@ -729,15 +735,15 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
transformed_X.Resize(new_input_shape); transformed_X.Resize(new_input_shape);
transformed_ddX.Resize(new_input_shape); transformed_ddX.Resize(new_input_shape);
transformed_dX.Resize(new_input_shape); transformed_dX.Resize(new_input_shape);
auto& dev_ctx =
ctx.template device_context<paddle::platform::CUDADeviceContext>();
transformed_X = transformed_X =
ctx.AllocateTmpTensor<T, paddle::platform::CUDADeviceContext>( ctx.AllocateTmpTensor<T, paddle::platform::CUDADeviceContext>(
new_input_shape, dev_ctx); new_input_shape, dev_ctx);
transformed_ddX = if (ddX) {
ctx.AllocateTmpTensor<T, paddle::platform::CUDADeviceContext>( transformed_ddX =
new_input_shape, dev_ctx); ctx.AllocateTmpTensor<T, paddle::platform::CUDADeviceContext>(
new_input_shape, dev_ctx);
}
if (dX) { if (dX) {
transformed_dX = transformed_dX =
ctx.AllocateTmpTensor<T, paddle::platform::CUDADeviceContext>( ctx.AllocateTmpTensor<T, paddle::platform::CUDADeviceContext>(
...@@ -751,16 +757,20 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> { ...@@ -751,16 +757,20 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
case 4: { case 4: {
math::PadFunction<paddle::platform::CUDADeviceContext, T, 4>( math::PadFunction<paddle::platform::CUDADeviceContext, T, 4>(
ctx, input_pad, transformed_X_channel, pad_value, &transformed_X); ctx, input_pad, transformed_X_channel, pad_value, &transformed_X);
math::PadFunction<paddle::platform::CUDADeviceContext, T, 4>( if (ddX) {
ctx, input_pad, transformed_ddX_channel, pad_value, math::PadFunction<paddle::platform::CUDADeviceContext, T, 4>(
&transformed_ddX); ctx, input_pad, transformed_ddX_channel, pad_value,
&transformed_ddX);
}
} break; } break;
case 5: { case 5: {
math::PadFunction<paddle::platform::CUDADeviceContext, T, 5>( math::PadFunction<paddle::platform::CUDADeviceContext, T, 5>(
ctx, input_pad, transformed_X_channel, pad_value, &transformed_X); ctx, input_pad, transformed_X_channel, pad_value, &transformed_X);
math::PadFunction<paddle::platform::CUDADeviceContext, T, 5>( if (ddX) {
ctx, input_pad, transformed_ddX_channel, pad_value, math::PadFunction<paddle::platform::CUDADeviceContext, T, 5>(
&transformed_ddX); ctx, input_pad, transformed_ddX_channel, pad_value,
&transformed_ddX);
}
} break; } break;
default: default:
PADDLE_THROW("ConvOp only support tensors with 4 or 5 dimensions."); PADDLE_THROW("ConvOp only support tensors with 4 or 5 dimensions.");
...@@ -768,7 +778,9 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> { ...@@ -768,7 +778,9 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
} else { } else {
transformed_X.ShareDataWith(transformed_X_channel); transformed_X.ShareDataWith(transformed_X_channel);
transformed_ddX.ShareDataWith(transformed_ddX_channel); if (ddX) {
transformed_ddX.ShareDataWith(transformed_ddX_channel);
}
if (dX) { if (dX) {
transformed_dX.ShareDataWith(transformed_dX_channel); transformed_dX.ShareDataWith(transformed_dX_channel);
} }
...@@ -936,10 +948,9 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> { ...@@ -936,10 +948,9 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
ctx, &transformed_ddO_channel, ddO); ctx, &transformed_ddO_channel, ddO);
} }
} }
T* transformed_dy_channel = nullptr; T* transformed_dy_channel = transformed_dO_channel.data<T>();
if (dW && ddX) { if (dW && ddX) {
ddx = transformed_ddX.data<T>(); ddx = transformed_ddX.data<T>();
transformed_dy_channel = transformed_dO_channel.data<T>();
for (int i = 0; i < groups; i++) { for (int i = 0; i < groups; i++) {
wkspace_handle.RunFunc( wkspace_handle.RunFunc(
[&](void* workspace_ptr) { [&](void* workspace_ptr) {
......
...@@ -651,7 +651,7 @@ class GemmConvDoubleGradKernel : public framework::OpKernel<T> { ...@@ -651,7 +651,7 @@ class GemmConvDoubleGradKernel : public framework::OpKernel<T> {
// transform Tensor // transform Tensor
Tensor transformed_X(X->type()); Tensor transformed_X(X->type());
Tensor transformed_dY(dY->type()); Tensor transformed_dY(dY->type());
Tensor transformed_ddX(ddX->type()); Tensor transformed_ddX(X->type());
if (channel_last) { if (channel_last) {
ResizeToChannelFirst<DeviceContext, T>(ctx, X, &transformed_X); ResizeToChannelFirst<DeviceContext, T>(ctx, X, &transformed_X);
...@@ -660,13 +660,16 @@ class GemmConvDoubleGradKernel : public framework::OpKernel<T> { ...@@ -660,13 +660,16 @@ class GemmConvDoubleGradKernel : public framework::OpKernel<T> {
ResizeToChannelFirst<DeviceContext, T>(ctx, dY, &transformed_dY); ResizeToChannelFirst<DeviceContext, T>(ctx, dY, &transformed_dY);
TransToChannelFirst<DeviceContext, T>(ctx, dY, &transformed_dY); TransToChannelFirst<DeviceContext, T>(ctx, dY, &transformed_dY);
ResizeToChannelFirst<DeviceContext, T>(ctx, ddX, &transformed_ddX); if (ddX) {
TransToChannelFirst<DeviceContext, T>(ctx, ddX, &transformed_ddX); ResizeToChannelFirst<DeviceContext, T>(ctx, ddX, &transformed_ddX);
TransToChannelFirst<DeviceContext, T>(ctx, ddX, &transformed_ddX);
}
} else { } else {
transformed_X = *X; transformed_X = *X;
transformed_dY = *dY; transformed_dY = *dY;
transformed_ddX = *ddX; if (ddX) {
transformed_ddX = *ddX;
}
} }
// update padding and dilation // update padding and dilation
...@@ -857,12 +860,11 @@ class GemmConvDoubleGradKernel : public framework::OpKernel<T> { ...@@ -857,12 +860,11 @@ class GemmConvDoubleGradKernel : public framework::OpKernel<T> {
} else if (data_dim == 3U) { } else if (data_dim == 3U) {
vol2col(dev_ctx, ddx_slice, dilations, strides, paddings, &col); vol2col(dev_ctx, ddx_slice, dilations, strides, paddings, &col);
} }
Tensor w_slice = W.Slice(g * out_step, (g + 1) * out_step);
blas.MatMul(w_slice, false, col_matrix, false, T(1.0), &ddy_slice,
T(0.0));
} }
Tensor w_slice = W.Slice(g * out_step, (g + 1) * out_step);
blas.MatMul(w_slice, false, col_matrix, false, T(1.0), &ddy_slice,
T(0.0));
if (ddW_in) { if (ddW_in) {
Tensor x_batch = transformed_X.Slice(i, i + 1).Resize(input_shape); Tensor x_batch = transformed_X.Slice(i, i + 1).Resize(input_shape);
Tensor x_slice = x_batch.Slice(g * in_step, (g + 1) * in_step); Tensor x_slice = x_batch.Slice(g * in_step, (g + 1) * in_step);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册