未验证 提交 814f1d65 编写于 作者: Y Yuanle Liu 提交者: GitHub

fix fp16 io (#54129)

* fix fp16 io

* disable precision test
上级 85907b58
...@@ -509,6 +509,7 @@ void AutoMixedPrecisionPass::UpdateOpPrecision() const { ...@@ -509,6 +509,7 @@ void AutoMixedPrecisionPass::UpdateOpPrecision() const {
// when op_1 only support cpu kernel. if op_2's intput var is op_1's // when op_1 only support cpu kernel. if op_2's intput var is op_1's
// output var, then op_2 should not run at low precision. // output var, then op_2 should not run at low precision.
if (GetOpOriginalType(op_type) != "feed" && if (GetOpOriginalType(op_type) != "feed" &&
GetOpOriginalType(op_type) != "tensorrt_engine" &&
!KernelSupportPrecision( !KernelSupportPrecision(
GetOpOriginalType(op_type), backend_, phi::DataType::FLOAT32)) { GetOpOriginalType(op_type), backend_, phi::DataType::FLOAT32)) {
for (auto* out_var_node : op_node->outputs) { for (auto* out_var_node : op_node->outputs) {
......
...@@ -66,10 +66,11 @@ class TestEnableLowPrecisionIO: ...@@ -66,10 +66,11 @@ class TestEnableLowPrecisionIO:
fp32_output = self.get_fp32_output() fp32_output = self.get_fp32_output()
fp16_output = self.get_fp16_output() fp16_output = self.get_fp16_output()
# np.testing.assert_allclose( # if os.name == 'posix':
# fp32_output.numpy().flatten(), # np.testing.assert_allclose(
# fp16_output.numpy().flatten(), # fp32_output.numpy().flatten(),
# ) # fp16_output.numpy().flatten(),
# )
class TestEnableLowPrecisionIOWithGPU( class TestEnableLowPrecisionIOWithGPU(
...@@ -105,6 +106,7 @@ class TestEnableLowPrecisionIOWithTRTAllGraph( ...@@ -105,6 +106,7 @@ class TestEnableLowPrecisionIOWithTRTAllGraph(
use_static=False, use_static=False,
use_calib_mode=False, use_calib_mode=False,
) )
config.enable_tuned_tensorrt_dynamic_shape()
config.enable_memory_optim() config.enable_memory_optim()
config.enable_low_precision_io(low_precision_io) config.enable_low_precision_io(low_precision_io)
config.disable_glog_info() config.disable_glog_info()
...@@ -129,6 +131,7 @@ class TestEnableLowPrecisionIOWithTRTSubGraph( ...@@ -129,6 +131,7 @@ class TestEnableLowPrecisionIOWithTRTSubGraph(
use_static=False, use_static=False,
use_calib_mode=False, use_calib_mode=False,
) )
config.enable_tuned_tensorrt_dynamic_shape()
config.enable_memory_optim() config.enable_memory_optim()
config.enable_low_precision_io(low_precision_io) config.enable_low_precision_io(low_precision_io)
config.exp_disable_tensorrt_ops(["flatten_contiguous_range"]) config.exp_disable_tensorrt_ops(["flatten_contiguous_range"])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册