diff --git a/mace/ops/eltwise.cc b/mace/ops/eltwise.cc index cfd0bdba91d2a5efe91ea5b407cc461bbd1aaeef..94622ac3f16625837f3336e90ba1d663982ab33a 100644 --- a/mace/ops/eltwise.cc +++ b/mace/ops/eltwise.cc @@ -900,23 +900,21 @@ class EltwiseOp : public Operation { } } - if (need_general_broadcast) { + if (input1->size() == 1) { + TensorScalarEltwise(type_, input0_ptr, input1_ptr[0], coeff_, + input0->size(), swapped, output_ptr); + } else if (input0_shape == input1_shape) { + TensorEltwise(type_, input0_ptr, input1_ptr, coeff_, input0->size(), + swapped, output_ptr); + } else if (need_general_broadcast) { TensorGeneralBroadcastEltwise(type_, input0_ptr, input1_ptr, coeff_, swapped, input0_shape, input1_shape, output_shape, output_ptr); - } else if (input1->size() == input0->size()) { - TensorEltwise(type_, input0_ptr, input1_ptr, coeff_, input0->size(), - swapped, output_ptr); - } else if (input1->size() < input0->size()) { - if (input1->size() > 1) { - index_t common_size = input1->size(); - index_t diff_size = input0->size() / common_size; - TensorBroadcastEltwise(type_, input0_ptr, input1_ptr, coeff_, - diff_size, common_size, swapped, output_ptr); - } else { - TensorScalarEltwise(type_, input0_ptr, input1_ptr[0], coeff_, - input0->size(), swapped, output_ptr); - } + } else { + index_t common_size = input1->size(); + index_t diff_size = input0->size() / common_size; + TensorBroadcastEltwise(type_, input0_ptr, input1_ptr, coeff_, + diff_size, common_size, swapped, output_ptr); } }