diff --git a/mace/ops/eltwise.cc b/mace/ops/eltwise.cc index d8520dcb37d578c41c2536a3e66420ec72b416b2..9bca9033cbb601aa3686d0be633f5fcdea2035d7 100644 --- a/mace/ops/eltwise.cc +++ b/mace/ops/eltwise.cc @@ -864,16 +864,19 @@ class EltwiseOp : public Operation { const T *input0_ptr = input0->data(); const T *input1_ptr = input1->data(); - if (data_format_ == NCHW && input1->dim_size() > 0 && - input1->size() <= input0->size()) { + if (data_format_ == NCHW && input1->dim_size() > 0) { MACE_RETURN_IF_ERROR(output->ResizeLike(input0)); Tensor::MappingGuard output_guard(output); DstType *output_ptr = output->mutable_data(); - TensorEltwisePerChannel( - type_, input0_ptr, input1_ptr, coeff_, input0->dim(0), - input1->dim_size() == 1 ? 1 : input1->dim(0), input0->dim(1), - input0->dim(2) * input0->dim(3), swapped, output_ptr); - + if (input1->size() < input0->size()) { + TensorEltwisePerChannel( + type_, input0_ptr, input1_ptr, coeff_, input0->dim(0), + input1->dim_size() == 1 ? 1 : input1->dim(0), input0->dim(1), + input0->dim(2) * input0->dim(3), swapped, output_ptr); + } else { + TensorEltwise(type_, input0_ptr, input1_ptr, coeff_, input0->size(), + swapped, output_ptr); + } } else { const std::vector &input0_shape = input0->shape(); std::vector input1_shape(rank_diff, 1);