未验证 提交 993f13d3 编写于 作者: W Wangzheee 提交者: GitHub

general transformer(interleaved) inference support (#43600)

上级 c5097af7
......@@ -130,6 +130,12 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
// output
remove_padding.SetOutput("Out", {remove_padding_out_name});
// set out_threshold for int8
if (op_node->Op()->HasAttr("out_threshold")) {
remove_padding.SetAttr("out_threshold",
op_node->Op()->GetAttr("out_threshold"));
}
auto remove_padding_op_node = graph->CreateOpNode(&remove_padding);
auto remove_padding_out_node = graph->CreateVarNode(remove_padding_out);
......@@ -184,6 +190,12 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
// output
recover_padding.SetOutput("Out", {out_node->Name()});
// set out_threshold for int8
if (op_node->Op()->HasAttr("out_threshold")) {
recover_padding.SetAttr("out_threshold",
op_node->Op()->GetAttr("out_threshold"));
}
auto recover_padding_op_node = graph->CreateOpNode(&recover_padding);
auto recover_padding_input_node =
graph->CreateVarNode(recover_padding_input);
......
......@@ -36,34 +36,45 @@ class RecoverPadding : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override {
VLOG(3) << "Recover padding of transformer'output: VarSeqlen -> Padding.";
if (!engine_->with_dynamic_shape()) {
PADDLE_THROW(platform::errors::Fatal(
"recover_padding_op: If you want to use transformer, must "
"be with dynamic shape"));
}
framework::OpDesc op_desc(op, nullptr);
/*
auto x_var_name = op_desc.Input(InputNames()).front();
auto* x_var_desc = block->FindVar(x_var_name);
const auto x_shape = x_var_desc->GetShape();
*/
auto input_name = op_desc.Input("Input").front();
std::cout << "input_name: " << input_name << std::endl;
auto input = engine_->GetITensor(input_name);
auto output_name = op_desc.Output("Out").front();
std::vector<nvinfer1::ITensor*> plugin_inputs;
plugin_inputs.push_back(engine_->GetITensor(input_name));
if (engine_->with_interleaved()) {
VLOG(3) << "with_interleaved data format: Recover padding of "
"transformer'output: VarSeqlen -> Padding.";
if (!op_desc.HasAttr("out_threshold")) {
PADDLE_THROW(
platform::errors::Fatal("use with_interleaved must be int8."));
}
auto* transpose = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input);
transpose->setSecondTranspose({2, 1, 0, 3});
auto* transpose_output = transpose->getOutput(0);
transpose->setName(
("recover_padding(with_interleaved): transpose(Output: " +
output_name + ")")
.c_str());
float out_scale =
BOOST_GET_CONST(float, op_desc.GetAttr("out_threshold"));
engine_->SetTensorDynamicRange(transpose_output, out_scale);
plugin_inputs.push_back(transpose_output);
} else {
VLOG(3) << "normal data format: Recover padding of transformer'output: "
"VarSeqlen -> Padding.";
plugin_inputs.push_back(input);
}
plugin_inputs.push_back(engine_->GetITensor("pos_id"));
plugin_inputs.push_back(engine_->GetITensor("mask_id"));
int input_num = 3;
auto output_name = op_desc.Output("Out").front();
size_t input_num = plugin_inputs.size();
plugin::RecoverPaddingPlugin* plugin = new plugin::RecoverPaddingPlugin();
nvinfer1::ILayer* layer =
engine_->AddDynamicPlugin(plugin_inputs.data(), input_num, plugin);
RreplenishLayerAndOutput(layer, "recover_padding", {output_name},
test_mode);
}
......
......@@ -36,7 +36,6 @@ class RemovePadding : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override {
VLOG(3) << "Remove padding of transformer'input: Padding -> VarSeqlen";
if (!engine_->with_dynamic_shape()) {
PADDLE_THROW(platform::errors::Fatal(
"remove_padding_op: If you want to use transformer, must "
......@@ -45,20 +44,39 @@ class RemovePadding : public OpConverter {
framework::OpDesc op_desc(op, nullptr);
auto input_name = op_desc.Input("Input").front();
auto output_name = op_desc.Output("Out").front();
std::vector<nvinfer1::ITensor*> plugin_inputs;
plugin_inputs.push_back(engine_->GetITensor(input_name));
plugin_inputs.push_back(engine_->GetITensor("pos_id"));
plugin_inputs.push_back(engine_->GetITensor("word_id"));
size_t input_num = plugin_inputs.size();
auto output_name = op_desc.Output("Out").front();
plugin::RemovePaddingPlugin* plugin = new plugin::RemovePaddingPlugin();
nvinfer1::ILayer* layer =
engine_->AddDynamicPlugin(plugin_inputs.data(), input_num, plugin);
RreplenishLayerAndOutput(layer, "remove_padding_op", {output_name},
test_mode);
layer->setName(("remove_padding: (Output: " + output_name + ")").c_str());
if (engine_->with_interleaved()) {
VLOG(3) << "with_interleaved data format: Remove padding of "
"transformer'input: Padding -> VarSeqlen.";
if (!op_desc.HasAttr("out_threshold")) {
PADDLE_THROW(
platform::errors::Fatal("use with_interleaved must be int8."));
}
float out_scale =
BOOST_GET_CONST(float, op_desc.GetAttr("out_threshold"));
engine_->SetTensorDynamicRange(layer->getOutput(0), out_scale);
auto* transpose =
TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *(layer->getOutput(0)));
transpose->setSecondTranspose({2, 1, 0, 3});
transpose->setName(
("remove_padding (with_interleaved): transpose(Output: " +
output_name + ")")
.c_str());
engine_->SetITensor(output_name, transpose->getOutput(0));
} else {
VLOG(3) << "normal data format: Remove padding of transformer'input: "
"Padding -> VarSeqlen.";
engine_->SetITensor(output_name, layer->getOutput(0));
}
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册