提交 c5e47b11 编写于 作者: S sunsuodong

fix_concat_slice

上级 7371cedd
...@@ -107,14 +107,8 @@ int Concat::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor ...@@ -107,14 +107,8 @@ int Concat::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
} }
auto input0_shape_without_axis = input0_shape; auto input0_shape_without_axis = input0_shape;
input0_shape_without_axis.erase(input0_shape_without_axis.begin() + axis); input0_shape_without_axis.erase(input0_shape_without_axis.begin() + axis);
auto input0_data_type = inputs_.at(0)->data_type();
int output_axis_dim = input0_shape.at(axis); int output_axis_dim = input0_shape.at(axis);
for (size_t i = 1; i < inputs_.size(); ++i) { for (size_t i = 1; i < inputs_.size(); ++i) {
if (inputs_.at(i)->data_type() != input0_data_type) {
MS_LOG(ERROR) << "All inputs should have the same data type!";
return RET_PARAM_INVALID;
}
auto shape_tmp = inputs_.at(i)->shape(); auto shape_tmp = inputs_.at(i)->shape();
if (shape_tmp.size() != input0_shape.size()) { if (shape_tmp.size() != input0_shape.size()) {
MS_LOG(ERROR) << "All inputs should have the same dim num!"; MS_LOG(ERROR) << "All inputs should have the same dim num!";
......
...@@ -60,6 +60,12 @@ int Slice::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers:: ...@@ -60,6 +60,12 @@ int Slice::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::
return RET_ERROR; return RET_ERROR;
} }
std::vector<int32_t> axes;
if (attr->axes() != nullptr) {
for (int i = 0; i < static_cast<int>(attr->axes()->size()); i++) {
axes.push_back(attr->axes()->data()[i]);
}
}
std::vector<int32_t> begin; std::vector<int32_t> begin;
if (attr->begin() != nullptr) { if (attr->begin() != nullptr) {
for (int i = 0; i < static_cast<int>(attr->begin()->size()); i++) { for (int i = 0; i < static_cast<int>(attr->begin()->size()); i++) {
...@@ -73,7 +79,7 @@ int Slice::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers:: ...@@ -73,7 +79,7 @@ int Slice::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::
} }
} }
auto val_offset = schema::CreateSliceDirect(*fbb, attr->format(), &begin, &size); auto val_offset = schema::CreateSliceDirect(*fbb, attr->format(), &axes, &begin, &size);
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Slice, val_offset.o); auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Slice, val_offset.o);
fbb->Finish(prim_offset); fbb->Finish(prim_offset);
return RET_OK; return RET_OK;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册