未验证 提交 fd679d31 编写于 作者: Z zhoutianzi666 提交者: GitHub

fix roll bug (#50391)

上级 d9a134c3
......@@ -28,7 +28,7 @@ namespace paddle {
namespace inference {
namespace tensorrt {
/*
* Stack converter from fluid to tensorRT.
* Roll converter from fluid to tensorRT.
*/
class RollOpConverter : public OpConverter {
public:
......@@ -53,7 +53,8 @@ class RollOpConverter : public OpConverter {
}
int axis_size = axis.size();
for (int i = 0; i < axis_size; i++) {
start.d[axis[i]] = (-shifts[i]) % input_dims.d[axis[i]];
start.d[axis[i]] =
(input_dims.d[axis[i]] - shifts[i]) % input_dims.d[axis[i]];
}
nvinfer1::Dims stride;
......@@ -70,11 +71,9 @@ class RollOpConverter : public OpConverter {
auto output_name = op_desc.Output("Out")[0];
auto shape_layer = TRT_ENGINE_ADD_LAYER(engine_, Shape, *input);
auto* layer =
TRT_ENGINE_ADD_LAYER(engine_, Slice, *input, start, size, stride);
layer->setInput(2, *shape_layer->getOutput(0));
layer->setInput(2, *Shape(input));
#if IS_TRT_VERSION_GE(7000)
layer->setMode(nvinfer1::SliceMode::kWRAP);
#endif
......
......@@ -32,7 +32,7 @@ class TrtConvertRollTest(TrtLayerAutoScanTest):
def sample_program_configs(self):
def generate_input1(attrs: List[Dict[str, Any]]):
return np.ones([1, 56, 56, 192]).astype(np.float32)
return np.random.random([1, 56, 56, 192]).astype(np.float32)
for axis in [[1, 2]]:
for shifts in [[-1, -1], [-3, -3]]:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册