未验证 提交 f4e06650 编写于 作者: J juncaipeng 提交者: GitHub

fix error for AxesTensorList in unsqueeze op, test=develop (#2411)

* fix error for AxesTensorList in unsqueeze op
上级 0ade1bc5
...@@ -856,7 +856,7 @@ struct UnsqueezeParam { ...@@ -856,7 +856,7 @@ struct UnsqueezeParam {
lite::Tensor* XShape{}; lite::Tensor* XShape{};
std::vector<int> axes{}; std::vector<int> axes{};
const lite::Tensor* axes_tensor{}; const lite::Tensor* axes_tensor{};
std::vector<lite::Tensor>* axes_tensor_vct{}; std::vector<const lite::Tensor*> axes_tensor_vct{};
}; };
/// ----------------------- expand operators ---------------------- /// ----------------------- expand operators ----------------------
......
...@@ -66,10 +66,7 @@ bool UnsqueezeOp::InferShape() const { ...@@ -66,10 +66,7 @@ bool UnsqueezeOp::InferShape() const {
std::vector<int> final_axes; std::vector<int> final_axes;
auto axes = param_.axes; auto axes = param_.axes;
auto *axes_tensor = param_.axes_tensor; auto *axes_tensor = param_.axes_tensor;
std::vector<lite::Tensor> axes_tensor_vct; auto axes_tensor_vct = param_.axes_tensor_vct;
if (param_.axes_tensor_vct) {
axes_tensor_vct = *(param_.axes_tensor_vct);
}
if (!axes.empty()) { if (!axes.empty()) {
final_axes = axes; final_axes = axes;
...@@ -79,7 +76,7 @@ bool UnsqueezeOp::InferShape() const { ...@@ -79,7 +76,7 @@ bool UnsqueezeOp::InferShape() const {
axes_tensor_data + axes_tensor->numel()); axes_tensor_data + axes_tensor->numel());
} else if (!axes_tensor_vct.empty()) { } else if (!axes_tensor_vct.empty()) {
for (int i = 0; i < axes_tensor_vct.size(); i++) { for (int i = 0; i < axes_tensor_vct.size(); i++) {
final_axes.push_back(axes_tensor_vct[i].data<int>()[0]); final_axes.push_back(axes_tensor_vct[i]->data<int>()[0]);
} }
} else { } else {
LOG(FATAL) << "Input axis error"; LOG(FATAL) << "Input axis error";
...@@ -114,16 +111,12 @@ bool UnsqueezeOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { ...@@ -114,16 +111,12 @@ bool UnsqueezeOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
if (opdesc.HasInput("AxesTensorList") && if (opdesc.HasInput("AxesTensorList") &&
opdesc.Input("AxesTensorList").size() > 0) { opdesc.Input("AxesTensorList").size() > 0) {
auto args = opdesc.Input("AxesTensorList"); auto args = opdesc.Input("AxesTensorList");
/*
for (auto arg : args) { for (auto arg : args) {
auto *var = scope->FindVar(arg); auto *var = scope->FindVar(arg);
if (var != nullptr) { if (var != nullptr) {
param_.axes_tensor_vct.push_back(var->GetMutable<lite::Tensor>()); param_.axes_tensor_vct.push_back(var->GetMutable<lite::Tensor>());
} }
} }
*/
auto *var = scope->FindVar(args.front());
param_.axes_tensor_vct = var->GetMutable<std::vector<lite::Tensor>>();
} }
CHECK(param_.X) << "Input(X) of UnsqueezeOp should not be null."; CHECK(param_.X) << "Input(X) of UnsqueezeOp should not be null.";
CHECK(param_.Out) << "Output(Out) of UnsqueezeOp should not be null."; CHECK(param_.Out) << "Output(Out) of UnsqueezeOp should not be null.";
......
...@@ -125,8 +125,7 @@ class UnsqueezeComputeTester : public arena::TestCase { ...@@ -125,8 +125,7 @@ class UnsqueezeComputeTester : public arena::TestCase {
for (size_t i = 0; i < axes_.size(); i++) { for (size_t i = 0; i < axes_.size(); i++) {
name = name + std::to_string(i); name = name + std::to_string(i);
axes_tensor_list_.push_back(name); axes_tensor_list_.push_back(name);
std::vector<int> in_data = {axes_[i]}; SetCommonTensor(name, DDim({1}), &axes_[i]);
SetCommonTensor(name, DDim({1}), in_data.data());
} }
} }
} }
...@@ -230,7 +229,7 @@ void test_unsqueeze(Place place) { ...@@ -230,7 +229,7 @@ void test_unsqueeze(Place place) {
for (int C : {3}) { for (int C : {3}) {
for (int H : {1}) { for (int H : {1}) {
for (int W : {5}) { for (int W : {5}) {
for (int input_axes_flag : {1, 2}) { for (int input_axes_flag : {1, 2, 3}) {
LOG(INFO) << N << " " << C << " " << H << " " << W << " " LOG(INFO) << N << " " << C << " " << H << " " << W << " "
<< input_axes_flag; << input_axes_flag;
std::unique_ptr<arena::TestCase> tester( std::unique_ptr<arena::TestCase> tester(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册