未验证 提交 338fb32b 编写于 作者: C chen 提交者: GitHub

[TRT] PrelnResidualBiasPluginDynamic Support 4D Inputs (#56304)

上级 9b5f6140
......@@ -358,10 +358,25 @@ int PrelnResidualBiasPluginDynamic::enqueue(
void *workspace,
cudaStream_t stream) TRT_NOEXCEPT {
auto input_dims = input_desc[0].dims;
int hidden = input_dims.d[2];
const size_t rows = static_cast<size_t>(
input_dims.d[0] * input_dims.d[1]); // batch * seq_length
const size_t cols = static_cast<size_t>(input_dims.d[2]);
int hidden;
int rows_temp;
int cols_temp;
if (input_dims.nbDims == 3) {
hidden = input_dims.d[2];
rows_temp = input_dims.d[0] * input_dims.d[1];
cols_temp = input_dims.d[2];
} else if (input_dims.nbDims == 4) {
hidden = input_dims.d[3];
rows_temp = input_dims.d[0] * input_dims.d[1] * input_dims.d[2];
cols_temp = input_dims.d[3];
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"fused_bias_dropout_residual_layer_norm op only support 3-D or 4-D "
"input, but get `%d`-D.",
input_dims.nbDims));
}
const size_t rows = static_cast<size_t>(rows_temp); // batch * seq_length
const size_t cols = static_cast<size_t>(cols_temp);
auto input_type = input_desc[0].type;
if (input_type == nvinfer1::DataType::kFLOAT) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册