提交 d6fb53c5 编写于 作者: J juncaipeng 提交者: GitHub

fix error for AxesTensorList in unsqueeze op, test=develop (#2411)

* fix error for AxesTensorList in unsqueeze op
上级 9f236a99
......@@ -856,7 +856,7 @@ struct UnsqueezeParam {
lite::Tensor* XShape{};
std::vector<int> axes{};
const lite::Tensor* axes_tensor{};
std::vector<lite::Tensor>* axes_tensor_vct{};
std::vector<const lite::Tensor*> axes_tensor_vct{};
};
/// ----------------------- expand operators ----------------------
......
......@@ -66,10 +66,7 @@ bool UnsqueezeOp::InferShape() const {
std::vector<int> final_axes;
auto axes = param_.axes;
auto *axes_tensor = param_.axes_tensor;
std::vector<lite::Tensor> axes_tensor_vct;
if (param_.axes_tensor_vct) {
axes_tensor_vct = *(param_.axes_tensor_vct);
}
auto axes_tensor_vct = param_.axes_tensor_vct;
if (!axes.empty()) {
final_axes = axes;
......@@ -79,7 +76,7 @@ bool UnsqueezeOp::InferShape() const {
axes_tensor_data + axes_tensor->numel());
} else if (!axes_tensor_vct.empty()) {
for (int i = 0; i < axes_tensor_vct.size(); i++) {
final_axes.push_back(axes_tensor_vct[i].data<int>()[0]);
final_axes.push_back(axes_tensor_vct[i]->data<int>()[0]);
}
} else {
LOG(FATAL) << "Input axis error";
......@@ -114,16 +111,12 @@ bool UnsqueezeOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
if (opdesc.HasInput("AxesTensorList") &&
opdesc.Input("AxesTensorList").size() > 0) {
auto args = opdesc.Input("AxesTensorList");
/*
for (auto arg : args) {
auto *var = scope->FindVar(arg);
if (var != nullptr) {
param_.axes_tensor_vct.push_back(var->GetMutable<lite::Tensor>());
}
}
*/
auto *var = scope->FindVar(args.front());
param_.axes_tensor_vct = var->GetMutable<std::vector<lite::Tensor>>();
}
CHECK(param_.X) << "Input(X) of UnsqueezeOp should not be null.";
CHECK(param_.Out) << "Output(Out) of UnsqueezeOp should not be null.";
......
......@@ -125,8 +125,7 @@ class UnsqueezeComputeTester : public arena::TestCase {
for (size_t i = 0; i < axes_.size(); i++) {
name = name + std::to_string(i);
axes_tensor_list_.push_back(name);
std::vector<int> in_data = {axes_[i]};
SetCommonTensor(name, DDim({1}), in_data.data());
SetCommonTensor(name, DDim({1}), &axes_[i]);
}
}
}
......@@ -230,7 +229,7 @@ void test_unsqueeze(Place place) {
for (int C : {3}) {
for (int H : {1}) {
for (int W : {5}) {
for (int input_axes_flag : {1, 2}) {
for (int input_axes_flag : {1, 2, 3}) {
LOG(INFO) << N << " " << C << " " << H << " " << W << " "
<< input_axes_flag;
std::unique_ptr<arena::TestCase> tester(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册