未验证 提交 97b09e81 编写于 作者: L Leo Chen 提交者: GitHub

Add elementwise_add into Paddle-TRT NHWC support (#56795)

* Add elementwise_add support into NHWC IR
上级 657b6401
......@@ -161,7 +161,7 @@ void TrtSupportNHWCPass::ApplyImpl(Graph *graph) const {
"affine_channel", "softmax", "temporal_shift"};
// OPs unrelated to layout are consistent according to the layout of input
// var!
std::unordered_set<std::string> any_layout_ops{"relu"};
std::unordered_set<std::string> any_layout_ops{"relu", "elementwise_add"};
//
//
// TODO(liuyuanle): Add other op if needed!
......
......@@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import shutil
import tempfile
import unittest
import numpy as np
......@@ -53,6 +55,15 @@ class SimpleNet(nn.Layer):
data_format='NHWC',
)
self.relu3 = nn.ReLU()
self.conv4 = nn.Conv2D(
in_channels=2,
out_channels=1,
kernel_size=3,
stride=2,
padding=0,
data_format='NHWC',
)
self.relu4 = nn.ReLU()
self.flatten = nn.Flatten()
self.fc = nn.Linear(729, 10)
self.softmax = nn.Softmax()
......@@ -62,8 +73,12 @@ class SimpleNet(nn.Layer):
x = self.relu1(x)
x = self.conv2(x)
x = self.relu2(x)
res = x
x = self.conv3(x)
x = self.relu3(x)
res = self.conv4(res)
res = self.relu4(res)
x = x + res
x = self.flatten(x)
x = self.fc(x)
x = self.softmax(x)
......@@ -73,7 +88,11 @@ class SimpleNet(nn.Layer):
class TRTNHWCConvertTest(unittest.TestCase):
def setUp(self):
self.place = paddle.CUDAPlace(0)
self.path = './inference_pass/nhwc_convert/infer_model'
self.temp_dir = tempfile.TemporaryDirectory()
self.path = os.path.join(
self.temp_dir.name, 'inference_pass', 'nhwc_converter', ''
)
self.model_prefix = self.path + 'infer_model'
def create_model(self):
image = static.data(
......@@ -82,11 +101,13 @@ class TRTNHWCConvertTest(unittest.TestCase):
predict = SimpleNet()(image)
exe = paddle.static.Executor(self.place)
exe.run(paddle.static.default_startup_program())
paddle.static.save_inference_model(self.path, [image], [predict], exe)
paddle.static.save_inference_model(
self.model_prefix, [image], [predict], exe
)
def create_predictor(self):
config = paddle.inference.Config(
self.path + '.pdmodel', self.path + '.pdiparams'
self.model_prefix + '.pdmodel', self.model_prefix + '.pdiparams'
)
config.enable_memory_optim()
config.enable_use_gpu(100, 0)
......@@ -123,7 +144,7 @@ class TRTNHWCConvertTest(unittest.TestCase):
result = self.infer(predictor, img=[img])
def tearDown(self):
shutil.rmtree('./inference_pass/nhwc_convert/')
shutil.rmtree(self.path)
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册