diff --git a/mace/ops/eltwise.cc b/mace/ops/eltwise.cc index 768349fffd268d58aee8f05260f37c841021947e..d345427a2a3a6258d90f3a55c71f1b8d8004419b 100644 --- a/mace/ops/eltwise.cc +++ b/mace/ops/eltwise.cc @@ -1171,12 +1171,23 @@ class EltwiseOp : public Operation { for (int i = 0; i < input_size; ++i) { if (ws->HasTensor(operator_def_->input(i)) && ws->GetTensor(operator_def_->input(i))->is_weight()) { - MACE_CHECK(TransformFilter( - context, - operator_def_.get(), - i, - OpenCLBufferType::ARGUMENT, - mem_type) == MaceStatus::MACE_SUCCESS); + if (ws->GetTensor(operator_def_->input(i))->dim_size() == 1) { + MACE_CHECK(TransformFilter( + context, + operator_def_.get(), + i, + OpenCLBufferType::ARGUMENT, + mem_type) == MaceStatus::MACE_SUCCESS); + } else if (ws->GetTensor(operator_def_->input(i))->dim_size() == 4) { + MACE_CHECK(TransformFilter( + context, + operator_def_.get(), + i, + OpenCLBufferType::IN_OUT_CHANNEL, + mem_type) == MaceStatus::MACE_SUCCESS); + } else { + MACE_NOT_IMPLEMENTED; + } } } }