未验证 提交 65c17315 编写于 作者: W Wangzheee 提交者: GitHub

[Paddle Inference]optimize token prune for no varlen (#49094)

* optimize token prune for no varlen
上级 4cdeab7b
......@@ -38,6 +38,17 @@ class FusedTokenPruneOpConverter : public OpConverter {
auto output_name = op_desc.Output("SlimmedX")[0];
auto out_inds_name = op_desc.Output("CLSInds")[0];
if (engine_->with_dynamic_shape()) {
// reduce_sum: (-1,headsize,token_length,token_length) ->
// (-1,token_length)
uint32_t reduce_dim = 0;
reduce_dim |= 1 << 1; // 00000000000000000000000000000010
reduce_dim |= 1 << 2; // 00000000000000000000000000000110
bool keep_dim = false;
nvinfer1::ReduceOperation reduce_type = nvinfer1::ReduceOperation::kSUM;
auto* reduce_sum_layer = TRT_ENGINE_ADD_LAYER(
engine_, Reduce, *Attn, reduce_type, reduce_dim, keep_dim);
auto* Reduced = reduce_sum_layer->getOutput(0);
bool with_fp16 =
engine_->WithFp16() && !engine_->disable_trt_plugin_fp16();
......@@ -53,21 +64,10 @@ class FusedTokenPruneOpConverter : public OpConverter {
auto* pos_id = engine_->GetITensor("pos_id");
auto* mask_id = engine_->GetITensor("mask_id");
// reduce_sum: (-1,headsize,token_length,token_length) ->
// (-1,token_length)
uint32_t reduce_dim = 0;
reduce_dim |= 1 << 1; // 00000000000000000000000000000010
reduce_dim |= 1 << 2; // 00000000000000000000000000000110
bool keep_dim = false;
nvinfer1::ReduceOperation reduce_type = nvinfer1::ReduceOperation::kSUM;
auto* reduce_sum_layer = TRT_ENGINE_ADD_LAYER(
engine_, Reduce, *Attn, reduce_type, reduce_dim, keep_dim);
// reduce_sum_layer->getOutput(0)->setType(reduce_sum_layer->getInput(0)->getType());
auto* Reduced = reduce_sum_layer->getOutput(0);
std::vector<nvinfer1::ITensor*> itensors = {
Reduced, X, Mask, NewMask, word_id, pos_id, mask_id};
layer = engine_->AddDynamicPlugin(itensors.data(), 7, plugin);
layer = engine_->AddDynamicPlugin(
itensors.data(), itensors.size(), plugin); // inputs'number: 7
layer->getOutput(0)->setName(output_name.c_str());
engine_->SetITensor(output_name, layer->getOutput(0));
......@@ -87,10 +87,13 @@ class FusedTokenPruneOpConverter : public OpConverter {
layer->getOutput(4)->setName("mask_id_after_token_prune");
engine_->SetITensor("mask_id", layer->getOutput(4));
} else {
std::vector<nvinfer1::ITensor*> itensors = {Attn, X, Mask, NewMask};
layer = engine_->AddDynamicPlugin(itensors.data(), 4, plugin);
std::vector<nvinfer1::ITensor*> itensors = {Reduced, X, Mask, NewMask};
layer = engine_->AddDynamicPlugin(
itensors.data(), itensors.size(), plugin); // inputs'number: 4
layer->getOutput(0)->setName(output_name.c_str());
engine_->SetITensor(output_name, layer->getOutput(0));
layer->getOutput(1)->setName(out_inds_name.c_str());
engine_->SetITensor(out_inds_name, layer->getOutput(1));
}
......
......@@ -93,12 +93,33 @@ class FusedTokenPrunePluginDynamic : public DynamicPluginTensorRT {
int nb_outputs) TRT_NOEXCEPT override {
max_batchs_ = in[1].max.d[0];
max_token_length_ = in[1].max.d[1];
int32_t padding_token_length;
if (max_token_length_ <= 64) {
padding_token_length = 64;
} else if (max_token_length_ <= 128) {
padding_token_length = 128;
} else if (max_token_length_ <= 256) {
padding_token_length = 256;
} else if (max_token_length_ <= 384) {
padding_token_length = 384;
} else if (max_token_length_ <= 512) {
padding_token_length = 512;
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Token_prune'token_length(max) must <= 512"));
}
PADDLE_ENFORCE_GPU_SUCCESS(cudaMalloc(&pruned_token_lengths_,
(max_batchs_ + 1) * sizeof(int32_t)));
PADDLE_ENFORCE_GPU_SUCCESS(cudaMalloc(
&token_index_, max_batchs_ * max_token_length_ * sizeof(int32_t)));
&token_index_, max_batchs_ * padding_token_length * sizeof(int32_t)));
int32_t type_size = 4;
if (in[0].desc.type == nvinfer1::DataType::kHALF) {
type_size = 2;
} else {
type_size = 4;
}
PADDLE_ENFORCE_GPU_SUCCESS(cudaMalloc(
&padding_scores_, max_batchs_ * max_token_length_ * sizeof(half)));
&padding_scores_, max_batchs_ * padding_token_length * type_size));
}
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
......@@ -129,7 +150,7 @@ class FusedTokenPrunePluginDynamic : public DynamicPluginTensorRT {
int32_t* token_index_;
int32_t max_batchs_;
int32_t max_token_length_;
half* padding_scores_;
void* padding_scores_;
};
class FusedTokenPrunePluginDynamicCreator : public nvinfer1::IPluginCreator {
......
......@@ -352,24 +352,24 @@ class TensorRTDynamicTestFusedTokenPrune : public ::testing::Test {
ctx_->PartialInitWithAllocator();
std::map<std::string, std::vector<int>> min_input_shape = {
{"attn", {4, 1, 4, 4}},
{"attn", {4, 4}},
{"x", {4, 4, 1}},
{"mask", {4, 1, 4, 4}},
{"new_mask", {4, 1, 2, 2}}};
std::map<std::string, std::vector<int>> max_input_shape = {
{"attn", {4, 1, 4, 4}},
{"attn", {4, 4}},
{"x", {4, 4, 1}},
{"mask", {4, 1, 4, 4}},
{"new_mask", {4, 1, 2, 2}}};
std::map<std::string, std::vector<int>> optim_input_shape = {
{"attn", {4, 1, 4, 4}},
{"attn", {4, 4}},
{"x", {4, 4, 1}},
{"mask", {4, 1, 4, 4}},
{"new_mask", {4, 1, 2, 2}}};
engine_ = new TensorRTEngine(16,
1 << 10,
AnalysisConfig::Precision::kHalf,
AnalysisConfig::Precision::kFloat32,
nullptr,
0,
min_input_shape,
......@@ -391,7 +391,7 @@ class TensorRTDynamicTestFusedTokenPrune : public ::testing::Test {
}
}
void PrepareInputOutput(const std::vector<std::vector<float16>> inputs,
void PrepareInputOutput(const std::vector<std::vector<float>> inputs,
std::vector<std::vector<int>> output_shapes) {
LOG(INFO) << "PrepareInputOutput";
int num_inputs = inputs.size();
......@@ -423,7 +423,206 @@ TEST_F(TensorRTDynamicTestFusedTokenPrune, test_fused_token_prune) {
#if IS_TRT_VERSION_GE(8000)
tensorrt::plugin::TrtPluginRegistry::Global()->RegistToTrt();
auto *attn = engine_->DeclareInput(
"attn", nvinfer1::DataType::kHALF, nvinfer1::Dims4{-1, 1, 4, 4});
"attn", nvinfer1::DataType::kFLOAT, nvinfer1::Dims2{-1, 4});
auto *x = engine_->DeclareInput(
"x", nvinfer1::DataType::kFLOAT, nvinfer1::Dims3{-1, 4, 1});
auto *mask = engine_->DeclareInput(
"mask", nvinfer1::DataType::kFLOAT, nvinfer1::Dims4{-1, 1, 4, 4});
auto *new_mask = engine_->DeclareInput(
"new_mask", nvinfer1::DataType::kFLOAT, nvinfer1::Dims4{-1, 1, 2, 2});
plugin::FusedTokenPrunePluginDynamic *plugin =
new plugin::FusedTokenPrunePluginDynamic(/*with_fp16*/ false,
/*keep_first_token*/ false,
/*keep_order*/ true,
/*flag_varseqlen*/ false);
std::vector<nvinfer1::ITensor *> itensors = {attn, x, mask, new_mask};
auto *layer = engine_->AddDynamicPlugin(itensors.data(), 4, plugin);
PADDLE_ENFORCE_NOT_NULL(layer,
platform::errors::InvalidArgument(
"TRT fused_token_prune layer building failed."));
std::vector<std::string> output_tensor_names{"out_slimmed_x", "out_cls_inds"};
for (size_t i = 0; i < 2; i++) {
layer->getOutput(i)->setName(output_tensor_names[i].c_str());
engine_->DeclareOutput(layer, i, output_tensor_names[i]);
}
engine_->FreezeNetwork();
ASSERT_EQ(engine_->engine()->getNbBindings(), 6);
LOG(INFO) << "create input";
std::vector<float> attn_v(16);
for (int j = 0; j < 4; ++j) {
for (int k = 0; k < 4; ++k) {
attn_v[j * 4 + k] = k;
}
}
std::vector<float> x_v(16);
for (int i = 0; i < 4; ++i) {
for (int j = 0; j < 4; ++j) {
x_v[i * 4 + j] = 4 - j;
}
}
std::vector<float> mask_v(64);
for (int i = 0; i < 4; ++i) {
for (int j = 0; j < 4; ++j) {
for (int k = 0; k < 4; ++k) {
mask_v[i * 16 + j * 4 + k] = 1;
}
}
}
std::vector<float> new_mask_v(16);
for (int i = 0; i < 4; ++i) {
for (int j = 0; j < 2; ++j) {
for (int k = 0; k < 2; ++k) {
new_mask_v[i * 4 + j * 2 + k] = 1;
}
}
}
LOG(INFO) << "create output";
std::vector<int> out_slimmed_x_shape{4, 2, 1};
std::vector<int> out_cls_ins_shape{4, 2};
PrepareInputOutput({attn_v, x_v, mask_v, new_mask_v},
{out_slimmed_x_shape, out_cls_ins_shape});
auto *attn_gpu_data = inputs_[0].mutable_data<float>(ctx_->GetPlace());
auto *x_gpu_data = inputs_[1].mutable_data<float>(ctx_->GetPlace());
auto *mask_gpu_data = inputs_[2].mutable_data<float>(ctx_->GetPlace());
auto *new_mask_gpu_data = inputs_[3].mutable_data<float>(ctx_->GetPlace());
auto *slimmed_x_gpu_data = outputs_[0].mutable_data<float>(ctx_->GetPlace());
auto *cls_inds_gpu_data = outputs_[1].mutable_data<int32_t>(ctx_->GetPlace());
LOG(INFO) << "create buffers";
std::vector<void *> buffers(6);
buffers[0] = reinterpret_cast<void *>(attn_gpu_data);
buffers[1] = reinterpret_cast<void *>(x_gpu_data);
buffers[2] = reinterpret_cast<void *>(mask_gpu_data);
buffers[3] = reinterpret_cast<void *>(new_mask_gpu_data);
buffers[4] = reinterpret_cast<void *>(slimmed_x_gpu_data);
buffers[5] = reinterpret_cast<void *>(cls_inds_gpu_data);
LOG(INFO) << "Execute";
engine_->Execute(4, &buffers, ctx_->stream());
std::vector<float> slimmed_x_v(8);
std::vector<int32_t> cls_inds_v;
LOG(INFO) << "GetOutput";
GetOutput(slimmed_x_v, cls_inds_v);
// slimmed_x_v: [[4,3,2,1],[4,3,2,1],[4,3,2,1],[4,3,2,1]] ->
// [[2,1],[2,1],[2,1],[2,1]]
ASSERT_EQ(slimmed_x_v[0], 2);
ASSERT_EQ(slimmed_x_v[1], 1);
ASSERT_EQ(slimmed_x_v[2], 2);
ASSERT_EQ(slimmed_x_v[3], 1);
ASSERT_EQ(slimmed_x_v[4], 2);
ASSERT_EQ(slimmed_x_v[5], 1);
ASSERT_EQ(slimmed_x_v[6], 2);
ASSERT_EQ(slimmed_x_v[7], 1);
LOG(INFO) << "finish";
#endif
}
class TensorRTDynamicTestFusedTokenPruneHalf : public ::testing::Test {
protected:
void SetUp() override {
ctx_ = new phi::GPUContext(platform::CUDAPlace(0));
ctx_->SetAllocator(paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(platform::CUDAPlace(0), ctx_->stream())
.get());
ctx_->SetHostAllocator(
paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(paddle::platform::CPUPlace())
.get());
ctx_->SetZeroAllocator(
paddle::memory::allocation::AllocatorFacade::Instance()
.GetZeroAllocator(platform::CUDAPlace(0))
.get());
ctx_->SetPinnedAllocator(
paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(paddle::platform::CUDAPinnedPlace())
.get());
ctx_->PartialInitWithAllocator();
std::map<std::string, std::vector<int>> min_input_shape = {
{"attn", {4, 4}},
{"x", {4, 4, 1}},
{"mask", {4, 1, 4, 4}},
{"new_mask", {4, 1, 2, 2}}};
std::map<std::string, std::vector<int>> max_input_shape = {
{"attn", {4, 4}},
{"x", {4, 4, 1}},
{"mask", {4, 1, 4, 4}},
{"new_mask", {4, 1, 2, 2}}};
std::map<std::string, std::vector<int>> optim_input_shape = {
{"attn", {4, 4}},
{"x", {4, 4, 1}},
{"mask", {4, 1, 4, 4}},
{"new_mask", {4, 1, 2, 2}}};
engine_ = new TensorRTEngine(16,
1 << 10,
AnalysisConfig::Precision::kHalf,
nullptr,
0,
min_input_shape,
max_input_shape,
optim_input_shape,
std::map<std::string, std::vector<int>>(),
std::map<std::string, std::vector<int>>(),
std::map<std::string, std::vector<int>>(),
false,
phi::DataType::FLOAT16,
NaiveLogger::Global());
engine_->InitNetwork();
}
void TearDown() override {
if (engine_) {
delete engine_;
engine_ = nullptr;
}
}
void PrepareInputOutput(const std::vector<std::vector<float16>> inputs,
std::vector<std::vector<int>> output_shapes) {
LOG(INFO) << "PrepareInputOutput";
int num_inputs = inputs.size();
int num_outputs = output_shapes.size();
inputs_.resize(num_inputs);
outputs_.resize(num_outputs);
for (int i = 0; i < num_inputs; ++i) {
paddle::framework::TensorFromVector(inputs[i], *ctx_, &inputs_[i]);
}
for (int i = 0; i < num_outputs; ++i) {
outputs_[i].Resize(phi::make_ddim(output_shapes[i]));
}
}
void GetOutput(std::vector<float> &slimmed_x, // NOLINT
std::vector<int32_t> &cls_inds) { // NOLINT
paddle::framework::TensorToVector(outputs_[0], *ctx_, &slimmed_x);
paddle::framework::TensorToVector(outputs_[1], *ctx_, &cls_inds);
}
protected:
std::vector<phi::DenseTensor> inputs_;
std::vector<phi::DenseTensor> outputs_;
TensorRTEngine *engine_;
phi::GPUContext *ctx_;
};
TEST_F(TensorRTDynamicTestFusedTokenPruneHalf, test_fused_token_prune) {
#if IS_TRT_VERSION_GE(8000)
tensorrt::plugin::TrtPluginRegistry::Global()->RegistToTrt();
auto *attn = engine_->DeclareInput(
"attn", nvinfer1::DataType::kHALF, nvinfer1::Dims2{-1, 4});
auto *x = engine_->DeclareInput(
"x", nvinfer1::DataType::kHALF, nvinfer1::Dims3{-1, 4, 1});
auto *mask = engine_->DeclareInput(
......@@ -431,7 +630,7 @@ TEST_F(TensorRTDynamicTestFusedTokenPrune, test_fused_token_prune) {
auto *new_mask = engine_->DeclareInput(
"new_mask", nvinfer1::DataType::kHALF, nvinfer1::Dims4{-1, 1, 2, 2});
plugin::FusedTokenPrunePluginDynamic *plugin =
new plugin::FusedTokenPrunePluginDynamic(true,
new plugin::FusedTokenPrunePluginDynamic(/*with_fp16*/ true,
/*keep_first_token*/ false,
/*keep_order*/ true,
/*flag_varseqlen*/ false);
......@@ -449,18 +648,16 @@ TEST_F(TensorRTDynamicTestFusedTokenPrune, test_fused_token_prune) {
ASSERT_EQ(engine_->engine()->getNbBindings(), 6);
LOG(INFO) << "create input";
std::vector<float16> attn_v(64);
for (int i = 0; i < 4; ++i) {
std::vector<float16> attn_v(16);
for (int j = 0; j < 4; ++j) {
for (int k = 0; k < 4; ++k) {
attn_v[i * 16 + j * 4 + k] = k;
}
attn_v[j * 4 + k] = k;
}
}
std::vector<float16> x_v(16);
for (int i = 0; i < 4; ++i) {
for (int j = 0; j < 4; ++j) {
x_v[i * 4 + j] = 1;
x_v[i * 4 + j] = 4 - j;
}
}
std::vector<float16> mask_v(64);
......@@ -509,20 +706,24 @@ TEST_F(TensorRTDynamicTestFusedTokenPrune, test_fused_token_prune) {
engine_->Execute(4, &buffers, ctx_->stream());
std::vector<float> slimmed_x_v;
std::vector<float> slimmed_x_v(8);
std::vector<int32_t> cls_inds_v;
LOG(INFO) << "GetOutput";
GetOutput(slimmed_x_v, cls_inds_v);
ASSERT_EQ(cls_inds_v[0], 2);
ASSERT_EQ(cls_inds_v[1], 3);
ASSERT_EQ(cls_inds_v[2], 2);
ASSERT_EQ(cls_inds_v[3], 3);
ASSERT_EQ(cls_inds_v[4], 2);
ASSERT_EQ(cls_inds_v[5], 3);
ASSERT_EQ(cls_inds_v[6], 2);
ASSERT_EQ(cls_inds_v[7], 3);
// slimmed_x_v: [[4,3,2,1],[4,3,2,1],[4,3,2,1],[4,3,2,1]] ->
// [[2,1],[2,1],[2,1],[2,1]]
ASSERT_EQ(slimmed_x_v[0], 2);
ASSERT_EQ(slimmed_x_v[1], 1);
ASSERT_EQ(slimmed_x_v[2], 2);
ASSERT_EQ(slimmed_x_v[3], 1);
ASSERT_EQ(slimmed_x_v[4], 2);
ASSERT_EQ(slimmed_x_v[5], 1);
ASSERT_EQ(slimmed_x_v[6], 2);
ASSERT_EQ(slimmed_x_v[7], 1);
LOG(INFO) << "finish";
#endif
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册