diff --git a/paddle/operators/clip_op.cc b/paddle/operators/clip_op.cc index 8d576bae2d28ac870c058a43b171ea3ba35cb297..d3d8cf176de67bfbd4f8c95b780f6f0284903da7 100644 --- a/paddle/operators/clip_op.cc +++ b/paddle/operators/clip_op.cc @@ -68,8 +68,9 @@ class ClipOpGrad : public framework::OperatorWithKernel { "Input(Out@GRAD) should not be null"); auto x_dims = ctx.Input("X")->dims(); auto *x_grad = ctx.Output(framework::GradVarName("X")); - - x_grad->Resize(x_dims); + if (x_grad != nullptr) { + x_grad->Resize(x_dims); + } } }; diff --git a/paddle/operators/clip_op.cu b/paddle/operators/clip_op.cu index ac6a062f6d9fc84b43d1fe74c0d4e922f16be494..7e9c6c23c2c46336a115ec36cdcba003d4667c24 100644 --- a/paddle/operators/clip_op.cu +++ b/paddle/operators/clip_op.cu @@ -43,22 +43,24 @@ class ClipGradientOpCUDAKernel : public framework::OpKernel { auto min = context.Attr("min"); auto* d_out = context.Input(framework::GradVarName("Out")); auto* d_x = context.Output(framework::GradVarName("X")); - auto* x = context.Input("X"); - auto dims = d_x->dims(); - int64_t count = d_out->numel(); - auto d_x_data = d_x->mutable_data(context.GetPlace()); - auto d_out_data = d_out->data(); - auto x_data = x->data(); + if (d_x != nullptr) { + auto* x = context.Input("X"); + auto dims = d_x->dims(); + int64_t count = d_out->numel(); + auto d_x_data = d_x->mutable_data(context.GetPlace()); + auto d_out_data = d_out->data(); + auto x_data = x->data(); - int N = d_x->dims()[0]; - int D = d_x->dims()[1]; - int block = 512; - int grid = (N * D + block - 1) / block; - ClipGradientKernel<<< - grid, block, 0, reinterpret_cast( - context.device_context()) - .stream()>>>(count, min, max, x_data, d_out_data, - d_x_data); + int N = d_x->dims()[0]; + int D = d_x->dims()[1]; + int block = 512; + int grid = (N * D + block - 1) / block; + ClipGradientKernel<<< + grid, block, 0, reinterpret_cast( + context.device_context()) + .stream()>>>(count, min, max, x_data, d_out_data, + d_x_data); + } } }; diff --git a/paddle/operators/clip_op.h b/paddle/operators/clip_op.h index ba0aa7416f6c2c354d4fee6cae12c318db31395b..47bfe1b7f8ef84e4896b6a5fef839d582adc9599 100644 --- a/paddle/operators/clip_op.h +++ b/paddle/operators/clip_op.h @@ -78,17 +78,19 @@ class ClipGradKernel : public framework::OpKernel { auto min = context.op().Attr("min"); auto* d_out = context.Input(framework::GradVarName("Out")); auto* d_x = context.Output(framework::GradVarName("X")); - auto* x = context.Input("X"); - auto dims = d_x->dims(); - int64_t count = d_out->numel(); - auto d_x_data = d_x->mutable_data(context.GetPlace()); - auto d_out_data = d_out->data(); - auto x_data = x->data(); - for (int i = 0; i < count; ++i) { - if (x_data[i] > min && x_data[i] < max) { - d_x_data[i] = d_out_data[i]; - } else { - d_x_data[i] = 0; + if (d_x != nullptr) { + auto* x = context.Input("X"); + auto dims = d_x->dims(); + int64_t count = d_out->numel(); + auto d_x_data = d_x->mutable_data(context.GetPlace()); + auto d_out_data = d_out->data(); + auto x_data = x->data(); + for (int i = 0; i < count; ++i) { + if (x_data[i] > min && x_data[i] < max) { + d_x_data[i] = d_out_data[i]; + } else { + d_x_data[i] = 0; + } } } }