未验证 提交 9347df84 编写于 作者: Z zhupengyang 提交者: GitHub

fix prelu, softmax if shape containes 0 (#33849)

上级 8225a6a1
...@@ -39,13 +39,19 @@ class PReluKernel : public framework::OpKernel<T> { ...@@ -39,13 +39,19 @@ class PReluKernel : public framework::OpKernel<T> {
int index = 0; int index = 0;
int i = 0; int i = 0;
if (mode == "channel") { if (mode == "channel") {
int temp = numel / (dim[0] * dim[1]); int temp = 1;
for (int j = 2; j < dim.size(); j++) {
temp *= dim[j];
}
for (i = 0; i < numel; i++) { for (i = 0; i < numel; i++) {
index = (i / temp) % dim[1]; index = (i / temp) % dim[1];
o_ptr[i] = x_ptr[i] > 0 ? x_ptr[i] : alpha_ptr[index] * x_ptr[i]; o_ptr[i] = x_ptr[i] > 0 ? x_ptr[i] : alpha_ptr[index] * x_ptr[i];
} }
} else if (mode == "element") { } else if (mode == "element") {
int temp = numel / dim[0]; int temp = 1;
for (int j = 1; j < dim.size(); j++) {
temp *= dim[j];
}
for (i = 0; i < numel; i++) { for (i = 0; i < numel; i++) {
index = i % temp; index = i % temp;
o_ptr[i] = x_ptr[i] > 0 ? x_ptr[i] : alpha_ptr[index] * x_ptr[i]; o_ptr[i] = x_ptr[i] > 0 ? x_ptr[i] : alpha_ptr[index] * x_ptr[i];
...@@ -75,18 +81,23 @@ class PReluGradKernel : public framework::OpKernel<T> { ...@@ -75,18 +81,23 @@ class PReluGradKernel : public framework::OpKernel<T> {
auto dim = x->dims(); auto dim = x->dims();
int index = 0; int index = 0;
int i = 0; int i = 0;
int temp = 0;
if (dx) { if (dx) {
T* dx_ptr = dx->mutable_data<T>(context.GetPlace()); T* dx_ptr = dx->mutable_data<T>(context.GetPlace());
if (mode == "channel") { if (mode == "channel") {
int temp = 1;
for (int j = 2; j < dim.size(); j++) {
temp *= dim[j];
}
for (i = 0; i < numel; i++) { for (i = 0; i < numel; i++) {
temp = numel / (dim[0] * dim[1]);
index = (i / temp) % dim[1]; index = (i / temp) % dim[1];
dx_ptr[i] = dx_ptr[i] =
x_ptr[i] > 0 ? dout_ptr[i] : alpha_ptr[index] * dout_ptr[i]; x_ptr[i] > 0 ? dout_ptr[i] : alpha_ptr[index] * dout_ptr[i];
} }
} else if (mode == "element") { } else if (mode == "element") {
temp = numel / dim[0]; int temp = 1;
for (int j = 1; j < dim.size(); j++) {
temp *= dim[j];
}
for (i = 0; i < numel; i++) { for (i = 0; i < numel; i++) {
index = i % temp; index = i % temp;
dx_ptr[i] = dx_ptr[i] =
...@@ -105,13 +116,19 @@ class PReluGradKernel : public framework::OpKernel<T> { ...@@ -105,13 +116,19 @@ class PReluGradKernel : public framework::OpKernel<T> {
memset(dalpha_ptr, 0, sizeof(T) * dalpha->numel()); memset(dalpha_ptr, 0, sizeof(T) * dalpha->numel());
if (mode == "channel") { if (mode == "channel") {
int temp = 1;
for (int j = 2; j < dim.size(); j++) {
temp *= dim[j];
}
for (i = 0; i < numel; i++) { for (i = 0; i < numel; i++) {
temp = numel / (dim[0] * dim[1]);
index = (i / temp) % dim[1]; index = (i / temp) % dim[1];
dalpha_ptr[index] += x_ptr[i] > 0 ? 0 : x_ptr[i] * dout_ptr[i]; dalpha_ptr[index] += x_ptr[i] > 0 ? 0 : x_ptr[i] * dout_ptr[i];
} }
} else if (mode == "element") { } else if (mode == "element") {
temp = numel / dim[0]; int temp = 1;
for (int j = 1; j < dim.size(); j++) {
temp *= dim[j];
}
for (i = 0; i < numel; i++) { for (i = 0; i < numel; i++) {
index = i % temp; index = i % temp;
dalpha_ptr[index] += x_ptr[i] > 0 ? 0 : x_ptr[i] * dout_ptr[i]; dalpha_ptr[index] += x_ptr[i] > 0 ? 0 : x_ptr[i] * dout_ptr[i];
......
...@@ -65,6 +65,9 @@ class SoftmaxKernel : public framework::OpKernel<T> { ...@@ -65,6 +65,9 @@ class SoftmaxKernel : public framework::OpKernel<T> {
// allocate memory on device. // allocate memory on device.
Out->mutable_data<T>(context.GetPlace()); Out->mutable_data<T>(context.GetPlace());
if (Out->numel() == 0) {
return;
}
const int n = SizeToAxis(axis, X->dims()); const int n = SizeToAxis(axis, X->dims());
const int d = SizeFromAxis(axis, X->dims()); const int d = SizeFromAxis(axis, X->dims());
...@@ -97,6 +100,9 @@ class SoftmaxGradKernel : public framework::OpKernel<T> { ...@@ -97,6 +100,9 @@ class SoftmaxGradKernel : public framework::OpKernel<T> {
// allocate memory on device. // allocate memory on device.
dX->mutable_data<T>(context.GetPlace()); dX->mutable_data<T>(context.GetPlace());
if (dX->numel() == 0) {
return;
}
const int n = SizeToAxis(axis, dX->dims()); const int n = SizeToAxis(axis, dX->dims());
const int d = SizeFromAxis(axis, dX->dims()); const int d = SizeFromAxis(axis, dX->dims());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册