未验证 提交 a8e13aae 编写于 作者: M mapingshuo 提交者: GitHub

fix segment fault bug, test=develop (#3282)

上级 e70eade1
......@@ -302,10 +302,10 @@ void elementwise_add_grad_broadcast<float>(const float* dout_grad,
int pre,
int n,
int post) {
if (x_grad) {
if (x_grad != nullptr) {
elementwise_add_grad(dout_grad, x_grad, pre * n * post);
}
if (y_grad) {
if (y_grad != nullptr) {
memset(y_grad, 0, n * sizeof(float));
#pragma omp parallel for
for (int i = 0; i < pre; ++i) {
......@@ -582,10 +582,10 @@ void elementwise_sub_grad<float>(const float* dout_grad,
float* x_grad,
float* y_grad,
int num) {
if (x_grad) {
if (x_grad != nullptr) {
elementwise_add_grad(dout_grad, x_grad, num);
}
if (y_grad) {
if (y_grad != nullptr) {
int cnt = num >> 4;
int remain = num & 0x0f;
float32x4_t minus = vdupq_n_f32(-1);
......@@ -624,10 +624,10 @@ void elementwise_sub_grad_broadcast<float>(const float* dout_grad,
int pre,
int n,
int post) {
if (x_grad) {
if (x_grad != nullptr) {
elementwise_add_grad(dout_grad, x_grad, pre * n * post);
}
if (y_grad) {
if (y_grad != nullptr) {
memset(y_grad, 0, n * sizeof(float));
#pragma omp parallel for
for (int i = 0; i < pre; ++i) {
......
......@@ -76,8 +76,8 @@ void ElementwiseAddGradCompute::Run() {
const float* x_data = param.X->data<float>();
const float* y_data = param.Y->data<float>();
const float* out_grad_data = param.OutGrad->data<float>();
float* x_grad_data;
float* y_grad_data;
float* x_grad_data = nullptr;
float* y_grad_data = nullptr;
if (param.XGrad) {
x_grad_data = param.XGrad->mutable_data<float>();
}
......@@ -122,8 +122,8 @@ void ElementwiseSubGradCompute::Run() {
const float* x_data = param.X->data<float>();
const float* y_data = param.Y->data<float>();
const float* out_data = param.OutGrad->data<float>();
float* x_grad_data;
float* y_grad_data;
float* x_grad_data = nullptr;
float* y_grad_data = nullptr;
if (param.XGrad) {
x_grad_data = param.XGrad->mutable_data<float>();
}
......@@ -137,10 +137,16 @@ void ElementwiseSubGradCompute::Run() {
if (!param.XGrad || !param.YGrad) {
CHECK(param.XGrad || param.YGrad);
if (param.XGrad) {
lite::arm::math::elementwise_sub_grad(
out_data, x_grad_data, y_grad_data, x_dims.production());
return;
} else {
lite::arm::math::elementwise_sub_grad(
out_data, x_grad_data, y_grad_data, y_dims.production());
return;
}
}
if (x_dims.size() < y_dims.size()) {
LOG(FATAL) << "elewise sub grad don't support x_dims size < y_dims size";
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册