未验证 提交 898792ca 编写于 作者: C cc 提交者: GitHub

Fix elemntwise_sub for the size of x and y, test=develop (#4008)

上级 d51324bf
......@@ -137,10 +137,11 @@ void ElementwiseSubCompute::Run() {
auto x_dims = param.X->dims();
auto y_dims = param.Y->dims();
int pre, n, post;
if (x_dims.size() < y_dims.size()) {
LOG(FATAL) << "elewise div don't support x_dims size < y_dims size";
}
if (is_broadcast(x_dims, y_dims, axis, &pre, &n, &post)) {
if (x_dims.size() < y_dims.size() &&
is_broadcast(y_dims, x_dims, axis, &pre, &n, &post)) {
lite::arm::math::elementwise_sub_broadcast(
y_data, x_data, out_data, pre, n, post);
} else if (is_broadcast(x_dims, y_dims, axis, &pre, &n, &post)) {
lite::arm::math::elementwise_sub_broadcast(
x_data, y_data, out_data, pre, n, post);
} else {
......@@ -158,24 +159,21 @@ void ElementwiseSubActivationCompute::Run() {
std::string act_type = param.act_type;
auto x_dims = param.X->dims();
auto y_dims = param.Y->dims();
if (x_dims.size() < y_dims.size()) {
LOG(FATAL) << "elewise div don't support x_dims size < y_dims size";
}
int pre, n, post;
if (is_broadcast(x_dims, y_dims, axis, &pre, &n, &post)) {
if (act_type == "relu") {
lite::arm::math::elementwise_sub_relu_broadcast(
x_data, y_data, out_data, pre, n, post);
} else {
LOG(FATAL) << "unsupported Activation type: " << act_type;
}
if (act_type != "relu") {
LOG(FATAL) << "unsupported Activation type: " << act_type;
}
if (x_dims.size() < y_dims.size() &&
is_broadcast(y_dims, x_dims, axis, &pre, &n, &post)) {
lite::arm::math::elementwise_sub_relu_broadcast(
y_data, x_data, out_data, pre, n, post);
} else if (is_broadcast(x_dims, y_dims, axis, &pre, &n, &post)) {
lite::arm::math::elementwise_sub_relu_broadcast(
x_data, y_data, out_data, pre, n, post);
} else {
if (act_type == "relu") {
lite::arm::math::elementwise_sub_relu(
x_data, y_data, out_data, x_dims.production());
} else {
LOG(FATAL) << "unsupported Activation type: " << act_type;
}
lite::arm::math::elementwise_sub_relu(
x_data, y_data, out_data, x_dims.production());
}
}
......
......@@ -28,8 +28,13 @@ bool FusionElementwiseActivationOp::CheckShape() const {
}
bool FusionElementwiseActivationOp::InferShapeImpl() const {
CHECK_OR_FALSE(param_.X->dims().size() >= param_.Y->dims().size());
param_.Out->Resize(param_.X->dims());
size_t x_size = param_.X->dims().size();
size_t y_size = param_.Y->dims().size();
if (x_size >= y_size) {
param_.Out->Resize(param_.X->dims());
} else {
param_.Out->Resize(param_.Y->dims());
}
return true;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册