未验证 提交 35148d17 编写于 作者: Z Zhaolong Xing 提交者: GitHub

[BUG]: Head number can only be > 1 on multihead op (#23974)

* support the head number == 1
test=develop

* fix slice op error.
test=develop
上级 54a47cd2
......@@ -194,9 +194,8 @@ int GeluPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc* input_desc,
if (input_type == nvinfer1::DataType::kFLOAT) {
const float* input = static_cast<const float*>(inputs[0]);
float* output = static_cast<float*>(outputs[0]);
no_exact_gelu_kernel<float,
block_size><<<grid_size, block_size, 0, stream>>>(
kAT, kBT, kCT, num, input, output);
gelu_kernel<float, block_size><<<grid_size, block_size, 0, stream>>>(
kA, num, input, output);
} else if (input_type == nvinfer1::DataType::kHALF) {
#ifdef SUPPORTS_CUDA_FP16
const half* input = static_cast<const half*>(inputs[0]);
......
......@@ -68,10 +68,7 @@ nvinfer1::DimsExprs SlicePluginDynamic::getOutputDimensions(
int output_index, const nvinfer1::DimsExprs *inputs, int nb_inputs,
nvinfer1::IExprBuilder &expr_builder) {
auto in_dims = inputs[0];
nvinfer1::DimsExprs ret;
for (int i = 0; i < ret.nbDims; i++) {
ret.d[i] = in_dims.d[i];
}
nvinfer1::DimsExprs ret = in_dims;
// start, ends should greater 0
for (size_t i = 0; i < axes_.size(); i++) {
int start = starts_[i];
......
......@@ -69,13 +69,6 @@ class MultiHeadMatMulV2Op : public framework::OperatorWithKernel {
"but it's %d-D tensor now.",
dim_bias_qk.size()));
int head_number = context->Attrs().Get<int>("head_number");
PADDLE_ENFORCE_GT(
head_number, 1,
platform::errors::InvalidArgument(
"Multihead input head number should be at least 1, but it %d now.",
head_number));
// modify this
auto dim_input = context->GetInputDim("Input");
context->SetOutputDim("Out", dim_input);
context->ShareLoD("Input", /*->*/ "Out");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册