提交 5da03e1b 编写于 作者: B Bin Li

Fix wrong branching of eltwise

上级 c62b41fa
...@@ -864,16 +864,19 @@ class EltwiseOp : public Operation { ...@@ -864,16 +864,19 @@ class EltwiseOp : public Operation {
const T *input0_ptr = input0->data<T>(); const T *input0_ptr = input0->data<T>();
const T *input1_ptr = input1->data<T>(); const T *input1_ptr = input1->data<T>();
if (data_format_ == NCHW && input1->dim_size() > 0 && if (data_format_ == NCHW && input1->dim_size() > 0) {
input1->size() <= input0->size()) {
MACE_RETURN_IF_ERROR(output->ResizeLike(input0)); MACE_RETURN_IF_ERROR(output->ResizeLike(input0));
Tensor::MappingGuard output_guard(output); Tensor::MappingGuard output_guard(output);
DstType *output_ptr = output->mutable_data<DstType>(); DstType *output_ptr = output->mutable_data<DstType>();
TensorEltwisePerChannel( if (input1->size() < input0->size()) {
type_, input0_ptr, input1_ptr, coeff_, input0->dim(0), TensorEltwisePerChannel(
input1->dim_size() == 1 ? 1 : input1->dim(0), input0->dim(1), type_, input0_ptr, input1_ptr, coeff_, input0->dim(0),
input0->dim(2) * input0->dim(3), swapped, output_ptr); input1->dim_size() == 1 ? 1 : input1->dim(0), input0->dim(1),
input0->dim(2) * input0->dim(3), swapped, output_ptr);
} else {
TensorEltwise(type_, input0_ptr, input1_ptr, coeff_, input0->size(),
swapped, output_ptr);
}
} else { } else {
const std::vector<index_t> &input0_shape = input0->shape(); const std::vector<index_t> &input0_shape = input0->shape();
std::vector<index_t> input1_shape(rank_diff, 1); std::vector<index_t> input1_shape(rank_diff, 1);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册