未验证 提交 c92348c3 编写于 作者: L lvmengsi 提交者: GitHub

fix conv_grad_grad (#20054)

上级 4e99c2af
...@@ -355,6 +355,7 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> { ...@@ -355,6 +355,7 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
size_t workspace_size = 0; size_t workspace_size = 0;
if (ddO) { if (ddO) {
ddy = ddO->mutable_data<T>(ctx.GetPlace()); ddy = ddO->mutable_data<T>(ctx.GetPlace());
if (ddX) {
args1.handle = handle; args1.handle = handle;
args1.idesc.set(*ddX, iwo_group); args1.idesc.set(*ddX, iwo_group);
args1.wdesc.set(*W, layout, iwo_group); args1.wdesc.set(*W, layout, iwo_group);
...@@ -364,6 +365,7 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> { ...@@ -364,6 +365,7 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
using search1 = SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t>; using search1 = SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t>;
fwd_algo1 = search1::Find<T>(args1, exhaustive_search, false, 0, ctx); fwd_algo1 = search1::Find<T>(args1, exhaustive_search, false, 0, ctx);
workspace_size = search1::GetWorkspaceSize(args1, fwd_algo1); workspace_size = search1::GetWorkspaceSize(args1, fwd_algo1);
}
if (ddW) { if (ddW) {
ddw = ddW->data<T>(); ddw = ddW->data<T>();
...@@ -380,7 +382,7 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> { ...@@ -380,7 +382,7 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
} }
} }
if (dW) { if (dW && ddX) {
dw = dW->mutable_data<T>(ctx.GetPlace()); dw = dW->mutable_data<T>(ctx.GetPlace());
args3.handle = handle; args3.handle = handle;
args3.idesc.set(*ddX, iwo_group); args3.idesc.set(*ddX, iwo_group);
...@@ -423,18 +425,21 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> { ...@@ -423,18 +425,21 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
auto wkspace_handle = dev_ctx.cudnn_workspace_handle(); auto wkspace_handle = dev_ctx.cudnn_workspace_handle();
if (ddO) { if (ddO) {
if (ddX) {
ddx = ddX->data<T>(); ddx = ddX->data<T>();
for (int i = 0; i < groups; i++) { for (int i = 0; i < groups; i++) {
wkspace_handle.RunFunc( wkspace_handle.RunFunc(
[&](void* workspace_ptr) { [&](void* workspace_ptr) {
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionForward( CUDNN_ENFORCE(platform::dynload::cudnnConvolutionForward(
handle, &alpha, args1.idesc.desc(), ddx + i * group_offset_in, handle, &alpha, args1.idesc.desc(),
args1.wdesc.desc(), w + i * group_offset_filter, ddx + i * group_offset_in, args1.wdesc.desc(),
args1.cdesc.desc(), fwd_algo1, workspace_ptr, workspace_size, w + i * group_offset_filter, args1.cdesc.desc(), fwd_algo1,
&beta, args1.odesc.desc(), ddy + i * group_offset_out)); workspace_ptr, workspace_size, &beta, args1.odesc.desc(),
ddy + i * group_offset_out));
}, },
workspace_size); workspace_size);
} }
}
if (ddW) { if (ddW) {
for (int i = 0; i < groups; i++) { for (int i = 0; i < groups; i++) {
wkspace_handle.RunFunc( wkspace_handle.RunFunc(
...@@ -451,7 +456,7 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> { ...@@ -451,7 +456,7 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
} }
} }
if (dW) { if (dW && ddX) {
ddx = ddX->data<T>(); ddx = ddX->data<T>();
for (int i = 0; i < groups; i++) { for (int i = 0; i < groups; i++) {
wkspace_handle.RunFunc( wkspace_handle.RunFunc(
......
...@@ -553,9 +553,10 @@ class Conv2DDoubleGradMaker : public framework::SingleGradOpDescMaker { ...@@ -553,9 +553,10 @@ class Conv2DDoubleGradMaker : public framework::SingleGradOpDescMaker {
auto ddw = OutputGrad(framework::GradVarName("Filter")); auto ddw = OutputGrad(framework::GradVarName("Filter"));
std::vector<std::string> empty_str = {}; std::vector<std::string> empty_str = {};
op->SetOutput( op->SetOutput("DDOutput",
"DDOutput", (ddx.empty() && ddw.empty())
ddx.empty() ? empty_str : InputGrad(framework::GradVarName("Output"))); ? empty_str
: InputGrad(framework::GradVarName("Output")));
op->SetOutput("DFilter", ddx.empty() ? empty_str : InputGrad("Filter")); op->SetOutput("DFilter", ddx.empty() ? empty_str : InputGrad("Filter"));
op->SetOutput("DInput", ddw.empty() ? empty_str : InputGrad("Input")); op->SetOutput("DInput", ddw.empty() ? empty_str : InputGrad("Input"));
...@@ -587,9 +588,10 @@ class Conv3DDoubleGradMaker : public framework::SingleGradOpDescMaker { ...@@ -587,9 +588,10 @@ class Conv3DDoubleGradMaker : public framework::SingleGradOpDescMaker {
auto ddw = OutputGrad(framework::GradVarName("Filter")); auto ddw = OutputGrad(framework::GradVarName("Filter"));
std::vector<std::string> empty_str = {}; std::vector<std::string> empty_str = {};
op->SetOutput( op->SetOutput("DDOutput",
"DDOutput", (ddx.empty() && ddw.empty())
ddx.empty() ? empty_str : InputGrad(framework::GradVarName("Output"))); ? empty_str
: InputGrad(framework::GradVarName("Output")));
op->SetOutput("DFilter", ddx.empty() ? empty_str : InputGrad("Filter")); op->SetOutput("DFilter", ddx.empty() ? empty_str : InputGrad("Filter"));
op->SetOutput("DInput", ddw.empty() ? empty_str : InputGrad("Input")); op->SetOutput("DInput", ddw.empty() ? empty_str : InputGrad("Input"));
...@@ -604,7 +606,8 @@ void ConvOpDoubleGrad::InferShape(framework::InferShapeContext* ctx) const { ...@@ -604,7 +606,8 @@ void ConvOpDoubleGrad::InferShape(framework::InferShapeContext* ctx) const {
auto w_dims = ctx->GetInputDim("Filter"); auto w_dims = ctx->GetInputDim("Filter");
auto do_dims = ctx->GetInputDim("DOutput"); auto do_dims = ctx->GetInputDim("DOutput");
if (ctx->HasOutput("DDOutput") && ctx->HasInput("DDInput")) { if (ctx->HasOutput("DDOutput") &&
(ctx->HasInput("DDInput") || (ctx->HasInput("DDFilter")))) {
ctx->SetOutputDim("DDOutput", do_dims); ctx->SetOutputDim("DDOutput", do_dims);
} }
if (ctx->HasOutput("DFilter") && ctx->HasInput("DDInput")) { if (ctx->HasOutput("DFilter") && ctx->HasInput("DDInput")) {
......
...@@ -506,7 +506,7 @@ class GemmConvDoubleGradKernel : public framework::OpKernel<T> { ...@@ -506,7 +506,7 @@ class GemmConvDoubleGradKernel : public framework::OpKernel<T> {
// dw = ddx * dy ==> dw(Cout, Cin, kh, kw), ddx(N, Cin, H, W), dy(N, Cout, // dw = ddx * dy ==> dw(Cout, Cin, kh, kw), ddx(N, Cin, H, W), dy(N, Cout,
// oH, oW) // oH, oW)
// dw convolution double grad: im2col(vol2col) + gemm // dw convolution double grad: im2col(vol2col) + gemm
if (dW) { if (dW && ddX) {
dW->mutable_data<T>(ctx.GetPlace()); dW->mutable_data<T>(ctx.GetPlace());
set_zero(dev_ctx, dW, static_cast<T>(0)); set_zero(dev_ctx, dW, static_cast<T>(0));
Tensor dW_arr = *dW; Tensor dW_arr = *dW;
...@@ -549,11 +549,11 @@ class GemmConvDoubleGradKernel : public framework::OpKernel<T> { ...@@ -549,11 +549,11 @@ class GemmConvDoubleGradKernel : public framework::OpKernel<T> {
math::Im2ColFunctor<math::ColFormat::kCFO, DeviceContext, T> im2col; math::Im2ColFunctor<math::ColFormat::kCFO, DeviceContext, T> im2col;
math::Vol2ColFunctor<DeviceContext, T> vol2col; math::Vol2ColFunctor<DeviceContext, T> vol2col;
for (int i = 0; i < batch_size; ++i) { for (int i = 0; i < batch_size; ++i) {
Tensor ddx_batch = ddX->Slice(i, i + 1).Resize(input_shape);
Tensor x_batch = X->Slice(i, i + 1).Resize(input_shape);
Tensor ddy_batch = ddY->Slice(i, i + 1).Resize(output_matrix_shape); Tensor ddy_batch = ddY->Slice(i, i + 1).Resize(output_matrix_shape);
for (int g = 0; g < groups; ++g) { for (int g = 0; g < groups; ++g) {
Tensor x_slice = x_batch.Slice(g * in_step, (g + 1) * in_step); Tensor ddy_slice = ddy_batch.Slice(g * out_step, (g + 1) * out_step);
if (ddX) {
Tensor ddx_batch = ddX->Slice(i, i + 1).Resize(input_shape);
Tensor ddx_slice = ddx_batch.Slice(g * in_step, (g + 1) * in_step); Tensor ddx_slice = ddx_batch.Slice(g * in_step, (g + 1) * in_step);
if (!is_expand) { if (!is_expand) {
col.ShareDataWith(ddx_slice); col.ShareDataWith(ddx_slice);
...@@ -571,14 +571,16 @@ class GemmConvDoubleGradKernel : public framework::OpKernel<T> { ...@@ -571,14 +571,16 @@ class GemmConvDoubleGradKernel : public framework::OpKernel<T> {
} }
// gemm // gemm
Tensor ddy_slice = ddy_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor w_slice = W.Slice(g * out_step, (g + 1) * out_step); Tensor w_slice = W.Slice(g * out_step, (g + 1) * out_step);
blas.MatMul(w_slice, false, col_matrix, false, T(1.0), &ddy_slice, blas.MatMul(w_slice, false, col_matrix, false, T(1.0), &ddy_slice,
T(0.0)); T(0.0));
}
if (ddW_in) { if (ddW_in) {
Tensor ddW; Tensor ddW;
ddW.ShareDataWith(*ddW_in).Resize(filter_matrix_shape); ddW.ShareDataWith(*ddW_in).Resize(filter_matrix_shape);
Tensor x_batch = X->Slice(i, i + 1).Resize(input_shape);
Tensor x_slice = x_batch.Slice(g * in_step, (g + 1) * in_step);
if (!is_expand) { if (!is_expand) {
col.ShareDataWith(x_slice); col.ShareDataWith(x_slice);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册