From 35148d17f7c170fb4c3e448bea4f6557bb90e566 Mon Sep 17 00:00:00 2001 From: Zhaolong Xing Date: Thu, 23 Apr 2020 00:51:15 +0800 Subject: [PATCH] [BUG]: Head number can only be > 1 on multihead op (#23974) * support the head number == 1 test=develop * fix slice op error. test=develop --- paddle/fluid/inference/tensorrt/plugin/gelu_op_plugin.cu | 5 ++--- paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.cu | 5 +---- paddle/fluid/operators/fused/multihead_matmul_op.cc | 7 ------- 3 files changed, 3 insertions(+), 14 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/plugin/gelu_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/gelu_op_plugin.cu index 30f1c37ab18..6a718d47b15 100644 --- a/paddle/fluid/inference/tensorrt/plugin/gelu_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/gelu_op_plugin.cu @@ -194,9 +194,8 @@ int GeluPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc* input_desc, if (input_type == nvinfer1::DataType::kFLOAT) { const float* input = static_cast(inputs[0]); float* output = static_cast(outputs[0]); - no_exact_gelu_kernel<<>>( - kAT, kBT, kCT, num, input, output); + gelu_kernel<<>>( + kA, num, input, output); } else if (input_type == nvinfer1::DataType::kHALF) { #ifdef SUPPORTS_CUDA_FP16 const half* input = static_cast(inputs[0]); diff --git a/paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.cu index 9c51bc9b8d2..7b2b7b10f08 100644 --- a/paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.cu @@ -68,10 +68,7 @@ nvinfer1::DimsExprs SlicePluginDynamic::getOutputDimensions( int output_index, const nvinfer1::DimsExprs *inputs, int nb_inputs, nvinfer1::IExprBuilder &expr_builder) { auto in_dims = inputs[0]; - nvinfer1::DimsExprs ret; - for (int i = 0; i < ret.nbDims; i++) { - ret.d[i] = in_dims.d[i]; - } + nvinfer1::DimsExprs ret = in_dims; // start, ends should greater 0 for (size_t i = 0; i < axes_.size(); i++) { int start = starts_[i]; diff --git a/paddle/fluid/operators/fused/multihead_matmul_op.cc b/paddle/fluid/operators/fused/multihead_matmul_op.cc index ad8db4c62ec..8f2c04d5afe 100644 --- a/paddle/fluid/operators/fused/multihead_matmul_op.cc +++ b/paddle/fluid/operators/fused/multihead_matmul_op.cc @@ -69,13 +69,6 @@ class MultiHeadMatMulV2Op : public framework::OperatorWithKernel { "but it's %d-D tensor now.", dim_bias_qk.size())); - int head_number = context->Attrs().Get("head_number"); - PADDLE_ENFORCE_GT( - head_number, 1, - platform::errors::InvalidArgument( - "Multihead input head number should be at least 1, but it %d now.", - head_number)); - // modify this auto dim_input = context->GetInputDim("Input"); context->SetOutputDim("Out", dim_input); context->ShareLoD("Input", /*->*/ "Out"); -- GitLab