未验证 提交 898792ca 编写于 作者: C cc 提交者: GitHub

Fix elemntwise_sub for the size of x and y, test=develop (#4008)

上级 d51324bf
...@@ -137,10 +137,11 @@ void ElementwiseSubCompute::Run() { ...@@ -137,10 +137,11 @@ void ElementwiseSubCompute::Run() {
auto x_dims = param.X->dims(); auto x_dims = param.X->dims();
auto y_dims = param.Y->dims(); auto y_dims = param.Y->dims();
int pre, n, post; int pre, n, post;
if (x_dims.size() < y_dims.size()) { if (x_dims.size() < y_dims.size() &&
LOG(FATAL) << "elewise div don't support x_dims size < y_dims size"; is_broadcast(y_dims, x_dims, axis, &pre, &n, &post)) {
} lite::arm::math::elementwise_sub_broadcast(
if (is_broadcast(x_dims, y_dims, axis, &pre, &n, &post)) { y_data, x_data, out_data, pre, n, post);
} else if (is_broadcast(x_dims, y_dims, axis, &pre, &n, &post)) {
lite::arm::math::elementwise_sub_broadcast( lite::arm::math::elementwise_sub_broadcast(
x_data, y_data, out_data, pre, n, post); x_data, y_data, out_data, pre, n, post);
} else { } else {
...@@ -158,24 +159,21 @@ void ElementwiseSubActivationCompute::Run() { ...@@ -158,24 +159,21 @@ void ElementwiseSubActivationCompute::Run() {
std::string act_type = param.act_type; std::string act_type = param.act_type;
auto x_dims = param.X->dims(); auto x_dims = param.X->dims();
auto y_dims = param.Y->dims(); auto y_dims = param.Y->dims();
if (x_dims.size() < y_dims.size()) {
LOG(FATAL) << "elewise div don't support x_dims size < y_dims size";
}
int pre, n, post; int pre, n, post;
if (is_broadcast(x_dims, y_dims, axis, &pre, &n, &post)) {
if (act_type == "relu") { if (act_type != "relu") {
lite::arm::math::elementwise_sub_relu_broadcast( LOG(FATAL) << "unsupported Activation type: " << act_type;
x_data, y_data, out_data, pre, n, post); }
} else { if (x_dims.size() < y_dims.size() &&
LOG(FATAL) << "unsupported Activation type: " << act_type; is_broadcast(y_dims, x_dims, axis, &pre, &n, &post)) {
} lite::arm::math::elementwise_sub_relu_broadcast(
y_data, x_data, out_data, pre, n, post);
} else if (is_broadcast(x_dims, y_dims, axis, &pre, &n, &post)) {
lite::arm::math::elementwise_sub_relu_broadcast(
x_data, y_data, out_data, pre, n, post);
} else { } else {
if (act_type == "relu") { lite::arm::math::elementwise_sub_relu(
lite::arm::math::elementwise_sub_relu( x_data, y_data, out_data, x_dims.production());
x_data, y_data, out_data, x_dims.production());
} else {
LOG(FATAL) << "unsupported Activation type: " << act_type;
}
} }
} }
......
...@@ -28,8 +28,13 @@ bool FusionElementwiseActivationOp::CheckShape() const { ...@@ -28,8 +28,13 @@ bool FusionElementwiseActivationOp::CheckShape() const {
} }
bool FusionElementwiseActivationOp::InferShapeImpl() const { bool FusionElementwiseActivationOp::InferShapeImpl() const {
CHECK_OR_FALSE(param_.X->dims().size() >= param_.Y->dims().size()); size_t x_size = param_.X->dims().size();
param_.Out->Resize(param_.X->dims()); size_t y_size = param_.Y->dims().size();
if (x_size >= y_size) {
param_.Out->Resize(param_.X->dims());
} else {
param_.Out->Resize(param_.Y->dims());
}
return true; return true;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册