未验证 提交 23845893 编写于 作者: L lvmengsi 提交者: GitHub

Fix conv_grad_grad (#20469)

* fix_conv_grad_grad

* fix_bug, test=develop
上级 82992033
......@@ -609,6 +609,8 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
auto dX = ctx.Output<Tensor>("DInput");
if (ddO) {
ddO->mutable_data<T>(ctx.GetPlace());
math::SetConstant<platform::CUDADeviceContext, T> set_zero;
set_zero(dev_ctx, ddO, static_cast<T>(0));
}
if (dW) {
dW->mutable_data<T>(ctx.GetPlace());
......@@ -646,7 +648,7 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
// transform Tensors to channel first-----------
Tensor transformed_X_channel(X->type());
Tensor transformed_dO_channel(dO->type());
Tensor transformed_ddX_channel(ddX->type());
Tensor transformed_ddX_channel(X->type());
Tensor transformed_ddO_channel(dO->type());
Tensor transformed_dX_channel(X->type());
......@@ -662,10 +664,12 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
TransToChannelFirst<platform::CUDADeviceContext, T>(
ctx, dO, &transformed_dO_channel);
if (ddX) {
ResizeToChannelFirst<platform::CUDADeviceContext, T>(
ctx, ddX, &transformed_ddX_channel);
TransToChannelFirst<platform::CUDADeviceContext, T>(
ctx, ddX, &transformed_ddX_channel);
}
if (ddO) {
ResizeToChannelFirst<platform::CUDADeviceContext, T>(
......@@ -680,7 +684,9 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
} else {
transformed_X_channel = *X;
transformed_dO_channel = *dO;
if (ddX) {
transformed_ddX_channel = *ddX;
}
if (ddO) {
transformed_ddO_channel.ShareDataWith(*ddO);
}
......@@ -729,15 +735,15 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
transformed_X.Resize(new_input_shape);
transformed_ddX.Resize(new_input_shape);
transformed_dX.Resize(new_input_shape);
auto& dev_ctx =
ctx.template device_context<paddle::platform::CUDADeviceContext>();
transformed_X =
ctx.AllocateTmpTensor<T, paddle::platform::CUDADeviceContext>(
new_input_shape, dev_ctx);
if (ddX) {
transformed_ddX =
ctx.AllocateTmpTensor<T, paddle::platform::CUDADeviceContext>(
new_input_shape, dev_ctx);
}
if (dX) {
transformed_dX =
ctx.AllocateTmpTensor<T, paddle::platform::CUDADeviceContext>(
......@@ -751,16 +757,20 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
case 4: {
math::PadFunction<paddle::platform::CUDADeviceContext, T, 4>(
ctx, input_pad, transformed_X_channel, pad_value, &transformed_X);
if (ddX) {
math::PadFunction<paddle::platform::CUDADeviceContext, T, 4>(
ctx, input_pad, transformed_ddX_channel, pad_value,
&transformed_ddX);
}
} break;
case 5: {
math::PadFunction<paddle::platform::CUDADeviceContext, T, 5>(
ctx, input_pad, transformed_X_channel, pad_value, &transformed_X);
if (ddX) {
math::PadFunction<paddle::platform::CUDADeviceContext, T, 5>(
ctx, input_pad, transformed_ddX_channel, pad_value,
&transformed_ddX);
}
} break;
default:
PADDLE_THROW("ConvOp only support tensors with 4 or 5 dimensions.");
......@@ -768,7 +778,9 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
} else {
transformed_X.ShareDataWith(transformed_X_channel);
if (ddX) {
transformed_ddX.ShareDataWith(transformed_ddX_channel);
}
if (dX) {
transformed_dX.ShareDataWith(transformed_dX_channel);
}
......@@ -936,10 +948,9 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
ctx, &transformed_ddO_channel, ddO);
}
}
T* transformed_dy_channel = nullptr;
T* transformed_dy_channel = transformed_dO_channel.data<T>();
if (dW && ddX) {
ddx = transformed_ddX.data<T>();
transformed_dy_channel = transformed_dO_channel.data<T>();
for (int i = 0; i < groups; i++) {
wkspace_handle.RunFunc(
[&](void* workspace_ptr) {
......
......@@ -651,7 +651,7 @@ class GemmConvDoubleGradKernel : public framework::OpKernel<T> {
// transform Tensor
Tensor transformed_X(X->type());
Tensor transformed_dY(dY->type());
Tensor transformed_ddX(ddX->type());
Tensor transformed_ddX(X->type());
if (channel_last) {
ResizeToChannelFirst<DeviceContext, T>(ctx, X, &transformed_X);
......@@ -660,14 +660,17 @@ class GemmConvDoubleGradKernel : public framework::OpKernel<T> {
ResizeToChannelFirst<DeviceContext, T>(ctx, dY, &transformed_dY);
TransToChannelFirst<DeviceContext, T>(ctx, dY, &transformed_dY);
if (ddX) {
ResizeToChannelFirst<DeviceContext, T>(ctx, ddX, &transformed_ddX);
TransToChannelFirst<DeviceContext, T>(ctx, ddX, &transformed_ddX);
}
} else {
transformed_X = *X;
transformed_dY = *dY;
if (ddX) {
transformed_ddX = *ddX;
}
}
// update padding and dilation
auto in_dims = transformed_X.dims();
......@@ -857,11 +860,10 @@ class GemmConvDoubleGradKernel : public framework::OpKernel<T> {
} else if (data_dim == 3U) {
vol2col(dev_ctx, ddx_slice, dilations, strides, paddings, &col);
}
}
Tensor w_slice = W.Slice(g * out_step, (g + 1) * out_step);
blas.MatMul(w_slice, false, col_matrix, false, T(1.0), &ddy_slice,
T(0.0));
}
if (ddW_in) {
Tensor x_batch = transformed_X.Slice(i, i + 1).Resize(input_shape);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册