未验证 提交 dc3c0de1 编写于 作者: Z Zhang Jun 提交者: GitHub

fix reshape error: (Repeated layer name: reshape (layers must have distinct names)) (#54072)

上级 a299797d
...@@ -115,7 +115,10 @@ class CumsumOpConverter : public OpConverter { ...@@ -115,7 +115,10 @@ class CumsumOpConverter : public OpConverter {
[axis](int x) { return x == axis; }); [axis](int x) { return x == axis; });
subscripts.resize(p - subscripts.begin()); subscripts.resize(p - subscripts.begin());
auto newDims = Gather(Shape(inputSliced_output), subscripts); auto newDims = Gather(Shape(inputSliced_output), subscripts);
inputSliced_output = Reshape(inputSliced_output, newDims); inputSliced_output =
Reshape(inputSliced_output,
newDims,
("cumsum: reshape: (Output(" + output_name + ")").c_str());
// creat ZeroTensor // creat ZeroTensor
std::vector<float> zero_vec{0.f}; std::vector<float> zero_vec{0.f};
...@@ -127,7 +130,11 @@ class CumsumOpConverter : public OpConverter { ...@@ -127,7 +130,11 @@ class CumsumOpConverter : public OpConverter {
engine_, engine_,
ElementWise, ElementWise,
*inputSliced_output, *inputSliced_output,
*BroadcastTensors(cast->getOutput(0), inputSliced_output), *BroadcastTensors(cast->getOutput(0),
inputSliced_output,
("cumsum: reshape_for_broadcast: (Output(" +
output_name + ")")
.c_str()),
nvinfer1::ElementWiseOperation::kPROD) nvinfer1::ElementWiseOperation::kPROD)
->getOutput(0); ->getOutput(0);
......
...@@ -403,15 +403,18 @@ class OpConverter { ...@@ -403,15 +403,18 @@ class OpConverter {
nvinfer1::ITensor* Reshape(nvinfer1::ITensor* input, nvinfer1::ITensor* Reshape(nvinfer1::ITensor* input,
nvinfer1::ITensor* newShape, nvinfer1::ITensor* newShape,
const std::string& name = "reshape") { const std::string& name = "") {
auto* shuffle = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input); auto* shuffle = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input);
shuffle->setInput(1, *newShape); shuffle->setInput(1, *newShape);
shuffle->setName(name.c_str()); if (name != "") {
shuffle->setName(name.c_str());
}
return shuffle->getOutput(0); return shuffle->getOutput(0);
} }
nvinfer1::ITensor* BroadcastTensor(nvinfer1::ITensor* input, nvinfer1::ITensor* BroadcastTensor(nvinfer1::ITensor* input,
const int nbDims) { const int nbDims,
const std::string& name = "") {
auto oldShape = Shape(input); auto oldShape = Shape(input);
auto oldShapeDims = oldShape->getDimensions(); auto oldShapeDims = oldShape->getDimensions();
const int rank = oldShapeDims.nbDims; const int rank = oldShapeDims.nbDims;
...@@ -427,22 +430,23 @@ class OpConverter { ...@@ -427,22 +430,23 @@ class OpConverter {
itensors.push_back(one_rank_tensor); itensors.push_back(one_rank_tensor);
itensors.push_back(oldShape); itensors.push_back(oldShape);
concat_shape_tensor = Concat(itensors); concat_shape_tensor = Concat(itensors);
input = Reshape(input, concat_shape_tensor); input = Reshape(input, concat_shape_tensor, name);
} }
return input; return input;
} }
nvinfer1::ITensor* BroadcastTensors(nvinfer1::ITensor* a, nvinfer1::ITensor* BroadcastTensors(nvinfer1::ITensor* a,
nvinfer1::ITensor* b) { nvinfer1::ITensor* b,
const std::string& name = "") {
const int aDims = a->getDimensions().nbDims; const int aDims = a->getDimensions().nbDims;
const int bDims = b->getDimensions().nbDims; const int bDims = b->getDimensions().nbDims;
if (aDims == bDims) { if (aDims == bDims) {
VLOG(3) << "Broadcast two equal rank tensors"; VLOG(3) << "Broadcast two equal rank tensors";
} }
if (aDims > bDims) { if (aDims > bDims) {
return BroadcastTensor(b, aDims); return BroadcastTensor(b, aDims, name);
} }
return BroadcastTensor(a, bDims); return BroadcastTensor(a, bDims, name);
} }
// Concat not make rank changed // Concat not make rank changed
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册