From f4e06650d03179e4b055dc4c3636255b14dbaf4c Mon Sep 17 00:00:00 2001 From: juncaipeng <52520497+juncaipeng@users.noreply.github.com> Date: Wed, 13 Nov 2019 10:27:04 +0800 Subject: [PATCH] fix error for AxesTensorList in unsqueeze op, test=develop (#2411) * fix error for AxesTensorList in unsqueeze op --- lite/operators/op_params.h | 2 +- lite/operators/unsqueeze_op.cc | 11 ++--------- lite/tests/kernels/unsqueeze_compute_test.cc | 5 ++--- 3 files changed, 5 insertions(+), 13 deletions(-) diff --git a/lite/operators/op_params.h b/lite/operators/op_params.h index 8609f17888..32b80518e5 100644 --- a/lite/operators/op_params.h +++ b/lite/operators/op_params.h @@ -856,7 +856,7 @@ struct UnsqueezeParam { lite::Tensor* XShape{}; std::vector axes{}; const lite::Tensor* axes_tensor{}; - std::vector* axes_tensor_vct{}; + std::vector axes_tensor_vct{}; }; /// ----------------------- expand operators ---------------------- diff --git a/lite/operators/unsqueeze_op.cc b/lite/operators/unsqueeze_op.cc index 8db14d0660..39b275b7b5 100644 --- a/lite/operators/unsqueeze_op.cc +++ b/lite/operators/unsqueeze_op.cc @@ -66,10 +66,7 @@ bool UnsqueezeOp::InferShape() const { std::vector final_axes; auto axes = param_.axes; auto *axes_tensor = param_.axes_tensor; - std::vector axes_tensor_vct; - if (param_.axes_tensor_vct) { - axes_tensor_vct = *(param_.axes_tensor_vct); - } + auto axes_tensor_vct = param_.axes_tensor_vct; if (!axes.empty()) { final_axes = axes; @@ -79,7 +76,7 @@ bool UnsqueezeOp::InferShape() const { axes_tensor_data + axes_tensor->numel()); } else if (!axes_tensor_vct.empty()) { for (int i = 0; i < axes_tensor_vct.size(); i++) { - final_axes.push_back(axes_tensor_vct[i].data()[0]); + final_axes.push_back(axes_tensor_vct[i]->data()[0]); } } else { LOG(FATAL) << "Input axis error"; @@ -114,16 +111,12 @@ bool UnsqueezeOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { if (opdesc.HasInput("AxesTensorList") && opdesc.Input("AxesTensorList").size() > 0) { auto args = opdesc.Input("AxesTensorList"); - /* for (auto arg : args) { auto *var = scope->FindVar(arg); if (var != nullptr) { param_.axes_tensor_vct.push_back(var->GetMutable()); } } - */ - auto *var = scope->FindVar(args.front()); - param_.axes_tensor_vct = var->GetMutable>(); } CHECK(param_.X) << "Input(X) of UnsqueezeOp should not be null."; CHECK(param_.Out) << "Output(Out) of UnsqueezeOp should not be null."; diff --git a/lite/tests/kernels/unsqueeze_compute_test.cc b/lite/tests/kernels/unsqueeze_compute_test.cc index 9bbf39b70d..22e475672a 100644 --- a/lite/tests/kernels/unsqueeze_compute_test.cc +++ b/lite/tests/kernels/unsqueeze_compute_test.cc @@ -125,8 +125,7 @@ class UnsqueezeComputeTester : public arena::TestCase { for (size_t i = 0; i < axes_.size(); i++) { name = name + std::to_string(i); axes_tensor_list_.push_back(name); - std::vector in_data = {axes_[i]}; - SetCommonTensor(name, DDim({1}), in_data.data()); + SetCommonTensor(name, DDim({1}), &axes_[i]); } } } @@ -230,7 +229,7 @@ void test_unsqueeze(Place place) { for (int C : {3}) { for (int H : {1}) { for (int W : {5}) { - for (int input_axes_flag : {1, 2}) { + for (int input_axes_flag : {1, 2, 3}) { LOG(INFO) << N << " " << C << " " << H << " " << W << " " << input_axes_flag; std::unique_ptr tester( -- GitLab