未验证 提交 ac744db1 编写于 作者: W Wangzheee 提交者: GitHub

fix Paddle-Trt concat, slice (#39277)

上级 397781f1
......@@ -44,12 +44,17 @@ class ConcatOpConverter : public OpConverter {
itensors.push_back(engine_->GetITensor(input_name));
}
int axis = BOOST_GET_CONST(int, op_desc.GetAttr("axis"));
if (axis == -1) {
axis = (engine_->GetITensor(op_desc.Input("X").front())->getDimensions())
.nbDims -
1;
} else {
if (!engine_->with_dynamic_shape()) {
axis = axis - 1; // Remove batch dim
}
}
auto* layer = TRT_ENGINE_ADD_LAYER(engine_, Concatenation, itensors.data(),
itensors.size());
if (!engine_->with_dynamic_shape()) {
axis = axis - 1; // Remove batch dim
}
layer->setAxis(axis);
auto output_name = op_desc.Output("Out")[0];
RreplenishLayerAndOutput(layer, "concat", {output_name}, test_mode);
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
......@@ -72,7 +69,8 @@ class SliceOpConverter : public OpConverter {
nvinfer1::ILayer* layer = nullptr;
if (engine_->with_dynamic_shape()) {
#if IS_TRT_VERSION_GE(6000)
if (engine_->use_oss() && engine_->with_ernie()) {
if (engine_->use_oss() && engine_->with_ernie() &&
input_dims.nbDims == 4) {
std::vector<nvinfer1::ITensor*> plugin_inputs;
// plugin_inputs.emplace_back(trans_layer->getOutput(0));
plugin_inputs.emplace_back(input);
......
......@@ -421,10 +421,8 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
return false;
}
int axis = BOOST_GET_CONST(int, desc.GetAttr("axis"));
if (with_dynamic_shape) {
if (axis < 0) return false;
} else {
if (axis <= 0) return false;
if (!with_dynamic_shape) {
if (axis == 0) return false;
}
auto concat_inputs = desc.Inputs();
if (concat_inputs.find("AxisTensor") != concat_inputs.end()) {
......
......@@ -113,32 +113,38 @@ nvinfer1::DataType SpecialSlicePluginDynamic::getOutputDataType(
template <typename T>
__global__ void SpecialSliceKernel(const T* slice_input,
const int32_t* cu_seqlens, T* output) {
const int hidden = blockDim.x;
const int hidden = blockDim.x * gridDim.y;
const int batch = blockIdx.x;
const int local_idx = blockIdx.y * blockDim.y + threadIdx.x;
output[batch * hidden + threadIdx.x] =
slice_input[cu_seqlens[batch] * hidden + threadIdx.x];
output[batch * hidden + local_idx] =
slice_input[cu_seqlens[batch] * hidden + local_idx];
}
int SpecialSlicePluginDynamic::enqueue(
const nvinfer1::PluginTensorDesc* input_desc,
const nvinfer1::PluginTensorDesc* output_desc, const void* const* inputs,
void* const* outputs, void* workspace, cudaStream_t stream) TRT_NOEXCEPT {
auto input_dims = input_desc[0].dims; // (sum(S), 768, 1, 1)
auto out_dims = output_desc[0].dims; // (batch, 768, 1, 1)
auto input_dims = input_desc[0].dims; // (sum(S), hidden, 1, 1)
auto out_dims = output_desc[0].dims; // (batch, hidden, 1, 1)
assert(input_desc[0].type == nvinfer1::DataType::kHALF);
PADDLE_ENFORCE_EQ(
input_desc[0].type, nvinfer1::DataType::kHALF,
platform::errors::InvalidArgument("Type of input should be half."));
const int32_t hidden = input_dims.d[1];
const int num_blocks = out_dims.d[0]; // batch size
const int num_threads = hidden;
PADDLE_ENFORCE_EQ(hidden % 128, 0, platform::errors::InvalidArgument(
"hidden should be multiple of 128."));
constexpr int num_threads = 128;
const dim3 blocks(out_dims.d[0], hidden / num_threads);
const half* slice_input = static_cast<const half*>(inputs[0]);
const int32_t* cu_seqlens = static_cast<const int32_t*>(inputs[1]);
half* output = static_cast<half*>(outputs[0]);
SpecialSliceKernel<<<num_blocks, num_threads, 0, stream>>>(
slice_input, cu_seqlens, output);
SpecialSliceKernel<<<blocks, num_threads, 0, stream>>>(slice_input,
cu_seqlens, output);
return cudaGetLastError() != cudaSuccess;
}
......
......@@ -71,7 +71,7 @@ class TrtConvertConcatTest(TrtLayerAutoScanTest):
def generate_weight1(attrs: List[Dict[str, Any]]):
return np.zeros([1]).astype(np.int32)
for dims in [1, 2, 3, 4]:
for dims in [2, 3, 4]:
for num_input in [0, 1]:
for batch in [1, 2, 4]:
for axis in [-1, 0, 1, 2, 3]:
......@@ -277,12 +277,9 @@ class TrtConvertConcatTest(TrtLayerAutoScanTest):
def generate_trt_nodes_num(attrs, dynamic_shape):
if dynamic_shape == True:
if attrs[0]['axis'] >= 0:
return 1, 4
else:
return 0, 5
return 1, 4
else:
if attrs[0]['axis'] > 0:
if attrs[0]['axis'] != 0:
return 1, 4
else:
return 0, 5
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册