未验证 提交 12486712 编写于 作者: R RichardWooSJTU 提交者: GitHub

Add int8 support in fused_multi_transformer_pass and fuse_multi_transformer_layer_pass (#48209)

* delete unnecessary shape and slice op
Co-authored-by: NYour Name <you@example.com>
上级 9ff99e9e
...@@ -96,6 +96,8 @@ pass_library(shuffle_channel_detect_pass inference) ...@@ -96,6 +96,8 @@ pass_library(shuffle_channel_detect_pass inference)
pass_library(delete_quant_dequant_op_pass inference) pass_library(delete_quant_dequant_op_pass inference)
pass_library(delete_quant_dequant_filter_op_pass inference) pass_library(delete_quant_dequant_filter_op_pass inference)
pass_library(delete_weight_dequant_linear_op_pass inference) pass_library(delete_weight_dequant_linear_op_pass inference)
pass_library(delete_weight_dequant_linear_op_encoder_pass inference)
pass_library(delete_weight_dequant_linear_op_decoder_pass inference)
pass_library(delete_quant_dequant_linear_op_pass inference) pass_library(delete_quant_dequant_linear_op_pass inference)
pass_library(delete_dropout_op_pass inference) pass_library(delete_dropout_op_pass inference)
pass_library(delete_c_identity_op_pass inference) pass_library(delete_c_identity_op_pass inference)
......
...@@ -121,14 +121,27 @@ void DeleteQuantDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const { ...@@ -121,14 +121,27 @@ void DeleteQuantDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const {
true, true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Input scale tensor's place should be CPU.")); "Input scale tensor's place should be CPU."));
const float* input_scale_data = input_scale_tensor.data<float>();
float input_scale = input_scale_data[0]; float input_scale;
if (input_scale_tensor.dtype() == paddle::experimental::DataType::FLOAT32) {
const float* input_scale_data = input_scale_tensor.data<float>();
input_scale = input_scale_data[0];
} else if (input_scale_tensor.dtype() ==
paddle::experimental::DataType::FLOAT16) {
const phi::dtype::float16* input_scale_data =
input_scale_tensor.data<phi::dtype::float16>();
input_scale = static_cast<float>(input_scale_data[0]);
} else {
PADDLE_THROW(platform::errors::Unimplemented("%d is not supported.",
input_scale_tensor.dtype()));
}
int nums_any_ops = dequantize_linear_op_out->outputs.size(); int nums_any_ops = dequantize_linear_op_out->outputs.size();
for (int i = 0; i < nums_any_ops; ++i) { for (int i = 0; i < nums_any_ops; ++i) {
auto* any_op_desc = dequantize_linear_op_out->outputs[i]->Op(); auto* any_op_desc = dequantize_linear_op_out->outputs[i]->Op();
any_op_desc->SetAttr("Input_scale_" + quantize_linear_op_x->Var()->Name(), any_op_desc->SetAttr("Input_scale_" + quantize_linear_op_x->Var()->Name(),
input_scale); input_scale);
// link x to any_op2 // link x to any_op2
any_op_desc->RenameInput(dequantize_linear_op_out->Var()->Name(), any_op_desc->RenameInput(dequantize_linear_op_out->Var()->Name(),
quantize_linear_op_x->Var()->Name()); quantize_linear_op_x->Var()->Name());
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/delete_weight_dequant_linear_op_decoder_pass.h"
#include <algorithm>
#include <memory>
#include <string>
#include <unordered_set>
#include <vector>
namespace paddle {
namespace framework {
namespace ir {
#define GET_IR_NODE(node__) GET_IR_NODE_FROM_SUBGRAPH(node__, node__, pattern);
#define GET_NODES \
GET_IR_NODE(weight_dequantize_linear_op_x); \
GET_IR_NODE(weight_dequantize_linear_op_scale); \
GET_IR_NODE(weight_dequantize_linear_op); \
GET_IR_NODE(weight_dequantize_linear_op_out); \
GET_IR_NODE(any_op2);
DeleteWeightDequantLinearOpDecoderPass::
DeleteWeightDequantLinearOpDecoderPass() {
AddOpCompat(OpCompat("quantize_linear"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Scale")
.IsTensor()
.End()
.AddInput("ZeroPoint")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Y")
.IsTensor()
.End()
.AddAttr("bit_length")
.IsType<int>()
.End()
.AddAttr("quant_axis")
.IsType<int>()
.End()
.AddAttr("round_type")
.IsOptional()
.IsType<int>()
.End();
AddOpCompat(OpCompat("dequantize_linear"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Scale")
.IsTensor()
.End()
.AddInput("ZeroPoint")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Y")
.IsTensor()
.End()
.AddAttr("bit_length")
.IsType<int>()
.End()
.AddAttr("quant_axis")
.IsType<int>()
.End()
.AddAttr("round_type")
.IsOptional()
.IsType<int>()
.End();
AddOpCompat(OpCompat("conv2d"))
.AddInput("Input")
.IsTensor()
.End()
.AddInput("Filter")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.IsOptional()
.End()
.AddInput("ResidualData")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Output")
.IsTensor()
.End()
.AddAttr("strides")
.IsType<std::vector<int>>()
.End()
.AddAttr("paddings")
.IsType<std::vector<int>>()
.End()
.AddAttr("padding_algorithm")
.IsOptional()
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End()
.AddAttr("groups")
.IsNumGE(1)
.End()
.AddAttr("dilations")
.IsType<std::vector<int>>()
.End()
.AddAttr("data_format")
.IsStringIn({"NCHW", "NHWC", "AnyLayout"})
.End();
AddOpCompat(OpCompat("depthwise_conv2d"))
.AddInput("Input")
.IsTensor()
.End()
.AddInput("Filter")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.IsOptional()
.End()
.AddInput("ResidualData")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Output")
.IsTensor()
.End()
.AddAttr("strides")
.IsType<std::vector<int>>()
.End()
.AddAttr("paddings")
.IsType<std::vector<int>>()
.End()
.AddAttr("padding_algorithm")
.IsOptional()
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End()
.AddAttr("groups")
.IsNumGE(1)
.End()
.AddAttr("dilations")
.IsType<std::vector<int>>()
.End()
.AddAttr("data_format")
.IsStringIn({"NCHW", "NHWC", "AnyLayout"})
.End();
AddOpCompat(OpCompat("mul"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("x_num_col_dims")
.IsNumGE(1)
.End()
.AddAttr("y_num_col_dims")
.IsNumEQ(1)
.End();
AddOpCompat(OpCompat("matmul_v2"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("trans_x")
.IsBoolEQ(false)
.End()
.AddAttr("trans_y")
.IsBoolEQ(false)
.End();
AddOpCompat(OpCompat("matmul"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("alpha")
.IsNumGE(0.99f)
.IsNumLE(1.01f)
.End()
.AddAttr("transpose_X")
.IsBoolEQ(false)
.End()
.AddAttr("transpose_Y")
.IsBoolEQ(false)
.End();
AddOpCompat(OpCompat("fc"))
.AddInput("Input")
.IsTensor()
.End()
.AddInput("W")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("in_num_col_dims")
.IsNumGE(1)
.End()
.AddAttr("activation_type")
.IsStringIn({"relu", ""})
.End();
AddOpCompat(OpCompat("conv2d_transpose"))
.AddInput("Input")
.IsTensor()
.End()
.AddInput("Filter")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Output")
.IsTensor()
.End()
.AddAttr("output_padding")
.IsType<std::vector<int>>()
.IsOptional()
.End()
.AddAttr("output_size")
.IsType<std::vector<int>>()
.IsOptional()
.End()
.AddAttr("groups")
.IsNumGE(1)
.End()
.AddAttr("dilations")
.IsType<std::vector<int>>()
.End()
.AddAttr("strides")
.IsType<std::vector<int>>()
.End()
.AddAttr("paddings")
.IsType<std::vector<int>>()
.End()
.AddAttr("padding_algorithm")
.IsOptional()
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End()
.AddAttr("data_format")
.IsStringIn({"NCHW", "NHWC", "AnyLayout"})
.End();
}
// Delete dequantize_linear_op, then dequantize weight
void DeleteWeightDequantLinearOpDecoderPass::ApplyImpl(ir::Graph* graph) const {
const std::string pattern_name =
"delete_weight_dequant_linear_op_decoder_pattern";
FusePassBase::Init(pattern_name, graph);
GraphPatternDetector gpd;
auto* scope = param_scope();
PADDLE_ENFORCE_NOT_NULL(scope,
platform::errors::InvalidArgument(
"Scope in DeleteWeightDequantLinearOpDecoderPass "
"should not be null."));
// Create pattern
patterns::DeleteWeightDequantLinearOpDecoderPattern pattern(
gpd.mutable_pattern(), pattern_name);
pattern();
int found_count = 0;
bool is_int8 = false;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_NODES;
/*
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "delete_weight_dequant_linear_op_pass "
"compat check failed.";
return;
}
*/
is_int8 = true;
std::unordered_set<const Node*> nodes2rm = {};
auto* any_op2_desc = any_op2->Op();
// Get weight scale
std::vector<float> weight_scale;
auto* weight_scale_tensor =
scope->GetVar(weight_dequantize_linear_op_scale->Name())
->GetMutable<phi::DenseTensor>();
auto weight_scale_nums = weight_scale_tensor->numel();
if (weight_scale_tensor->dtype() ==
paddle::experimental::DataType::FLOAT32) {
float* weight_scale_data = weight_scale_tensor->data<float>();
for (int i = 0; i < weight_scale_nums; i++) {
weight_scale.push_back(weight_scale_data[i]);
}
} else if (weight_scale_tensor->dtype() ==
paddle::experimental::DataType::FLOAT16) {
phi::dtype::float16* weight_scale_data =
weight_scale_tensor->data<phi::dtype::float16>();
for (int i = 0; i < weight_scale_nums; i++) {
weight_scale.push_back(static_cast<float>(weight_scale_data[i]));
}
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"%d is not supported.", weight_scale_tensor->dtype()));
}
int quant_axis = PADDLE_GET_CONST(
int, weight_dequantize_linear_op->Op()->GetAttr("quant_axis"));
if (quant_axis == -1) { // per_layer quant_dequant: all OP
PADDLE_ENFORCE_EQ(weight_scale_nums,
1,
platform::errors::InvalidArgument(
"When quant_axis == -1 means use per_layer "
"quant_dequant, weight_scale'number should be 1."));
// Add attr to anyop 2
any_op2_desc->SetAttr("weight_scale", weight_scale[0]);
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Delete Weight Dequant Linear Op Encoder Pass is not supported for "
"per-channel quantization"));
}
nodes2rm.insert(weight_dequantize_linear_op_scale);
nodes2rm.insert(weight_dequantize_linear_op);
nodes2rm.insert(weight_dequantize_linear_op_out);
// relink weight to any_op2
any_op2_desc->RenameInput(weight_dequantize_linear_op_out->Var()->Name(),
weight_dequantize_linear_op_x->Var()->Name());
any_op2_desc->Flush();
IR_NODE_LINK_TO(weight_dequantize_linear_op_x, any_op2);
GraphSafeRemoveNodes(graph, nodes2rm);
found_count++;
};
gpd(graph, handler);
if (is_int8) {
auto& enable_int8 = graph->Get<bool>("enable_int8");
enable_int8 = true;
}
AddStatis(found_count);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(delete_weight_dequant_linear_op_decoder_pass,
paddle::framework::ir::DeleteWeightDequantLinearOpDecoderPass);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
namespace paddle {
namespace framework {
namespace ir {
class DeleteWeightDequantLinearOpDecoderPass : public FusePassBase {
public:
DeleteWeightDequantLinearOpDecoderPass();
virtual ~DeleteWeightDequantLinearOpDecoderPass() {}
protected:
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/delete_weight_dequant_linear_op_encoder_pass.h"
#include <algorithm>
#include <memory>
#include <string>
#include <unordered_set>
#include <vector>
namespace paddle {
namespace framework {
namespace ir {
#define GET_IR_NODE(node__) GET_IR_NODE_FROM_SUBGRAPH(node__, node__, pattern);
#define GET_NODES \
GET_IR_NODE(weight_dequantize_linear_op_x); \
GET_IR_NODE(weight_dequantize_linear_op_scale); \
GET_IR_NODE(weight_dequantize_linear_op); \
GET_IR_NODE(weight_dequantize_linear_op_out); \
GET_IR_NODE(any_op2);
DeleteWeightDequantLinearOpEncoderPass::
DeleteWeightDequantLinearOpEncoderPass() {
AddOpCompat(OpCompat("quantize_linear"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Scale")
.IsTensor()
.End()
.AddInput("ZeroPoint")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Y")
.IsTensor()
.End()
.AddAttr("bit_length")
.IsType<int>()
.End()
.AddAttr("quant_axis")
.IsType<int>()
.End()
.AddAttr("round_type")
.IsOptional()
.IsType<int>()
.End();
AddOpCompat(OpCompat("dequantize_linear"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Scale")
.IsTensor()
.End()
.AddInput("ZeroPoint")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Y")
.IsTensor()
.End()
.AddAttr("bit_length")
.IsType<int>()
.End()
.AddAttr("quant_axis")
.IsType<int>()
.End()
.AddAttr("round_type")
.IsOptional()
.IsType<int>()
.End();
AddOpCompat(OpCompat("conv2d"))
.AddInput("Input")
.IsTensor()
.End()
.AddInput("Filter")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.IsOptional()
.End()
.AddInput("ResidualData")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Output")
.IsTensor()
.End()
.AddAttr("strides")
.IsType<std::vector<int>>()
.End()
.AddAttr("paddings")
.IsType<std::vector<int>>()
.End()
.AddAttr("padding_algorithm")
.IsOptional()
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End()
.AddAttr("groups")
.IsNumGE(1)
.End()
.AddAttr("dilations")
.IsType<std::vector<int>>()
.End()
.AddAttr("data_format")
.IsStringIn({"NCHW", "NHWC", "AnyLayout"})
.End();
AddOpCompat(OpCompat("depthwise_conv2d"))
.AddInput("Input")
.IsTensor()
.End()
.AddInput("Filter")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.IsOptional()
.End()
.AddInput("ResidualData")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Output")
.IsTensor()
.End()
.AddAttr("strides")
.IsType<std::vector<int>>()
.End()
.AddAttr("paddings")
.IsType<std::vector<int>>()
.End()
.AddAttr("padding_algorithm")
.IsOptional()
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End()
.AddAttr("groups")
.IsNumGE(1)
.End()
.AddAttr("dilations")
.IsType<std::vector<int>>()
.End()
.AddAttr("data_format")
.IsStringIn({"NCHW", "NHWC", "AnyLayout"})
.End();
AddOpCompat(OpCompat("mul"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("x_num_col_dims")
.IsNumGE(1)
.End()
.AddAttr("y_num_col_dims")
.IsNumEQ(1)
.End();
AddOpCompat(OpCompat("matmul_v2"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("trans_x")
.IsBoolEQ(false)
.End()
.AddAttr("trans_y")
.IsBoolEQ(false)
.End();
AddOpCompat(OpCompat("matmul"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("alpha")
.IsNumGE(0.99f)
.IsNumLE(1.01f)
.End()
.AddAttr("transpose_X")
.IsBoolEQ(false)
.End()
.AddAttr("transpose_Y")
.IsBoolEQ(false)
.End();
AddOpCompat(OpCompat("fc"))
.AddInput("Input")
.IsTensor()
.End()
.AddInput("W")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("in_num_col_dims")
.IsNumGE(1)
.End()
.AddAttr("activation_type")
.IsStringIn({"relu", ""})
.End();
AddOpCompat(OpCompat("conv2d_transpose"))
.AddInput("Input")
.IsTensor()
.End()
.AddInput("Filter")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Output")
.IsTensor()
.End()
.AddAttr("output_padding")
.IsType<std::vector<int>>()
.IsOptional()
.End()
.AddAttr("output_size")
.IsType<std::vector<int>>()
.IsOptional()
.End()
.AddAttr("groups")
.IsNumGE(1)
.End()
.AddAttr("dilations")
.IsType<std::vector<int>>()
.End()
.AddAttr("strides")
.IsType<std::vector<int>>()
.End()
.AddAttr("paddings")
.IsType<std::vector<int>>()
.End()
.AddAttr("padding_algorithm")
.IsOptional()
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End()
.AddAttr("data_format")
.IsStringIn({"NCHW", "NHWC", "AnyLayout"})
.End();
}
// Delete dequantize_linear_op, then dequantize weight
void DeleteWeightDequantLinearOpEncoderPass::ApplyImpl(ir::Graph* graph) const {
const std::string pattern_name =
"delete_weight_dequant_linear_op_encoder_pattern";
FusePassBase::Init(pattern_name, graph);
GraphPatternDetector gpd;
auto* scope = param_scope();
PADDLE_ENFORCE_NOT_NULL(scope,
platform::errors::InvalidArgument(
"Scope in DeleteWeightDequantLinearOpEncoderPass "
"should not be null."));
// Create pattern
patterns::DeleteWeightDequantLinearOpEncoderPattern pattern(
gpd.mutable_pattern(), pattern_name);
pattern();
int found_count = 0;
bool is_int8 = false;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_NODES;
/*
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "delete_weight_dequant_linear_op_pass "
"compat check failed.";
return;
}
*/
is_int8 = true;
std::unordered_set<const Node*> nodes2rm = {};
auto* any_op2_desc = any_op2->Op();
// Get weight scale
std::vector<float> weight_scale;
auto* weight_scale_tensor =
scope->GetVar(weight_dequantize_linear_op_scale->Name())
->GetMutable<phi::DenseTensor>();
auto weight_scale_nums = weight_scale_tensor->numel();
if (weight_scale_tensor->dtype() ==
paddle::experimental::DataType::FLOAT32) {
float* weight_scale_data = weight_scale_tensor->data<float>();
for (int i = 0; i < weight_scale_nums; i++) {
weight_scale.push_back(weight_scale_data[i]);
}
} else if (weight_scale_tensor->dtype() ==
paddle::experimental::DataType::FLOAT16) {
phi::dtype::float16* weight_scale_data =
weight_scale_tensor->data<phi::dtype::float16>();
for (int i = 0; i < weight_scale_nums; i++) {
weight_scale.push_back(static_cast<float>(weight_scale_data[i]));
}
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"%d is not supported.", weight_scale_tensor->dtype()));
}
int quant_axis = PADDLE_GET_CONST(
int, weight_dequantize_linear_op->Op()->GetAttr("quant_axis"));
if (quant_axis == -1) { // per_layer quant_dequant: all OP
PADDLE_ENFORCE_EQ(weight_scale_nums,
1,
platform::errors::InvalidArgument(
"When quant_axis == -1 means use per_layer "
"quant_dequant, weight_scale'number should be 1."));
// Add attr to anyop 2
any_op2_desc->SetAttr("weight_scale", weight_scale[0]);
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Delete Weight Dequant Linear Op Encoder Pass is not supported for "
"per-channel quantization"));
}
nodes2rm.insert(weight_dequantize_linear_op_scale);
nodes2rm.insert(weight_dequantize_linear_op);
nodes2rm.insert(weight_dequantize_linear_op_out);
// relink weight to any_op2
any_op2_desc->RenameInput(weight_dequantize_linear_op_out->Var()->Name(),
weight_dequantize_linear_op_x->Var()->Name());
any_op2_desc->Flush();
IR_NODE_LINK_TO(weight_dequantize_linear_op_x, any_op2);
GraphSafeRemoveNodes(graph, nodes2rm);
found_count++;
};
gpd(graph, handler);
graph->Set("enable_int8", new bool(is_int8));
AddStatis(found_count);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(delete_weight_dequant_linear_op_encoder_pass,
paddle::framework::ir::DeleteWeightDequantLinearOpEncoderPass);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
namespace paddle {
namespace framework {
namespace ir {
class DeleteWeightDequantLinearOpEncoderPass : public FusePassBase {
public:
DeleteWeightDequantLinearOpEncoderPass();
virtual ~DeleteWeightDequantLinearOpEncoderPass() {}
protected:
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace ir
} // namespace framework
} // namespace paddle
...@@ -118,9 +118,15 @@ int FuseMultiTransformerLayerPass::BuildFusion(Graph* graph, ...@@ -118,9 +118,15 @@ int FuseMultiTransformerLayerPass::BuildFusion(Graph* graph,
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto* pattern = gpd.mutable_pattern(); auto* pattern = gpd.mutable_pattern();
// TODO(wufeisheng): Get enable_int8 attr from graph after
// fused_multi_transformer pass with int8 merged
bool enable_int8 = false; bool enable_int8 = false;
if (graph->Has("enable_int8")) {
enable_int8 = graph->Get<bool>("enable_int8");
}
if (!enable_int8) {
VLOG(4)
<< "fuse_multi_layer_transformer_pass will match float transformer op "
"cause enable_int8 is not been set or set to false";
}
int num_fuse_op = 0; int num_fuse_op = 0;
bool is_decoder = false; bool is_decoder = false;
...@@ -209,7 +215,13 @@ int FuseMultiTransformerLayerPass::BuildFusion(Graph* graph, ...@@ -209,7 +215,13 @@ int FuseMultiTransformerLayerPass::BuildFusion(Graph* graph,
"OutLinearW", "OutLinearW",
"QKVBias", "QKVBias",
"QKVW"}; "QKVW"};
if (enable_int8) {
std::vector<std::string> inputs_names_int8_supp = {
"FFN1OutScale", "FFN2OutScale", "OutLinearOutScale", "QKVOutScale"};
inputs_names.insert(inputs_names.end(),
inputs_names_int8_supp.begin(),
inputs_names_int8_supp.end());
}
for (const auto& input_name : inputs_names) { for (const auto& input_name : inputs_names) {
MergeInput(fuse_op_descs[0], fuse_op_input_var_name_maps, input_name); MergeInput(fuse_op_descs[0], fuse_op_input_var_name_maps, input_name);
} }
...@@ -227,6 +239,17 @@ int FuseMultiTransformerLayerPass::BuildFusion(Graph* graph, ...@@ -227,6 +239,17 @@ int FuseMultiTransformerLayerPass::BuildFusion(Graph* graph,
} }
fuse_op_descs[0]->SetOutput("CacheKVOut", merged_cache_kv_out_names); fuse_op_descs[0]->SetOutput("CacheKVOut", merged_cache_kv_out_names);
if (enable_int8) {
// Merge inputs scale
std::vector<std::string> attr_names = {"qkv_in_scale",
"out_linear_in_scale",
"ffn1_in_scale",
"ffn2_in_scale"};
for (const auto& name : attr_names) {
MergeAttrs<float>(fuse_op_descs, name);
}
}
//////////////// ////////////////
//// ReLink //// //// ReLink ////
//////////////// ////////////////
......
...@@ -98,6 +98,7 @@ TEST(FuseMultiTransformerLayerPass, encoder_fp) { ...@@ -98,6 +98,7 @@ TEST(FuseMultiTransformerLayerPass, encoder_fp) {
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program())); std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
graph->Set("__param_scope__", CreateParamScope()); graph->Set("__param_scope__", CreateParamScope());
graph->Set(kFusedMultiTransformerEncoderFusionCount, new int(num_layers)); graph->Set(kFusedMultiTransformerEncoderFusionCount, new int(num_layers));
graph->Set("enable_int8", new bool(false));
auto pass = PassRegistry::Instance().Get("fuse_multi_transformer_layer_pass"); auto pass = PassRegistry::Instance().Get("fuse_multi_transformer_layer_pass");
if (pass.get() == nullptr) if (pass.get() == nullptr)
......
...@@ -1075,12 +1075,27 @@ PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() { ...@@ -1075,12 +1075,27 @@ PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() {
} // namespace patterns } // namespace patterns
inline Node* CreatePersistableVarNode(Graph* graph, const std::string& name) {
auto var_desc = VarDesc(name);
var_desc.SetDataType(framework::proto::VarType::FP32);
var_desc.SetPersistable(true);
auto node = graph->CreateVarNode(&var_desc);
return node;
}
int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph, int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
const std::string& name_scope, const std::string& name_scope,
Scope* scope) const { Scope* scope) const {
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto* pattern = gpd.mutable_pattern(); auto* pattern = gpd.mutable_pattern();
bool enable_int8 = graph->Get<bool>("enable_int8");
if (enable_int8) {
VLOG(3) << "FusedMultiTransformerDecoderPass with int8";
} else {
VLOG(3) << "FusedMultiTransformerDecoderPass with fp";
}
// Create pattern. // Create pattern.
patterns::FusedMultiTransformerDecoderPattern fused_multi_transformer_pattern( patterns::FusedMultiTransformerDecoderPattern fused_multi_transformer_pattern(
pattern, name_scope); pattern, name_scope);
...@@ -1093,6 +1108,7 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph, ...@@ -1093,6 +1108,7 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
Node* layer_norm_bias, Node* layer_norm_bias,
Node* layer_norm_mean, Node* layer_norm_mean,
Node* layer_norm_variance, Node* layer_norm_variance,
Node* matmul0,
Node* matmul0_w, Node* matmul0_w,
Node* matmul1_w, Node* matmul1_w,
Node* matmul2_w, Node* matmul2_w,
...@@ -1103,6 +1119,7 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph, ...@@ -1103,6 +1119,7 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
Node* transpose2_2_out, Node* transpose2_2_out,
Node* eltadd_qk_b, Node* eltadd_qk_b,
Node* reshape2_0, Node* reshape2_0,
Node* matmul_linear,
Node* matmul_linear_w, Node* matmul_linear_w,
Node* eltadd_linear_b, Node* eltadd_linear_b,
Node* ffn_layer_norm, Node* ffn_layer_norm,
...@@ -1110,11 +1127,17 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph, ...@@ -1110,11 +1127,17 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
Node* ffn_layer_norm_bias, Node* ffn_layer_norm_bias,
Node* ffn_layer_norm_mean, Node* ffn_layer_norm_mean,
Node* ffn_layer_norm_variance, Node* ffn_layer_norm_variance,
Node* ffn_matmul0,
Node* ffn_matmul0_w, Node* ffn_matmul0_w,
Node* ffn_matmul1,
Node* ffn_matmul1_w, Node* ffn_matmul1_w,
Node* ffn_eltadd0_b, Node* ffn_eltadd0_b,
Node* ffn_eltadd1_b, Node* ffn_eltadd1_b,
Node* ffn_output) { Node* ffn_output) {
auto* matmul0_op = matmul0->Op();
auto* matmul_linear_op = matmul_linear->Op();
auto* ffn_matmul_0_op = ffn_matmul0->Op();
auto* ffn_matmul_1_op = ffn_matmul1->Op();
// Calc index of transformer layer by LayerNorm Scale name // Calc index of transformer layer by LayerNorm Scale name
// This calculation assumes: // This calculation assumes:
// 1. no LayerNorm before all transformer layer // 1. no LayerNorm before all transformer layer
...@@ -1126,7 +1149,9 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph, ...@@ -1126,7 +1149,9 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
// create fused_multi_transformer // create fused_multi_transformer
OpDesc fused_multi_transformer_op_desc(layer_norm->Op()->Block()); OpDesc fused_multi_transformer_op_desc(layer_norm->Op()->Block());
fused_multi_transformer_op_desc.SetType("fused_multi_transformer"); fused_multi_transformer_op_desc.SetType(enable_int8
? "fused_multi_transformer_int8"
: "fused_multi_transformer");
// 1. Input setting // 1. Input setting
fused_multi_transformer_op_desc.SetInput("X", {input0->Name()}); fused_multi_transformer_op_desc.SetInput("X", {input0->Name()});
...@@ -1181,8 +1206,66 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph, ...@@ -1181,8 +1206,66 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
fused_multi_transformer_op_desc.SetAttr("is_test", true); fused_multi_transformer_op_desc.SetAttr("is_test", true);
fused_multi_transformer_op_desc.SetAttr("dropout_rate", 0.0f); fused_multi_transformer_op_desc.SetAttr("dropout_rate", 0.0f);
if (enable_int8) {
// Set input scale
std::string qkv_input_name = matmul0_op->Input("X")[0];
auto qkv_in_scale = PADDLE_GET_CONST(
float, matmul0_op->GetAttr("Input_scale_" + qkv_input_name));
std::string out_linear_input_name = matmul_linear_op->Input("X")[0];
auto out_linear_in_scale = PADDLE_GET_CONST(
float,
matmul_linear_op->GetAttr("Input_scale_" + out_linear_input_name));
std::string ffn0_input_name = ffn_matmul_0_op->Input("X")[0];
auto ffn0_in_scale = PADDLE_GET_CONST(
float, ffn_matmul_0_op->GetAttr("Input_scale_" + ffn0_input_name));
std::string ffn1_input_name = ffn_matmul_1_op->Input("X")[0];
auto ffn1_in_scale = PADDLE_GET_CONST(
float, ffn_matmul_1_op->GetAttr("Input_scale_" + ffn1_input_name));
// Inverse input scale
qkv_in_scale = 1.0f / qkv_in_scale;
out_linear_in_scale = 1.0f / out_linear_in_scale;
ffn0_in_scale = 1.0f / ffn0_in_scale;
ffn1_in_scale = 1.0f / ffn1_in_scale;
fused_multi_transformer_op_desc.SetAttr("qkv_in_scale",
std::vector<float>{qkv_in_scale});
fused_multi_transformer_op_desc.SetAttr(
"out_linear_in_scale", std::vector<float>{out_linear_in_scale});
fused_multi_transformer_op_desc.SetAttr(
"ffn1_in_scale", std::vector<float>{ffn0_in_scale});
fused_multi_transformer_op_desc.SetAttr(
"ffn2_in_scale", std::vector<float>{ffn1_in_scale});
fused_multi_transformer_op_desc.SetInput(
"QKVOutScale", {matmul0_w->Name() + "_out_scale"});
fused_multi_transformer_op_desc.SetInput(
"OutLinearOutScale", {matmul_linear_w->Name() + "_out_scale"});
fused_multi_transformer_op_desc.SetInput(
"FFN1OutScale", {ffn_matmul0_w->Name() + "_out_scale"});
fused_multi_transformer_op_desc.SetInput(
"FFN2OutScale", {ffn_matmul1_w->Name() + "_out_scale"});
}
auto* fused_multi_transformer = auto* fused_multi_transformer =
graph->CreateOpNode(&fused_multi_transformer_op_desc); graph->CreateOpNode(&fused_multi_transformer_op_desc);
if (enable_int8) {
auto qkv_out_scale_node =
CreatePersistableVarNode(graph, matmul0_w->Name() + "_out_scale");
auto out_out_scale_node = CreatePersistableVarNode(
graph, matmul_linear_w->Name() + "_out_scale");
auto ffn0_out_scale_node =
CreatePersistableVarNode(graph, ffn_matmul0_w->Name() + "_out_scale");
auto ffn1_out_scale_node =
CreatePersistableVarNode(graph, ffn_matmul1_w->Name() + "_out_scale");
IR_NODE_LINK_TO(qkv_out_scale_node, fused_multi_transformer);
IR_NODE_LINK_TO(out_out_scale_node, fused_multi_transformer);
IR_NODE_LINK_TO(ffn0_out_scale_node, fused_multi_transformer);
IR_NODE_LINK_TO(ffn1_out_scale_node, fused_multi_transformer);
}
IR_NODE_LINK_TO(input0, fused_multi_transformer); IR_NODE_LINK_TO(input0, fused_multi_transformer);
IR_NODE_LINK_TO(layer_norm_scale, fused_multi_transformer); IR_NODE_LINK_TO(layer_norm_scale, fused_multi_transformer);
IR_NODE_LINK_TO(layer_norm_bias, fused_multi_transformer); IR_NODE_LINK_TO(layer_norm_bias, fused_multi_transformer);
...@@ -1456,6 +1539,7 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph, ...@@ -1456,6 +1539,7 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
layer_norm_bias, layer_norm_bias,
layer_norm_mean, layer_norm_mean,
layer_norm_variance, layer_norm_variance,
matmul0,
matmul0_w, matmul0_w,
matmul1_w, matmul1_w,
matmul2_w, matmul2_w,
...@@ -1466,6 +1550,7 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph, ...@@ -1466,6 +1550,7 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
transpose2_2_out, transpose2_2_out,
eltadd_qk_b, eltadd_qk_b,
reshape2_0, reshape2_0,
matmul_linear,
matmul_linear_w, matmul_linear_w,
eltadd_linear_b, eltadd_linear_b,
ffn_layer_norm, ffn_layer_norm,
...@@ -1473,7 +1558,9 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph, ...@@ -1473,7 +1558,9 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
ffn_layer_norm_bias, ffn_layer_norm_bias,
ffn_layer_norm_mean, ffn_layer_norm_mean,
ffn_layer_norm_variance, ffn_layer_norm_variance,
ffn_matmul0,
ffn_matmul0_w, ffn_matmul0_w,
ffn_matmul1,
ffn_matmul1_w, ffn_matmul1_w,
ffn_eltadd0_b, ffn_eltadd0_b,
ffn_eltadd1_b, ffn_eltadd1_b,
...@@ -1732,6 +1819,13 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion( ...@@ -1732,6 +1819,13 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto* pattern = gpd.mutable_pattern(); auto* pattern = gpd.mutable_pattern();
bool enable_int8 = graph->Get<bool>("enable_int8");
if (enable_int8) {
VLOG(3) << "FusedMultiTransformerDecoderFuseQKVPass with int8";
} else {
VLOG(3) << "FusedMultiTransformerDecoderFuseQKVPass with fp";
}
// Create pattern. // Create pattern.
patterns::FusedMultiTransformerDecoderFuseQKVPattern patterns::FusedMultiTransformerDecoderFuseQKVPattern
fused_multi_transformer_fuse_qkv_pattern(pattern, name_scope); fused_multi_transformer_fuse_qkv_pattern(pattern, name_scope);
...@@ -1744,10 +1838,12 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion( ...@@ -1744,10 +1838,12 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
Node* layer_norm_bias, Node* layer_norm_bias,
Node* layer_norm_mean, Node* layer_norm_mean,
Node* layer_norm_variance, Node* layer_norm_variance,
Node* matmul0,
Node* matmul0_w, Node* matmul0_w,
Node* eltadd0_b, Node* eltadd0_b,
Node* eltadd_qk_b, Node* eltadd_qk_b,
Node* reshape2_0, Node* reshape2_0,
Node* matmul_linear,
Node* matmul_linear_w, Node* matmul_linear_w,
Node* eltadd_linear_b, Node* eltadd_linear_b,
Node* ffn_layer_norm, Node* ffn_layer_norm,
...@@ -1755,11 +1851,17 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion( ...@@ -1755,11 +1851,17 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
Node* ffn_layer_norm_bias, Node* ffn_layer_norm_bias,
Node* ffn_layer_norm_mean, Node* ffn_layer_norm_mean,
Node* ffn_layer_norm_variance, Node* ffn_layer_norm_variance,
Node* ffn_matmul0,
Node* ffn_matmul0_w, Node* ffn_matmul0_w,
Node* ffn_matmul1,
Node* ffn_matmul1_w, Node* ffn_matmul1_w,
Node* ffn_eltadd0_b, Node* ffn_eltadd0_b,
Node* ffn_eltadd1_b, Node* ffn_eltadd1_b,
Node* ffn_output) { Node* ffn_output) {
auto* matmul0_op = matmul0->Op();
auto* matmul_linear_op = matmul_linear->Op();
auto* ffn_matmul_0_op = ffn_matmul0->Op();
auto* ffn_matmul_1_op = ffn_matmul1->Op();
// Calc index of transformer layer by LayerNorm Scale name // Calc index of transformer layer by LayerNorm Scale name
// This calculation assumes: // This calculation assumes:
// 1. no LayerNorm before all transformer layer // 1. no LayerNorm before all transformer layer
...@@ -1771,7 +1873,9 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion( ...@@ -1771,7 +1873,9 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
// create fused_multi_transformer // create fused_multi_transformer
OpDesc fused_multi_transformer_op_desc(layer_norm->Op()->Block()); OpDesc fused_multi_transformer_op_desc(layer_norm->Op()->Block());
fused_multi_transformer_op_desc.SetType("fused_multi_transformer"); fused_multi_transformer_op_desc.SetType(enable_int8
? "fused_multi_transformer_int8"
: "fused_multi_transformer");
// 1. Input setting // 1. Input setting
fused_multi_transformer_op_desc.SetInput("X", {input0->Name()}); fused_multi_transformer_op_desc.SetInput("X", {input0->Name()});
...@@ -1826,8 +1930,65 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion( ...@@ -1826,8 +1930,65 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
fused_multi_transformer_op_desc.SetAttr("dropout_rate", 0.0f); fused_multi_transformer_op_desc.SetAttr("dropout_rate", 0.0f);
fused_multi_transformer_op_desc.SetAttr("is_test", true); fused_multi_transformer_op_desc.SetAttr("is_test", true);
if (enable_int8) {
// Set input scale
std::string qkv_input_name = matmul0_op->Input("X")[0];
auto qkv_in_scale = PADDLE_GET_CONST(
float, matmul0_op->GetAttr("Input_scale_" + qkv_input_name));
std::string out_linear_input_name = matmul_linear_op->Input("X")[0];
auto out_linear_in_scale = PADDLE_GET_CONST(
float,
matmul_linear_op->GetAttr("Input_scale_" + out_linear_input_name));
std::string ffn0_input_name = ffn_matmul_0_op->Input("X")[0];
auto ffn0_in_scale = PADDLE_GET_CONST(
float, ffn_matmul_0_op->GetAttr("Input_scale_" + ffn0_input_name));
std::string ffn1_input_name = ffn_matmul_1_op->Input("X")[0];
auto ffn1_in_scale = PADDLE_GET_CONST(
float, ffn_matmul_1_op->GetAttr("Input_scale_" + ffn1_input_name));
// Inverse input scale
qkv_in_scale = 1.0f / qkv_in_scale;
out_linear_in_scale = 1.0f / out_linear_in_scale;
ffn0_in_scale = 1.0f / ffn0_in_scale;
ffn1_in_scale = 1.0f / ffn1_in_scale;
fused_multi_transformer_op_desc.SetAttr("qkv_in_scale",
std::vector<float>{qkv_in_scale});
fused_multi_transformer_op_desc.SetAttr(
"out_linear_in_scale", std::vector<float>{out_linear_in_scale});
fused_multi_transformer_op_desc.SetAttr(
"ffn1_in_scale", std::vector<float>{ffn0_in_scale});
fused_multi_transformer_op_desc.SetAttr(
"ffn2_in_scale", std::vector<float>{ffn1_in_scale});
fused_multi_transformer_op_desc.SetInput(
"QKVOutScale", {matmul0_w->Name() + "_out_scale"});
fused_multi_transformer_op_desc.SetInput(
"OutLinearOutScale", {matmul_linear_w->Name() + "_out_scale"});
fused_multi_transformer_op_desc.SetInput(
"FFN1OutScale", {ffn_matmul0_w->Name() + "_out_scale"});
fused_multi_transformer_op_desc.SetInput(
"FFN2OutScale", {ffn_matmul1_w->Name() + "_out_scale"});
}
auto* fused_multi_transformer = auto* fused_multi_transformer =
graph->CreateOpNode(&fused_multi_transformer_op_desc); graph->CreateOpNode(&fused_multi_transformer_op_desc);
if (enable_int8) {
auto qkv_out_scale_node =
CreatePersistableVarNode(graph, matmul0_w->Name() + "_out_scale");
auto out_out_scale_node = CreatePersistableVarNode(
graph, matmul_linear_w->Name() + "_out_scale");
auto ffn0_out_scale_node =
CreatePersistableVarNode(graph, ffn_matmul0_w->Name() + "_out_scale");
auto ffn1_out_scale_node =
CreatePersistableVarNode(graph, ffn_matmul1_w->Name() + "_out_scale");
IR_NODE_LINK_TO(qkv_out_scale_node, fused_multi_transformer);
IR_NODE_LINK_TO(out_out_scale_node, fused_multi_transformer);
IR_NODE_LINK_TO(ffn0_out_scale_node, fused_multi_transformer);
IR_NODE_LINK_TO(ffn1_out_scale_node, fused_multi_transformer);
}
IR_NODE_LINK_TO(input0, fused_multi_transformer); IR_NODE_LINK_TO(input0, fused_multi_transformer);
IR_NODE_LINK_TO(layer_norm_scale, fused_multi_transformer); IR_NODE_LINK_TO(layer_norm_scale, fused_multi_transformer);
IR_NODE_LINK_TO(layer_norm_bias, fused_multi_transformer); IR_NODE_LINK_TO(layer_norm_bias, fused_multi_transformer);
...@@ -2088,10 +2249,12 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion( ...@@ -2088,10 +2249,12 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
layer_norm_bias, layer_norm_bias,
layer_norm_mean, layer_norm_mean,
layer_norm_variance, layer_norm_variance,
matmul0,
matmul0_w, matmul0_w,
eltadd0_b, eltadd0_b,
eltadd_qk_b, eltadd_qk_b,
reshape2_0, reshape2_0,
matmul_linear,
matmul_linear_w, matmul_linear_w,
eltadd_linear_b, eltadd_linear_b,
ffn_layer_norm, ffn_layer_norm,
...@@ -2099,7 +2262,9 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion( ...@@ -2099,7 +2262,9 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
ffn_layer_norm_bias, ffn_layer_norm_bias,
ffn_layer_norm_mean, ffn_layer_norm_mean,
ffn_layer_norm_variance, ffn_layer_norm_variance,
ffn_matmul0,
ffn_matmul0_w, ffn_matmul0_w,
ffn_matmul1,
ffn_matmul1_w, ffn_matmul1_w,
ffn_eltadd0_b, ffn_eltadd0_b,
ffn_eltadd1_b, ffn_eltadd1_b,
...@@ -2349,6 +2514,13 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion( ...@@ -2349,6 +2514,13 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto* pattern = gpd.mutable_pattern(); auto* pattern = gpd.mutable_pattern();
bool enable_int8 = graph->Get<bool>("enable_int8");
if (enable_int8) {
VLOG(3) << "MultiDevicesFusedMultiTransformerDecoderFuseQKVPass with int8";
} else {
VLOG(3) << "MultiDevicesFusedMultiTransformerDecoderFuseQKVPass with fp";
}
// Create pattern. // Create pattern.
patterns::MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern patterns::MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern
fused_multi_transformer_fuse_qkv_pattern(pattern, name_scope); fused_multi_transformer_fuse_qkv_pattern(pattern, name_scope);
...@@ -2362,10 +2534,12 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion( ...@@ -2362,10 +2534,12 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
Node* layer_norm_mean, Node* layer_norm_mean,
Node* layer_norm_variance, Node* layer_norm_variance,
Node* c_identity, Node* c_identity,
Node* matmul0,
Node* matmul0_w, Node* matmul0_w,
Node* eltadd0_b, Node* eltadd0_b,
Node* eltadd_qk_b, Node* eltadd_qk_b,
Node* reshape2_0, Node* reshape2_0,
Node* matmul_linear,
Node* matmul_linear_w, Node* matmul_linear_w,
Node* eltadd_linear_b, Node* eltadd_linear_b,
Node* ffn_layer_norm, Node* ffn_layer_norm,
...@@ -2373,11 +2547,16 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion( ...@@ -2373,11 +2547,16 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
Node* ffn_layer_norm_bias, Node* ffn_layer_norm_bias,
Node* ffn_layer_norm_mean, Node* ffn_layer_norm_mean,
Node* ffn_layer_norm_variance, Node* ffn_layer_norm_variance,
Node* ffn_c_identity,
Node* ffn_matmul0,
Node* ffn_matmul0_w, Node* ffn_matmul0_w,
Node* ffn_matmul1,
Node* ffn_matmul1_w, Node* ffn_matmul1_w,
Node* ffn_eltadd0_b, Node* ffn_eltadd0_b,
Node* ffn_eltadd1_b, Node* ffn_eltadd1_b,
Node* ffn_output) { Node* ffn_output) {
auto* matmul_linear_op = matmul_linear->Op();
auto* ffn_matmul_1_op = ffn_matmul1->Op();
// Calc index of transformer layer by LayerNorm Scale name // Calc index of transformer layer by LayerNorm Scale name
// This calculation assumes: // This calculation assumes:
// 1. no LayerNorm before all transformer layer // 1. no LayerNorm before all transformer layer
...@@ -2389,7 +2568,9 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion( ...@@ -2389,7 +2568,9 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
// create fused_multi_transformer // create fused_multi_transformer
OpDesc fused_multi_transformer_op_desc(layer_norm->Op()->Block()); OpDesc fused_multi_transformer_op_desc(layer_norm->Op()->Block());
fused_multi_transformer_op_desc.SetType("fused_multi_transformer"); fused_multi_transformer_op_desc.SetType(enable_int8
? "fused_multi_transformer_int8"
: "fused_multi_transformer");
// 1. Input setting // 1. Input setting
fused_multi_transformer_op_desc.SetInput("X", {input0->Name()}); fused_multi_transformer_op_desc.SetInput("X", {input0->Name()});
...@@ -2449,8 +2630,71 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion( ...@@ -2449,8 +2630,71 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
fused_multi_transformer_op_desc.SetAttr("ring_id", fused_multi_transformer_op_desc.SetAttr("ring_id",
c_identity_op->GetAttr("ring_id")); c_identity_op->GetAttr("ring_id"));
if (enable_int8) {
std::string matmul_input_scale_suffix = c_identity_op->Input("X")[0];
auto qkv_in_scale = PADDLE_GET_CONST(
float,
c_identity_op->GetAttr("Input_scale_" + matmul_input_scale_suffix));
std::string out_linear_input_name = matmul_linear_op->Input("X")[0];
auto out_linear_in_scale = PADDLE_GET_CONST(
float,
matmul_linear_op->GetAttr("Input_scale_" + out_linear_input_name));
auto* ffn_c_identity_op = ffn_c_identity->Op();
std::string ffn_input_scale_suffix = ffn_c_identity_op->Input("X")[0];
auto ffn0_in_scale = PADDLE_GET_CONST(
float,
ffn_c_identity_op->GetAttr("Input_scale_" + ffn_input_scale_suffix));
std::string ffn1_input_name = ffn_matmul_1_op->Input("X")[0];
auto ffn1_in_scale = PADDLE_GET_CONST(
float, ffn_matmul_1_op->GetAttr("Input_scale_" + ffn1_input_name));
// Inverse input scale
qkv_in_scale = 1.0f / qkv_in_scale;
out_linear_in_scale = 1.0f / out_linear_in_scale;
ffn0_in_scale = 1.0f / ffn0_in_scale;
ffn1_in_scale = 1.0f / ffn1_in_scale;
fused_multi_transformer_op_desc.SetAttr("qkv_in_scale",
std::vector<float>{qkv_in_scale});
fused_multi_transformer_op_desc.SetAttr(
"out_linear_in_scale", std::vector<float>{out_linear_in_scale});
fused_multi_transformer_op_desc.SetAttr(
"ffn1_in_scale", std::vector<float>{ffn0_in_scale});
fused_multi_transformer_op_desc.SetAttr(
"ffn2_in_scale", std::vector<float>{ffn1_in_scale});
fused_multi_transformer_op_desc.SetInput(
"QKVOutScale", {matmul0_w->Name() + "_out_scale"});
fused_multi_transformer_op_desc.SetInput(
"OutLinearOutScale", {matmul_linear_w->Name() + "_out_scale"});
fused_multi_transformer_op_desc.SetInput(
"FFN1OutScale", {ffn_matmul0_w->Name() + "_out_scale"});
fused_multi_transformer_op_desc.SetInput(
"FFN2OutScale", {ffn_matmul1_w->Name() + "_out_scale"});
}
auto* fused_multi_transformer = auto* fused_multi_transformer =
graph->CreateOpNode(&fused_multi_transformer_op_desc); graph->CreateOpNode(&fused_multi_transformer_op_desc);
if (enable_int8) {
auto qkv_out_scale_node =
CreatePersistableVarNode(graph, matmul0_w->Name() + "_out_scale");
auto out_out_scale_node = CreatePersistableVarNode(
graph, matmul_linear_w->Name() + "_out_scale");
auto ffn0_out_scale_node =
CreatePersistableVarNode(graph, ffn_matmul0_w->Name() + "_out_scale");
auto ffn1_out_scale_node =
CreatePersistableVarNode(graph, ffn_matmul1_w->Name() + "_out_scale");
IR_NODE_LINK_TO(qkv_out_scale_node, fused_multi_transformer);
IR_NODE_LINK_TO(out_out_scale_node, fused_multi_transformer);
IR_NODE_LINK_TO(ffn0_out_scale_node, fused_multi_transformer);
IR_NODE_LINK_TO(ffn1_out_scale_node, fused_multi_transformer);
}
IR_NODE_LINK_TO(input0, fused_multi_transformer); IR_NODE_LINK_TO(input0, fused_multi_transformer);
IR_NODE_LINK_TO(layer_norm_scale, fused_multi_transformer); IR_NODE_LINK_TO(layer_norm_scale, fused_multi_transformer);
IR_NODE_LINK_TO(layer_norm_bias, fused_multi_transformer); IR_NODE_LINK_TO(layer_norm_bias, fused_multi_transformer);
...@@ -2737,10 +2981,12 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion( ...@@ -2737,10 +2981,12 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
layer_norm_mean, layer_norm_mean,
layer_norm_variance, layer_norm_variance,
c_identity, c_identity,
matmul0,
matmul0_w, matmul0_w,
eltadd0_b, eltadd0_b,
eltadd_qk_b, eltadd_qk_b,
reshape2_0, reshape2_0,
matmul_linear,
matmul_linear_w, matmul_linear_w,
eltadd_linear_b, eltadd_linear_b,
ffn_layer_norm, ffn_layer_norm,
...@@ -2748,7 +2994,10 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion( ...@@ -2748,7 +2994,10 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
ffn_layer_norm_bias, ffn_layer_norm_bias,
ffn_layer_norm_mean, ffn_layer_norm_mean,
ffn_layer_norm_variance, ffn_layer_norm_variance,
ffn_c_identity,
ffn_matmul0,
ffn_matmul0_w, ffn_matmul0_w,
ffn_matmul1,
ffn_matmul1_w, ffn_matmul1_w,
ffn_eltadd0_b, ffn_eltadd0_b,
ffn_eltadd1_b, ffn_eltadd1_b,
......
...@@ -193,6 +193,7 @@ TEST(FusedMultiTransformerDecoderPass, basic) { ...@@ -193,6 +193,7 @@ TEST(FusedMultiTransformerDecoderPass, basic) {
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program())); std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
graph->Set("__param_scope__", CreateParamScope()); graph->Set("__param_scope__", CreateParamScope());
graph->Set("enable_int8", new bool(false));
auto pass = auto pass =
PassRegistry::Instance().Get("fused_multi_transformer_decoder_pass"); PassRegistry::Instance().Get("fused_multi_transformer_decoder_pass");
...@@ -344,6 +345,7 @@ TEST(FusedMultiTransformerDecoderFuseQKVPass, basic) { ...@@ -344,6 +345,7 @@ TEST(FusedMultiTransformerDecoderFuseQKVPass, basic) {
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program())); std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
graph->Set("__param_scope__", CreateParamScope()); graph->Set("__param_scope__", CreateParamScope());
graph->Set("enable_int8", new bool(false));
auto pass = PassRegistry::Instance().Get( auto pass = PassRegistry::Instance().Get(
"fused_multi_transformer_decoder_fuse_qkv_pass"); "fused_multi_transformer_decoder_fuse_qkv_pass");
...@@ -503,6 +505,7 @@ TEST(MultiDevicesFusedMultiTransformerDecoderFuseQKVPass, basic) { ...@@ -503,6 +505,7 @@ TEST(MultiDevicesFusedMultiTransformerDecoderFuseQKVPass, basic) {
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program())); std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
graph->Set("__param_scope__", CreateParamScope()); graph->Set("__param_scope__", CreateParamScope());
graph->Set("enable_int8", new bool(false));
auto pass = PassRegistry::Instance().Get( auto pass = PassRegistry::Instance().Get(
"multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass"); "multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass");
......
...@@ -1025,21 +1025,14 @@ template <typename T> ...@@ -1025,21 +1025,14 @@ template <typename T>
inline void QKVWeightsProcess(phi::DenseTensor* wq_tensor, inline void QKVWeightsProcess(phi::DenseTensor* wq_tensor,
phi::DenseTensor* wk_tensor, phi::DenseTensor* wk_tensor,
phi::DenseTensor* wv_tensor, phi::DenseTensor* wv_tensor,
phi::DenseTensor* bq_tensor,
phi::DenseTensor* bk_tensor,
phi::DenseTensor* bv_tensor,
const int num_head, const int num_head,
const int dim_head, const int dim_head,
const int dim_embed) { const int dim_embed) {
auto* wq_data = wq_tensor->mutable_data<T>(platform::CPUPlace()); auto* wq_data = wq_tensor->mutable_data<T>(platform::CPUPlace());
auto* wk_data = wk_tensor->mutable_data<T>(platform::CPUPlace()); auto* wk_data = wk_tensor->mutable_data<T>(platform::CPUPlace());
auto* wv_data = wv_tensor->mutable_data<T>(platform::CPUPlace()); auto* wv_data = wv_tensor->mutable_data<T>(platform::CPUPlace());
auto* bq_data = bq_tensor->mutable_data<T>(platform::CPUPlace());
auto* bk_data = bk_tensor->mutable_data<T>(platform::CPUPlace());
auto* bv_data = bv_tensor->mutable_data<T>(platform::CPUPlace());
auto combined_w_dims = phi::make_ddim({3, num_head, dim_head, dim_embed}); auto combined_w_dims = phi::make_ddim({3, num_head, dim_head, dim_embed});
auto combined_bias_dims = phi::make_ddim({3, num_head, dim_head});
phi::DenseTensor tmp_combined_w_tensor; phi::DenseTensor tmp_combined_w_tensor;
tmp_combined_w_tensor.Resize(combined_w_dims); tmp_combined_w_tensor.Resize(combined_w_dims);
...@@ -1065,6 +1058,20 @@ inline void QKVWeightsProcess(phi::DenseTensor* wq_tensor, ...@@ -1065,6 +1058,20 @@ inline void QKVWeightsProcess(phi::DenseTensor* wq_tensor,
auto* new_combined_w_data = wq_tensor->mutable_data<T>(platform::CPUPlace()); auto* new_combined_w_data = wq_tensor->mutable_data<T>(platform::CPUPlace());
memcpy( memcpy(
new_combined_w_data, tmp_combined_w_data, sizeof(T) * wq_tensor->numel()); new_combined_w_data, tmp_combined_w_data, sizeof(T) * wq_tensor->numel());
}
template <typename T>
inline void QKVBiasProcess(phi::DenseTensor* bq_tensor,
phi::DenseTensor* bk_tensor,
phi::DenseTensor* bv_tensor,
const int num_head,
const int dim_head,
const int dim_embed) {
auto* bq_data = bq_tensor->mutable_data<T>(platform::CPUPlace());
auto* bk_data = bk_tensor->mutable_data<T>(platform::CPUPlace());
auto* bv_data = bv_tensor->mutable_data<T>(platform::CPUPlace());
auto combined_bias_dims = phi::make_ddim({3, num_head, dim_head});
phi::DenseTensor tmp_combined_bias_tensor; phi::DenseTensor tmp_combined_bias_tensor;
tmp_combined_bias_tensor.Resize(combined_bias_dims); tmp_combined_bias_tensor.Resize(combined_bias_dims);
...@@ -1085,13 +1092,57 @@ inline void QKVWeightsProcess(phi::DenseTensor* wq_tensor, ...@@ -1085,13 +1092,57 @@ inline void QKVWeightsProcess(phi::DenseTensor* wq_tensor,
sizeof(T) * bq_tensor->numel()); sizeof(T) * bq_tensor->numel());
} }
inline void QKVWeightsBiasProcess(phi::DenseTensor* wq_tensor,
phi::DenseTensor* wk_tensor,
phi::DenseTensor* wv_tensor,
phi::DenseTensor* bq_tensor,
phi::DenseTensor* bk_tensor,
phi::DenseTensor* bv_tensor,
const int num_head,
const int dim_head,
const int dim_embed) {
switch (wq_tensor->dtype()) {
case paddle::experimental::DataType::FLOAT16:
QKVWeightsProcess<platform::float16>(
wq_tensor, wk_tensor, wv_tensor, num_head, dim_head, dim_embed);
break;
case paddle::experimental::DataType::FLOAT32:
QKVWeightsProcess<float>(
wq_tensor, wk_tensor, wv_tensor, num_head, dim_head, dim_embed);
break;
case paddle::experimental::DataType::INT8:
QKVWeightsProcess<int8_t>(
wq_tensor, wk_tensor, wv_tensor, num_head, dim_head, dim_embed);
break;
default:
PADDLE_THROW(platform::errors::Unavailable(
"fused_multi_transformer not supported weight dtype. "
"we now only support fp32/fp16/int8."));
break;
}
switch (bq_tensor->dtype()) {
case paddle::experimental::DataType::FLOAT16:
QKVBiasProcess<platform::float16>(
bq_tensor, bk_tensor, bv_tensor, num_head, dim_head, dim_embed);
break;
case paddle::experimental::DataType::FLOAT32:
QKVBiasProcess<float>(
bq_tensor, bk_tensor, bv_tensor, num_head, dim_head, dim_embed);
break;
default:
PADDLE_THROW(platform::errors::Unavailable(
"fused_multi_transformer not supported bias dtype. "
"we now only support fp32/fp16."));
break;
}
}
template <typename T> template <typename T>
inline void QKVWeightsProcessFuseQKV(phi::DenseTensor* qkv_w_tensor, inline void QKVWeightsProcessFuseQKV(phi::DenseTensor* qkv_w_tensor,
phi::DenseTensor* qkv_b_tensor,
const int num_head, const int num_head,
const int dim_head, const int dim_head,
const int dim_embed) { const int dim_embed) {
auto* qkv_w_data = qkv_w_tensor->mutable_data<T>(platform::CPUPlace()); auto* qkv_w_data = qkv_w_tensor->data<T>();
auto transpose_w_dims = phi::make_ddim({3, num_head, dim_head, dim_embed}); auto transpose_w_dims = phi::make_ddim({3, num_head, dim_head, dim_embed});
phi::DenseTensor tmp_transpose_w_tensor; phi::DenseTensor tmp_transpose_w_tensor;
...@@ -1120,8 +1171,14 @@ inline void QKVWeightsProcessFuseQKV(phi::DenseTensor* qkv_w_tensor, ...@@ -1120,8 +1171,14 @@ inline void QKVWeightsProcessFuseQKV(phi::DenseTensor* qkv_w_tensor,
memcpy(new_transpose_w_data, memcpy(new_transpose_w_data,
tmp_transpose_w_data, tmp_transpose_w_data,
sizeof(T) * qkv_w_tensor->numel()); sizeof(T) * qkv_w_tensor->numel());
}
auto* qkv_b_data = qkv_b_tensor->mutable_data<T>(platform::CPUPlace()); template <typename T>
inline void QKVBiasProcessFuseQKV(phi::DenseTensor* qkv_b_tensor,
const int num_head,
const int dim_head,
const int dim_embed) {
auto* qkv_b_data = qkv_b_tensor->data<T>();
auto transpose_b_dims = phi::make_ddim({3, num_head, dim_head}); auto transpose_b_dims = phi::make_ddim({3, num_head, dim_head});
phi::DenseTensor tmp_transpose_b_tensor; phi::DenseTensor tmp_transpose_b_tensor;
...@@ -1148,11 +1205,86 @@ inline void QKVWeightsProcessFuseQKV(phi::DenseTensor* qkv_w_tensor, ...@@ -1148,11 +1205,86 @@ inline void QKVWeightsProcessFuseQKV(phi::DenseTensor* qkv_w_tensor,
sizeof(T) * qkv_b_tensor->numel()); sizeof(T) * qkv_b_tensor->numel());
} }
inline void QKVWeightsBiasProcessFuseQKV(phi::DenseTensor* qkv_w_tensor,
phi::DenseTensor* qkv_b_tensor,
const int num_head,
const int dim_head,
const int dim_embed) {
switch (qkv_w_tensor->dtype()) {
case paddle::experimental::DataType::FLOAT16:
QKVWeightsProcessFuseQKV<platform::float16>(
qkv_w_tensor, num_head, dim_head, dim_embed);
break;
case paddle::experimental::DataType::FLOAT32:
QKVWeightsProcessFuseQKV<float>(
qkv_w_tensor, num_head, dim_head, dim_embed);
break;
case paddle::experimental::DataType::INT8:
QKVWeightsProcessFuseQKV<int8_t>(
qkv_w_tensor, num_head, dim_head, dim_embed);
break;
default:
PADDLE_THROW(platform::errors::Unavailable(
"fused_multi_transformer not supported weight dtype. "
"we now only support fp32/fp16/int8."));
break;
}
switch (qkv_b_tensor->dtype()) {
case paddle::experimental::DataType::FLOAT16:
QKVBiasProcessFuseQKV<platform::float16>(
qkv_b_tensor, num_head, dim_head, dim_embed);
break;
case paddle::experimental::DataType::FLOAT32:
QKVBiasProcessFuseQKV<float>(qkv_b_tensor, num_head, dim_head, dim_embed);
break;
default:
PADDLE_THROW(platform::errors::Unavailable(
"fused_multi_transformer not supported bias dtype. "
"we now only support fp32/fp16."));
break;
}
}
// Just use for fused_multi_transformer_int8
inline void TransposeWeights(phi::DenseTensor* weight_tensor) {
int m = weight_tensor->dims()[0];
int n = weight_tensor->dims()[1];
phi::DenseTensor tmp_weight_tensor;
auto tmp_weight_data =
tmp_weight_tensor.mutable_data<int8_t>({n, m}, platform::CPUPlace());
auto weight_data = weight_tensor->data<int8_t>();
for (int i = 0; i < m; ++i) {
for (int j = 0; j < n; ++j) {
int in_idx = i * n + j;
int out_idx = j * m + i;
tmp_weight_data[out_idx] = weight_data[in_idx];
}
}
weight_tensor->Resize({n, m});
auto new_weight_data =
weight_tensor->mutable_data<int8_t>(platform::CPUPlace());
memcpy(new_weight_data, tmp_weight_data, sizeof(int8_t) * m * n);
}
inline Node* CreatePersistableVarNode(Graph* graph, const std::string& name) {
auto var_desc = VarDesc(name);
var_desc.SetDataType(framework::proto::VarType::FP32);
var_desc.SetPersistable(true);
auto node = graph->CreateVarNode(&var_desc);
return node;
}
int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph, int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
const std::string& name_scope, const std::string& name_scope,
Scope* scope) const { Scope* scope) const {
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto* pattern = gpd.mutable_pattern(); auto* pattern = gpd.mutable_pattern();
bool enable_int8 = graph->Get<bool>("enable_int8");
if (enable_int8) {
VLOG(3) << "FusedMultiTransformerEncoderPass with int8";
} else {
VLOG(3) << "FusedMultiTransformerEncoderPass with fp";
}
// Create pattern. // Create pattern.
patterns::FusedMultiTransformerEncoderPattern fused_multi_transformer_pattern( patterns::FusedMultiTransformerEncoderPattern fused_multi_transformer_pattern(
...@@ -1166,6 +1298,7 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph, ...@@ -1166,6 +1298,7 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
Node* layer_norm_bias, Node* layer_norm_bias,
Node* layer_norm_mean, Node* layer_norm_mean,
Node* layer_norm_variance, Node* layer_norm_variance,
Node* matmul0,
Node* matmul0_w, Node* matmul0_w,
Node* matmul1_w, Node* matmul1_w,
Node* matmul2_w, Node* matmul2_w,
...@@ -1176,6 +1309,7 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph, ...@@ -1176,6 +1309,7 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
Node* transpose2_2_out, Node* transpose2_2_out,
Node* eltadd_qk_b, Node* eltadd_qk_b,
Node* reshape2_0, Node* reshape2_0,
Node* matmul_linear,
Node* matmul_linear_w, Node* matmul_linear_w,
Node* eltadd_linear_b, Node* eltadd_linear_b,
Node* while0, Node* while0,
...@@ -1184,7 +1318,9 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph, ...@@ -1184,7 +1318,9 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
Node* ffn_layer_norm_bias, Node* ffn_layer_norm_bias,
Node* ffn_layer_norm_mean, Node* ffn_layer_norm_mean,
Node* ffn_layer_norm_variance, Node* ffn_layer_norm_variance,
Node* ffn_matmul0,
Node* ffn_matmul0_w, Node* ffn_matmul0_w,
Node* ffn_matmul1,
Node* ffn_matmul1_w, Node* ffn_matmul1_w,
Node* ffn_eltadd0_b, Node* ffn_eltadd0_b,
Node* ffn_eltadd1_b, Node* ffn_eltadd1_b,
...@@ -1196,7 +1332,14 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph, ...@@ -1196,7 +1332,14 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
int dim_head = int dim_head =
PADDLE_GET_CONST(std::vector<int>, reshape_desc->GetAttr("shape")) PADDLE_GET_CONST(std::vector<int>, reshape_desc->GetAttr("shape"))
.at(3); .at(3);
int dim_embed = num_head * dim_head; auto* layer_norm_bias_tensor =
scope->FindVar(layer_norm_bias->Name())->GetMutable<phi::DenseTensor>();
int dim_embed = layer_norm_bias_tensor->dims()[0];
auto* matmul0_op = matmul0->Op();
auto* matmul_linear_op = matmul_linear->Op();
auto* ffn_matmul_0_op = ffn_matmul0->Op();
auto* ffn_matmul_1_op = ffn_matmul1->Op();
// Calc index of transformer layer by LayerNorm Scale name // Calc index of transformer layer by LayerNorm Scale name
// This calculation assumes: // This calculation assumes:
...@@ -1221,30 +1364,27 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph, ...@@ -1221,30 +1364,27 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
auto* bv_tensor = auto* bv_tensor =
scope->FindVar(eltadd2_b->Name())->GetMutable<phi::DenseTensor>(); scope->FindVar(eltadd2_b->Name())->GetMutable<phi::DenseTensor>();
if (wq_tensor->dtype() == phi::DataType::FLOAT32) { QKVWeightsBiasProcess(wq_tensor,
QKVWeightsProcess<float>(wq_tensor, wk_tensor,
wk_tensor, wv_tensor,
wv_tensor, bq_tensor,
bq_tensor, bk_tensor,
bk_tensor, bv_tensor,
bv_tensor, num_head,
num_head, dim_head,
dim_head, dim_embed);
dim_embed);
} else if (wq_tensor->dtype() == phi::DataType::FLOAT16) { if (enable_int8) {
QKVWeightsProcess<platform::float16>(wq_tensor, auto* out_linear_w_tensor = scope->FindVar(matmul_linear_w->Name())
wk_tensor, ->GetMutable<phi::DenseTensor>();
wv_tensor, auto* ffn0_w_tensor =
bq_tensor, scope->FindVar(ffn_matmul0_w->Name())->GetMutable<phi::DenseTensor>();
bk_tensor, auto* ffn1_w_tensor =
bv_tensor, scope->FindVar(ffn_matmul1_w->Name())->GetMutable<phi::DenseTensor>();
num_head,
dim_head, TransposeWeights(out_linear_w_tensor);
dim_embed); TransposeWeights(ffn0_w_tensor);
} else { TransposeWeights(ffn1_w_tensor);
PADDLE_THROW(platform::errors::Unavailable(
"fused_multi_transformer not supported weight dtype. "
"we now only support fp32 and fp16."));
} }
// reuse the mul0_w and eltadd_0_b nodes for the combined nodes. // reuse the mul0_w and eltadd_0_b nodes for the combined nodes.
...@@ -1261,7 +1401,9 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph, ...@@ -1261,7 +1401,9 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
// create fused_multi_transformer // create fused_multi_transformer
OpDesc fused_multi_transformer_op_desc(layer_norm->Op()->Block()); OpDesc fused_multi_transformer_op_desc(layer_norm->Op()->Block());
fused_multi_transformer_op_desc.SetType("fused_multi_transformer"); fused_multi_transformer_op_desc.SetType(enable_int8
? "fused_multi_transformer_int8"
: "fused_multi_transformer");
// 1. Input setting // 1. Input setting
fused_multi_transformer_op_desc.SetInput("X", {input0->Name()}); fused_multi_transformer_op_desc.SetInput("X", {input0->Name()});
...@@ -1281,7 +1423,7 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph, ...@@ -1281,7 +1423,7 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
VarDesc cache_kv_desc("cache_kv" + std::to_string(layer_idx)); VarDesc cache_kv_desc("cache_kv" + std::to_string(layer_idx));
// FIXME: only support max_seq_len <= 1024 // FIXME: only support max_seq_len <= 1024
cache_kv_desc.SetDataType( cache_kv_desc.SetDataType(
framework::TransToProtoVarType(wq_tensor->dtype())); framework::TransToProtoVarType(bq_tensor->dtype()));
cache_kv_desc.SetPersistable(false); cache_kv_desc.SetPersistable(false);
auto* cache_kv = graph->CreateVarNode(&cache_kv_desc); auto* cache_kv = graph->CreateVarNode(&cache_kv_desc);
...@@ -1296,7 +1438,7 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph, ...@@ -1296,7 +1438,7 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
fill_const_op_desc.SetAttr("value", 0); fill_const_op_desc.SetAttr("value", 0);
fill_const_op_desc.SetAttr( fill_const_op_desc.SetAttr(
"dtype", "dtype",
static_cast<int>(framework::TransToProtoVarType(wq_tensor->dtype()))); static_cast<int>(framework::TransToProtoVarType(bq_tensor->dtype())));
auto* fill_const_op = graph->CreateOpNode(&fill_const_op_desc); auto* fill_const_op = graph->CreateOpNode(&fill_const_op_desc);
fused_multi_transformer_op_desc.SetInput("CacheKV", {cache_kv->Name()}); fused_multi_transformer_op_desc.SetInput("CacheKV", {cache_kv->Name()});
...@@ -1333,8 +1475,123 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph, ...@@ -1333,8 +1475,123 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
fused_multi_transformer_op_desc.SetAttr("is_test", true); fused_multi_transformer_op_desc.SetAttr("is_test", true);
fused_multi_transformer_op_desc.SetAttr("dropout_rate", 0.0f); fused_multi_transformer_op_desc.SetAttr("dropout_rate", 0.0f);
// Quantization attribute/Input
if (enable_int8) {
// Set input scale
std::string qkv_input_name = matmul0_op->Input("X")[0];
auto qkv_in_scale = PADDLE_GET_CONST(
float, matmul0_op->GetAttr("Input_scale_" + qkv_input_name));
std::string out_linear_input_name = matmul_linear_op->Input("X")[0];
auto out_linear_in_scale = PADDLE_GET_CONST(
float,
matmul_linear_op->GetAttr("Input_scale_" + out_linear_input_name));
std::string ffn0_input_name = ffn_matmul_0_op->Input("X")[0];
auto ffn0_in_scale = PADDLE_GET_CONST(
float, ffn_matmul_0_op->GetAttr("Input_scale_" + ffn0_input_name));
std::string ffn1_input_name = ffn_matmul_1_op->Input("X")[0];
auto ffn1_in_scale = PADDLE_GET_CONST(
float, ffn_matmul_1_op->GetAttr("Input_scale_" + ffn1_input_name));
// Calc outscale and Set them
auto qkv_weight_scale =
PADDLE_GET_CONST(float, matmul0_op->GetAttr("weight_scale"));
auto out_weight_scale =
PADDLE_GET_CONST(float, matmul_linear_op->GetAttr("weight_scale"));
auto ffn0_weight_scale =
PADDLE_GET_CONST(float, ffn_matmul_0_op->GetAttr("weight_scale"));
auto ffn1_weight_scale =
PADDLE_GET_CONST(float, ffn_matmul_1_op->GetAttr("weight_scale"));
auto qkv_out_scales = std::vector<float>(
3 * dim_embed, (qkv_weight_scale / 127.0f) * (qkv_in_scale / 127.0f));
auto out_out_scales = std::vector<float>(
dim_embed,
(out_weight_scale / 127.0f) * (out_linear_in_scale / 127.0f));
auto ffn0_out_scales = std::vector<float>(
4 * dim_embed,
(ffn0_weight_scale / 127.0f) * (ffn0_in_scale / 127.0f));
auto ffn1_out_scales = std::vector<float>(
dim_embed, (ffn1_weight_scale / 127.0f) * (ffn1_in_scale / 127.0f));
// Inverse input scale
qkv_in_scale = 1.0f / qkv_in_scale;
out_linear_in_scale = 1.0f / out_linear_in_scale;
ffn0_in_scale = 1.0f / ffn0_in_scale;
ffn1_in_scale = 1.0f / ffn1_in_scale;
fused_multi_transformer_op_desc.SetAttr("qkv_in_scale",
std::vector<float>{qkv_in_scale});
fused_multi_transformer_op_desc.SetAttr(
"out_linear_in_scale", std::vector<float>{out_linear_in_scale});
fused_multi_transformer_op_desc.SetAttr(
"ffn1_in_scale", std::vector<float>{ffn0_in_scale});
fused_multi_transformer_op_desc.SetAttr(
"ffn2_in_scale", std::vector<float>{ffn1_in_scale});
auto qkv_out_scale_var = scope->Var(matmul0_w->Name() + "_out_scale");
auto out_out_scale_var =
scope->Var(matmul_linear_w->Name() + "_out_scale");
auto ffn0_out_scale_var =
scope->Var(ffn_matmul0_w->Name() + "_out_scale");
auto ffn1_out_scale_var =
scope->Var(ffn_matmul1_w->Name() + "_out_scale");
auto qkv_out_scale_data =
qkv_out_scale_var->GetMutable<phi::DenseTensor>()
->mutable_data<float>({3 * dim_embed}, platform::CPUPlace());
memcpy(qkv_out_scale_data,
qkv_out_scales.data(),
qkv_out_scales.size() * sizeof(float));
fused_multi_transformer_op_desc.SetInput(
"QKVOutScale", {matmul0_w->Name() + "_out_scale"});
auto out_out_scale_data =
out_out_scale_var->GetMutable<phi::DenseTensor>()
->mutable_data<float>({dim_embed}, platform::CPUPlace());
memcpy(out_out_scale_data,
out_out_scales.data(),
out_out_scales.size() * sizeof(float));
fused_multi_transformer_op_desc.SetInput(
"OutLinearOutScale", {matmul_linear_w->Name() + "_out_scale"});
auto ffn0_out_scale_data =
ffn0_out_scale_var->GetMutable<phi::DenseTensor>()
->mutable_data<float>({4 * dim_embed}, platform::CPUPlace());
memcpy(ffn0_out_scale_data,
ffn0_out_scales.data(),
ffn0_out_scales.size() * sizeof(float));
fused_multi_transformer_op_desc.SetInput(
"FFN1OutScale", {ffn_matmul0_w->Name() + "_out_scale"});
auto ffn1_out_scale_data =
ffn1_out_scale_var->GetMutable<phi::DenseTensor>()
->mutable_data<float>({dim_embed}, platform::CPUPlace());
memcpy(ffn1_out_scale_data,
ffn1_out_scales.data(),
ffn1_out_scales.size() * sizeof(float));
fused_multi_transformer_op_desc.SetInput(
"FFN2OutScale", {ffn_matmul1_w->Name() + "_out_scale"});
}
auto* fused_multi_transformer = auto* fused_multi_transformer =
graph->CreateOpNode(&fused_multi_transformer_op_desc); graph->CreateOpNode(&fused_multi_transformer_op_desc);
if (enable_int8) {
auto qkv_out_scale_node =
CreatePersistableVarNode(graph, matmul0_w->Name() + "_out_scale");
auto out_out_scale_node = CreatePersistableVarNode(
graph, matmul_linear_w->Name() + "_out_scale");
auto ffn0_out_scale_node =
CreatePersistableVarNode(graph, ffn_matmul0_w->Name() + "_out_scale");
auto ffn1_out_scale_node =
CreatePersistableVarNode(graph, ffn_matmul1_w->Name() + "_out_scale");
IR_NODE_LINK_TO(qkv_out_scale_node, fused_multi_transformer);
IR_NODE_LINK_TO(out_out_scale_node, fused_multi_transformer);
IR_NODE_LINK_TO(ffn0_out_scale_node, fused_multi_transformer);
IR_NODE_LINK_TO(ffn1_out_scale_node, fused_multi_transformer);
}
IR_NODE_LINK_TO(input0, fused_multi_transformer); IR_NODE_LINK_TO(input0, fused_multi_transformer);
IR_NODE_LINK_TO(layer_norm_scale, fused_multi_transformer); IR_NODE_LINK_TO(layer_norm_scale, fused_multi_transformer);
IR_NODE_LINK_TO(layer_norm_bias, fused_multi_transformer); IR_NODE_LINK_TO(layer_norm_bias, fused_multi_transformer);
...@@ -1622,6 +1879,7 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph, ...@@ -1622,6 +1879,7 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
layer_norm_bias, layer_norm_bias,
layer_norm_mean, layer_norm_mean,
layer_norm_variance, layer_norm_variance,
matmul0,
matmul0_w, matmul0_w,
matmul1_w, matmul1_w,
matmul2_w, matmul2_w,
...@@ -1632,6 +1890,7 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph, ...@@ -1632,6 +1890,7 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
transpose2_2_out, transpose2_2_out,
eltadd_qk_b, eltadd_qk_b,
reshape2_0, reshape2_0,
matmul_linear,
matmul_linear_w, matmul_linear_w,
eltadd_linear_b, eltadd_linear_b,
while0, while0,
...@@ -1640,7 +1899,9 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph, ...@@ -1640,7 +1899,9 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
ffn_layer_norm_bias, ffn_layer_norm_bias,
ffn_layer_norm_mean, ffn_layer_norm_mean,
ffn_layer_norm_variance, ffn_layer_norm_variance,
ffn_matmul0,
ffn_matmul0_w, ffn_matmul0_w,
ffn_matmul1,
ffn_matmul1_w, ffn_matmul1_w,
ffn_eltadd0_b, ffn_eltadd0_b,
ffn_eltadd1_b, ffn_eltadd1_b,
...@@ -1892,6 +2153,12 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion( ...@@ -1892,6 +2153,12 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
Graph* graph, const std::string& name_scope, Scope* scope) const { Graph* graph, const std::string& name_scope, Scope* scope) const {
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto* pattern = gpd.mutable_pattern(); auto* pattern = gpd.mutable_pattern();
bool enable_int8 = graph->Get<bool>("enable_int8");
if (enable_int8) {
VLOG(3) << "FusedMultiTransformerEncoderFuseQKVPass with int8";
} else {
VLOG(3) << "FusedMultiTransformerEncoderFuseQKVPass with fp";
}
// Create pattern. // Create pattern.
patterns::FusedMultiTransformerEncoderFuseQKVPattern patterns::FusedMultiTransformerEncoderFuseQKVPattern
...@@ -1905,12 +2172,14 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion( ...@@ -1905,12 +2172,14 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
Node* layer_norm_bias, Node* layer_norm_bias,
Node* layer_norm_mean, Node* layer_norm_mean,
Node* layer_norm_variance, Node* layer_norm_variance,
Node* matmul0,
Node* matmul0_w, Node* matmul0_w,
Node* eltadd0_b, Node* eltadd0_b,
Node* split0_k_out, Node* split0_k_out,
Node* split0_v_out, Node* split0_v_out,
Node* eltadd_qk_b, Node* eltadd_qk_b,
Node* reshape2_0, Node* reshape2_0,
Node* matmul_linear,
Node* matmul_linear_w, Node* matmul_linear_w,
Node* eltadd_linear_b, Node* eltadd_linear_b,
Node* while0, Node* while0,
...@@ -1919,7 +2188,9 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion( ...@@ -1919,7 +2188,9 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
Node* ffn_layer_norm_bias, Node* ffn_layer_norm_bias,
Node* ffn_layer_norm_mean, Node* ffn_layer_norm_mean,
Node* ffn_layer_norm_variance, Node* ffn_layer_norm_variance,
Node* ffn_matmul0,
Node* ffn_matmul0_w, Node* ffn_matmul0_w,
Node* ffn_matmul1,
Node* ffn_matmul1_w, Node* ffn_matmul1_w,
Node* ffn_eltadd0_b, Node* ffn_eltadd0_b,
Node* ffn_eltadd1_b, Node* ffn_eltadd1_b,
...@@ -1932,7 +2203,14 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion( ...@@ -1932,7 +2203,14 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
PADDLE_GET_CONST(std::vector<int>, reshape_desc->GetAttr("shape")) PADDLE_GET_CONST(std::vector<int>, reshape_desc->GetAttr("shape"))
.at(3) / .at(3) /
3; // 3 for qkv 3; // 3 for qkv
int dim_embed = num_head * dim_head; auto* layer_norm_bias_tensor =
scope->FindVar(layer_norm_bias->Name())->GetMutable<phi::DenseTensor>();
int dim_embed = layer_norm_bias_tensor->dims()[0];
auto* matmul0_op = matmul0->Op();
auto* matmul_linear_op = matmul_linear->Op();
auto* ffn_matmul_0_op = ffn_matmul0->Op();
auto* ffn_matmul_1_op = ffn_matmul1->Op();
// Calc index of transformer layer by LayerNorm Scale name // Calc index of transformer layer by LayerNorm Scale name
// This calculation assumes: // This calculation assumes:
...@@ -1948,21 +2226,27 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion( ...@@ -1948,21 +2226,27 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
auto* qkv_b_tensor = auto* qkv_b_tensor =
scope->FindVar(eltadd0_b->Name())->GetMutable<phi::DenseTensor>(); scope->FindVar(eltadd0_b->Name())->GetMutable<phi::DenseTensor>();
if (qkv_w_tensor->dtype() == phi::DataType::FLOAT32) { QKVWeightsBiasProcessFuseQKV(
QKVWeightsProcessFuseQKV<float>( qkv_w_tensor, qkv_b_tensor, num_head, dim_head, dim_embed);
qkv_w_tensor, qkv_b_tensor, num_head, dim_head, dim_embed);
} else if (qkv_w_tensor->dtype() == phi::DataType::FLOAT16) { if (enable_int8) {
QKVWeightsProcessFuseQKV<platform::float16>( auto* out_linear_w_tensor = scope->FindVar(matmul_linear_w->Name())
qkv_w_tensor, qkv_b_tensor, num_head, dim_head, dim_embed); ->GetMutable<phi::DenseTensor>();
} else { auto* ffn0_w_tensor =
PADDLE_THROW(platform::errors::Unavailable( scope->FindVar(ffn_matmul0_w->Name())->GetMutable<phi::DenseTensor>();
"fused_multi_transformer not supported weight dtype. " auto* ffn1_w_tensor =
"we now only support fp32 and fp16.")); scope->FindVar(ffn_matmul1_w->Name())->GetMutable<phi::DenseTensor>();
TransposeWeights(out_linear_w_tensor);
TransposeWeights(ffn0_w_tensor);
TransposeWeights(ffn1_w_tensor);
} }
// create fused_multi_transformer // create fused_multi_transformer
OpDesc fused_multi_transformer_op_desc(layer_norm->Op()->Block()); OpDesc fused_multi_transformer_op_desc(layer_norm->Op()->Block());
fused_multi_transformer_op_desc.SetType("fused_multi_transformer"); fused_multi_transformer_op_desc.SetType(enable_int8
? "fused_multi_transformer_int8"
: "fused_multi_transformer");
// 1. Input setting // 1. Input setting
fused_multi_transformer_op_desc.SetInput("X", {input0->Name()}); fused_multi_transformer_op_desc.SetInput("X", {input0->Name()});
...@@ -1982,7 +2266,7 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion( ...@@ -1982,7 +2266,7 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
VarDesc cache_kv_desc("cache_kv" + std::to_string(layer_idx)); VarDesc cache_kv_desc("cache_kv" + std::to_string(layer_idx));
// FIXME: only support max_seq_len <= 1024 // FIXME: only support max_seq_len <= 1024
cache_kv_desc.SetDataType( cache_kv_desc.SetDataType(
framework::TransToProtoVarType(qkv_w_tensor->dtype())); framework::TransToProtoVarType(qkv_b_tensor->dtype()));
cache_kv_desc.SetPersistable(false); cache_kv_desc.SetPersistable(false);
auto* cache_kv = graph->CreateVarNode(&cache_kv_desc); auto* cache_kv = graph->CreateVarNode(&cache_kv_desc);
...@@ -1997,7 +2281,7 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion( ...@@ -1997,7 +2281,7 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
fill_const_op_desc.SetAttr("value", 0); fill_const_op_desc.SetAttr("value", 0);
fill_const_op_desc.SetAttr("dtype", fill_const_op_desc.SetAttr("dtype",
static_cast<int>(framework::TransToProtoVarType( static_cast<int>(framework::TransToProtoVarType(
qkv_w_tensor->dtype()))); qkv_b_tensor->dtype())));
auto* fill_const_op = graph->CreateOpNode(&fill_const_op_desc); auto* fill_const_op = graph->CreateOpNode(&fill_const_op_desc);
fused_multi_transformer_op_desc.SetInput("CacheKV", {cache_kv->Name()}); fused_multi_transformer_op_desc.SetInput("CacheKV", {cache_kv->Name()});
...@@ -2035,8 +2319,125 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion( ...@@ -2035,8 +2319,125 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
fused_multi_transformer_op_desc.SetAttr("is_test", true); fused_multi_transformer_op_desc.SetAttr("is_test", true);
fused_multi_transformer_op_desc.SetAttr("dropout_rate", 0.0f); fused_multi_transformer_op_desc.SetAttr("dropout_rate", 0.0f);
// Quantization attribute/Input
if (enable_int8) {
// Set input scale
std::string qkv_input_name = matmul0_op->Input("X")[0];
auto qkv_in_scale = PADDLE_GET_CONST(
float, matmul0_op->GetAttr("Input_scale_" + qkv_input_name));
std::string out_linear_input_name = matmul_linear_op->Input("X")[0];
auto out_linear_in_scale = PADDLE_GET_CONST(
float,
matmul_linear_op->GetAttr("Input_scale_" + out_linear_input_name));
std::string ffn0_input_name = ffn_matmul_0_op->Input("X")[0];
auto ffn0_in_scale = PADDLE_GET_CONST(
float, ffn_matmul_0_op->GetAttr("Input_scale_" + ffn0_input_name));
std::string ffn1_input_name = ffn_matmul_1_op->Input("X")[0];
auto ffn1_in_scale = PADDLE_GET_CONST(
float, ffn_matmul_1_op->GetAttr("Input_scale_" + ffn1_input_name));
// Calc outscale and Set them
// TODO(wufeisheng): Currently just match layer-wise weight scale, where
// channel-wise weight scale should also be surpported.
auto qkv_weight_scale =
PADDLE_GET_CONST(float, matmul0_op->GetAttr("weight_scale"));
auto out_weight_scale =
PADDLE_GET_CONST(float, matmul_linear_op->GetAttr("weight_scale"));
auto ffn0_weight_scale =
PADDLE_GET_CONST(float, ffn_matmul_0_op->GetAttr("weight_scale"));
auto ffn1_weight_scale =
PADDLE_GET_CONST(float, ffn_matmul_1_op->GetAttr("weight_scale"));
auto qkv_out_scales = std::vector<float>(
3 * dim_embed, (qkv_weight_scale / 127.0f) * (qkv_in_scale / 127.0f));
auto out_out_scales = std::vector<float>(
dim_embed,
(out_weight_scale / 127.0f) * (out_linear_in_scale / 127.0f));
auto ffn0_out_scales = std::vector<float>(
4 * dim_embed,
(ffn0_weight_scale / 127.0f) * (ffn0_in_scale / 127.0f));
auto ffn1_out_scales = std::vector<float>(
dim_embed, (ffn1_weight_scale / 127.0f) * (ffn1_in_scale / 127.0f));
// Inverse input scale
qkv_in_scale = 1.0f / qkv_in_scale;
out_linear_in_scale = 1.0f / out_linear_in_scale;
ffn0_in_scale = 1.0f / ffn0_in_scale;
ffn1_in_scale = 1.0f / ffn1_in_scale;
fused_multi_transformer_op_desc.SetAttr("qkv_in_scale",
std::vector<float>{qkv_in_scale});
fused_multi_transformer_op_desc.SetAttr(
"out_linear_in_scale", std::vector<float>{out_linear_in_scale});
fused_multi_transformer_op_desc.SetAttr(
"ffn1_in_scale", std::vector<float>{ffn0_in_scale});
fused_multi_transformer_op_desc.SetAttr(
"ffn2_in_scale", std::vector<float>{ffn1_in_scale});
auto qkv_out_scale_var = scope->Var(matmul0_w->Name() + "_out_scale");
auto out_out_scale_var =
scope->Var(matmul_linear_w->Name() + "_out_scale");
auto ffn0_out_scale_var =
scope->Var(ffn_matmul0_w->Name() + "_out_scale");
auto ffn1_out_scale_var =
scope->Var(ffn_matmul1_w->Name() + "_out_scale");
auto qkv_out_scale_data =
qkv_out_scale_var->GetMutable<phi::DenseTensor>()
->mutable_data<float>({3 * dim_embed}, platform::CPUPlace());
memcpy(qkv_out_scale_data,
qkv_out_scales.data(),
qkv_out_scales.size() * sizeof(float));
fused_multi_transformer_op_desc.SetInput(
"QKVOutScale", {matmul0_w->Name() + "_out_scale"});
auto out_out_scale_data =
out_out_scale_var->GetMutable<phi::DenseTensor>()
->mutable_data<float>({dim_embed}, platform::CPUPlace());
memcpy(out_out_scale_data,
out_out_scales.data(),
out_out_scales.size() * sizeof(float));
fused_multi_transformer_op_desc.SetInput(
"OutLinearOutScale", {matmul_linear_w->Name() + "_out_scale"});
auto ffn0_out_scale_data =
ffn0_out_scale_var->GetMutable<phi::DenseTensor>()
->mutable_data<float>({4 * dim_embed}, platform::CPUPlace());
memcpy(ffn0_out_scale_data,
ffn0_out_scales.data(),
ffn0_out_scales.size() * sizeof(float));
fused_multi_transformer_op_desc.SetInput(
"FFN1OutScale", {ffn_matmul0_w->Name() + "_out_scale"});
auto ffn1_out_scale_data =
ffn1_out_scale_var->GetMutable<phi::DenseTensor>()
->mutable_data<float>({dim_embed}, platform::CPUPlace());
memcpy(ffn1_out_scale_data,
ffn1_out_scales.data(),
ffn1_out_scales.size() * sizeof(float));
fused_multi_transformer_op_desc.SetInput(
"FFN2OutScale", {ffn_matmul1_w->Name() + "_out_scale"});
}
auto* fused_multi_transformer = auto* fused_multi_transformer =
graph->CreateOpNode(&fused_multi_transformer_op_desc); graph->CreateOpNode(&fused_multi_transformer_op_desc);
if (enable_int8) {
auto qkv_out_scale_node =
CreatePersistableVarNode(graph, matmul0_w->Name() + "_out_scale");
auto out_out_scale_node = CreatePersistableVarNode(
graph, matmul_linear_w->Name() + "_out_scale");
auto ffn0_out_scale_node =
CreatePersistableVarNode(graph, ffn_matmul0_w->Name() + "_out_scale");
auto ffn1_out_scale_node =
CreatePersistableVarNode(graph, ffn_matmul1_w->Name() + "_out_scale");
IR_NODE_LINK_TO(qkv_out_scale_node, fused_multi_transformer);
IR_NODE_LINK_TO(out_out_scale_node, fused_multi_transformer);
IR_NODE_LINK_TO(ffn0_out_scale_node, fused_multi_transformer);
IR_NODE_LINK_TO(ffn1_out_scale_node, fused_multi_transformer);
}
IR_NODE_LINK_TO(input0, fused_multi_transformer); IR_NODE_LINK_TO(input0, fused_multi_transformer);
IR_NODE_LINK_TO(layer_norm_scale, fused_multi_transformer); IR_NODE_LINK_TO(layer_norm_scale, fused_multi_transformer);
IR_NODE_LINK_TO(layer_norm_bias, fused_multi_transformer); IR_NODE_LINK_TO(layer_norm_bias, fused_multi_transformer);
...@@ -2290,12 +2691,14 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion( ...@@ -2290,12 +2691,14 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
layer_norm_bias, layer_norm_bias,
layer_norm_mean, layer_norm_mean,
layer_norm_variance, layer_norm_variance,
matmul0,
matmul0_w, matmul0_w,
eltadd0_b, eltadd0_b,
split0_k_out, split0_k_out,
split0_v_out, split0_v_out,
eltadd_qk_b, eltadd_qk_b,
reshape2_0, reshape2_0,
matmul_linear,
matmul_linear_w, matmul_linear_w,
eltadd_linear_b, eltadd_linear_b,
while0, while0,
...@@ -2304,7 +2707,9 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion( ...@@ -2304,7 +2707,9 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
ffn_layer_norm_bias, ffn_layer_norm_bias,
ffn_layer_norm_mean, ffn_layer_norm_mean,
ffn_layer_norm_variance, ffn_layer_norm_variance,
ffn_matmul0,
ffn_matmul0_w, ffn_matmul0_w,
ffn_matmul1,
ffn_matmul1_w, ffn_matmul1_w,
ffn_eltadd0_b, ffn_eltadd0_b,
ffn_eltadd1_b, ffn_eltadd1_b,
...@@ -2546,6 +2951,12 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion( ...@@ -2546,6 +2951,12 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
Graph* graph, const std::string& name_scope, Scope* scope) const { Graph* graph, const std::string& name_scope, Scope* scope) const {
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto* pattern = gpd.mutable_pattern(); auto* pattern = gpd.mutable_pattern();
bool enable_int8 = graph->Get<bool>("enable_int8");
if (enable_int8) {
VLOG(3) << "MultiDevicesFusedMultiTransformerEncoderFuseQKVPass with int8";
} else {
VLOG(3) << "MultiDevicesFusedMultiTransformerEncoderFuseQKVPass with fp";
}
// Create pattern. // Create pattern.
patterns::MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern patterns::MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern
...@@ -2560,12 +2971,14 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion( ...@@ -2560,12 +2971,14 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
Node* layer_norm_mean, Node* layer_norm_mean,
Node* layer_norm_variance, Node* layer_norm_variance,
Node* c_identity, Node* c_identity,
Node* matmul0,
Node* matmul0_w, Node* matmul0_w,
Node* eltadd0_b, Node* eltadd0_b,
Node* split0_k_out, Node* split0_k_out,
Node* split0_v_out, Node* split0_v_out,
Node* eltadd_qk_b, Node* eltadd_qk_b,
Node* reshape2_0, Node* reshape2_0,
Node* matmul_linear,
Node* matmul_linear_w, Node* matmul_linear_w,
Node* eltadd_linear_b, Node* eltadd_linear_b,
Node* while0, Node* while0,
...@@ -2574,7 +2987,10 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion( ...@@ -2574,7 +2987,10 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
Node* ffn_layer_norm_bias, Node* ffn_layer_norm_bias,
Node* ffn_layer_norm_mean, Node* ffn_layer_norm_mean,
Node* ffn_layer_norm_variance, Node* ffn_layer_norm_variance,
Node* ffn_c_identity,
Node* ffn_matmul0,
Node* ffn_matmul0_w, Node* ffn_matmul0_w,
Node* ffn_matmul1,
Node* ffn_matmul1_w, Node* ffn_matmul1_w,
Node* ffn_eltadd0_b, Node* ffn_eltadd0_b,
Node* ffn_eltadd1_b, Node* ffn_eltadd1_b,
...@@ -2588,6 +3004,11 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion( ...@@ -2588,6 +3004,11 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
.at(3) / .at(3) /
3; // 3 for qkv 3; // 3 for qkv
auto* matmul0_op = matmul0->Op();
auto* matmul_linear_op = matmul_linear->Op();
auto* ffn_matmul_0_op = ffn_matmul0->Op();
auto* ffn_matmul_1_op = ffn_matmul1->Op();
// Calc index of transformer layer by LayerNorm Scale name // Calc index of transformer layer by LayerNorm Scale name
// This calculation assumes: // This calculation assumes:
// 1. no LayerNorm before all transformer layer // 1. no LayerNorm before all transformer layer
...@@ -2602,23 +3023,31 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion( ...@@ -2602,23 +3023,31 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
auto* qkv_b_tensor = auto* qkv_b_tensor =
scope->FindVar(eltadd0_b->Name())->GetMutable<phi::DenseTensor>(); scope->FindVar(eltadd0_b->Name())->GetMutable<phi::DenseTensor>();
int dim_embed = qkv_w_tensor->dims()[0]; auto* layer_norm_bias_tensor =
scope->FindVar(layer_norm_bias->Name())->GetMutable<phi::DenseTensor>();
int dim_embed = layer_norm_bias_tensor->dims()[0];
if (qkv_w_tensor->dtype() == phi::DataType::FLOAT32) { QKVWeightsBiasProcessFuseQKV(
QKVWeightsProcessFuseQKV<float>( qkv_w_tensor, qkv_b_tensor, num_head, dim_head, dim_embed);
qkv_w_tensor, qkv_b_tensor, num_head, dim_head, dim_embed);
} else if (qkv_w_tensor->dtype() == phi::DataType::FLOAT16) { if (enable_int8) {
QKVWeightsProcessFuseQKV<platform::float16>( auto* out_linear_w_tensor = scope->FindVar(matmul_linear_w->Name())
qkv_w_tensor, qkv_b_tensor, num_head, dim_head, dim_embed); ->GetMutable<phi::DenseTensor>();
} else { auto* ffn0_w_tensor =
PADDLE_THROW(platform::errors::Unavailable( scope->FindVar(ffn_matmul0_w->Name())->GetMutable<phi::DenseTensor>();
"fused_multi_transformer not supported weight dtype. " auto* ffn1_w_tensor =
"we now only support fp32 and fp16.")); scope->FindVar(ffn_matmul1_w->Name())->GetMutable<phi::DenseTensor>();
TransposeWeights(out_linear_w_tensor);
TransposeWeights(ffn0_w_tensor);
TransposeWeights(ffn1_w_tensor);
} }
// create fused_multi_transformer // create fused_multi_transformer
OpDesc fused_multi_transformer_op_desc(layer_norm->Op()->Block()); OpDesc fused_multi_transformer_op_desc(layer_norm->Op()->Block());
fused_multi_transformer_op_desc.SetType("fused_multi_transformer"); fused_multi_transformer_op_desc.SetType(enable_int8
? "fused_multi_transformer_int8"
: "fused_multi_transformer");
// 1. Input setting // 1. Input setting
fused_multi_transformer_op_desc.SetInput("X", {input0->Name()}); fused_multi_transformer_op_desc.SetInput("X", {input0->Name()});
...@@ -2638,7 +3067,7 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion( ...@@ -2638,7 +3067,7 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
VarDesc cache_kv_desc("cache_kv" + std::to_string(layer_idx)); VarDesc cache_kv_desc("cache_kv" + std::to_string(layer_idx));
// FIXME: only support max_seq_len <= 1024 // FIXME: only support max_seq_len <= 1024
cache_kv_desc.SetDataType( cache_kv_desc.SetDataType(
framework::TransToProtoVarType(qkv_w_tensor->dtype())); framework::TransToProtoVarType(qkv_b_tensor->dtype()));
cache_kv_desc.SetPersistable(false); cache_kv_desc.SetPersistable(false);
auto* cache_kv = graph->CreateVarNode(&cache_kv_desc); auto* cache_kv = graph->CreateVarNode(&cache_kv_desc);
...@@ -2653,7 +3082,7 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion( ...@@ -2653,7 +3082,7 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
fill_const_op_desc.SetAttr("value", 0); fill_const_op_desc.SetAttr("value", 0);
fill_const_op_desc.SetAttr("dtype", fill_const_op_desc.SetAttr("dtype",
static_cast<int>(framework::TransToProtoVarType( static_cast<int>(framework::TransToProtoVarType(
qkv_w_tensor->dtype()))); qkv_b_tensor->dtype())));
auto* fill_const_op = graph->CreateOpNode(&fill_const_op_desc); auto* fill_const_op = graph->CreateOpNode(&fill_const_op_desc);
fused_multi_transformer_op_desc.SetInput("CacheKV", {cache_kv->Name()}); fused_multi_transformer_op_desc.SetInput("CacheKV", {cache_kv->Name()});
...@@ -2696,8 +3125,129 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion( ...@@ -2696,8 +3125,129 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
fused_multi_transformer_op_desc.SetAttr("ring_id", fused_multi_transformer_op_desc.SetAttr("ring_id",
c_identity_op->GetAttr("ring_id")); c_identity_op->GetAttr("ring_id"));
// Quantization attribute/Input
if (enable_int8) {
// Set input scale
std::string matmul_input_scale_suffix = c_identity_op->Input("X")[0];
auto qkv_in_scale = PADDLE_GET_CONST(
float,
c_identity_op->GetAttr("Input_scale_" + matmul_input_scale_suffix));
std::string out_linear_input_name = matmul_linear_op->Input("X")[0];
auto out_linear_in_scale = PADDLE_GET_CONST(
float,
matmul_linear_op->GetAttr("Input_scale_" + out_linear_input_name));
auto* ffn_c_identity_op = ffn_c_identity->Op();
std::string ffn_input_scale_suffix = ffn_c_identity_op->Input("X")[0];
auto ffn0_in_scale = PADDLE_GET_CONST(
float,
ffn_c_identity_op->GetAttr("Input_scale_" + ffn_input_scale_suffix));
std::string ffn1_input_name = ffn_matmul_1_op->Input("X")[0];
auto ffn1_in_scale = PADDLE_GET_CONST(
float, ffn_matmul_1_op->GetAttr("Input_scale_" + ffn1_input_name));
// Calc outscale and Set them
auto qkv_weight_scale =
PADDLE_GET_CONST(float, matmul0_op->GetAttr("weight_scale"));
auto out_weight_scale =
PADDLE_GET_CONST(float, matmul_linear_op->GetAttr("weight_scale"));
auto ffn0_weight_scale =
PADDLE_GET_CONST(float, ffn_matmul_0_op->GetAttr("weight_scale"));
auto ffn1_weight_scale =
PADDLE_GET_CONST(float, ffn_matmul_1_op->GetAttr("weight_scale"));
auto qkv_out_scales = std::vector<float>(
3 * dim_embed, (qkv_weight_scale / 127.0f) * (qkv_in_scale / 127.0f));
auto out_out_scales = std::vector<float>(
dim_embed,
(out_weight_scale / 127.0f) * (out_linear_in_scale / 127.0f));
auto ffn0_out_scales = std::vector<float>(
4 * dim_embed,
(ffn0_weight_scale / 127.0f) * (ffn0_in_scale / 127.0f));
auto ffn1_out_scales = std::vector<float>(
dim_embed, (ffn1_weight_scale / 127.0f) * (ffn1_in_scale / 127.0f));
// Inverse input scale
qkv_in_scale = 1.0f / qkv_in_scale;
out_linear_in_scale = 1.0f / out_linear_in_scale;
ffn0_in_scale = 1.0f / ffn0_in_scale;
ffn1_in_scale = 1.0f / ffn1_in_scale;
fused_multi_transformer_op_desc.SetAttr("qkv_in_scale",
std::vector<float>{qkv_in_scale});
fused_multi_transformer_op_desc.SetAttr(
"out_linear_in_scale", std::vector<float>{out_linear_in_scale});
fused_multi_transformer_op_desc.SetAttr(
"ffn1_in_scale", std::vector<float>{ffn0_in_scale});
fused_multi_transformer_op_desc.SetAttr(
"ffn2_in_scale", std::vector<float>{ffn1_in_scale});
auto qkv_out_scale_var = scope->Var(matmul0_w->Name() + "_out_scale");
auto out_out_scale_var =
scope->Var(matmul_linear_w->Name() + "_out_scale");
auto ffn0_out_scale_var =
scope->Var(ffn_matmul0_w->Name() + "_out_scale");
auto ffn1_out_scale_var =
scope->Var(ffn_matmul1_w->Name() + "_out_scale");
auto qkv_out_scale_data =
qkv_out_scale_var->GetMutable<phi::DenseTensor>()
->mutable_data<float>({3 * dim_embed}, platform::CPUPlace());
memcpy(qkv_out_scale_data,
qkv_out_scales.data(),
qkv_out_scales.size() * sizeof(float));
fused_multi_transformer_op_desc.SetInput(
"QKVOutScale", {matmul0_w->Name() + "_out_scale"});
auto out_out_scale_data =
out_out_scale_var->GetMutable<phi::DenseTensor>()
->mutable_data<float>({dim_embed}, platform::CPUPlace());
memcpy(out_out_scale_data,
out_out_scales.data(),
out_out_scales.size() * sizeof(float));
fused_multi_transformer_op_desc.SetInput(
"OutLinearOutScale", {matmul_linear_w->Name() + "_out_scale"});
auto ffn0_out_scale_data =
ffn0_out_scale_var->GetMutable<phi::DenseTensor>()
->mutable_data<float>({4 * dim_embed}, platform::CPUPlace());
memcpy(ffn0_out_scale_data,
ffn0_out_scales.data(),
ffn0_out_scales.size() * sizeof(float));
fused_multi_transformer_op_desc.SetInput(
"FFN1OutScale", {ffn_matmul0_w->Name() + "_out_scale"});
auto ffn1_out_scale_data =
ffn1_out_scale_var->GetMutable<phi::DenseTensor>()
->mutable_data<float>({dim_embed}, platform::CPUPlace());
memcpy(ffn1_out_scale_data,
ffn1_out_scales.data(),
ffn1_out_scales.size() * sizeof(float));
fused_multi_transformer_op_desc.SetInput(
"FFN2OutScale", {ffn_matmul1_w->Name() + "_out_scale"});
}
auto* fused_multi_transformer = auto* fused_multi_transformer =
graph->CreateOpNode(&fused_multi_transformer_op_desc); graph->CreateOpNode(&fused_multi_transformer_op_desc);
if (enable_int8) {
auto qkv_out_scale_node =
CreatePersistableVarNode(graph, matmul0_w->Name() + "_out_scale");
auto out_out_scale_node = CreatePersistableVarNode(
graph, matmul_linear_w->Name() + "_out_scale");
auto ffn0_out_scale_node =
CreatePersistableVarNode(graph, ffn_matmul0_w->Name() + "_out_scale");
auto ffn1_out_scale_node =
CreatePersistableVarNode(graph, ffn_matmul1_w->Name() + "_out_scale");
IR_NODE_LINK_TO(qkv_out_scale_node, fused_multi_transformer);
IR_NODE_LINK_TO(out_out_scale_node, fused_multi_transformer);
IR_NODE_LINK_TO(ffn0_out_scale_node, fused_multi_transformer);
IR_NODE_LINK_TO(ffn1_out_scale_node, fused_multi_transformer);
}
IR_NODE_LINK_TO(input0, fused_multi_transformer); IR_NODE_LINK_TO(input0, fused_multi_transformer);
IR_NODE_LINK_TO(layer_norm_scale, fused_multi_transformer); IR_NODE_LINK_TO(layer_norm_scale, fused_multi_transformer);
IR_NODE_LINK_TO(layer_norm_bias, fused_multi_transformer); IR_NODE_LINK_TO(layer_norm_bias, fused_multi_transformer);
...@@ -2977,12 +3527,14 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion( ...@@ -2977,12 +3527,14 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
layer_norm_mean, layer_norm_mean,
layer_norm_variance, layer_norm_variance,
c_identity, c_identity,
matmul0,
matmul0_w, matmul0_w,
eltadd0_b, eltadd0_b,
split0_k_out, split0_k_out,
split0_v_out, split0_v_out,
eltadd_qk_b, eltadd_qk_b,
reshape2_0, reshape2_0,
matmul_linear,
matmul_linear_w, matmul_linear_w,
eltadd_linear_b, eltadd_linear_b,
while0, while0,
...@@ -2991,7 +3543,10 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion( ...@@ -2991,7 +3543,10 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
ffn_layer_norm_bias, ffn_layer_norm_bias,
ffn_layer_norm_mean, ffn_layer_norm_mean,
ffn_layer_norm_variance, ffn_layer_norm_variance,
ffn_c_identity,
ffn_matmul0,
ffn_matmul0_w, ffn_matmul0_w,
ffn_matmul1,
ffn_matmul1_w, ffn_matmul1_w,
ffn_eltadd0_b, ffn_eltadd0_b,
ffn_eltadd1_b, ffn_eltadd1_b,
......
...@@ -188,6 +188,7 @@ TEST(FusedMultiTransformerEncoderPass, basic) { ...@@ -188,6 +188,7 @@ TEST(FusedMultiTransformerEncoderPass, basic) {
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program())); std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
graph->Set("__param_scope__", CreateParamScope()); graph->Set("__param_scope__", CreateParamScope());
graph->Set("enable_int8", new bool(false));
auto pass = auto pass =
PassRegistry::Instance().Get("fused_multi_transformer_encoder_pass"); PassRegistry::Instance().Get("fused_multi_transformer_encoder_pass");
...@@ -334,6 +335,7 @@ TEST(FusedMultiTransformerEncoderFuseQKVPass, basic) { ...@@ -334,6 +335,7 @@ TEST(FusedMultiTransformerEncoderFuseQKVPass, basic) {
layers.elementwise_add(attention_out, ffn_eltadd1_out); layers.elementwise_add(attention_out, ffn_eltadd1_out);
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program())); std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
graph->Set("enable_int8", new bool(false));
graph->Set("__param_scope__", CreateParamScope()); graph->Set("__param_scope__", CreateParamScope());
auto pass = PassRegistry::Instance().Get( auto pass = PassRegistry::Instance().Get(
...@@ -489,6 +491,7 @@ TEST(MultiDevicesFusedMultiTransformerEncoderFuseQKVPass, basic) { ...@@ -489,6 +491,7 @@ TEST(MultiDevicesFusedMultiTransformerEncoderFuseQKVPass, basic) {
layers.elementwise_add(attention_out, ffn_eltadd1_out); layers.elementwise_add(attention_out, ffn_eltadd1_out);
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program())); std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
graph->Set("enable_int8", new bool(false));
graph->Set("__param_scope__", CreateParamScope()); graph->Set("__param_scope__", CreateParamScope());
auto pass = PassRegistry::Instance().Get( auto pass = PassRegistry::Instance().Get(
......
...@@ -3175,6 +3175,73 @@ void patterns::DeleteWeightQuantDequantLinearOpPattern::operator()() { ...@@ -3175,6 +3175,73 @@ void patterns::DeleteWeightQuantDequantLinearOpPattern::operator()() {
any_op2->LinksFrom({weight_dequantize_linear_op_out}); any_op2->LinksFrom({weight_dequantize_linear_op_out});
} }
void patterns::DeleteWeightDequantLinearOpEncoderPattern::operator()() {
auto weight_dequantize_linear_op_x =
pattern->NewNode(weight_dequantize_linear_op_x_repr())
->AsInput()
->assert_is_op_input("dequantize_linear", "X")
->assert_is_persistable_var();
auto weight_dequantize_linear_op_scale =
pattern->NewNode(weight_dequantize_linear_op_scale_repr())
->AsInput()
->assert_is_op_input("dequantize_linear", "Scale")
->assert_is_persistable_var();
auto weight_dequantize_linear_op =
pattern->NewNode(weight_dequantize_linear_op_repr())
->assert_is_op("dequantize_linear");
auto weight_dequantize_linear_op_out =
pattern->NewNode(weight_dequantize_linear_op_out_repr())
->AsIntermediate()
->assert_is_op_output("dequantize_linear", "Y");
auto any_op2 = pattern->NewNode(any_op2_repr())->assert_is_op()->AsOutput();
// while loop
auto *while0 =
pattern->NewNode(while0_repr())->assert_is_op("while")->AsOutput();
while0->LinksFrom({weight_dequantize_linear_op_out});
weight_dequantize_linear_op
->LinksFrom(
{weight_dequantize_linear_op_x, weight_dequantize_linear_op_scale})
.LinksTo({weight_dequantize_linear_op_out});
any_op2->LinksFrom({weight_dequantize_linear_op_out});
}
void patterns::DeleteWeightDequantLinearOpDecoderPattern::operator()() {
auto weight_dequantize_linear_op_x =
pattern->NewNode(weight_dequantize_linear_op_x_repr())
->AsInput()
->assert_is_op_input("dequantize_linear", "X")
->assert_is_persistable_var();
auto weight_dequantize_linear_op_scale =
pattern->NewNode(weight_dequantize_linear_op_scale_repr())
->AsInput()
->assert_is_op_input("dequantize_linear", "Scale")
->assert_is_persistable_var();
auto weight_dequantize_linear_op =
pattern->NewNode(weight_dequantize_linear_op_repr())
->assert_is_op("dequantize_linear");
auto weight_dequantize_linear_op_out =
pattern->NewNode(weight_dequantize_linear_op_out_repr())
->AsIntermediate()
->assert_is_op_output("dequantize_linear", "Y");
auto any_op2 = pattern->NewNode(any_op2_repr())->assert_is_op()->AsOutput();
weight_dequantize_linear_op
->LinksFrom(
{weight_dequantize_linear_op_x, weight_dequantize_linear_op_scale})
.LinksTo({weight_dequantize_linear_op_out});
any_op2->LinksFrom({weight_dequantize_linear_op_out});
}
void patterns::DeleteQuantDequantLinearOpPattern::operator()() { void patterns::DeleteQuantDequantLinearOpPattern::operator()() {
auto quantize_linear_op_x = pattern->NewNode(quantize_linear_op_x_repr()) auto quantize_linear_op_x = pattern->NewNode(quantize_linear_op_x_repr())
->AsInput() ->AsInput()
......
...@@ -1765,6 +1765,39 @@ struct DeleteWeightQuantDequantLinearOpPattern : public PatternBase { ...@@ -1765,6 +1765,39 @@ struct DeleteWeightQuantDequantLinearOpPattern : public PatternBase {
PATTERN_DECL_NODE(any_op2); PATTERN_DECL_NODE(any_op2);
}; };
struct DeleteWeightDequantLinearOpEncoderPattern : public PatternBase {
DeleteWeightDequantLinearOpEncoderPattern(PDPattern* pattern,
const std::string& name_scope)
: PatternBase(pattern,
name_scope,
"delete_weight_quant_dequant_linear_op_pattern") {}
void operator()();
PATTERN_DECL_NODE(weight_dequantize_linear_op_x);
PATTERN_DECL_NODE(weight_dequantize_linear_op_scale);
PATTERN_DECL_NODE(while0);
PATTERN_DECL_NODE(weight_dequantize_linear_op);
PATTERN_DECL_NODE(weight_dequantize_linear_op_out);
PATTERN_DECL_NODE(any_op2);
};
struct DeleteWeightDequantLinearOpDecoderPattern : public PatternBase {
DeleteWeightDequantLinearOpDecoderPattern(PDPattern* pattern,
const std::string& name_scope)
: PatternBase(pattern,
name_scope,
"delete_weight_quant_dequant_linear_op_pattern") {}
void operator()();
PATTERN_DECL_NODE(weight_dequantize_linear_op_x);
PATTERN_DECL_NODE(weight_dequantize_linear_op_scale);
PATTERN_DECL_NODE(weight_dequantize_linear_op);
PATTERN_DECL_NODE(weight_dequantize_linear_op_out);
PATTERN_DECL_NODE(any_op2);
};
struct DeleteQuantDequantLinearOpPattern : public PatternBase { struct DeleteQuantDequantLinearOpPattern : public PatternBase {
DeleteQuantDequantLinearOpPattern(PDPattern* pattern, DeleteQuantDequantLinearOpPattern(PDPattern* pattern,
const std::string& name_scope) const std::string& name_scope)
......
...@@ -46,7 +46,10 @@ static const std::vector<std::string> support_subgraph_passes = { ...@@ -46,7 +46,10 @@ static const std::vector<std::string> support_subgraph_passes = {
"fused_multi_transformer_decoder_fuse_qkv_pass", "fused_multi_transformer_decoder_fuse_qkv_pass",
"multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass", "multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass",
"multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass", "multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass",
"fuse_multi_transformer_layer_pass"}; "fuse_multi_transformer_layer_pass",
"delete_quant_dequant_linear_op_pass",
"delete_weight_dequant_linear_op_encoder_pass",
"delete_weight_dequant_linear_op_decoder_pass"};
Graph *Pass::Apply(Graph *graph) const { Graph *Pass::Apply(Graph *graph) const {
VLOG(10) << "start to apply pass " << Type() << " to graph"; VLOG(10) << "start to apply pass " << Type() << " to graph";
......
...@@ -165,6 +165,9 @@ const std::vector<std::string> kLiteSubgraphPasses({ ...@@ -165,6 +165,9 @@ const std::vector<std::string> kLiteSubgraphPasses({
// running errors. After fusion operator supports low precision, delete this. // running errors. After fusion operator supports low precision, delete this.
const std::vector<std::string> kGpuLowerPrecisionPasses{ const std::vector<std::string> kGpuLowerPrecisionPasses{
"simplify_with_basic_ops_pass", "simplify_with_basic_ops_pass",
"delete_quant_dequant_linear_op_pass",
"delete_weight_dequant_linear_op_encoder_pass",
"delete_weight_dequant_linear_op_decoder_pass",
"map_depthwise_conv_to_conv_pass", "map_depthwise_conv_to_conv_pass",
"conv_bn_fuse_pass", "conv_bn_fuse_pass",
"conv_eltwiseadd_bn_fuse_pass", "conv_eltwiseadd_bn_fuse_pass",
...@@ -203,9 +206,12 @@ const std::vector<std::string> kTrtLowerPrecisionPasses{ ...@@ -203,9 +206,12 @@ const std::vector<std::string> kTrtLowerPrecisionPasses{
GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) { GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
passes_.assign({ passes_.assign({
// "identity_scale_op_clean_pass", // // "identity_scale_op_clean_pass", //
"is_test_pass", // "is_test_pass", //
"simplify_with_basic_ops_pass", // "simplify_with_basic_ops_pass", //
"map_depthwise_conv_to_conv_pass", "delete_quant_dequant_linear_op_pass", //
"delete_weight_dequant_linear_op_encoder_pass", //
"delete_weight_dequant_linear_op_decoder_pass", //
"map_depthwise_conv_to_conv_pass", //
"conv_bn_fuse_pass", // "conv_bn_fuse_pass", //
"conv_eltwiseadd_bn_fuse_pass", // "conv_eltwiseadd_bn_fuse_pass", //
"embedding_eltwise_layernorm_fuse_pass", // "embedding_eltwise_layernorm_fuse_pass", //
......
...@@ -20,6 +20,7 @@ limitations under the License. */ ...@@ -20,6 +20,7 @@ limitations under the License. */
#include "paddle/fluid/operators/fused/quant_dequant_kernel.h" #include "paddle/fluid/operators/fused/quant_dequant_kernel.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h" #include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/kernels/funcs/broadcast_function.h" #include "paddle/phi/kernels/funcs/broadcast_function.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h" #include "paddle/phi/kernels/funcs/elementwise_functor.h"
...@@ -27,6 +28,7 @@ namespace paddle { ...@@ -27,6 +28,7 @@ namespace paddle {
namespace operators { namespace operators {
using Tensor = phi::DenseTensor; using Tensor = phi::DenseTensor;
using phi::backends::gpu::GpuLaunchConfig;
template <typename T> template <typename T>
class AttnMatmulINT8 { class AttnMatmulINT8 {
...@@ -36,6 +38,9 @@ class AttnMatmulINT8 { ...@@ -36,6 +38,9 @@ class AttnMatmulINT8 {
: dev_ctx_(dev_ctx), m_(m), n_(n), k_(k), compute_bias_(compute_bias) { : dev_ctx_(dev_ctx), m_(m), n_(n), k_(k), compute_bias_(compute_bias) {
auto helper = std::make_shared<CublasLtHelper>(m, k, n); auto helper = std::make_shared<CublasLtHelper>(m, k, n);
helpers_.emplace_back(helper); helpers_.emplace_back(helper);
gpu_config_ = std::make_unique<GpuLaunchConfig>(
phi::backends::gpu::GetGpuLaunchConfig1D(
dev_ctx, m * n, DequantKernelVecSize));
} }
~AttnMatmulINT8() {} ~AttnMatmulINT8() {}
...@@ -50,7 +55,6 @@ class AttnMatmulINT8 { ...@@ -50,7 +55,6 @@ class AttnMatmulINT8 {
phi::DenseTensor* bias_out, phi::DenseTensor* bias_out,
const float quant_in_scale, const float quant_in_scale,
const phi::DenseTensor* dequant_out_scale, const phi::DenseTensor* dequant_out_scale,
const int quant_out_scale_offset,
const int quant_round_type = 1, const int quant_round_type = 1,
const float quant_max_bound = 127.0, const float quant_max_bound = 127.0,
const float quant_min_bound = -127.0) { const float quant_min_bound = -127.0) {
...@@ -74,9 +78,9 @@ class AttnMatmulINT8 { ...@@ -74,9 +78,9 @@ class AttnMatmulINT8 {
m_, m_,
n_, n_,
dev_ctx_.stream(), dev_ctx_.stream(),
gpu_config_.get(),
quant_in_scale, quant_in_scale,
dequant_out_scale->data<float>(), dequant_out_scale->data<float>());
quant_out_scale_offset);
if (compute_bias_) { if (compute_bias_) {
// bias_out = output + bias // bias_out = output + bias
...@@ -99,11 +103,13 @@ class AttnMatmulINT8 { ...@@ -99,11 +103,13 @@ class AttnMatmulINT8 {
phi::DenseTensor* input, phi::DenseTensor* input,
const phi::DenseTensor* bias, const phi::DenseTensor* bias,
phi::DenseTensor* output, phi::DenseTensor* output,
phi::DenseTensor* bias_out) { phi::DenseTensor* bias_out,
void* workspace = nullptr) {
helpers_[0]->GEMM(input->data<int8_t>(), helpers_[0]->GEMM(input->data<int8_t>(),
weight->data<int8_t>(), weight->data<int8_t>(),
output->data<int32_t>(), output->data<int32_t>(),
dev_ctx_.stream()); dev_ctx_.stream(),
workspace);
} }
// This function is used to execute GEMM, with input and output's types are // This function is used to execute GEMM, with input and output's types are
...@@ -115,8 +121,7 @@ class AttnMatmulINT8 { ...@@ -115,8 +121,7 @@ class AttnMatmulINT8 {
phi::DenseTensor* output, phi::DenseTensor* output,
phi::DenseTensor* output_tmp, phi::DenseTensor* output_tmp,
phi::DenseTensor* bias_out, phi::DenseTensor* bias_out,
const phi::DenseTensor* dequant_out_scale, const phi::DenseTensor* dequant_out_scale) {
const int quant_out_scale_offset) {
helpers_[0]->GEMM(input->data<int8_t>(), helpers_[0]->GEMM(input->data<int8_t>(),
weight->data<int8_t>(), weight->data<int8_t>(),
output_tmp->data<int32_t>(), output_tmp->data<int32_t>(),
...@@ -127,9 +132,9 @@ class AttnMatmulINT8 { ...@@ -127,9 +132,9 @@ class AttnMatmulINT8 {
m_, m_,
n_, n_,
dev_ctx_.stream(), dev_ctx_.stream(),
gpu_config_.get(),
quant_in_scale, quant_in_scale,
dequant_out_scale->data<float>(), dequant_out_scale->data<float>());
quant_out_scale_offset);
if (compute_bias_) { if (compute_bias_) {
// bias_out = output + bias // bias_out = output + bias
...@@ -183,6 +188,7 @@ class AttnMatmulINT8 { ...@@ -183,6 +188,7 @@ class AttnMatmulINT8 {
int compute_bias_; int compute_bias_;
std::vector<std::shared_ptr<CublasLtHelper>> helpers_; std::vector<std::shared_ptr<CublasLtHelper>> helpers_;
std::unique_ptr<GpuLaunchConfig> gpu_config_;
}; };
} // namespace operators } // namespace operators
......
...@@ -24,6 +24,20 @@ namespace dyl = paddle::platform::dynload; ...@@ -24,6 +24,20 @@ namespace dyl = paddle::platform::dynload;
namespace paddle { namespace paddle {
namespace operators { namespace operators {
struct CublasLtAlgoParam {
int algoId;
int swizzle;
int customOption;
int tile;
int splitK_val;
int reductionScheme;
int stages;
size_t workspace_size;
};
const std::map<std::tuple<int, int, int>, CublasLtAlgoParam> AlgoParamCache{};
class CublasLtHelper { class CublasLtHelper {
public: public:
CublasLtHelper(int m, int k, int n) CublasLtHelper(int m, int k, int n)
...@@ -99,38 +113,34 @@ class CublasLtHelper { ...@@ -99,38 +113,34 @@ class CublasLtHelper {
"cublasLtMatrixLayoutCreate execution error" "cublasLtMatrixLayoutCreate execution error"
"refer https://docs.nvidia.com/cuda/cublas/index.html to get more " "refer https://docs.nvidia.com/cuda/cublas/index.html to get more "
"information")); "information"));
}
~CublasLtHelper() {
if (handle_) dyl::cublasLtDestroy(handle_);
if (matmul_desc_) dyl::cublasLtMatmulDescDestroy(matmul_desc_);
if (A_desc_) dyl::cublasLtMatrixLayoutDestroy(A_desc_);
if (B_desc_) dyl::cublasLtMatrixLayoutDestroy(B_desc_);
if (C_desc_) dyl::cublasLtMatrixLayoutDestroy(C_desc_);
}
void GEMM(int8_t* A_dev, #if CUDA_VERSION >= 11020
const int8_t* B_dev,
int32_t* C_dev,
cudaStream_t stream) {
cublasStatus_t status;
#if __CUDA_ARCH__ >= 800 && CUDA_VERSION >= 11020
cublasLtMatmulAlgo_t algo;
int algoId = 21; int algoId = 21;
int swizzle = 0; int swizzle = 0;
int customOption = 0; int customOption = 0;
int tile = 15; int tile = 15;
int splitK_val = 0; int splitK_val = 0;
int reductionScheme = 0; int reductionScheme = 0;
#if CUDA_VERSION >= 11000
int stages = 23; int stages = 23;
#endif workspace_size_ = 0;
if (m >= 128) {
#if CUBLAS_VER_MAJOR < 11 tile = 20;
cudaDataType_t cudaComputeType = CUDA_R_32I; stages = 17;
#else }
cublasComputeType_t cudaComputeType = CUBLAS_COMPUTE_32I;
#endif std::tuple<int, int, int> key(m_, k_, n_);
if (AlgoParamCache.count(key) != 0) {
auto value = AlgoParamCache.at(key);
algoId = value.algoId;
swizzle = value.swizzle;
customOption = value.customOption;
tile = value.tile;
splitK_val = value.splitK_val;
reductionScheme = value.reductionScheme;
stages = value.stages;
workspace_size_ = value.workspace_size;
}
dyl::cublasLtMatmulAlgoInit(handle_, dyl::cublasLtMatmulAlgoInit(handle_,
cudaComputeType, cudaComputeType,
...@@ -140,30 +150,43 @@ class CublasLtHelper { ...@@ -140,30 +150,43 @@ class CublasLtHelper {
CUDA_R_32I, CUDA_R_32I,
CUDA_R_32I, CUDA_R_32I,
algoId, algoId,
&algo); &algo_);
dyl::cublasLtMatmulAlgoConfigSetAttribute( dyl::cublasLtMatmulAlgoConfigSetAttribute(
&algo, &algo_,
CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION,
&(customOption), &(customOption),
sizeof(customOption)); sizeof(customOption));
dyl::cublasLtMatmulAlgoConfigSetAttribute( dyl::cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_TILE_ID, &(tile), sizeof(tile)); &algo_, CUBLASLT_ALGO_CONFIG_TILE_ID, &(tile), sizeof(tile));
dyl::cublasLtMatmulAlgoConfigSetAttribute(&algo, dyl::cublasLtMatmulAlgoConfigSetAttribute(&algo_,
CUBLASLT_ALGO_CONFIG_SPLITK_NUM, CUBLASLT_ALGO_CONFIG_SPLITK_NUM,
&(splitK_val), &(splitK_val),
sizeof(splitK_val)); sizeof(splitK_val));
dyl::cublasLtMatmulAlgoConfigSetAttribute( dyl::cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(swizzle), sizeof(swizzle)); &algo_,
CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING,
&(swizzle),
sizeof(swizzle));
dyl::cublasLtMatmulAlgoConfigSetAttribute( dyl::cublasLtMatmulAlgoConfigSetAttribute(
&algo, &algo_,
CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME,
&(reductionScheme), &(reductionScheme),
sizeof(int)); sizeof(int));
#if CUDA_VERSION >= 11000 #if CUDA_VERSION >= 11000
dyl::cublasLtMatmulAlgoConfigSetAttribute( dyl::cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &(stages), sizeof(stages)); &algo_, CUBLASLT_ALGO_CONFIG_STAGES_ID, &(stages), sizeof(stages));
#endif #endif
#endif #endif
}
~CublasLtHelper() {}
void GEMM(int8_t* A_dev,
const int8_t* B_dev,
int32_t* C_dev,
cudaStream_t stream,
void* workspace = nullptr) {
cublasStatus_t status;
status = dyl::cublasLtMatmul(handle_, status = dyl::cublasLtMatmul(handle_,
matmul_desc_, matmul_desc_,
&alpha_, &alpha_,
...@@ -176,13 +199,15 @@ class CublasLtHelper { ...@@ -176,13 +199,15 @@ class CublasLtHelper {
C_desc_, C_desc_,
C_dev, C_dev,
C_desc_, C_desc_,
#if __CUDA_ARCH__ >= 800 && CUDA_VERSION >= 11020 #if CUDA_VERSION >= 11020
&algo, &algo_,
workspace,
workspace_size_,
#else #else
nullptr, nullptr,
#endif
nullptr, nullptr,
0, 0,
#endif
stream); stream);
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
status, status,
...@@ -199,12 +224,17 @@ class CublasLtHelper { ...@@ -199,12 +224,17 @@ class CublasLtHelper {
cublasLtMatrixLayout_t A_desc_; cublasLtMatrixLayout_t A_desc_;
cublasLtMatrixLayout_t B_desc_; cublasLtMatrixLayout_t B_desc_;
cublasLtMatrixLayout_t C_desc_; cublasLtMatrixLayout_t C_desc_;
cublasLtMatmulAlgo_t algo_;
int32_t alpha_; int32_t alpha_;
int32_t beta_; int32_t beta_;
int m_; int m_;
int k_; int k_;
int n_; int n_;
size_t workspace_size_;
}; };
} // namespace operators } // namespace operators
......
...@@ -86,7 +86,6 @@ __global__ void FusedDropoutActBias( ...@@ -86,7 +86,6 @@ __global__ void FusedDropoutActBias(
MaskType *mask, MaskType *mask,
const float quant_last_in_scale = 1.0, const float quant_last_in_scale = 1.0,
const float *dequant_out_scale_data = nullptr, const float *dequant_out_scale_data = nullptr,
const int quant_out_scale_offset = 0,
const float quant_next_in_scale = 1.0, const float quant_next_in_scale = 1.0,
const int quant_round_type = 1, const int quant_round_type = 1,
const float quant_max_bound = 127.0, const float quant_max_bound = 127.0,
...@@ -127,7 +126,6 @@ __global__ void FusedDropoutActBias( ...@@ -127,7 +126,6 @@ __global__ void FusedDropoutActBias(
act, act,
quant_last_in_scale, quant_last_in_scale,
dequant_out_scale_data, dequant_out_scale_data,
quant_out_scale_offset,
quant_next_in_scale, quant_next_in_scale,
quant_round_type, quant_round_type,
quant_max_bound, quant_max_bound,
...@@ -146,7 +144,13 @@ __global__ void FusedActBias(Functor act, ...@@ -146,7 +144,13 @@ __global__ void FusedActBias(Functor act,
const uint64_t cols, const uint64_t cols,
const InType *__restrict__ src, const InType *__restrict__ src,
const T *__restrict__ bias, const T *__restrict__ bias,
OutType *dst) { OutType *dst,
const float quant_last_in_scale = 1.0,
const float *dequant_out_scale_data = nullptr,
const float quant_next_in_scale = 1.0,
const int quant_round_type = 1,
const float quant_max_bound = 127.0,
const float quant_min_bound = -127.0) {
const int32_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x; const int32_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
using LoadT = phi::AlignedVector<T, VecSize>; using LoadT = phi::AlignedVector<T, VecSize>;
using LoadInType = phi::AlignedVector<InType, VecSize>; using LoadInType = phi::AlignedVector<InType, VecSize>;
...@@ -156,23 +160,42 @@ __global__ void FusedActBias(Functor act, ...@@ -156,23 +160,42 @@ __global__ void FusedActBias(Functor act,
LoadInType src_vec; LoadInType src_vec;
LoadT bias_vec; LoadT bias_vec;
StoreOutType out_vec; StoreOutType out_vec;
LoadFloat dequant_out_scale_vec;
for (int32_t idx = global_thread_idx * VecSize, for (int32_t idx = global_thread_idx * VecSize,
step = blockDim.x * gridDim.x * VecSize; step = blockDim.x * gridDim.x * VecSize;
idx < elem_cnt; idx < elem_cnt;
idx += step) { idx += step) {
const int32_t col_idx = idx % cols; const int32_t col_idx = idx % cols;
phi::Load<InType, VecSize>(&src[idx], &src_vec); phi::Load<InType, VecSize>(&src[idx], &src_vec);
phi::Load<float, VecSize>(&dequant_out_scale_data[col_idx],
&dequant_out_scale_vec);
if (bias) { if (bias) {
phi::Load<T, VecSize>(&bias[col_idx], &bias_vec); phi::Load<T, VecSize>(&bias[col_idx], &bias_vec);
} }
#pragma unroll #pragma unroll
for (int32_t unroll_idx = 0; unroll_idx < VecSize; unroll_idx++) { for (int32_t unroll_idx = 0; unroll_idx < VecSize; unroll_idx++) {
if (bias) { T tmp;
out_vec[unroll_idx] = static_cast<OutType>( if (std::is_same<InType, int32_t>::value) {
act(static_cast<T>(src_vec[unroll_idx]) + bias_vec[unroll_idx])); tmp = static_cast<T>(static_cast<float>(src_vec[unroll_idx]) *
dequant_out_scale_vec[unroll_idx]);
if (bias) {
tmp = static_cast<T>(act(tmp + bias_vec[unroll_idx]));
} else {
tmp = static_cast<T>(act(tmp));
}
out_vec[unroll_idx] = quant_helper(tmp,
quant_next_in_scale,
quant_round_type,
quant_max_bound,
quant_min_bound);
} else { } else {
out_vec[unroll_idx] = if (bias) {
static_cast<OutType>(act(static_cast<T>(src_vec[unroll_idx]))); out_vec[unroll_idx] = static_cast<OutType>(
act(static_cast<T>(src_vec[unroll_idx]) + bias_vec[unroll_idx]));
} else {
out_vec[unroll_idx] =
static_cast<OutType>(act(static_cast<T>(src_vec[unroll_idx])));
}
} }
} }
phi::Store<OutType, VecSize>(out_vec, &dst[idx]); phi::Store<OutType, VecSize>(out_vec, &dst[idx]);
...@@ -202,7 +225,6 @@ void LaunchDropoutActBias(Functor act_functor, ...@@ -202,7 +225,6 @@ void LaunchDropoutActBias(Functor act_functor,
const phi::GPUContext &ctx, const phi::GPUContext &ctx,
const float quant_last_in_scale = 1.0, const float quant_last_in_scale = 1.0,
const float *dequant_out_scale_data = nullptr, const float *dequant_out_scale_data = nullptr,
const int quant_out_scale_offset = 0,
const float quant_next_in_scale = 1.0, const float quant_next_in_scale = 1.0,
const int quant_round_type = 1, const int quant_round_type = 1,
const float quant_max_bound = 127.0, const float quant_max_bound = 127.0,
...@@ -218,7 +240,7 @@ void LaunchDropoutActBias(Functor act_functor, ...@@ -218,7 +240,7 @@ void LaunchDropoutActBias(Functor act_functor,
const int real_vec_size = cols % VecSize == 0 ? VecSize : 1; const int real_vec_size = cols % VecSize == 0 ? VecSize : 1;
const auto config = Get1DBlocksAnd2DGrids(ctx, rows, cols, real_vec_size); const auto config = Get1DBlocksAnd2DGrids(ctx, rows, cols, real_vec_size);
if (cols % VecSize == 0) { if (cols % VecSize == 0) {
if (is_test && (dequant_out_scale_data == nullptr)) { if (is_test) {
const int32_t elem_cnt = rows * cols; const int32_t elem_cnt = rows * cols;
const int32_t pack_num = elem_cnt / VecSize; const int32_t pack_num = elem_cnt / VecSize;
const int32_t tmp_cols = cols / VecSize; const int32_t tmp_cols = cols / VecSize;
...@@ -227,8 +249,15 @@ void LaunchDropoutActBias(Functor act_functor, ...@@ -227,8 +249,15 @@ void LaunchDropoutActBias(Functor act_functor,
const int grid_size = std::max(static_cast<int32_t>(1), const int grid_size = std::max(static_cast<int32_t>(1),
(pack_num + block_size - 1) / block_size); (pack_num + block_size - 1) / block_size);
FusedActBias<T, VecSize, Functor, InType, OutType> FusedActBias<T, VecSize, Functor, InType, OutType>
<<<grid_size, block_size, 0, ctx.stream()>>>( <<<grid_size, block_size, 0, ctx.stream()>>>(act_functor,
act_functor, elem_cnt, cols, src, bias, dst); elem_cnt,
cols,
src,
bias,
dst,
quant_last_in_scale,
dequant_out_scale_data,
quant_next_in_scale);
} else { } else {
FusedDropoutActBias<T, MaskType, VecSize, Functor, InType, OutType> FusedDropoutActBias<T, MaskType, VecSize, Functor, InType, OutType>
<<<config.block_per_grid, config.thread_per_block, 0, ctx.stream()>>>( <<<config.block_per_grid, config.thread_per_block, 0, ctx.stream()>>>(
...@@ -246,7 +275,6 @@ void LaunchDropoutActBias(Functor act_functor, ...@@ -246,7 +275,6 @@ void LaunchDropoutActBias(Functor act_functor,
mask_data, mask_data,
quant_last_in_scale, quant_last_in_scale,
dequant_out_scale_data, dequant_out_scale_data,
quant_out_scale_offset,
quant_next_in_scale); quant_next_in_scale);
} }
} else { } else {
...@@ -266,7 +294,6 @@ void LaunchDropoutActBias(Functor act_functor, ...@@ -266,7 +294,6 @@ void LaunchDropoutActBias(Functor act_functor,
mask_data, mask_data,
quant_last_in_scale, quant_last_in_scale,
dequant_out_scale_data, dequant_out_scale_data,
quant_out_scale_offset,
quant_next_in_scale); quant_next_in_scale);
} }
} }
......
...@@ -154,7 +154,6 @@ class FusedDropoutHelper { ...@@ -154,7 +154,6 @@ class FusedDropoutHelper {
MaskType* mask, MaskType* mask,
const float quant_last_in_scale = 1.0, const float quant_last_in_scale = 1.0,
const float* dequant_out_scale_data = nullptr, const float* dequant_out_scale_data = nullptr,
const int quant_out_scale_offset = 0,
const float quant_next_in_scale = 1.0) { const float quant_next_in_scale = 1.0) {
auto increment = GetIncrement(ctx); auto increment = GetIncrement(ctx);
LaunchResidualDropoutBias<T, MaskType, InType, OutType>( LaunchResidualDropoutBias<T, MaskType, InType, OutType>(
...@@ -173,7 +172,6 @@ class FusedDropoutHelper { ...@@ -173,7 +172,6 @@ class FusedDropoutHelper {
ctx, ctx,
quant_last_in_scale, quant_last_in_scale,
dequant_out_scale_data, dequant_out_scale_data,
quant_out_scale_offset,
quant_next_in_scale); quant_next_in_scale);
} }
...@@ -212,7 +210,6 @@ class FusedDropoutHelper { ...@@ -212,7 +210,6 @@ class FusedDropoutHelper {
MaskType* mask, MaskType* mask,
const float quant_last_in_scale = 1.0, const float quant_last_in_scale = 1.0,
const float* dequant_out_scale_data = nullptr, const float* dequant_out_scale_data = nullptr,
const int quant_out_scale_offset = 0,
const float quant_next_in_scale = 1.0, const float quant_next_in_scale = 1.0,
const int quant_round_type = 1, const int quant_round_type = 1,
const float quant_max_bound = 127.0, const float quant_max_bound = 127.0,
...@@ -237,7 +234,6 @@ class FusedDropoutHelper { ...@@ -237,7 +234,6 @@ class FusedDropoutHelper {
ctx, ctx,
quant_last_in_scale, quant_last_in_scale,
dequant_out_scale_data, dequant_out_scale_data,
quant_out_scale_offset,
quant_next_in_scale, quant_next_in_scale,
quant_round_type, quant_round_type,
quant_max_bound, quant_max_bound,
...@@ -260,7 +256,6 @@ class FusedDropoutHelper { ...@@ -260,7 +256,6 @@ class FusedDropoutHelper {
ctx, ctx,
quant_last_in_scale, quant_last_in_scale,
dequant_out_scale_data, dequant_out_scale_data,
quant_out_scale_offset,
quant_next_in_scale, quant_next_in_scale,
quant_round_type, quant_round_type,
quant_max_bound, quant_max_bound,
...@@ -287,7 +282,6 @@ class FusedDropoutHelper { ...@@ -287,7 +282,6 @@ class FusedDropoutHelper {
ctx, ctx,
quant_last_in_scale, quant_last_in_scale,
dequant_out_scale_data, dequant_out_scale_data,
quant_out_scale_offset,
quant_next_in_scale, quant_next_in_scale,
quant_round_type, quant_round_type,
quant_max_bound, quant_max_bound,
...@@ -454,7 +448,6 @@ class FusedDropoutLayerNormHelper ...@@ -454,7 +448,6 @@ class FusedDropoutLayerNormHelper
LayerNormParamType<T>* variance, LayerNormParamType<T>* variance,
const float quant_last_in_scale = 1.0, const float quant_last_in_scale = 1.0,
const float* dequant_out_scale_data = nullptr, const float* dequant_out_scale_data = nullptr,
const int quant_out_scale_offset = 0,
const float quant_next_in_scale = 1.0, const float quant_next_in_scale = 1.0,
const int quant_round_type = 1, const int quant_round_type = 1,
const float quant_max_bound = 127.0, const float quant_max_bound = 127.0,
...@@ -494,7 +487,6 @@ class FusedDropoutLayerNormHelper ...@@ -494,7 +487,6 @@ class FusedDropoutLayerNormHelper
ctx, ctx,
quant_last_in_scale, quant_last_in_scale,
dequant_out_scale_data, dequant_out_scale_data,
quant_out_scale_offset,
quant_next_in_scale, quant_next_in_scale,
quant_round_type, quant_round_type,
quant_max_bound, quant_max_bound,
......
...@@ -442,7 +442,6 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_fast_ln_fwd_kernel( ...@@ -442,7 +442,6 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_fast_ln_fwd_kernel(
OutType *__restrict__ y_ptr, OutType *__restrict__ y_ptr,
const float quant_last_in_scale = 1.0, const float quant_last_in_scale = 1.0,
const float *__restrict__ quant_out_scale_ptr = nullptr, const float *__restrict__ quant_out_scale_ptr = nullptr,
const int quant_out_scale_offset = 0,
const float quant_next_in_scale = 1.0, const float quant_next_in_scale = 1.0,
const int quant_round_type = 1, const int quant_round_type = 1,
const float quant_max_bound = 127.0, const float quant_max_bound = 127.0,
...@@ -504,9 +503,8 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_fast_ln_fwd_kernel( ...@@ -504,9 +503,8 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_fast_ln_fwd_kernel(
phi::Load<InType, VecSize>(x_ptr + row * ELTS_PER_ROW + col * VecSize, phi::Load<InType, VecSize>(x_ptr + row * ELTS_PER_ROW + col * VecSize,
&x_input[it]); &x_input[it]);
if (quant_out_scale_ptr != nullptr) { if (quant_out_scale_ptr != nullptr) {
phi::Load<float, VecSize>( phi::Load<float, VecSize>(quant_out_scale_ptr + col * VecSize,
quant_out_scale_ptr + quant_out_scale_offset + col * VecSize, &dequant_out_scale[it]);
&dequant_out_scale[it]);
} }
col += THREADS_PER_ROW; col += THREADS_PER_ROW;
} }
...@@ -543,7 +541,6 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_fast_ln_fwd_kernel( ...@@ -543,7 +541,6 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_fast_ln_fwd_kernel(
// dropout(x) + residual // dropout(x) + residual
if (std::is_same<InType, int32_t>::value) { if (std::is_same<InType, int32_t>::value) {
T tmp = (static_cast<T>(static_cast<float>(x_input[it][jt]) * T tmp = (static_cast<T>(static_cast<float>(x_input[it][jt]) *
quant_last_in_scale /
dequant_out_scale[it][jt]) + dequant_out_scale[it][jt]) +
bias[it][jt]) * bias[it][jt]) *
static_cast<T>(mask_vec[it][jt]) * factor + static_cast<T>(mask_vec[it][jt]) * factor +
...@@ -567,7 +564,6 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_fast_ln_fwd_kernel( ...@@ -567,7 +564,6 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_fast_ln_fwd_kernel(
if (std::is_same<InType, int32_t>::value) { if (std::is_same<InType, int32_t>::value) {
// for int32 input, we need to dequantize. // for int32 input, we need to dequantize.
T tmp = static_cast<T>(static_cast<float>(x_input[it][jt]) * T tmp = static_cast<T>(static_cast<float>(x_input[it][jt]) *
quant_last_in_scale /
dequant_out_scale[it][jt]) * dequant_out_scale[it][jt]) *
static_cast<T>(mask_vec[it][jt]) * factor + static_cast<T>(mask_vec[it][jt]) * factor +
residual[it][jt]; residual[it][jt];
...@@ -752,7 +748,6 @@ void LaunchLayernormResidualDropoutBias( ...@@ -752,7 +748,6 @@ void LaunchLayernormResidualDropoutBias(
const phi::GPUContext &ctx, const phi::GPUContext &ctx,
const float quant_last_in_scale = 1.0, const float quant_last_in_scale = 1.0,
const float *dequant_out_scale_data = nullptr, const float *dequant_out_scale_data = nullptr,
const int quant_out_scale_offset = 0,
const float quant_next_in_scale = 1.0, const float quant_next_in_scale = 1.0,
const int quant_round_type = 1, const int quant_round_type = 1,
const float quant_max_bound = 127.0, const float quant_max_bound = 127.0,
...@@ -844,7 +839,6 @@ void LaunchLayernormResidualDropoutBias( ...@@ -844,7 +839,6 @@ void LaunchLayernormResidualDropoutBias(
layernorm_dst, \ layernorm_dst, \
quant_last_in_scale, \ quant_last_in_scale, \
dequant_out_scale_data, \ dequant_out_scale_data, \
quant_out_scale_offset, \
quant_next_in_scale, \ quant_next_in_scale, \
quant_round_type, \ quant_round_type, \
quant_max_bound, \ quant_max_bound, \
......
...@@ -58,6 +58,12 @@ class FusedMultiTransformerINT8Op : public framework::OperatorWithKernel { ...@@ -58,6 +58,12 @@ class FusedMultiTransformerINT8Op : public framework::OperatorWithKernel {
CHECK_INPUTS(FFN1Weight); CHECK_INPUTS(FFN1Weight);
CHECK_INPUTS(FFN2Weight); CHECK_INPUTS(FFN2Weight);
// scale
CHECK_INPUTS(QKVOutScale);
CHECK_INPUTS(OutLinearOutScale);
CHECK_INPUTS(FFN1OutScale);
CHECK_INPUTS(FFN2OutScale);
CHECK_OUTPUT(Out); CHECK_OUTPUT(Out);
// x: qkv's input [batch_size, seq_len, dim_embed] // x: qkv's input [batch_size, seq_len, dim_embed]
...@@ -232,20 +238,24 @@ class FusedMultiTransformerINT8OpMaker ...@@ -232,20 +238,24 @@ class FusedMultiTransformerINT8OpMaker
"In order to keep consistent with the PTQ/QAT calculation logic," "In order to keep consistent with the PTQ/QAT calculation logic,"
"QKVOutScale should be max_bound * max_bound / max_range." "QKVOutScale should be max_bound * max_bound / max_range."
"Here max_range is per-channel weight scale." "Here max_range is per-channel weight scale."
"The shape of QKVOutScale is [num_layers, num_channels]") "The shape of QKVOutScale is [num_channels]")
.AsDispensable(); .AsDispensable()
.AsDuplicable();
AddInput("OutLinearOutScale", AddInput("OutLinearOutScale",
"OutLinearOutScale is used to dequantize out_linear output tensor." "OutLinearOutScale is used to dequantize out_linear output tensor."
"The definition and shape is the same as QKVOutScale") "The definition and shape is the same as QKVOutScale")
.AsDispensable(); .AsDispensable()
.AsDuplicable();
AddInput("FFN1OutScale", AddInput("FFN1OutScale",
"FFN1OutScale is used to dequantize ffn1 output tensor." "FFN1OutScale is used to dequantize ffn1 output tensor."
"The definition and shape is the same as QKVOutScale") "The definition and shape is the same as QKVOutScale")
.AsDispensable(); .AsDispensable()
.AsDuplicable();
AddInput("FFN2OutScale", AddInput("FFN2OutScale",
"FFN2OutScale is used to dequantize ffn2 output tensor." "FFN2OutScale is used to dequantize ffn2 output tensor."
"The definition and shape is the same as QKVOutScale") "The definition and shape is the same as QKVOutScale")
.AsDispensable(); .AsDispensable()
.AsDuplicable();
AddOutput("CacheKVOut", "The updated cache KV. Inplace with CacheKV") AddOutput("CacheKVOut", "The updated cache KV. Inplace with CacheKV")
.AsDispensable() .AsDispensable()
......
...@@ -48,16 +48,11 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel<T> { ...@@ -48,16 +48,11 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel<T> {
// dequant output scales, tensor, size = [num_layers, n], n is gemm output // dequant output scales, tensor, size = [num_layers, n], n is gemm output
// size // size
auto *qkv_out_scale = ctx.Input<phi::DenseTensor>("QKVOutScale"); auto qkv_out_scales = ctx.MultiInput<phi::DenseTensor>("QKVOutScale");
auto *out_linear_out_scale = auto out_linear_out_scales =
ctx.Input<phi::DenseTensor>("OutLinearOutScale"); ctx.MultiInput<phi::DenseTensor>("OutLinearOutScale");
auto *ffn1_out_scale = ctx.Input<phi::DenseTensor>("FFN1OutScale"); auto ffn1_out_scales = ctx.MultiInput<phi::DenseTensor>("FFN1OutScale");
auto *ffn2_out_scale = ctx.Input<phi::DenseTensor>("FFN2OutScale"); auto ffn2_out_scales = ctx.MultiInput<phi::DenseTensor>("FFN2OutScale");
int qkv_out_scale_n = qkv_out_scale->dims()[1];
int out_linear_out_scale_n = out_linear_out_scale->dims()[1];
int ffn1_out_scale_n = ffn1_out_scale->dims()[1];
int ffn2_out_scale_n = ffn2_out_scale->dims()[1];
// 1. layer norm // 1. layer norm
const auto pre_layer_norm = ctx.Attr<bool>("pre_layer_norm"); const auto pre_layer_norm = ctx.Attr<bool>("pre_layer_norm");
...@@ -132,6 +127,7 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel<T> { ...@@ -132,6 +127,7 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel<T> {
transpose_out_2.Resize({{3, bsz, num_head, seq_len, dim_head}}); transpose_out_2.Resize({{3, bsz, num_head, seq_len, dim_head}});
auto *transpose_out_2_data = auto *transpose_out_2_data =
dev_ctx.Alloc<T>(&transpose_out_2, transpose_out_2.numel() * sizeof(T)); dev_ctx.Alloc<T>(&transpose_out_2, transpose_out_2.numel() * sizeof(T));
qk_out.Resize({{bsz, num_head, seq_len, out_seq_len}}); qk_out.Resize({{bsz, num_head, seq_len, out_seq_len}});
auto *qk_out_data = dev_ctx.Alloc<T>(&qk_out, qk_out.numel() * sizeof(T)); auto *qk_out_data = dev_ctx.Alloc<T>(&qk_out, qk_out.numel() * sizeof(T));
...@@ -232,19 +228,23 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel<T> { ...@@ -232,19 +228,23 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel<T> {
dev_ctx, bsz_seq, dim_embed, ffn2_dropout_param, epsilon); dev_ctx, bsz_seq, dim_embed, ffn2_dropout_param, epsilon);
// []. init workspace for cublasLt transform // []. init workspace for cublasLt transform
Tensor input_workspace, output_workspace; Tensor input_workspace, output_workspace, cublaslt_workspace;
// for input and output transform data is CUBLASLT_ORDER_COL32 format, // for input and output transform data is CUBLASLT_ORDER_COL32 format,
int m_max = bsz_seq, k_max = std::max(dim_embed, dim_ffn), int m_max = bsz_seq, k_max = std::max(dim_embed, dim_ffn),
n_max = std::max({output_size, dim_embed, dim_ffn}); n_max = std::max({output_size, dim_embed, dim_ffn});
input_workspace.Resize( input_workspace.Resize({{(m_max * k_max + 31) / 32 * 32}});
{{32 * ((m_max + 32 - 1) / 32), (k_max + 31) / 32 * 32}});
dev_ctx.Alloc<int8_t>(&input_workspace, dev_ctx.Alloc<int8_t>(&input_workspace,
input_workspace.numel() * sizeof(int8_t)); input_workspace.numel() * sizeof(int8_t));
output_workspace.Resize({{n_max * 4, (m_max + 31) / 32 * 32 * 4}});
output_workspace.Resize({{(n_max * m_max + 31) / 32 * 32}});
dev_ctx.Alloc<int32_t>(&output_workspace, dev_ctx.Alloc<int32_t>(&output_workspace,
output_workspace.numel() * sizeof(int32_t)); output_workspace.numel() * sizeof(int32_t));
cublaslt_workspace.Resize({{3000000}});
dev_ctx.Alloc<int8_t>(&cublaslt_workspace,
cublaslt_workspace.numel() * sizeof(int8_t));
// calc // calc
auto *out = ctx.Output<phi::DenseTensor>("Out"); auto *out = ctx.Output<phi::DenseTensor>("Out");
auto *from_data = dev_ctx.Alloc<T>(out, out->numel() * sizeof(T)); auto *from_data = dev_ctx.Alloc<T>(out, out->numel() * sizeof(T));
...@@ -305,8 +305,7 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel<T> { ...@@ -305,8 +305,7 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel<T> {
&output_workspace, &output_workspace,
&qkv_out, &qkv_out,
qkv_in_scale[i], qkv_in_scale[i],
qkv_out_scale, qkv_out_scales[i],
i * qkv_out_scale_n,
quant_round_type, quant_round_type,
quant_max_bound, quant_max_bound,
quant_min_bound); quant_min_bound);
...@@ -319,8 +318,7 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel<T> { ...@@ -319,8 +318,7 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel<T> {
&output_workspace, &output_workspace,
&qkv_out, &qkv_out,
qkv_in_scale[i], qkv_in_scale[i],
qkv_out_scale, qkv_out_scales[i],
i * qkv_out_scale_n,
quant_round_type, quant_round_type,
quant_max_bound, quant_max_bound,
quant_min_bound); quant_min_bound);
...@@ -332,8 +330,7 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel<T> { ...@@ -332,8 +330,7 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel<T> {
&qkv_out, &qkv_out,
&output_workspace, &output_workspace,
&qkv_out, &qkv_out,
qkv_out_scale, qkv_out_scales[i]);
i * qkv_out_scale_n);
} }
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG(0) << "step2"; VLOG(0) << "step2";
...@@ -441,8 +438,7 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel<T> { ...@@ -441,8 +438,7 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel<T> {
&output_workspace, &output_workspace,
nullptr, nullptr,
out_linear_in_scale[i], out_linear_in_scale[i],
out_linear_out_scale, out_linear_out_scales[i],
i * out_linear_out_scale_n,
quant_round_type, quant_round_type,
quant_max_bound, quant_max_bound,
quant_min_bound); quant_min_bound);
...@@ -473,8 +469,7 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel<T> { ...@@ -473,8 +469,7 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel<T> {
ln_mean_data, ln_mean_data,
ln_var_data, ln_var_data,
out_linear_in_scale[i], out_linear_in_scale[i],
out_linear_out_scale->data<float>(), out_linear_out_scales[i]->data<float>(),
i * out_linear_out_scale_n,
ffn1_in_scale[i], ffn1_in_scale[i],
quant_round_type, quant_round_type,
quant_max_bound, quant_max_bound,
...@@ -504,11 +499,13 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel<T> { ...@@ -504,11 +499,13 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel<T> {
// step6. ffn matmul1 // step6. ffn matmul1
if (pre_layer_norm) { if (pre_layer_norm) {
ffn1_linear_compute.ComputeForwardINT8ToINT8(ffn1_weights[i], ffn1_linear_compute.ComputeForwardINT8ToINT8(
&input_workspace, ffn1_weights[i],
nullptr, &input_workspace,
&output_workspace, nullptr,
nullptr); &output_workspace,
nullptr,
cublaslt_workspace.data<int8_t>());
} else { } else {
ffn1_linear_compute.ComputeForward(ffn1_weights[i], ffn1_linear_compute.ComputeForward(ffn1_weights[i],
buf1, buf1,
...@@ -518,8 +515,7 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel<T> { ...@@ -518,8 +515,7 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel<T> {
&output_workspace, &output_workspace,
nullptr, nullptr,
ffn1_in_scale[i], ffn1_in_scale[i],
ffn1_out_scale, ffn1_out_scales[i],
i * ffn1_out_scale_n,
quant_round_type, quant_round_type,
quant_max_bound, quant_max_bound,
quant_min_bound); quant_min_bound);
...@@ -539,8 +535,7 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel<T> { ...@@ -539,8 +535,7 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel<T> {
input_workspace.data<int8_t>(), input_workspace.data<int8_t>(),
ffn1_dropout_mask_data, ffn1_dropout_mask_data,
ffn1_in_scale[i], ffn1_in_scale[i],
ffn1_out_scale->data<float>(), ffn1_out_scales[i]->data<float>(),
i * ffn1_out_scale_n,
ffn2_in_scale[i], ffn2_in_scale[i],
quant_round_type, quant_round_type,
quant_max_bound, quant_max_bound,
...@@ -560,11 +555,13 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel<T> { ...@@ -560,11 +555,13 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel<T> {
// step8. ffn matmul2 // step8. ffn matmul2
if (pre_layer_norm) { if (pre_layer_norm) {
ffn2_linear_compute.ComputeForwardINT8ToINT8(ffn2_weights[i], ffn2_linear_compute.ComputeForwardINT8ToINT8(
&input_workspace, ffn2_weights[i],
nullptr, &input_workspace,
&output_workspace, nullptr,
nullptr); &output_workspace,
nullptr,
cublaslt_workspace.data<int8_t>());
} else { } else {
ffn2_linear_compute.ComputeForward(ffn2_weights[i], ffn2_linear_compute.ComputeForward(ffn2_weights[i],
&ffn1_dropout_out, &ffn1_dropout_out,
...@@ -574,8 +571,7 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel<T> { ...@@ -574,8 +571,7 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel<T> {
&output_workspace, &output_workspace,
nullptr, nullptr,
ffn2_in_scale[i], ffn2_in_scale[i],
ffn2_out_scale, ffn2_out_scales[i],
i * ffn2_out_scale_n,
quant_round_type, quant_round_type,
quant_max_bound, quant_max_bound,
quant_min_bound); quant_min_bound);
...@@ -616,8 +612,7 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel<T> { ...@@ -616,8 +612,7 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel<T> {
ln_mean_data, ln_mean_data,
ln_var_data, ln_var_data,
ffn2_in_scale[i], ffn2_in_scale[i],
ffn2_out_scale->data<float>(), ffn2_out_scales[i]->data<float>(),
i * ffn2_out_scale_n,
qkv_in_scale[i + 1], qkv_in_scale[i + 1],
quant_round_type, quant_round_type,
quant_max_bound, quant_max_bound,
...@@ -631,8 +626,7 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel<T> { ...@@ -631,8 +626,7 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel<T> {
buf1->data<T>(), buf1->data<T>(),
dropout_mask_out_data, dropout_mask_out_data,
ffn2_in_scale[i], ffn2_in_scale[i],
ffn2_out_scale->data<float>(), ffn2_out_scales[i]->data<float>(),
i * ffn2_out_scale_n,
1.0); 1.0);
} }
} else { } else {
......
...@@ -49,7 +49,6 @@ __forceinline__ __device__ void FusedResidualDropoutBiasOneThread( ...@@ -49,7 +49,6 @@ __forceinline__ __device__ void FusedResidualDropoutBiasOneThread(
Functor act_func, Functor act_func,
const float quant_last_in_scale = 1.0, const float quant_last_in_scale = 1.0,
const float *dequant_out_scale_data = nullptr, const float *dequant_out_scale_data = nullptr,
const int quant_out_scale_offset = 0,
const float quant_next_in_scale = 1.0, const float quant_next_in_scale = 1.0,
const int quant_round_type = 1, const int quant_round_type = 1,
const float quant_max_bound = 127.0, const float quant_max_bound = 127.0,
...@@ -74,9 +73,8 @@ __forceinline__ __device__ void FusedResidualDropoutBiasOneThread( ...@@ -74,9 +73,8 @@ __forceinline__ __device__ void FusedResidualDropoutBiasOneThread(
} }
// vectorize load data from global // vectorize load data from global
phi::Load<InType, VecSize>(&src[row_id * cols + col_id], &src_vec); phi::Load<InType, VecSize>(&src[row_id * cols + col_id], &src_vec);
phi::Load<float, VecSize>( phi::Load<float, VecSize>(&dequant_out_scale_data[col_id],
&dequant_out_scale_data[quant_out_scale_offset + col_id], &quant_out_scale_vec);
&quant_out_scale_vec);
if (residual) { if (residual) {
phi::Load<T, VecSize>(&residual[row_id * cols + col_id], &residual_vec); phi::Load<T, VecSize>(&residual[row_id * cols + col_id], &residual_vec);
} }
...@@ -108,7 +106,7 @@ __forceinline__ __device__ void FusedResidualDropoutBiasOneThread( ...@@ -108,7 +106,7 @@ __forceinline__ __device__ void FusedResidualDropoutBiasOneThread(
T tmp; T tmp;
if (std::is_same<InType, int32_t>::value) { if (std::is_same<InType, int32_t>::value) {
T tmp0 = static_cast<T>(static_cast<float>(src_vec[ii]) * T tmp0 = static_cast<T>(static_cast<float>(src_vec[ii]) *
quant_last_in_scale / quant_out_scale_vec[ii]); quant_out_scale_vec[ii]);
tmp = tmp0 + bias_vec[ii]; tmp = tmp0 + bias_vec[ii];
} else { } else {
tmp = static_cast<T>(src_vec[ii]) + bias_vec[ii]; tmp = static_cast<T>(src_vec[ii]) + bias_vec[ii];
...@@ -172,7 +170,6 @@ __global__ void FusedResidualDropoutBias( ...@@ -172,7 +170,6 @@ __global__ void FusedResidualDropoutBias(
const bool is_test, const bool is_test,
const float quant_last_in_scale = 1.0, const float quant_last_in_scale = 1.0,
const float *dequant_out_scale_data = nullptr, const float *dequant_out_scale_data = nullptr,
const int quant_out_scale_offset = 0,
const float quant_next_in_scale = 1.0) { const float quant_next_in_scale = 1.0) {
int col_id = blockDim.x * blockIdx.x + threadIdx.x; int col_id = blockDim.x * blockIdx.x + threadIdx.x;
int row_id = blockIdx.y; int row_id = blockIdx.y;
...@@ -208,7 +205,6 @@ __global__ void FusedResidualDropoutBias( ...@@ -208,7 +205,6 @@ __global__ void FusedResidualDropoutBias(
relu, relu,
quant_last_in_scale, quant_last_in_scale,
dequant_out_scale_data, dequant_out_scale_data,
quant_out_scale_offset,
quant_next_in_scale); quant_next_in_scale);
} }
} }
...@@ -236,7 +232,6 @@ void LaunchResidualDropoutBias(const uint32_t rows, ...@@ -236,7 +232,6 @@ void LaunchResidualDropoutBias(const uint32_t rows,
const phi::GPUContext &ctx, const phi::GPUContext &ctx,
const float quant_last_in_scale = 1.0, const float quant_last_in_scale = 1.0,
const float *dequant_out_scale_data = nullptr, const float *dequant_out_scale_data = nullptr,
const int quant_out_scale_offset = 0,
const float quant_next_in_scale = 1.0) { const float quant_next_in_scale = 1.0) {
// dropout_prob == 1.0f // dropout_prob == 1.0f
if (std::abs(dropout_prob - 1.0f) < 1e-5) { if (std::abs(dropout_prob - 1.0f) < 1e-5) {
...@@ -278,7 +273,6 @@ void LaunchResidualDropoutBias(const uint32_t rows, ...@@ -278,7 +273,6 @@ void LaunchResidualDropoutBias(const uint32_t rows,
is_test, is_test,
quant_last_in_scale, quant_last_in_scale,
dequant_out_scale_data, dequant_out_scale_data,
quant_out_scale_offset,
quant_next_in_scale); quant_next_in_scale);
} else { } else {
FusedResidualDropoutBias<T, uint8_t, 1, InType, OutType> FusedResidualDropoutBias<T, uint8_t, 1, InType, OutType>
...@@ -297,7 +291,6 @@ void LaunchResidualDropoutBias(const uint32_t rows, ...@@ -297,7 +291,6 @@ void LaunchResidualDropoutBias(const uint32_t rows,
is_test, is_test,
quant_last_in_scale, quant_last_in_scale,
dequant_out_scale_data, dequant_out_scale_data,
quant_out_scale_offset,
quant_next_in_scale); quant_next_in_scale);
} }
} }
......
...@@ -18,17 +18,24 @@ limitations under the License. */ ...@@ -18,17 +18,24 @@ limitations under the License. */
#include "paddle/fluid/operators/fake_quantize_op.h" #include "paddle/fluid/operators/fake_quantize_op.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h" #include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/kernels/funcs/aligned_vector.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using phi::backends::gpu::GpuLaunchConfig;
constexpr int DequantKernelVecSize = 4;
template <typename T> template <typename T>
__forceinline__ __device__ int8_t quant_helper(const T input, __forceinline__ __device__ int8_t quant_helper(const T input,
const float scale, const float scale,
const int round_type, const int round_type,
const float max_bound, const float max_bound,
const float min_bound) { const float min_bound) {
float quant_value = max_bound * inverse(scale) * static_cast<float>(input); float quant_value = max_bound * scale * static_cast<float>(input);
if (round_type == 0) { if (round_type == 0) {
quant_value = static_cast<float>(roundWithTiesToEven(quant_value)); quant_value = static_cast<float>(roundWithTiesToEven(quant_value));
} else { } else {
...@@ -77,7 +84,7 @@ void quantize_kernel_launcher(const T* input, ...@@ -77,7 +84,7 @@ void quantize_kernel_launcher(const T* input,
const float min_bound, const float min_bound,
gpuStream_t stream) { gpuStream_t stream) {
// TODO(minghaoBD): optimize the kennel launch times when m==1 or n==1 // TODO(minghaoBD): optimize the kennel launch times when m==1 or n==1
dim3 grid((n + 31) / 32, (m + 31) / 32); dim3 grid((n >> 2 + 31) / 32, (m + 31) / 32);
dim3 block(32, 32); dim3 block(32, 32);
quantize_kernel<<<grid, block, 0, stream>>>(input, quantize_kernel<<<grid, block, 0, stream>>>(input,
...@@ -90,46 +97,48 @@ void quantize_kernel_launcher(const T* input, ...@@ -90,46 +97,48 @@ void quantize_kernel_launcher(const T* input,
min_bound); min_bound);
} }
// dequantize using weight scales and input scales template <typename T, int VecSize>
template <typename T>
__global__ void dequantize_kernel(T* output, __global__ void dequantize_kernel(T* output,
const int32_t* input, const int32_t* input,
const int m, // hidden const int m, // batch size
const int n, // batch size const int n, // hidden
const float quant_in_scale, const float quant_in_scale,
const float* dequant_out_scale_data, const float* dequant_out_scale_data) {
const int quant_out_scale_offset) { int numel = m * n;
int m_id = blockIdx.x * blockDim.x + threadIdx.x; // hidden int stride = blockDim.x * gridDim.x * VecSize;
int n_id = blockIdx.y * blockDim.y + threadIdx.y; // batch size int idx = (blockIdx.x * blockDim.x + threadIdx.x) * VecSize;
int col_id = idx % n;
bool check = ((m_id < m) && (n_id < n));
if (check) { phi::AlignedVector<int32_t, VecSize> in_vec;
float out_scale = dequant_out_scale_data[quant_out_scale_offset + m_id]; phi::AlignedVector<float, VecSize> out_scale_vec;
output[n_id * m + m_id] = phi::AlignedVector<T, VecSize> out_vec;
static_cast<T>(static_cast<float>(input[n_id * m + m_id]) *
quant_in_scale / out_scale); for (; idx < numel; idx += stride) {
phi::Load<int32_t, VecSize>(input + idx, &in_vec);
phi::Load<float, VecSize>(dequant_out_scale_data + col_id, &out_scale_vec);
#pragma unroll
for (int i = 0; i < VecSize; ++i) {
out_vec[i] =
static_cast<T>(static_cast<float>(in_vec[i]) * out_scale_vec[i]);
}
phi::Store<T, VecSize>(out_vec, output + idx);
} }
} }
template <typename T> template <typename T>
void dequantize_kernel_launcher(const int32_t* input, void dequantize_kernel_launcher(const int32_t* input,
T* output, T* output,
const int batch_size, // m const int m, // m
const int hidden_units, // n const int n, // n
gpuStream_t stream, gpuStream_t stream,
GpuLaunchConfig* gpu_config,
const float quant_in_scale, const float quant_in_scale,
const float* dequant_out_scale_data, const float* dequant_out_scale_data) {
const int quant_out_scale_offset) { dequantize_kernel<T, DequantKernelVecSize>
dim3 grid((hidden_units + 31) / 32, (batch_size + 31) / 32); <<<gpu_config->block_per_grid, gpu_config->thread_per_block, 0, stream>>>(
dim3 block(32, 32); output, input, m, n, quant_in_scale, dequant_out_scale_data);
dequantize_kernel<<<grid, block, 0, stream>>>(output,
input,
hidden_units,
batch_size,
quant_in_scale,
dequant_out_scale_data,
quant_out_scale_offset);
} }
} // namespace operators } // namespace operators
......
...@@ -307,7 +307,7 @@ class TestFusedMultiTransformerInt8Op(unittest.TestCase): ...@@ -307,7 +307,7 @@ class TestFusedMultiTransformerInt8Op(unittest.TestCase):
self.attn_mask = None self.attn_mask = None
def fake_quant(self, input, scale): def fake_quant(self, input, scale):
quant_value = 127.0 * (1.0 / scale) * paddle.cast(input, 'float32') quant_value = 127.0 * scale * paddle.cast(input, 'float32')
quant_value = paddle.round(quant_value) quant_value = paddle.round(quant_value)
# No need to clip here because scale is the max value # No need to clip here because scale is the max value
...@@ -333,11 +333,8 @@ class TestFusedMultiTransformerInt8Op(unittest.TestCase): ...@@ -333,11 +333,8 @@ class TestFusedMultiTransformerInt8Op(unittest.TestCase):
if self.pre_layer_norm: if self.pre_layer_norm:
ln1_out = self.norm(tensor_query) ln1_out = self.norm(tensor_query)
max_v = paddle.max(paddle.abs(paddle.cast(ln1_out, 'float32')))[0] max_v = paddle.max(paddle.abs(paddle.cast(ln1_out, 'float32')))[0]
# self.qkv_in_scales.append(127.0 / max_v) self.qkv_in_scales.append(1 / max_v)
self.qkv_in_scales.append(max_v) self.qkv_out_scales.append(max_v / (127.0 * 127.0))
self.qkv_out_scales.append(127.0 * 127.0)
# print('qkv_in_scales ', i, self.qkv_in_scales[i])
# print('qkv_out_scales ', i, self.qkv_out_scales[i])
# quant ln1_out # quant ln1_out
ln1_out = self.fake_quant(ln1_out, self.qkv_in_scales[i]) ln1_out = self.fake_quant(ln1_out, self.qkv_in_scales[i])
...@@ -345,9 +342,7 @@ class TestFusedMultiTransformerInt8Op(unittest.TestCase): ...@@ -345,9 +342,7 @@ class TestFusedMultiTransformerInt8Op(unittest.TestCase):
q = paddle.nn.functional.linear(ln1_out, self.q_weight_tensor) q = paddle.nn.functional.linear(ln1_out, self.q_weight_tensor)
# de quant # de quant
q = paddle.cast( q = paddle.cast(
paddle.cast(q, 'float32') paddle.cast(q, 'float32') * self.qkv_out_scales[i],
* self.qkv_in_scales[i]
/ self.qkv_out_scales[i],
self.x_type, self.x_type,
) )
...@@ -357,17 +352,13 @@ class TestFusedMultiTransformerInt8Op(unittest.TestCase): ...@@ -357,17 +352,13 @@ class TestFusedMultiTransformerInt8Op(unittest.TestCase):
k = paddle.nn.functional.linear(ln1_out, self.k_weight_tensor) k = paddle.nn.functional.linear(ln1_out, self.k_weight_tensor)
k = paddle.cast( k = paddle.cast(
paddle.cast(k, 'float32') paddle.cast(k, 'float32') * self.qkv_out_scales[i],
* self.qkv_in_scales[i]
/ self.qkv_out_scales[i],
self.x_type, self.x_type,
) )
k = k + self.k_proj_bias_tensor k = k + self.k_proj_bias_tensor
v = paddle.nn.functional.linear(ln1_out, self.v_weight_tensor) v = paddle.nn.functional.linear(ln1_out, self.v_weight_tensor)
v = paddle.cast( v = paddle.cast(
paddle.cast(v, 'float32') paddle.cast(v, 'float32') * self.qkv_out_scales[i],
* self.qkv_in_scales[i]
/ self.qkv_out_scales[i],
self.x_type, self.x_type,
) )
v = v + self.v_proj_bias_tensor v = v + self.v_proj_bias_tensor
...@@ -442,10 +433,10 @@ class TestFusedMultiTransformerInt8Op(unittest.TestCase): ...@@ -442,10 +433,10 @@ class TestFusedMultiTransformerInt8Op(unittest.TestCase):
max_v = paddle.max( max_v = paddle.max(
paddle.abs(paddle.cast(out_linear_in, 'float32')) paddle.abs(paddle.cast(out_linear_in, 'float32'))
)[0] )[0]
# self.out_linear_in_scales.append(127.0 / max_v)
self.out_linear_in_scales.append(max_v) self.out_linear_in_scales.append(1 / max_v)
self.out_linear_out_scales.append((127.0 * 127.0)) self.out_linear_out_scales.append(max_v / (127.0 * 127.0))
out_linear_in = self.fake_quant( out_linear_in = self.fake_quant(
out_linear_in, self.out_linear_in_scales[i] out_linear_in, self.out_linear_in_scales[i]
) )
...@@ -455,9 +446,7 @@ class TestFusedMultiTransformerInt8Op(unittest.TestCase): ...@@ -455,9 +446,7 @@ class TestFusedMultiTransformerInt8Op(unittest.TestCase):
) )
out = paddle.cast( out = paddle.cast(
paddle.cast(out, 'float32') paddle.cast(out, 'float32') * self.out_linear_out_scales[i],
* self.out_linear_in_scales[i]
/ self.out_linear_out_scales[i],
self.x_type, self.x_type,
) )
...@@ -476,8 +465,8 @@ class TestFusedMultiTransformerInt8Op(unittest.TestCase): ...@@ -476,8 +465,8 @@ class TestFusedMultiTransformerInt8Op(unittest.TestCase):
max_v = paddle.max(paddle.abs(paddle.cast(ffn_ln_out, 'float32')))[ max_v = paddle.max(paddle.abs(paddle.cast(ffn_ln_out, 'float32')))[
0 0
] ]
self.ffn1_in_scales.append(max_v) self.ffn1_in_scales.append(1 / max_v)
self.ffn1_out_scales.append((127.0 * 127.0)) self.ffn1_out_scales.append(max_v / (127.0 * 127.0))
ffn_ln_out = self.fake_quant(ffn_ln_out, self.ffn1_in_scales[i]) ffn_ln_out = self.fake_quant(ffn_ln_out, self.ffn1_in_scales[i])
ffn1_out = paddle.nn.functional.linear( ffn1_out = paddle.nn.functional.linear(
...@@ -485,9 +474,7 @@ class TestFusedMultiTransformerInt8Op(unittest.TestCase): ...@@ -485,9 +474,7 @@ class TestFusedMultiTransformerInt8Op(unittest.TestCase):
) )
ffn1_out = paddle.cast( ffn1_out = paddle.cast(
paddle.cast(ffn1_out, 'float32') paddle.cast(ffn1_out, 'float32') * self.ffn1_out_scales[i],
* self.ffn1_in_scales[i]
/ self.ffn1_out_scales[i],
self.x_type, self.x_type,
) )
...@@ -495,10 +482,8 @@ class TestFusedMultiTransformerInt8Op(unittest.TestCase): ...@@ -495,10 +482,8 @@ class TestFusedMultiTransformerInt8Op(unittest.TestCase):
ffn1_out = self.dropout(self.activation(ffn1_out)) ffn1_out = self.dropout(self.activation(ffn1_out))
max_v = paddle.max(paddle.abs(paddle.cast(ffn1_out, 'float32')))[0] max_v = paddle.max(paddle.abs(paddle.cast(ffn1_out, 'float32')))[0]
# self.ffn2_in_scales.append(127.0 / max_v) self.ffn2_in_scales.append(1 / max_v)
self.ffn2_in_scales.append(max_v) self.ffn2_out_scales.append(max_v / (127.0 * 127.0))
self.ffn2_out_scales.append((127.0 * 127.0))
# print('ffn2_in_scales ', i, self.ffn2_in_scales[i])
ffn1_out = self.fake_quant(ffn1_out, self.ffn2_in_scales[i]) ffn1_out = self.fake_quant(ffn1_out, self.ffn2_in_scales[i])
ffn2_out = paddle.nn.functional.linear( ffn2_out = paddle.nn.functional.linear(
...@@ -506,16 +491,12 @@ class TestFusedMultiTransformerInt8Op(unittest.TestCase): ...@@ -506,16 +491,12 @@ class TestFusedMultiTransformerInt8Op(unittest.TestCase):
) )
ffn2_out = paddle.cast( ffn2_out = paddle.cast(
paddle.cast(ffn2_out, 'float32') paddle.cast(ffn2_out, 'float32') * self.ffn2_out_scales[i],
* self.ffn2_in_scales[i]
/ self.ffn2_out_scales[i],
self.x_type, self.x_type,
) )
ffn2_out = ffn2_out + self.ffn2_proj_bias_tensor ffn2_out = ffn2_out + self.ffn2_proj_bias_tensor
residual_out = attn_out + self.dropout(ffn2_out) residual_out = attn_out + self.dropout(ffn2_out)
# print("residual ", attn_out)
# print("residual_out ", residual_out)
final_out = residual_out final_out = residual_out
if not self.pre_layer_norm: if not self.pre_layer_norm:
final_out = self.ffn_norm(residual_out) final_out = self.ffn_norm(residual_out)
...@@ -644,23 +625,18 @@ class TestFusedMultiTransformerInt8Op(unittest.TestCase): ...@@ -644,23 +625,18 @@ class TestFusedMultiTransformerInt8Op(unittest.TestCase):
ffn1_weights, ffn1_biases = [], [] ffn1_weights, ffn1_biases = [], []
ffn2_weights, ffn2_biases = [], [] ffn2_weights, ffn2_biases = [], []
ffn_ln_scales, ffn_ln_biases = [], [] ffn_ln_scales, ffn_ln_biases = [], []
# Input scales: list of value
qkv_in_scale = [] qkv_in_scale = []
out_linear_in_scale = [] out_linear_in_scale = []
ffn1_in_scale = [] ffn1_in_scale = []
ffn2_in_scale = [] ffn2_in_scale = []
qkv_out_scales_tensor = paddle.ones( # Output dequant scales: list of tensor
[self.layers, 3 * self.embed_dim], 'float32' qkv_out_scales = []
) out_linear_out_scales = []
out_linear_out_scales_tensor = paddle.ones( ffn1_out_scales = []
[self.layers, self.embed_dim], 'float32' ffn2_out_scales = []
)
ffn1_out_scales_tensor = paddle.ones(
[self.layers, 4 * self.embed_dim], 'float32'
)
ffn2_out_scales_tensor = paddle.ones(
[self.layers, self.embed_dim], 'float32'
)
for i in range(self.layers): for i in range(self.layers):
qkv_weights.append(qkv_weight_tensor) qkv_weights.append(qkv_weight_tensor)
...@@ -680,10 +656,30 @@ class TestFusedMultiTransformerInt8Op(unittest.TestCase): ...@@ -680,10 +656,30 @@ class TestFusedMultiTransformerInt8Op(unittest.TestCase):
ffn1_in_scale.append(self.ffn1_in_scales[i]) ffn1_in_scale.append(self.ffn1_in_scales[i])
ffn2_in_scale.append(self.ffn2_in_scales[i]) ffn2_in_scale.append(self.ffn2_in_scales[i])
qkv_out_scales_tensor[i, :] *= self.qkv_out_scales[i] qkv_out_scale = (
out_linear_out_scales_tensor[i, :] *= self.out_linear_out_scales[i] paddle.ones([3 * self.embed_dim], 'float32')
ffn1_out_scales_tensor[i, :] *= self.ffn1_out_scales[i] * self.qkv_out_scales[i]
ffn2_out_scales_tensor[i, :] *= self.ffn2_out_scales[i] )
out_linear_out_scale = (
paddle.ones([self.embed_dim], 'float32')
* self.out_linear_out_scales[i]
)
ffn1_out_scale = (
paddle.ones([4 * self.embed_dim], 'float32')
* self.ffn1_out_scales[i]
)
ffn2_out_scale = (
paddle.ones([self.embed_dim], 'float32')
* self.ffn2_out_scales[i]
)
qkv_out_scales.append(qkv_out_scale)
out_linear_out_scales.append(out_linear_out_scale)
ffn1_out_scales.append(ffn1_out_scale)
ffn2_out_scales.append(ffn2_out_scale)
if self.has_cache_kv: if self.has_cache_kv:
cache_kvs.append(paddle.to_tensor(cache_kv, stop_gradient=True)) cache_kvs.append(paddle.to_tensor(cache_kv, stop_gradient=True))
...@@ -713,10 +709,10 @@ class TestFusedMultiTransformerInt8Op(unittest.TestCase): ...@@ -713,10 +709,10 @@ class TestFusedMultiTransformerInt8Op(unittest.TestCase):
trans_qkvw=True, trans_qkvw=True,
ring_id=-1, ring_id=-1,
name=None, name=None,
qkv_out_scales=qkv_out_scales_tensor, qkv_out_scales=qkv_out_scales,
out_linear_out_scales=out_linear_out_scales_tensor, out_linear_out_scales=out_linear_out_scales,
ffn1_out_scales=ffn1_out_scales_tensor, ffn1_out_scales=ffn1_out_scales,
ffn2_out_scales=ffn2_out_scales_tensor, ffn2_out_scales=ffn2_out_scales,
num_head=self.num_heads, num_head=self.num_heads,
dim_head=self.head_dim, dim_head=self.head_dim,
dim_ffn=4 * self.embed_dim, dim_ffn=4 * self.embed_dim,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册