未验证 提交 c92348c3 编写于 作者: L lvmengsi 提交者: GitHub

fix conv_grad_grad (#20054)

上级 4e99c2af
......@@ -355,6 +355,7 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
size_t workspace_size = 0;
if (ddO) {
ddy = ddO->mutable_data<T>(ctx.GetPlace());
if (ddX) {
args1.handle = handle;
args1.idesc.set(*ddX, iwo_group);
args1.wdesc.set(*W, layout, iwo_group);
......@@ -364,6 +365,7 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
using search1 = SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t>;
fwd_algo1 = search1::Find<T>(args1, exhaustive_search, false, 0, ctx);
workspace_size = search1::GetWorkspaceSize(args1, fwd_algo1);
}
if (ddW) {
ddw = ddW->data<T>();
......@@ -380,7 +382,7 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
}
}
if (dW) {
if (dW && ddX) {
dw = dW->mutable_data<T>(ctx.GetPlace());
args3.handle = handle;
args3.idesc.set(*ddX, iwo_group);
......@@ -423,18 +425,21 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
auto wkspace_handle = dev_ctx.cudnn_workspace_handle();
if (ddO) {
if (ddX) {
ddx = ddX->data<T>();
for (int i = 0; i < groups; i++) {
wkspace_handle.RunFunc(
[&](void* workspace_ptr) {
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionForward(
handle, &alpha, args1.idesc.desc(), ddx + i * group_offset_in,
args1.wdesc.desc(), w + i * group_offset_filter,
args1.cdesc.desc(), fwd_algo1, workspace_ptr, workspace_size,
&beta, args1.odesc.desc(), ddy + i * group_offset_out));
handle, &alpha, args1.idesc.desc(),
ddx + i * group_offset_in, args1.wdesc.desc(),
w + i * group_offset_filter, args1.cdesc.desc(), fwd_algo1,
workspace_ptr, workspace_size, &beta, args1.odesc.desc(),
ddy + i * group_offset_out));
},
workspace_size);
}
}
if (ddW) {
for (int i = 0; i < groups; i++) {
wkspace_handle.RunFunc(
......@@ -451,7 +456,7 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
}
}
if (dW) {
if (dW && ddX) {
ddx = ddX->data<T>();
for (int i = 0; i < groups; i++) {
wkspace_handle.RunFunc(
......
......@@ -553,9 +553,10 @@ class Conv2DDoubleGradMaker : public framework::SingleGradOpDescMaker {
auto ddw = OutputGrad(framework::GradVarName("Filter"));
std::vector<std::string> empty_str = {};
op->SetOutput(
"DDOutput",
ddx.empty() ? empty_str : InputGrad(framework::GradVarName("Output")));
op->SetOutput("DDOutput",
(ddx.empty() && ddw.empty())
? empty_str
: InputGrad(framework::GradVarName("Output")));
op->SetOutput("DFilter", ddx.empty() ? empty_str : InputGrad("Filter"));
op->SetOutput("DInput", ddw.empty() ? empty_str : InputGrad("Input"));
......@@ -587,9 +588,10 @@ class Conv3DDoubleGradMaker : public framework::SingleGradOpDescMaker {
auto ddw = OutputGrad(framework::GradVarName("Filter"));
std::vector<std::string> empty_str = {};
op->SetOutput(
"DDOutput",
ddx.empty() ? empty_str : InputGrad(framework::GradVarName("Output")));
op->SetOutput("DDOutput",
(ddx.empty() && ddw.empty())
? empty_str
: InputGrad(framework::GradVarName("Output")));
op->SetOutput("DFilter", ddx.empty() ? empty_str : InputGrad("Filter"));
op->SetOutput("DInput", ddw.empty() ? empty_str : InputGrad("Input"));
......@@ -604,7 +606,8 @@ void ConvOpDoubleGrad::InferShape(framework::InferShapeContext* ctx) const {
auto w_dims = ctx->GetInputDim("Filter");
auto do_dims = ctx->GetInputDim("DOutput");
if (ctx->HasOutput("DDOutput") && ctx->HasInput("DDInput")) {
if (ctx->HasOutput("DDOutput") &&
(ctx->HasInput("DDInput") || (ctx->HasInput("DDFilter")))) {
ctx->SetOutputDim("DDOutput", do_dims);
}
if (ctx->HasOutput("DFilter") && ctx->HasInput("DDInput")) {
......
......@@ -506,7 +506,7 @@ class GemmConvDoubleGradKernel : public framework::OpKernel<T> {
// dw = ddx * dy ==> dw(Cout, Cin, kh, kw), ddx(N, Cin, H, W), dy(N, Cout,
// oH, oW)
// dw convolution double grad: im2col(vol2col) + gemm
if (dW) {
if (dW && ddX) {
dW->mutable_data<T>(ctx.GetPlace());
set_zero(dev_ctx, dW, static_cast<T>(0));
Tensor dW_arr = *dW;
......@@ -549,11 +549,11 @@ class GemmConvDoubleGradKernel : public framework::OpKernel<T> {
math::Im2ColFunctor<math::ColFormat::kCFO, DeviceContext, T> im2col;
math::Vol2ColFunctor<DeviceContext, T> vol2col;
for (int i = 0; i < batch_size; ++i) {
Tensor ddx_batch = ddX->Slice(i, i + 1).Resize(input_shape);
Tensor x_batch = X->Slice(i, i + 1).Resize(input_shape);
Tensor ddy_batch = ddY->Slice(i, i + 1).Resize(output_matrix_shape);
for (int g = 0; g < groups; ++g) {
Tensor x_slice = x_batch.Slice(g * in_step, (g + 1) * in_step);
Tensor ddy_slice = ddy_batch.Slice(g * out_step, (g + 1) * out_step);
if (ddX) {
Tensor ddx_batch = ddX->Slice(i, i + 1).Resize(input_shape);
Tensor ddx_slice = ddx_batch.Slice(g * in_step, (g + 1) * in_step);
if (!is_expand) {
col.ShareDataWith(ddx_slice);
......@@ -571,14 +571,16 @@ class GemmConvDoubleGradKernel : public framework::OpKernel<T> {
}
// gemm
Tensor ddy_slice = ddy_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor w_slice = W.Slice(g * out_step, (g + 1) * out_step);
blas.MatMul(w_slice, false, col_matrix, false, T(1.0), &ddy_slice,
T(0.0));
}
if (ddW_in) {
Tensor ddW;
ddW.ShareDataWith(*ddW_in).Resize(filter_matrix_shape);
Tensor x_batch = X->Slice(i, i + 1).Resize(input_shape);
Tensor x_slice = x_batch.Slice(g * in_step, (g + 1) * in_step);
if (!is_expand) {
col.ShareDataWith(x_slice);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册