未验证 提交 c3f3643b 编写于 作者: W wenbin 提交者: GitHub

EmbEltwiseLayernorm fix (#40015)

* emb fix

* fix trt6 compile

* fix half

* absolute error fix
上级 5d9e11a4
...@@ -54,6 +54,8 @@ TRT_DT FluidDataType2TRT(FluidDT type) { ...@@ -54,6 +54,8 @@ TRT_DT FluidDataType2TRT(FluidDT type) {
return TRT_DT::kFLOAT; return TRT_DT::kFLOAT;
case FluidDT::VarType_Type_INT32: case FluidDT::VarType_Type_INT32:
return TRT_DT::kINT32; return TRT_DT::kINT32;
case FluidDT::VarType_Type_FP16:
return TRT_DT::kHALF;
default: default:
return TRT_DT::kINT32; return TRT_DT::kINT32;
} }
......
...@@ -79,6 +79,28 @@ static void RuntimeStaticShapeCheck(std::vector<int64_t> runtime_input_shape, ...@@ -79,6 +79,28 @@ static void RuntimeStaticShapeCheck(std::vector<int64_t> runtime_input_shape,
model_input_shape_str, runtime_input_shape_str)); model_input_shape_str, runtime_input_shape_str));
} }
static paddle::experimental::DataType TRT2FluidDataType(
nvinfer1::DataType type) {
switch (type) {
case nvinfer1::DataType::kFLOAT:
return paddle::experimental::DataType::FLOAT32;
case nvinfer1::DataType::kINT32:
return paddle::experimental::DataType::INT32;
case nvinfer1::DataType::kHALF:
return paddle::experimental::DataType::FLOAT16;
case nvinfer1::DataType::kINT8:
return paddle::experimental::DataType::INT8;
#if IS_TRT_VERSION_GE(7000)
case nvinfer1::DataType::kBOOL:
return paddle::experimental::DataType::BOOL;
#endif
default:
PADDLE_THROW(platform::errors::InvalidArgument(
"unknown fluid datatype in Fluid op converter"));
return paddle::experimental::DataType::FLOAT32;
}
}
static void RuntimeDynamicShapeCheck( static void RuntimeDynamicShapeCheck(
const std::string &x, const std::vector<int32_t> &runtime_input_shape, const std::string &x, const std::vector<int32_t> &runtime_input_shape,
const std::vector<int32_t> &min_input_shape, const std::vector<int32_t> &min_input_shape,
...@@ -520,9 +542,12 @@ class TensorRTEngineOp : public framework::OperatorBase { ...@@ -520,9 +542,12 @@ class TensorRTEngineOp : public framework::OperatorBase {
buffers[bind_index] = static_cast<void *>(t.data<int64_t>()); buffers[bind_index] = static_cast<void *>(t.data<int64_t>());
} else if (type == framework::proto::VarType::INT32) { } else if (type == framework::proto::VarType::INT32) {
buffers[bind_index] = static_cast<void *>(t.data<int32_t>()); buffers[bind_index] = static_cast<void *>(t.data<int32_t>());
} else if (type == framework::proto::VarType::FP16) {
buffers[bind_index] = static_cast<void *>(t.data<float16>());
} else { } else {
PADDLE_THROW(platform::errors::Fatal( PADDLE_THROW(
"The TRT Engine OP only support float/int32_t/int64_t input.")); platform::errors::Fatal("The TRT Engine OP only support "
"float/int32_t/int64_t/float16 input."));
} }
} }
...@@ -570,9 +595,10 @@ class TensorRTEngineOp : public framework::OperatorBase { ...@@ -570,9 +595,10 @@ class TensorRTEngineOp : public framework::OperatorBase {
"than the number of bindings, but got binding " "than the number of bindings, but got binding "
"index = %d, number of bindings = %d.", "index = %d, number of bindings = %d.",
bind_index, num_bindings)); bind_index, num_bindings));
buffers[bind_index] = auto trt_type = engine->engine()->getBindingDataType(bind_index);
static_cast<void *>(fluid_t->mutable_data<float>(dev_place)); // get adr and set type
buffers[bind_index] = static_cast<void *>(
fluid_t->mutable_data(dev_place, TRT2FluidDataType(trt_type)));
output_index += 1; output_index += 1;
} }
......
...@@ -244,28 +244,16 @@ class TrtConvertEmbEltwiseLayernormTest1(TrtLayerAutoScanTest): ...@@ -244,28 +244,16 @@ class TrtConvertEmbEltwiseLayernormTest1(TrtLayerAutoScanTest):
self.trt_param.precision = paddle_infer.PrecisionType.Float32 self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), (0, 5), 1e-5 yield self.create_inference_config(), (0, 5), 1e-5
self.trt_param.precision = paddle_infer.PrecisionType.Half self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), (0, 5), 1e-5 yield self.create_inference_config(), (0, 5), 2e-2
# for dynamic_shape # for dynamic_shape
generate_dynamic_shape(attrs) generate_dynamic_shape(attrs)
self.trt_param.precision = paddle_infer.PrecisionType.Float32 self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), (1, 4), 1e-5 yield self.create_inference_config(), (1, 4), 1e-5
self.trt_param.precision = paddle_infer.PrecisionType.Half self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), (1, 4), 1e-5 yield self.create_inference_config(), (1, 4), 2e-2
def add_skip_trt_case(self):
def teller1(program_config, predictor_config):
if self.trt_param.precision == paddle_infer.PrecisionType.Half and len(
self.dynamic_shape.min_input_shape) != 0:
return True
return False
self.add_skip_case(
teller1, SkipReasons.TRT_NOT_IMPLEMENTED,
"The output has diff between gpu and trt when dynamic fp16 mode.")
def test(self): def test(self):
self.add_skip_trt_case()
self.run_test() self.run_test()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册