From 9347df845fb91a8daef3af71e9ab91a98188a146 Mon Sep 17 00:00:00 2001 From: zhupengyang Date: Wed, 30 Jun 2021 14:44:00 +0800 Subject: [PATCH] fix prelu, softmax if shape containes 0 (#33849) --- paddle/fluid/operators/prelu_op.h | 31 ++++++++++++++++++++++------- paddle/fluid/operators/softmax_op.h | 6 ++++++ 2 files changed, 30 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/operators/prelu_op.h b/paddle/fluid/operators/prelu_op.h index cfc0a2b6fb..60fd75ce3c 100644 --- a/paddle/fluid/operators/prelu_op.h +++ b/paddle/fluid/operators/prelu_op.h @@ -39,13 +39,19 @@ class PReluKernel : public framework::OpKernel { int index = 0; int i = 0; if (mode == "channel") { - int temp = numel / (dim[0] * dim[1]); + int temp = 1; + for (int j = 2; j < dim.size(); j++) { + temp *= dim[j]; + } for (i = 0; i < numel; i++) { index = (i / temp) % dim[1]; o_ptr[i] = x_ptr[i] > 0 ? x_ptr[i] : alpha_ptr[index] * x_ptr[i]; } } else if (mode == "element") { - int temp = numel / dim[0]; + int temp = 1; + for (int j = 1; j < dim.size(); j++) { + temp *= dim[j]; + } for (i = 0; i < numel; i++) { index = i % temp; o_ptr[i] = x_ptr[i] > 0 ? x_ptr[i] : alpha_ptr[index] * x_ptr[i]; @@ -75,18 +81,23 @@ class PReluGradKernel : public framework::OpKernel { auto dim = x->dims(); int index = 0; int i = 0; - int temp = 0; if (dx) { T* dx_ptr = dx->mutable_data(context.GetPlace()); if (mode == "channel") { + int temp = 1; + for (int j = 2; j < dim.size(); j++) { + temp *= dim[j]; + } for (i = 0; i < numel; i++) { - temp = numel / (dim[0] * dim[1]); index = (i / temp) % dim[1]; dx_ptr[i] = x_ptr[i] > 0 ? dout_ptr[i] : alpha_ptr[index] * dout_ptr[i]; } } else if (mode == "element") { - temp = numel / dim[0]; + int temp = 1; + for (int j = 1; j < dim.size(); j++) { + temp *= dim[j]; + } for (i = 0; i < numel; i++) { index = i % temp; dx_ptr[i] = @@ -105,13 +116,19 @@ class PReluGradKernel : public framework::OpKernel { memset(dalpha_ptr, 0, sizeof(T) * dalpha->numel()); if (mode == "channel") { + int temp = 1; + for (int j = 2; j < dim.size(); j++) { + temp *= dim[j]; + } for (i = 0; i < numel; i++) { - temp = numel / (dim[0] * dim[1]); index = (i / temp) % dim[1]; dalpha_ptr[index] += x_ptr[i] > 0 ? 0 : x_ptr[i] * dout_ptr[i]; } } else if (mode == "element") { - temp = numel / dim[0]; + int temp = 1; + for (int j = 1; j < dim.size(); j++) { + temp *= dim[j]; + } for (i = 0; i < numel; i++) { index = i % temp; dalpha_ptr[index] += x_ptr[i] > 0 ? 0 : x_ptr[i] * dout_ptr[i]; diff --git a/paddle/fluid/operators/softmax_op.h b/paddle/fluid/operators/softmax_op.h index 08266318fb..68a1649d0a 100644 --- a/paddle/fluid/operators/softmax_op.h +++ b/paddle/fluid/operators/softmax_op.h @@ -65,6 +65,9 @@ class SoftmaxKernel : public framework::OpKernel { // allocate memory on device. Out->mutable_data(context.GetPlace()); + if (Out->numel() == 0) { + return; + } const int n = SizeToAxis(axis, X->dims()); const int d = SizeFromAxis(axis, X->dims()); @@ -97,6 +100,9 @@ class SoftmaxGradKernel : public framework::OpKernel { // allocate memory on device. dX->mutable_data(context.GetPlace()); + if (dX->numel() == 0) { + return; + } const int n = SizeToAxis(axis, dX->dims()); const int d = SizeFromAxis(axis, dX->dims()); -- GitLab