未验证 提交 fd679d31 编写于 作者: Z zhoutianzi666 提交者: GitHub

fix roll bug (#50391)

上级 d9a134c3
...@@ -28,7 +28,7 @@ namespace paddle { ...@@ -28,7 +28,7 @@ namespace paddle {
namespace inference { namespace inference {
namespace tensorrt { namespace tensorrt {
/* /*
* Stack converter from fluid to tensorRT. * Roll converter from fluid to tensorRT.
*/ */
class RollOpConverter : public OpConverter { class RollOpConverter : public OpConverter {
public: public:
...@@ -53,7 +53,8 @@ class RollOpConverter : public OpConverter { ...@@ -53,7 +53,8 @@ class RollOpConverter : public OpConverter {
} }
int axis_size = axis.size(); int axis_size = axis.size();
for (int i = 0; i < axis_size; i++) { for (int i = 0; i < axis_size; i++) {
start.d[axis[i]] = (-shifts[i]) % input_dims.d[axis[i]]; start.d[axis[i]] =
(input_dims.d[axis[i]] - shifts[i]) % input_dims.d[axis[i]];
} }
nvinfer1::Dims stride; nvinfer1::Dims stride;
...@@ -70,11 +71,9 @@ class RollOpConverter : public OpConverter { ...@@ -70,11 +71,9 @@ class RollOpConverter : public OpConverter {
auto output_name = op_desc.Output("Out")[0]; auto output_name = op_desc.Output("Out")[0];
auto shape_layer = TRT_ENGINE_ADD_LAYER(engine_, Shape, *input);
auto* layer = auto* layer =
TRT_ENGINE_ADD_LAYER(engine_, Slice, *input, start, size, stride); TRT_ENGINE_ADD_LAYER(engine_, Slice, *input, start, size, stride);
layer->setInput(2, *shape_layer->getOutput(0)); layer->setInput(2, *Shape(input));
#if IS_TRT_VERSION_GE(7000) #if IS_TRT_VERSION_GE(7000)
layer->setMode(nvinfer1::SliceMode::kWRAP); layer->setMode(nvinfer1::SliceMode::kWRAP);
#endif #endif
......
...@@ -32,7 +32,7 @@ class TrtConvertRollTest(TrtLayerAutoScanTest): ...@@ -32,7 +32,7 @@ class TrtConvertRollTest(TrtLayerAutoScanTest):
def sample_program_configs(self): def sample_program_configs(self):
def generate_input1(attrs: List[Dict[str, Any]]): def generate_input1(attrs: List[Dict[str, Any]]):
return np.ones([1, 56, 56, 192]).astype(np.float32) return np.random.random([1, 56, 56, 192]).astype(np.float32)
for axis in [[1, 2]]: for axis in [[1, 2]]:
for shifts in [[-1, -1], [-3, -3]]: for shifts in [[-1, -1], [-3, -3]]:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册