未验证 提交 ac744db1 编写于 作者: W Wangzheee 提交者: GitHub

fix Paddle-Trt concat, slice (#39277)

上级 397781f1
...@@ -44,12 +44,17 @@ class ConcatOpConverter : public OpConverter { ...@@ -44,12 +44,17 @@ class ConcatOpConverter : public OpConverter {
itensors.push_back(engine_->GetITensor(input_name)); itensors.push_back(engine_->GetITensor(input_name));
} }
int axis = BOOST_GET_CONST(int, op_desc.GetAttr("axis")); int axis = BOOST_GET_CONST(int, op_desc.GetAttr("axis"));
if (axis == -1) {
axis = (engine_->GetITensor(op_desc.Input("X").front())->getDimensions())
.nbDims -
1;
} else {
if (!engine_->with_dynamic_shape()) {
axis = axis - 1; // Remove batch dim
}
}
auto* layer = TRT_ENGINE_ADD_LAYER(engine_, Concatenation, itensors.data(), auto* layer = TRT_ENGINE_ADD_LAYER(engine_, Concatenation, itensors.data(),
itensors.size()); itensors.size());
if (!engine_->with_dynamic_shape()) {
axis = axis - 1; // Remove batch dim
}
layer->setAxis(axis); layer->setAxis(axis);
auto output_name = op_desc.Output("Out")[0]; auto output_name = op_desc.Output("Out")[0];
RreplenishLayerAndOutput(layer, "concat", {output_name}, test_mode); RreplenishLayerAndOutput(layer, "concat", {output_name}, test_mode);
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...@@ -72,7 +69,8 @@ class SliceOpConverter : public OpConverter { ...@@ -72,7 +69,8 @@ class SliceOpConverter : public OpConverter {
nvinfer1::ILayer* layer = nullptr; nvinfer1::ILayer* layer = nullptr;
if (engine_->with_dynamic_shape()) { if (engine_->with_dynamic_shape()) {
#if IS_TRT_VERSION_GE(6000) #if IS_TRT_VERSION_GE(6000)
if (engine_->use_oss() && engine_->with_ernie()) { if (engine_->use_oss() && engine_->with_ernie() &&
input_dims.nbDims == 4) {
std::vector<nvinfer1::ITensor*> plugin_inputs; std::vector<nvinfer1::ITensor*> plugin_inputs;
// plugin_inputs.emplace_back(trans_layer->getOutput(0)); // plugin_inputs.emplace_back(trans_layer->getOutput(0));
plugin_inputs.emplace_back(input); plugin_inputs.emplace_back(input);
......
...@@ -421,10 +421,8 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, ...@@ -421,10 +421,8 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
return false; return false;
} }
int axis = BOOST_GET_CONST(int, desc.GetAttr("axis")); int axis = BOOST_GET_CONST(int, desc.GetAttr("axis"));
if (with_dynamic_shape) { if (!with_dynamic_shape) {
if (axis < 0) return false; if (axis == 0) return false;
} else {
if (axis <= 0) return false;
} }
auto concat_inputs = desc.Inputs(); auto concat_inputs = desc.Inputs();
if (concat_inputs.find("AxisTensor") != concat_inputs.end()) { if (concat_inputs.find("AxisTensor") != concat_inputs.end()) {
......
...@@ -113,32 +113,38 @@ nvinfer1::DataType SpecialSlicePluginDynamic::getOutputDataType( ...@@ -113,32 +113,38 @@ nvinfer1::DataType SpecialSlicePluginDynamic::getOutputDataType(
template <typename T> template <typename T>
__global__ void SpecialSliceKernel(const T* slice_input, __global__ void SpecialSliceKernel(const T* slice_input,
const int32_t* cu_seqlens, T* output) { const int32_t* cu_seqlens, T* output) {
const int hidden = blockDim.x; const int hidden = blockDim.x * gridDim.y;
const int batch = blockIdx.x; const int batch = blockIdx.x;
const int local_idx = blockIdx.y * blockDim.y + threadIdx.x;
output[batch * hidden + threadIdx.x] = output[batch * hidden + local_idx] =
slice_input[cu_seqlens[batch] * hidden + threadIdx.x]; slice_input[cu_seqlens[batch] * hidden + local_idx];
} }
int SpecialSlicePluginDynamic::enqueue( int SpecialSlicePluginDynamic::enqueue(
const nvinfer1::PluginTensorDesc* input_desc, const nvinfer1::PluginTensorDesc* input_desc,
const nvinfer1::PluginTensorDesc* output_desc, const void* const* inputs, const nvinfer1::PluginTensorDesc* output_desc, const void* const* inputs,
void* const* outputs, void* workspace, cudaStream_t stream) TRT_NOEXCEPT { void* const* outputs, void* workspace, cudaStream_t stream) TRT_NOEXCEPT {
auto input_dims = input_desc[0].dims; // (sum(S), 768, 1, 1) auto input_dims = input_desc[0].dims; // (sum(S), hidden, 1, 1)
auto out_dims = output_desc[0].dims; // (batch, 768, 1, 1) auto out_dims = output_desc[0].dims; // (batch, hidden, 1, 1)
assert(input_desc[0].type == nvinfer1::DataType::kHALF); PADDLE_ENFORCE_EQ(
input_desc[0].type, nvinfer1::DataType::kHALF,
platform::errors::InvalidArgument("Type of input should be half."));
const int32_t hidden = input_dims.d[1]; const int32_t hidden = input_dims.d[1];
const int num_blocks = out_dims.d[0]; // batch size PADDLE_ENFORCE_EQ(hidden % 128, 0, platform::errors::InvalidArgument(
const int num_threads = hidden; "hidden should be multiple of 128."));
constexpr int num_threads = 128;
const dim3 blocks(out_dims.d[0], hidden / num_threads);
const half* slice_input = static_cast<const half*>(inputs[0]); const half* slice_input = static_cast<const half*>(inputs[0]);
const int32_t* cu_seqlens = static_cast<const int32_t*>(inputs[1]); const int32_t* cu_seqlens = static_cast<const int32_t*>(inputs[1]);
half* output = static_cast<half*>(outputs[0]); half* output = static_cast<half*>(outputs[0]);
SpecialSliceKernel<<<num_blocks, num_threads, 0, stream>>>( SpecialSliceKernel<<<blocks, num_threads, 0, stream>>>(slice_input,
slice_input, cu_seqlens, output); cu_seqlens, output);
return cudaGetLastError() != cudaSuccess; return cudaGetLastError() != cudaSuccess;
} }
......
...@@ -71,7 +71,7 @@ class TrtConvertConcatTest(TrtLayerAutoScanTest): ...@@ -71,7 +71,7 @@ class TrtConvertConcatTest(TrtLayerAutoScanTest):
def generate_weight1(attrs: List[Dict[str, Any]]): def generate_weight1(attrs: List[Dict[str, Any]]):
return np.zeros([1]).astype(np.int32) return np.zeros([1]).astype(np.int32)
for dims in [1, 2, 3, 4]: for dims in [2, 3, 4]:
for num_input in [0, 1]: for num_input in [0, 1]:
for batch in [1, 2, 4]: for batch in [1, 2, 4]:
for axis in [-1, 0, 1, 2, 3]: for axis in [-1, 0, 1, 2, 3]:
...@@ -277,12 +277,9 @@ class TrtConvertConcatTest(TrtLayerAutoScanTest): ...@@ -277,12 +277,9 @@ class TrtConvertConcatTest(TrtLayerAutoScanTest):
def generate_trt_nodes_num(attrs, dynamic_shape): def generate_trt_nodes_num(attrs, dynamic_shape):
if dynamic_shape == True: if dynamic_shape == True:
if attrs[0]['axis'] >= 0: return 1, 4
return 1, 4
else:
return 0, 5
else: else:
if attrs[0]['axis'] > 0: if attrs[0]['axis'] != 0:
return 1, 4 return 1, 4
else: else:
return 0, 5 return 0, 5
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册