提交 743dfd82 编写于 作者: W wanghaoshuang

Add nullptr check

上级 14fb15b6
...@@ -68,8 +68,9 @@ class ClipOpGrad : public framework::OperatorWithKernel { ...@@ -68,8 +68,9 @@ class ClipOpGrad : public framework::OperatorWithKernel {
"Input(Out@GRAD) should not be null"); "Input(Out@GRAD) should not be null");
auto x_dims = ctx.Input<LoDTensor>("X")->dims(); auto x_dims = ctx.Input<LoDTensor>("X")->dims();
auto *x_grad = ctx.Output<LoDTensor>(framework::GradVarName("X")); auto *x_grad = ctx.Output<LoDTensor>(framework::GradVarName("X"));
if (x_grad != nullptr) {
x_grad->Resize(x_dims); x_grad->Resize(x_dims);
}
} }
}; };
......
...@@ -43,22 +43,24 @@ class ClipGradientOpCUDAKernel : public framework::OpKernel { ...@@ -43,22 +43,24 @@ class ClipGradientOpCUDAKernel : public framework::OpKernel {
auto min = context.Attr<float>("min"); auto min = context.Attr<float>("min");
auto* d_out = context.Input<LoDTensor>(framework::GradVarName("Out")); auto* d_out = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto* d_x = context.Output<LoDTensor>(framework::GradVarName("X")); auto* d_x = context.Output<LoDTensor>(framework::GradVarName("X"));
auto* x = context.Input<LoDTensor>("X"); if (d_x != nullptr) {
auto dims = d_x->dims(); auto* x = context.Input<LoDTensor>("X");
int64_t count = d_out->numel(); auto dims = d_x->dims();
auto d_x_data = d_x->mutable_data<T>(context.GetPlace()); int64_t count = d_out->numel();
auto d_out_data = d_out->data<T>(); auto d_x_data = d_x->mutable_data<T>(context.GetPlace());
auto x_data = x->data<T>(); auto d_out_data = d_out->data<T>();
auto x_data = x->data<T>();
int N = d_x->dims()[0]; int N = d_x->dims()[0];
int D = d_x->dims()[1]; int D = d_x->dims()[1];
int block = 512; int block = 512;
int grid = (N * D + block - 1) / block; int grid = (N * D + block - 1) / block;
ClipGradientKernel<T><<< ClipGradientKernel<T><<<
grid, block, 0, reinterpret_cast<const platform::CUDADeviceContext&>( grid, block, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
context.device_context()) context.device_context())
.stream()>>>(count, min, max, x_data, d_out_data, .stream()>>>(count, min, max, x_data, d_out_data,
d_x_data); d_x_data);
}
} }
}; };
......
...@@ -78,17 +78,19 @@ class ClipGradKernel : public framework::OpKernel { ...@@ -78,17 +78,19 @@ class ClipGradKernel : public framework::OpKernel {
auto min = context.op().Attr<float>("min"); auto min = context.op().Attr<float>("min");
auto* d_out = context.Input<LoDTensor>(framework::GradVarName("Out")); auto* d_out = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto* d_x = context.Output<LoDTensor>(framework::GradVarName("X")); auto* d_x = context.Output<LoDTensor>(framework::GradVarName("X"));
auto* x = context.Input<LoDTensor>("X"); if (d_x != nullptr) {
auto dims = d_x->dims(); auto* x = context.Input<LoDTensor>("X");
int64_t count = d_out->numel(); auto dims = d_x->dims();
auto d_x_data = d_x->mutable_data<T>(context.GetPlace()); int64_t count = d_out->numel();
auto d_out_data = d_out->data<T>(); auto d_x_data = d_x->mutable_data<T>(context.GetPlace());
auto x_data = x->data<T>(); auto d_out_data = d_out->data<T>();
for (int i = 0; i < count; ++i) { auto x_data = x->data<T>();
if (x_data[i] > min && x_data[i] < max) { for (int i = 0; i < count; ++i) {
d_x_data[i] = d_out_data[i]; if (x_data[i] > min && x_data[i] < max) {
} else { d_x_data[i] = d_out_data[i];
d_x_data[i] = 0; } else {
d_x_data[i] = 0;
}
} }
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册