diff --git a/paddle/fluid/framework/ir/trt_support_nhwc_pass.cc b/paddle/fluid/framework/ir/trt_support_nhwc_pass.cc index 0b12559e31deb85d50dca174acfdc50a489e82a9..0403330f77cd187bbf0f6378164555267cb42e42 100644 --- a/paddle/fluid/framework/ir/trt_support_nhwc_pass.cc +++ b/paddle/fluid/framework/ir/trt_support_nhwc_pass.cc @@ -161,7 +161,7 @@ void TrtSupportNHWCPass::ApplyImpl(Graph *graph) const { "affine_channel", "softmax", "temporal_shift"}; // OPs unrelated to layout are consistent according to the layout of input // var! - std::unordered_set any_layout_ops{"relu"}; + std::unordered_set any_layout_ops{"relu", "elementwise_add"}; // // // TODO(liuyuanle): Add other op if needed! diff --git a/test/ir/inference/test_trt_support_nhwc_pass.py b/test/ir/inference/test_trt_support_nhwc_pass.py index 7c0a6eb4b4a8904c694d81e2a399e1b9d4710233..0648202aba30c4172c0ae7c7f76fe927f312bc25 100644 --- a/test/ir/inference/test_trt_support_nhwc_pass.py +++ b/test/ir/inference/test_trt_support_nhwc_pass.py @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import shutil +import tempfile import unittest import numpy as np @@ -53,6 +55,15 @@ class SimpleNet(nn.Layer): data_format='NHWC', ) self.relu3 = nn.ReLU() + self.conv4 = nn.Conv2D( + in_channels=2, + out_channels=1, + kernel_size=3, + stride=2, + padding=0, + data_format='NHWC', + ) + self.relu4 = nn.ReLU() self.flatten = nn.Flatten() self.fc = nn.Linear(729, 10) self.softmax = nn.Softmax() @@ -62,8 +73,12 @@ class SimpleNet(nn.Layer): x = self.relu1(x) x = self.conv2(x) x = self.relu2(x) + res = x x = self.conv3(x) x = self.relu3(x) + res = self.conv4(res) + res = self.relu4(res) + x = x + res x = self.flatten(x) x = self.fc(x) x = self.softmax(x) @@ -73,7 +88,11 @@ class SimpleNet(nn.Layer): class TRTNHWCConvertTest(unittest.TestCase): def setUp(self): self.place = paddle.CUDAPlace(0) - self.path = './inference_pass/nhwc_convert/infer_model' + self.temp_dir = tempfile.TemporaryDirectory() + self.path = os.path.join( + self.temp_dir.name, 'inference_pass', 'nhwc_converter', '' + ) + self.model_prefix = self.path + 'infer_model' def create_model(self): image = static.data( @@ -82,11 +101,13 @@ class TRTNHWCConvertTest(unittest.TestCase): predict = SimpleNet()(image) exe = paddle.static.Executor(self.place) exe.run(paddle.static.default_startup_program()) - paddle.static.save_inference_model(self.path, [image], [predict], exe) + paddle.static.save_inference_model( + self.model_prefix, [image], [predict], exe + ) def create_predictor(self): config = paddle.inference.Config( - self.path + '.pdmodel', self.path + '.pdiparams' + self.model_prefix + '.pdmodel', self.model_prefix + '.pdiparams' ) config.enable_memory_optim() config.enable_use_gpu(100, 0) @@ -123,7 +144,7 @@ class TRTNHWCConvertTest(unittest.TestCase): result = self.infer(predictor, img=[img]) def tearDown(self): - shutil.rmtree('./inference_pass/nhwc_convert/') + shutil.rmtree(self.path) if __name__ == '__main__':