diff --git a/mace/ops/eltwise.cc b/mace/ops/eltwise.cc index 9bca9033cbb601aa3686d0be633f5fcdea2035d7..373210b3313099eeff3bd2b1c078540d79d20df4 100644 --- a/mace/ops/eltwise.cc +++ b/mace/ops/eltwise.cc @@ -900,23 +900,21 @@ class EltwiseOp : public Operation { } } - if (need_general_broadcast) { + if (input1->size() == 1) { + TensorScalarEltwise(type_, input0_ptr, input1_ptr[0], coeff_, + input0->size(), swapped, output_ptr); + } else if (input0_shape == input1_shape) { + TensorEltwise(type_, input0_ptr, input1_ptr, coeff_, input0->size(), + swapped, output_ptr); + } else if (need_general_broadcast) { TensorGeneralBroadcastEltwise(type_, input0_ptr, input1_ptr, coeff_, swapped, input0_shape, input1_shape, output_shape, output_ptr); - } else if (input1->size() == input0->size()) { - TensorEltwise(type_, input0_ptr, input1_ptr, coeff_, input0->size(), - swapped, output_ptr); - } else if (input1->size() < input0->size()) { - if (input1->size() > 1) { - index_t common_size = input1->size(); - index_t diff_size = input0->size() / common_size; - TensorBroadcastEltwise(type_, input0_ptr, input1_ptr, coeff_, - diff_size, common_size, swapped, output_ptr); - } else { - TensorScalarEltwise(type_, input0_ptr, input1_ptr[0], coeff_, - input0->size(), swapped, output_ptr); - } + } else { + index_t common_size = input1->size(); + index_t diff_size = input0->size() / common_size; + TensorBroadcastEltwise(type_, input0_ptr, input1_ptr, coeff_, + diff_size, common_size, swapped, output_ptr); } }