提交 5da03e1b 编写于 作者: B Bin Li

Fix wrong branching of eltwise

上级 c62b41fa
......@@ -864,16 +864,19 @@ class EltwiseOp : public Operation {
const T *input0_ptr = input0->data<T>();
const T *input1_ptr = input1->data<T>();
if (data_format_ == NCHW && input1->dim_size() > 0 &&
input1->size() <= input0->size()) {
if (data_format_ == NCHW && input1->dim_size() > 0) {
MACE_RETURN_IF_ERROR(output->ResizeLike(input0));
Tensor::MappingGuard output_guard(output);
DstType *output_ptr = output->mutable_data<DstType>();
TensorEltwisePerChannel(
type_, input0_ptr, input1_ptr, coeff_, input0->dim(0),
input1->dim_size() == 1 ? 1 : input1->dim(0), input0->dim(1),
input0->dim(2) * input0->dim(3), swapped, output_ptr);
if (input1->size() < input0->size()) {
TensorEltwisePerChannel(
type_, input0_ptr, input1_ptr, coeff_, input0->dim(0),
input1->dim_size() == 1 ? 1 : input1->dim(0), input0->dim(1),
input0->dim(2) * input0->dim(3), swapped, output_ptr);
} else {
TensorEltwise(type_, input0_ptr, input1_ptr, coeff_, input0->size(),
swapped, output_ptr);
}
} else {
const std::vector<index_t> &input0_shape = input0->shape();
std::vector<index_t> input1_shape(rank_diff, 1);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册