From 743dfd82e7e3d4891d46c96d924c6b28f4d2c08b Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Wed, 20 Sep 2017 00:01:14 +0800 Subject: [PATCH] Add nullptr check --- paddle/operators/clip_op.cc | 5 +++-- paddle/operators/clip_op.cu | 32 +++++++++++++++++--------------- paddle/operators/clip_op.h | 24 +++++++++++++----------- 3 files changed, 33 insertions(+), 28 deletions(-) diff --git a/paddle/operators/clip_op.cc b/paddle/operators/clip_op.cc index 8d576bae2d2..d3d8cf176de 100644 --- a/paddle/operators/clip_op.cc +++ b/paddle/operators/clip_op.cc @@ -68,8 +68,9 @@ class ClipOpGrad : public framework::OperatorWithKernel { "Input(Out@GRAD) should not be null"); auto x_dims = ctx.Input("X")->dims(); auto *x_grad = ctx.Output(framework::GradVarName("X")); - - x_grad->Resize(x_dims); + if (x_grad != nullptr) { + x_grad->Resize(x_dims); + } } }; diff --git a/paddle/operators/clip_op.cu b/paddle/operators/clip_op.cu index ac6a062f6d9..7e9c6c23c2c 100644 --- a/paddle/operators/clip_op.cu +++ b/paddle/operators/clip_op.cu @@ -43,22 +43,24 @@ class ClipGradientOpCUDAKernel : public framework::OpKernel { auto min = context.Attr("min"); auto* d_out = context.Input(framework::GradVarName("Out")); auto* d_x = context.Output(framework::GradVarName("X")); - auto* x = context.Input("X"); - auto dims = d_x->dims(); - int64_t count = d_out->numel(); - auto d_x_data = d_x->mutable_data(context.GetPlace()); - auto d_out_data = d_out->data(); - auto x_data = x->data(); + if (d_x != nullptr) { + auto* x = context.Input("X"); + auto dims = d_x->dims(); + int64_t count = d_out->numel(); + auto d_x_data = d_x->mutable_data(context.GetPlace()); + auto d_out_data = d_out->data(); + auto x_data = x->data(); - int N = d_x->dims()[0]; - int D = d_x->dims()[1]; - int block = 512; - int grid = (N * D + block - 1) / block; - ClipGradientKernel<<< - grid, block, 0, reinterpret_cast( - context.device_context()) - .stream()>>>(count, min, max, x_data, d_out_data, - d_x_data); + int N = d_x->dims()[0]; + int D = d_x->dims()[1]; + int block = 512; + int grid = (N * D + block - 1) / block; + ClipGradientKernel<<< + grid, block, 0, reinterpret_cast( + context.device_context()) + .stream()>>>(count, min, max, x_data, d_out_data, + d_x_data); + } } }; diff --git a/paddle/operators/clip_op.h b/paddle/operators/clip_op.h index ba0aa7416f6..47bfe1b7f8e 100644 --- a/paddle/operators/clip_op.h +++ b/paddle/operators/clip_op.h @@ -78,17 +78,19 @@ class ClipGradKernel : public framework::OpKernel { auto min = context.op().Attr("min"); auto* d_out = context.Input(framework::GradVarName("Out")); auto* d_x = context.Output(framework::GradVarName("X")); - auto* x = context.Input("X"); - auto dims = d_x->dims(); - int64_t count = d_out->numel(); - auto d_x_data = d_x->mutable_data(context.GetPlace()); - auto d_out_data = d_out->data(); - auto x_data = x->data(); - for (int i = 0; i < count; ++i) { - if (x_data[i] > min && x_data[i] < max) { - d_x_data[i] = d_out_data[i]; - } else { - d_x_data[i] = 0; + if (d_x != nullptr) { + auto* x = context.Input("X"); + auto dims = d_x->dims(); + int64_t count = d_out->numel(); + auto d_x_data = d_x->mutable_data(context.GetPlace()); + auto d_out_data = d_out->data(); + auto x_data = x->data(); + for (int i = 0; i < count; ++i) { + if (x_data[i] > min && x_data[i] < max) { + d_x_data[i] = d_out_data[i]; + } else { + d_x_data[i] = 0; + } } } } -- GitLab