未验证 提交 338fb32b 编写于 作者: C chen 提交者: GitHub

[TRT] PrelnResidualBiasPluginDynamic Support 4D Inputs (#56304)

上级 9b5f6140
...@@ -358,10 +358,25 @@ int PrelnResidualBiasPluginDynamic::enqueue( ...@@ -358,10 +358,25 @@ int PrelnResidualBiasPluginDynamic::enqueue(
void *workspace, void *workspace,
cudaStream_t stream) TRT_NOEXCEPT { cudaStream_t stream) TRT_NOEXCEPT {
auto input_dims = input_desc[0].dims; auto input_dims = input_desc[0].dims;
int hidden = input_dims.d[2]; int hidden;
const size_t rows = static_cast<size_t>( int rows_temp;
input_dims.d[0] * input_dims.d[1]); // batch * seq_length int cols_temp;
const size_t cols = static_cast<size_t>(input_dims.d[2]); if (input_dims.nbDims == 3) {
hidden = input_dims.d[2];
rows_temp = input_dims.d[0] * input_dims.d[1];
cols_temp = input_dims.d[2];
} else if (input_dims.nbDims == 4) {
hidden = input_dims.d[3];
rows_temp = input_dims.d[0] * input_dims.d[1] * input_dims.d[2];
cols_temp = input_dims.d[3];
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"fused_bias_dropout_residual_layer_norm op only support 3-D or 4-D "
"input, but get `%d`-D.",
input_dims.nbDims));
}
const size_t rows = static_cast<size_t>(rows_temp); // batch * seq_length
const size_t cols = static_cast<size_t>(cols_temp);
auto input_type = input_desc[0].type; auto input_type = input_desc[0].type;
if (input_type == nvinfer1::DataType::kFLOAT) { if (input_type == nvinfer1::DataType::kFLOAT) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册