未验证 提交 c3f3643b 编写于 作者: W wenbin 提交者: GitHub

EmbEltwiseLayernorm fix (#40015)

* emb fix

* fix trt6 compile

* fix half

* absolute error fix
上级 5d9e11a4
......@@ -54,6 +54,8 @@ TRT_DT FluidDataType2TRT(FluidDT type) {
return TRT_DT::kFLOAT;
case FluidDT::VarType_Type_INT32:
return TRT_DT::kINT32;
case FluidDT::VarType_Type_FP16:
return TRT_DT::kHALF;
default:
return TRT_DT::kINT32;
}
......
......@@ -79,6 +79,28 @@ static void RuntimeStaticShapeCheck(std::vector<int64_t> runtime_input_shape,
model_input_shape_str, runtime_input_shape_str));
}
static paddle::experimental::DataType TRT2FluidDataType(
nvinfer1::DataType type) {
switch (type) {
case nvinfer1::DataType::kFLOAT:
return paddle::experimental::DataType::FLOAT32;
case nvinfer1::DataType::kINT32:
return paddle::experimental::DataType::INT32;
case nvinfer1::DataType::kHALF:
return paddle::experimental::DataType::FLOAT16;
case nvinfer1::DataType::kINT8:
return paddle::experimental::DataType::INT8;
#if IS_TRT_VERSION_GE(7000)
case nvinfer1::DataType::kBOOL:
return paddle::experimental::DataType::BOOL;
#endif
default:
PADDLE_THROW(platform::errors::InvalidArgument(
"unknown fluid datatype in Fluid op converter"));
return paddle::experimental::DataType::FLOAT32;
}
}
static void RuntimeDynamicShapeCheck(
const std::string &x, const std::vector<int32_t> &runtime_input_shape,
const std::vector<int32_t> &min_input_shape,
......@@ -520,9 +542,12 @@ class TensorRTEngineOp : public framework::OperatorBase {
buffers[bind_index] = static_cast<void *>(t.data<int64_t>());
} else if (type == framework::proto::VarType::INT32) {
buffers[bind_index] = static_cast<void *>(t.data<int32_t>());
} else if (type == framework::proto::VarType::FP16) {
buffers[bind_index] = static_cast<void *>(t.data<float16>());
} else {
PADDLE_THROW(platform::errors::Fatal(
"The TRT Engine OP only support float/int32_t/int64_t input."));
PADDLE_THROW(
platform::errors::Fatal("The TRT Engine OP only support "
"float/int32_t/int64_t/float16 input."));
}
}
......@@ -570,9 +595,10 @@ class TensorRTEngineOp : public framework::OperatorBase {
"than the number of bindings, but got binding "
"index = %d, number of bindings = %d.",
bind_index, num_bindings));
buffers[bind_index] =
static_cast<void *>(fluid_t->mutable_data<float>(dev_place));
auto trt_type = engine->engine()->getBindingDataType(bind_index);
// get adr and set type
buffers[bind_index] = static_cast<void *>(
fluid_t->mutable_data(dev_place, TRT2FluidDataType(trt_type)));
output_index += 1;
}
......
......@@ -244,28 +244,16 @@ class TrtConvertEmbEltwiseLayernormTest1(TrtLayerAutoScanTest):
self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), (0, 5), 1e-5
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), (0, 5), 1e-5
yield self.create_inference_config(), (0, 5), 2e-2
# for dynamic_shape
generate_dynamic_shape(attrs)
self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), (1, 4), 1e-5
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), (1, 4), 1e-5
def add_skip_trt_case(self):
def teller1(program_config, predictor_config):
if self.trt_param.precision == paddle_infer.PrecisionType.Half and len(
self.dynamic_shape.min_input_shape) != 0:
return True
return False
self.add_skip_case(
teller1, SkipReasons.TRT_NOT_IMPLEMENTED,
"The output has diff between gpu and trt when dynamic fp16 mode.")
yield self.create_inference_config(), (1, 4), 2e-2
def test(self):
self.add_skip_trt_case()
self.run_test()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册