diff --git a/paddle/fluid/operators/prelu_op.h b/paddle/fluid/operators/prelu_op.h index cfc0a2b6fb1128ee4460cbc669772c6257aad8ab..60fd75ce3cffd3e0565945b281ad4c4961385956 100644 --- a/paddle/fluid/operators/prelu_op.h +++ b/paddle/fluid/operators/prelu_op.h @@ -39,13 +39,19 @@ class PReluKernel : public framework::OpKernel { int index = 0; int i = 0; if (mode == "channel") { - int temp = numel / (dim[0] * dim[1]); + int temp = 1; + for (int j = 2; j < dim.size(); j++) { + temp *= dim[j]; + } for (i = 0; i < numel; i++) { index = (i / temp) % dim[1]; o_ptr[i] = x_ptr[i] > 0 ? x_ptr[i] : alpha_ptr[index] * x_ptr[i]; } } else if (mode == "element") { - int temp = numel / dim[0]; + int temp = 1; + for (int j = 1; j < dim.size(); j++) { + temp *= dim[j]; + } for (i = 0; i < numel; i++) { index = i % temp; o_ptr[i] = x_ptr[i] > 0 ? x_ptr[i] : alpha_ptr[index] * x_ptr[i]; @@ -75,18 +81,23 @@ class PReluGradKernel : public framework::OpKernel { auto dim = x->dims(); int index = 0; int i = 0; - int temp = 0; if (dx) { T* dx_ptr = dx->mutable_data(context.GetPlace()); if (mode == "channel") { + int temp = 1; + for (int j = 2; j < dim.size(); j++) { + temp *= dim[j]; + } for (i = 0; i < numel; i++) { - temp = numel / (dim[0] * dim[1]); index = (i / temp) % dim[1]; dx_ptr[i] = x_ptr[i] > 0 ? dout_ptr[i] : alpha_ptr[index] * dout_ptr[i]; } } else if (mode == "element") { - temp = numel / dim[0]; + int temp = 1; + for (int j = 1; j < dim.size(); j++) { + temp *= dim[j]; + } for (i = 0; i < numel; i++) { index = i % temp; dx_ptr[i] = @@ -105,13 +116,19 @@ class PReluGradKernel : public framework::OpKernel { memset(dalpha_ptr, 0, sizeof(T) * dalpha->numel()); if (mode == "channel") { + int temp = 1; + for (int j = 2; j < dim.size(); j++) { + temp *= dim[j]; + } for (i = 0; i < numel; i++) { - temp = numel / (dim[0] * dim[1]); index = (i / temp) % dim[1]; dalpha_ptr[index] += x_ptr[i] > 0 ? 0 : x_ptr[i] * dout_ptr[i]; } } else if (mode == "element") { - temp = numel / dim[0]; + int temp = 1; + for (int j = 1; j < dim.size(); j++) { + temp *= dim[j]; + } for (i = 0; i < numel; i++) { index = i % temp; dalpha_ptr[index] += x_ptr[i] > 0 ? 0 : x_ptr[i] * dout_ptr[i]; diff --git a/paddle/fluid/operators/softmax_op.h b/paddle/fluid/operators/softmax_op.h index 08266318fb970ba976269991351152c22b38dbf2..68a1649d0a039d8b63b4811f1e7606b0c071fb9d 100644 --- a/paddle/fluid/operators/softmax_op.h +++ b/paddle/fluid/operators/softmax_op.h @@ -65,6 +65,9 @@ class SoftmaxKernel : public framework::OpKernel { // allocate memory on device. Out->mutable_data(context.GetPlace()); + if (Out->numel() == 0) { + return; + } const int n = SizeToAxis(axis, X->dims()); const int d = SizeFromAxis(axis, X->dims()); @@ -97,6 +100,9 @@ class SoftmaxGradKernel : public framework::OpKernel { // allocate memory on device. dX->mutable_data(context.GetPlace()); + if (dX->numel() == 0) { + return; + } const int n = SizeToAxis(axis, dX->dims()); const int d = SizeFromAxis(axis, dX->dims());