diff --git a/lite/operators/op_params.h b/lite/operators/op_params.h index 8609f178886808f5dedf2de86e7cf7941c4a4c5d..32b80518e505c9dbc46d392308cf572a4e7f1278 100644 --- a/lite/operators/op_params.h +++ b/lite/operators/op_params.h @@ -856,7 +856,7 @@ struct UnsqueezeParam { lite::Tensor* XShape{}; std::vector axes{}; const lite::Tensor* axes_tensor{}; - std::vector* axes_tensor_vct{}; + std::vector axes_tensor_vct{}; }; /// ----------------------- expand operators ---------------------- diff --git a/lite/operators/unsqueeze_op.cc b/lite/operators/unsqueeze_op.cc index 8db14d0660a7b48b94406e35908f0636a53d57f6..39b275b7b55f79f2c8daf16ab0a6acd2e76e8b48 100644 --- a/lite/operators/unsqueeze_op.cc +++ b/lite/operators/unsqueeze_op.cc @@ -66,10 +66,7 @@ bool UnsqueezeOp::InferShape() const { std::vector final_axes; auto axes = param_.axes; auto *axes_tensor = param_.axes_tensor; - std::vector axes_tensor_vct; - if (param_.axes_tensor_vct) { - axes_tensor_vct = *(param_.axes_tensor_vct); - } + auto axes_tensor_vct = param_.axes_tensor_vct; if (!axes.empty()) { final_axes = axes; @@ -79,7 +76,7 @@ bool UnsqueezeOp::InferShape() const { axes_tensor_data + axes_tensor->numel()); } else if (!axes_tensor_vct.empty()) { for (int i = 0; i < axes_tensor_vct.size(); i++) { - final_axes.push_back(axes_tensor_vct[i].data()[0]); + final_axes.push_back(axes_tensor_vct[i]->data()[0]); } } else { LOG(FATAL) << "Input axis error"; @@ -114,16 +111,12 @@ bool UnsqueezeOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { if (opdesc.HasInput("AxesTensorList") && opdesc.Input("AxesTensorList").size() > 0) { auto args = opdesc.Input("AxesTensorList"); - /* for (auto arg : args) { auto *var = scope->FindVar(arg); if (var != nullptr) { param_.axes_tensor_vct.push_back(var->GetMutable()); } } - */ - auto *var = scope->FindVar(args.front()); - param_.axes_tensor_vct = var->GetMutable>(); } CHECK(param_.X) << "Input(X) of UnsqueezeOp should not be null."; CHECK(param_.Out) << "Output(Out) of UnsqueezeOp should not be null."; diff --git a/lite/tests/kernels/unsqueeze_compute_test.cc b/lite/tests/kernels/unsqueeze_compute_test.cc index 9bbf39b70d5aab67454233efb909f932e0b5bec1..22e475672a87dafee29d68a3824e4f8ac0c15615 100644 --- a/lite/tests/kernels/unsqueeze_compute_test.cc +++ b/lite/tests/kernels/unsqueeze_compute_test.cc @@ -125,8 +125,7 @@ class UnsqueezeComputeTester : public arena::TestCase { for (size_t i = 0; i < axes_.size(); i++) { name = name + std::to_string(i); axes_tensor_list_.push_back(name); - std::vector in_data = {axes_[i]}; - SetCommonTensor(name, DDim({1}), in_data.data()); + SetCommonTensor(name, DDim({1}), &axes_[i]); } } } @@ -230,7 +229,7 @@ void test_unsqueeze(Place place) { for (int C : {3}) { for (int H : {1}) { for (int W : {5}) { - for (int input_axes_flag : {1, 2}) { + for (int input_axes_flag : {1, 2, 3}) { LOG(INFO) << N << " " << C << " " << H << " " << W << " " << input_axes_flag; std::unique_ptr tester(