未验证 提交 97b09e81 编写于 作者: L Leo Chen 提交者: GitHub

Add elementwise_add into Paddle-TRT NHWC support (#56795)

* Add elementwise_add support into NHWC IR
上级 657b6401
...@@ -161,7 +161,7 @@ void TrtSupportNHWCPass::ApplyImpl(Graph *graph) const { ...@@ -161,7 +161,7 @@ void TrtSupportNHWCPass::ApplyImpl(Graph *graph) const {
"affine_channel", "softmax", "temporal_shift"}; "affine_channel", "softmax", "temporal_shift"};
// OPs unrelated to layout are consistent according to the layout of input // OPs unrelated to layout are consistent according to the layout of input
// var! // var!
std::unordered_set<std::string> any_layout_ops{"relu"}; std::unordered_set<std::string> any_layout_ops{"relu", "elementwise_add"};
// //
// //
// TODO(liuyuanle): Add other op if needed! // TODO(liuyuanle): Add other op if needed!
......
...@@ -12,7 +12,9 @@ ...@@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import shutil import shutil
import tempfile
import unittest import unittest
import numpy as np import numpy as np
...@@ -53,6 +55,15 @@ class SimpleNet(nn.Layer): ...@@ -53,6 +55,15 @@ class SimpleNet(nn.Layer):
data_format='NHWC', data_format='NHWC',
) )
self.relu3 = nn.ReLU() self.relu3 = nn.ReLU()
self.conv4 = nn.Conv2D(
in_channels=2,
out_channels=1,
kernel_size=3,
stride=2,
padding=0,
data_format='NHWC',
)
self.relu4 = nn.ReLU()
self.flatten = nn.Flatten() self.flatten = nn.Flatten()
self.fc = nn.Linear(729, 10) self.fc = nn.Linear(729, 10)
self.softmax = nn.Softmax() self.softmax = nn.Softmax()
...@@ -62,8 +73,12 @@ class SimpleNet(nn.Layer): ...@@ -62,8 +73,12 @@ class SimpleNet(nn.Layer):
x = self.relu1(x) x = self.relu1(x)
x = self.conv2(x) x = self.conv2(x)
x = self.relu2(x) x = self.relu2(x)
res = x
x = self.conv3(x) x = self.conv3(x)
x = self.relu3(x) x = self.relu3(x)
res = self.conv4(res)
res = self.relu4(res)
x = x + res
x = self.flatten(x) x = self.flatten(x)
x = self.fc(x) x = self.fc(x)
x = self.softmax(x) x = self.softmax(x)
...@@ -73,7 +88,11 @@ class SimpleNet(nn.Layer): ...@@ -73,7 +88,11 @@ class SimpleNet(nn.Layer):
class TRTNHWCConvertTest(unittest.TestCase): class TRTNHWCConvertTest(unittest.TestCase):
def setUp(self): def setUp(self):
self.place = paddle.CUDAPlace(0) self.place = paddle.CUDAPlace(0)
self.path = './inference_pass/nhwc_convert/infer_model' self.temp_dir = tempfile.TemporaryDirectory()
self.path = os.path.join(
self.temp_dir.name, 'inference_pass', 'nhwc_converter', ''
)
self.model_prefix = self.path + 'infer_model'
def create_model(self): def create_model(self):
image = static.data( image = static.data(
...@@ -82,11 +101,13 @@ class TRTNHWCConvertTest(unittest.TestCase): ...@@ -82,11 +101,13 @@ class TRTNHWCConvertTest(unittest.TestCase):
predict = SimpleNet()(image) predict = SimpleNet()(image)
exe = paddle.static.Executor(self.place) exe = paddle.static.Executor(self.place)
exe.run(paddle.static.default_startup_program()) exe.run(paddle.static.default_startup_program())
paddle.static.save_inference_model(self.path, [image], [predict], exe) paddle.static.save_inference_model(
self.model_prefix, [image], [predict], exe
)
def create_predictor(self): def create_predictor(self):
config = paddle.inference.Config( config = paddle.inference.Config(
self.path + '.pdmodel', self.path + '.pdiparams' self.model_prefix + '.pdmodel', self.model_prefix + '.pdiparams'
) )
config.enable_memory_optim() config.enable_memory_optim()
config.enable_use_gpu(100, 0) config.enable_use_gpu(100, 0)
...@@ -123,7 +144,7 @@ class TRTNHWCConvertTest(unittest.TestCase): ...@@ -123,7 +144,7 @@ class TRTNHWCConvertTest(unittest.TestCase):
result = self.infer(predictor, img=[img]) result = self.infer(predictor, img=[img])
def tearDown(self): def tearDown(self):
shutil.rmtree('./inference_pass/nhwc_convert/') shutil.rmtree(self.path)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册