未验证 提交 f080e8d5 编写于 作者: W Wangzheee 提交者: GitHub

[Paddle-Inference]: fix concat slice (#39096)

* Paddle-Inference:fix_concat_slice

* Paddle-Inference:fix_concat_slice

* Paddle-Inference:fix_concat_slice

* Paddle-Inference:fix_concat_slice

* [Paddle-Inference]: fix concat slice

* [Paddle-Inference]: fix concat slice

* [Paddle-Inference]: fix concat slice
上级 d6d745d2
...@@ -44,12 +44,17 @@ class ConcatOpConverter : public OpConverter { ...@@ -44,12 +44,17 @@ class ConcatOpConverter : public OpConverter {
itensors.push_back(engine_->GetITensor(input_name)); itensors.push_back(engine_->GetITensor(input_name));
} }
int axis = BOOST_GET_CONST(int, op_desc.GetAttr("axis")); int axis = BOOST_GET_CONST(int, op_desc.GetAttr("axis"));
if (axis == -1) {
auto* layer = TRT_ENGINE_ADD_LAYER(engine_, Concatenation, itensors.data(), axis = (engine_->GetITensor(op_desc.Input("X").front())->getDimensions())
itensors.size()); .nbDims -
1;
} else {
if (!engine_->with_dynamic_shape()) { if (!engine_->with_dynamic_shape()) {
axis = axis - 1; // Remove batch dim axis = axis - 1; // Remove batch dim
} }
}
auto* layer = TRT_ENGINE_ADD_LAYER(engine_, Concatenation, itensors.data(),
itensors.size());
layer->setAxis(axis); layer->setAxis(axis);
auto output_name = op_desc.Output("Out")[0]; auto output_name = op_desc.Output("Out")[0];
RreplenishLayerAndOutput(layer, "concat", {output_name}, test_mode); RreplenishLayerAndOutput(layer, "concat", {output_name}, test_mode);
......
...@@ -72,7 +72,8 @@ class SliceOpConverter : public OpConverter { ...@@ -72,7 +72,8 @@ class SliceOpConverter : public OpConverter {
nvinfer1::ILayer* layer = nullptr; nvinfer1::ILayer* layer = nullptr;
if (engine_->with_dynamic_shape()) { if (engine_->with_dynamic_shape()) {
if (engine_->use_oss() && engine_->with_ernie()) { if (engine_->use_oss() && engine_->with_ernie() &&
input_dims.nbDims == 4) {
std::vector<nvinfer1::ITensor*> plugin_inputs; std::vector<nvinfer1::ITensor*> plugin_inputs;
if (engine_->with_interleaved()) { if (engine_->with_interleaved()) {
auto* shuffler_slice = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input); auto* shuffler_slice = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input);
...@@ -81,7 +82,7 @@ class SliceOpConverter : public OpConverter { ...@@ -81,7 +82,7 @@ class SliceOpConverter : public OpConverter {
engine_->SetTensorDynamicRange(shuffler_slice->getOutput(0), engine_->SetTensorDynamicRange(shuffler_slice->getOutput(0),
out_scale); out_scale);
shuffler_slice->setName( shuffler_slice->setName(
("SpecialSlice_interleaved: Shuffle: (Output: " + output_name + ("SpecialSlice_interleaved: transpose: (Output: " + output_name +
")") ")")
.c_str()); .c_str());
plugin_inputs.emplace_back(shuffler_slice->getOutput(0)); plugin_inputs.emplace_back(shuffler_slice->getOutput(0));
......
...@@ -437,10 +437,8 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, ...@@ -437,10 +437,8 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
return false; return false;
} }
int axis = BOOST_GET_CONST(int, desc.GetAttr("axis")); int axis = BOOST_GET_CONST(int, desc.GetAttr("axis"));
if (with_dynamic_shape) { if (!with_dynamic_shape) {
if (axis < 0) return false; if (axis == 0) return false;
} else {
if (axis <= 0) return false;
} }
auto concat_inputs = desc.Inputs(); auto concat_inputs = desc.Inputs();
if (concat_inputs.find("AxisTensor") != concat_inputs.end()) { if (concat_inputs.find("AxisTensor") != concat_inputs.end()) {
......
...@@ -113,32 +113,34 @@ nvinfer1::DataType SpecialSlicePluginDynamic::getOutputDataType( ...@@ -113,32 +113,34 @@ nvinfer1::DataType SpecialSlicePluginDynamic::getOutputDataType(
template <typename T> template <typename T>
__global__ void SpecialSliceKernel(const T* slice_input, __global__ void SpecialSliceKernel(const T* slice_input,
const int32_t* cu_seqlens, T* output) { const int32_t* cu_seqlens, T* output) {
const int hidden = blockDim.x; const int hidden = blockDim.x * gridDim.y;
const int batch = blockIdx.x; const int batch = blockIdx.x;
const int local_idx = blockIdx.y * blockDim.y + threadIdx.x;
output[batch * hidden + threadIdx.x] = output[batch * hidden + local_idx] =
slice_input[cu_seqlens[batch] * hidden + threadIdx.x]; slice_input[cu_seqlens[batch] * hidden + local_idx];
} }
int SpecialSlicePluginDynamic::enqueue( int SpecialSlicePluginDynamic::enqueue(
const nvinfer1::PluginTensorDesc* input_desc, const nvinfer1::PluginTensorDesc* input_desc,
const nvinfer1::PluginTensorDesc* output_desc, const void* const* inputs, const nvinfer1::PluginTensorDesc* output_desc, const void* const* inputs,
void* const* outputs, void* workspace, cudaStream_t stream) TRT_NOEXCEPT { void* const* outputs, void* workspace, cudaStream_t stream) TRT_NOEXCEPT {
auto input_dims = input_desc[0].dims; // (sum(S), 768, 1, 1) auto input_dims = input_desc[0].dims; // (sum(S), hidden, 1, 1)
auto out_dims = output_desc[0].dims; // (batch, 768, 1, 1) auto out_dims = output_desc[0].dims; // (batch, hidden, 1, 1)
assert(input_desc[0].type == nvinfer1::DataType::kHALF); assert(input_desc[0].type == nvinfer1::DataType::kHALF);
assert(hidden % 128 == 0);
const int32_t hidden = input_dims.d[1]; const int32_t hidden = input_dims.d[1];
const int num_blocks = out_dims.d[0]; // batch size constexpr int num_threads = 128;
const int num_threads = hidden; const dim3 blocks(out_dims.d[0], hidden / num_threads);
const half* slice_input = static_cast<const half*>(inputs[0]); const half* slice_input = static_cast<const half*>(inputs[0]);
const int32_t* cu_seqlens = static_cast<const int32_t*>(inputs[1]); const int32_t* cu_seqlens = static_cast<const int32_t*>(inputs[1]);
half* output = static_cast<half*>(outputs[0]); half* output = static_cast<half*>(outputs[0]);
SpecialSliceKernel<<<num_blocks, num_threads, 0, stream>>>( SpecialSliceKernel<<<blocks, num_threads, 0, stream>>>(slice_input,
slice_input, cu_seqlens, output); cu_seqlens, output);
return cudaGetLastError() != cudaSuccess; return cudaGetLastError() != cudaSuccess;
} }
......
...@@ -71,7 +71,7 @@ class TrtConvertConcatTest(TrtLayerAutoScanTest): ...@@ -71,7 +71,7 @@ class TrtConvertConcatTest(TrtLayerAutoScanTest):
def generate_weight1(attrs: List[Dict[str, Any]]): def generate_weight1(attrs: List[Dict[str, Any]]):
return np.zeros([1]).astype(np.int32) return np.zeros([1]).astype(np.int32)
for dims in [1, 2, 3, 4]: for dims in [2, 3, 4]:
for num_input in [0, 1]: for num_input in [0, 1]:
for batch in [1, 2, 4]: for batch in [1, 2, 4]:
for axis in [-1, 0, 1, 2, 3]: for axis in [-1, 0, 1, 2, 3]:
...@@ -277,12 +277,9 @@ class TrtConvertConcatTest(TrtLayerAutoScanTest): ...@@ -277,12 +277,9 @@ class TrtConvertConcatTest(TrtLayerAutoScanTest):
def generate_trt_nodes_num(attrs, dynamic_shape): def generate_trt_nodes_num(attrs, dynamic_shape):
if dynamic_shape == True: if dynamic_shape == True:
if attrs[0]['axis'] >= 0:
return 1, 4 return 1, 4
else: else:
return 0, 5 if attrs[0]['axis'] != 0:
else:
if attrs[0]['axis'] > 0:
return 1, 4 return 1, 4
else: else:
return 0, 5 return 0, 5
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册