未验证 提交 12486712 编写于 作者: R RichardWooSJTU 提交者: GitHub

Add int8 support in fused_multi_transformer_pass and fuse_multi_transformer_layer_pass (#48209)

* delete unnecessary shape and slice op
Co-authored-by: NYour Name <you@example.com>
上级 9ff99e9e
......@@ -96,6 +96,8 @@ pass_library(shuffle_channel_detect_pass inference)
pass_library(delete_quant_dequant_op_pass inference)
pass_library(delete_quant_dequant_filter_op_pass inference)
pass_library(delete_weight_dequant_linear_op_pass inference)
pass_library(delete_weight_dequant_linear_op_encoder_pass inference)
pass_library(delete_weight_dequant_linear_op_decoder_pass inference)
pass_library(delete_quant_dequant_linear_op_pass inference)
pass_library(delete_dropout_op_pass inference)
pass_library(delete_c_identity_op_pass inference)
......
......@@ -121,14 +121,27 @@ void DeleteQuantDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const {
true,
platform::errors::InvalidArgument(
"Input scale tensor's place should be CPU."));
const float* input_scale_data = input_scale_tensor.data<float>();
float input_scale = input_scale_data[0];
float input_scale;
if (input_scale_tensor.dtype() == paddle::experimental::DataType::FLOAT32) {
const float* input_scale_data = input_scale_tensor.data<float>();
input_scale = input_scale_data[0];
} else if (input_scale_tensor.dtype() ==
paddle::experimental::DataType::FLOAT16) {
const phi::dtype::float16* input_scale_data =
input_scale_tensor.data<phi::dtype::float16>();
input_scale = static_cast<float>(input_scale_data[0]);
} else {
PADDLE_THROW(platform::errors::Unimplemented("%d is not supported.",
input_scale_tensor.dtype()));
}
int nums_any_ops = dequantize_linear_op_out->outputs.size();
for (int i = 0; i < nums_any_ops; ++i) {
auto* any_op_desc = dequantize_linear_op_out->outputs[i]->Op();
any_op_desc->SetAttr("Input_scale_" + quantize_linear_op_x->Var()->Name(),
input_scale);
// link x to any_op2
any_op_desc->RenameInput(dequantize_linear_op_out->Var()->Name(),
quantize_linear_op_x->Var()->Name());
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/delete_weight_dequant_linear_op_decoder_pass.h"
#include <algorithm>
#include <memory>
#include <string>
#include <unordered_set>
#include <vector>
namespace paddle {
namespace framework {
namespace ir {
#define GET_IR_NODE(node__) GET_IR_NODE_FROM_SUBGRAPH(node__, node__, pattern);
#define GET_NODES \
GET_IR_NODE(weight_dequantize_linear_op_x); \
GET_IR_NODE(weight_dequantize_linear_op_scale); \
GET_IR_NODE(weight_dequantize_linear_op); \
GET_IR_NODE(weight_dequantize_linear_op_out); \
GET_IR_NODE(any_op2);
DeleteWeightDequantLinearOpDecoderPass::
DeleteWeightDequantLinearOpDecoderPass() {
AddOpCompat(OpCompat("quantize_linear"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Scale")
.IsTensor()
.End()
.AddInput("ZeroPoint")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Y")
.IsTensor()
.End()
.AddAttr("bit_length")
.IsType<int>()
.End()
.AddAttr("quant_axis")
.IsType<int>()
.End()
.AddAttr("round_type")
.IsOptional()
.IsType<int>()
.End();
AddOpCompat(OpCompat("dequantize_linear"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Scale")
.IsTensor()
.End()
.AddInput("ZeroPoint")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Y")
.IsTensor()
.End()
.AddAttr("bit_length")
.IsType<int>()
.End()
.AddAttr("quant_axis")
.IsType<int>()
.End()
.AddAttr("round_type")
.IsOptional()
.IsType<int>()
.End();
AddOpCompat(OpCompat("conv2d"))
.AddInput("Input")
.IsTensor()
.End()
.AddInput("Filter")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.IsOptional()
.End()
.AddInput("ResidualData")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Output")
.IsTensor()
.End()
.AddAttr("strides")
.IsType<std::vector<int>>()
.End()
.AddAttr("paddings")
.IsType<std::vector<int>>()
.End()
.AddAttr("padding_algorithm")
.IsOptional()
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End()
.AddAttr("groups")
.IsNumGE(1)
.End()
.AddAttr("dilations")
.IsType<std::vector<int>>()
.End()
.AddAttr("data_format")
.IsStringIn({"NCHW", "NHWC", "AnyLayout"})
.End();
AddOpCompat(OpCompat("depthwise_conv2d"))
.AddInput("Input")
.IsTensor()
.End()
.AddInput("Filter")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.IsOptional()
.End()
.AddInput("ResidualData")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Output")
.IsTensor()
.End()
.AddAttr("strides")
.IsType<std::vector<int>>()
.End()
.AddAttr("paddings")
.IsType<std::vector<int>>()
.End()
.AddAttr("padding_algorithm")
.IsOptional()
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End()
.AddAttr("groups")
.IsNumGE(1)
.End()
.AddAttr("dilations")
.IsType<std::vector<int>>()
.End()
.AddAttr("data_format")
.IsStringIn({"NCHW", "NHWC", "AnyLayout"})
.End();
AddOpCompat(OpCompat("mul"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("x_num_col_dims")
.IsNumGE(1)
.End()
.AddAttr("y_num_col_dims")
.IsNumEQ(1)
.End();
AddOpCompat(OpCompat("matmul_v2"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("trans_x")
.IsBoolEQ(false)
.End()
.AddAttr("trans_y")
.IsBoolEQ(false)
.End();
AddOpCompat(OpCompat("matmul"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("alpha")
.IsNumGE(0.99f)
.IsNumLE(1.01f)
.End()
.AddAttr("transpose_X")
.IsBoolEQ(false)
.End()
.AddAttr("transpose_Y")
.IsBoolEQ(false)
.End();
AddOpCompat(OpCompat("fc"))
.AddInput("Input")
.IsTensor()
.End()
.AddInput("W")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("in_num_col_dims")
.IsNumGE(1)
.End()
.AddAttr("activation_type")
.IsStringIn({"relu", ""})
.End();
AddOpCompat(OpCompat("conv2d_transpose"))
.AddInput("Input")
.IsTensor()
.End()
.AddInput("Filter")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Output")
.IsTensor()
.End()
.AddAttr("output_padding")
.IsType<std::vector<int>>()
.IsOptional()
.End()
.AddAttr("output_size")
.IsType<std::vector<int>>()
.IsOptional()
.End()
.AddAttr("groups")
.IsNumGE(1)
.End()
.AddAttr("dilations")
.IsType<std::vector<int>>()
.End()
.AddAttr("strides")
.IsType<std::vector<int>>()
.End()
.AddAttr("paddings")
.IsType<std::vector<int>>()
.End()
.AddAttr("padding_algorithm")
.IsOptional()
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End()
.AddAttr("data_format")
.IsStringIn({"NCHW", "NHWC", "AnyLayout"})
.End();
}
// Delete dequantize_linear_op, then dequantize weight
void DeleteWeightDequantLinearOpDecoderPass::ApplyImpl(ir::Graph* graph) const {
const std::string pattern_name =
"delete_weight_dequant_linear_op_decoder_pattern";
FusePassBase::Init(pattern_name, graph);
GraphPatternDetector gpd;
auto* scope = param_scope();
PADDLE_ENFORCE_NOT_NULL(scope,
platform::errors::InvalidArgument(
"Scope in DeleteWeightDequantLinearOpDecoderPass "
"should not be null."));
// Create pattern
patterns::DeleteWeightDequantLinearOpDecoderPattern pattern(
gpd.mutable_pattern(), pattern_name);
pattern();
int found_count = 0;
bool is_int8 = false;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_NODES;
/*
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "delete_weight_dequant_linear_op_pass "
"compat check failed.";
return;
}
*/
is_int8 = true;
std::unordered_set<const Node*> nodes2rm = {};
auto* any_op2_desc = any_op2->Op();
// Get weight scale
std::vector<float> weight_scale;
auto* weight_scale_tensor =
scope->GetVar(weight_dequantize_linear_op_scale->Name())
->GetMutable<phi::DenseTensor>();
auto weight_scale_nums = weight_scale_tensor->numel();
if (weight_scale_tensor->dtype() ==
paddle::experimental::DataType::FLOAT32) {
float* weight_scale_data = weight_scale_tensor->data<float>();
for (int i = 0; i < weight_scale_nums; i++) {
weight_scale.push_back(weight_scale_data[i]);
}
} else if (weight_scale_tensor->dtype() ==
paddle::experimental::DataType::FLOAT16) {
phi::dtype::float16* weight_scale_data =
weight_scale_tensor->data<phi::dtype::float16>();
for (int i = 0; i < weight_scale_nums; i++) {
weight_scale.push_back(static_cast<float>(weight_scale_data[i]));
}
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"%d is not supported.", weight_scale_tensor->dtype()));
}
int quant_axis = PADDLE_GET_CONST(
int, weight_dequantize_linear_op->Op()->GetAttr("quant_axis"));
if (quant_axis == -1) { // per_layer quant_dequant: all OP
PADDLE_ENFORCE_EQ(weight_scale_nums,
1,
platform::errors::InvalidArgument(
"When quant_axis == -1 means use per_layer "
"quant_dequant, weight_scale'number should be 1."));
// Add attr to anyop 2
any_op2_desc->SetAttr("weight_scale", weight_scale[0]);
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Delete Weight Dequant Linear Op Encoder Pass is not supported for "
"per-channel quantization"));
}
nodes2rm.insert(weight_dequantize_linear_op_scale);
nodes2rm.insert(weight_dequantize_linear_op);
nodes2rm.insert(weight_dequantize_linear_op_out);
// relink weight to any_op2
any_op2_desc->RenameInput(weight_dequantize_linear_op_out->Var()->Name(),
weight_dequantize_linear_op_x->Var()->Name());
any_op2_desc->Flush();
IR_NODE_LINK_TO(weight_dequantize_linear_op_x, any_op2);
GraphSafeRemoveNodes(graph, nodes2rm);
found_count++;
};
gpd(graph, handler);
if (is_int8) {
auto& enable_int8 = graph->Get<bool>("enable_int8");
enable_int8 = true;
}
AddStatis(found_count);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(delete_weight_dequant_linear_op_decoder_pass,
paddle::framework::ir::DeleteWeightDequantLinearOpDecoderPass);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
namespace paddle {
namespace framework {
namespace ir {
class DeleteWeightDequantLinearOpDecoderPass : public FusePassBase {
public:
DeleteWeightDequantLinearOpDecoderPass();
virtual ~DeleteWeightDequantLinearOpDecoderPass() {}
protected:
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/delete_weight_dequant_linear_op_encoder_pass.h"
#include <algorithm>
#include <memory>
#include <string>
#include <unordered_set>
#include <vector>
namespace paddle {
namespace framework {
namespace ir {
#define GET_IR_NODE(node__) GET_IR_NODE_FROM_SUBGRAPH(node__, node__, pattern);
#define GET_NODES \
GET_IR_NODE(weight_dequantize_linear_op_x); \
GET_IR_NODE(weight_dequantize_linear_op_scale); \
GET_IR_NODE(weight_dequantize_linear_op); \
GET_IR_NODE(weight_dequantize_linear_op_out); \
GET_IR_NODE(any_op2);
DeleteWeightDequantLinearOpEncoderPass::
DeleteWeightDequantLinearOpEncoderPass() {
AddOpCompat(OpCompat("quantize_linear"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Scale")
.IsTensor()
.End()
.AddInput("ZeroPoint")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Y")
.IsTensor()
.End()
.AddAttr("bit_length")
.IsType<int>()
.End()
.AddAttr("quant_axis")
.IsType<int>()
.End()
.AddAttr("round_type")
.IsOptional()
.IsType<int>()
.End();
AddOpCompat(OpCompat("dequantize_linear"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Scale")
.IsTensor()
.End()
.AddInput("ZeroPoint")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Y")
.IsTensor()
.End()
.AddAttr("bit_length")
.IsType<int>()
.End()
.AddAttr("quant_axis")
.IsType<int>()
.End()
.AddAttr("round_type")
.IsOptional()
.IsType<int>()
.End();
AddOpCompat(OpCompat("conv2d"))
.AddInput("Input")
.IsTensor()
.End()
.AddInput("Filter")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.IsOptional()
.End()
.AddInput("ResidualData")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Output")
.IsTensor()
.End()
.AddAttr("strides")
.IsType<std::vector<int>>()
.End()
.AddAttr("paddings")
.IsType<std::vector<int>>()
.End()
.AddAttr("padding_algorithm")
.IsOptional()
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End()
.AddAttr("groups")
.IsNumGE(1)
.End()
.AddAttr("dilations")
.IsType<std::vector<int>>()
.End()
.AddAttr("data_format")
.IsStringIn({"NCHW", "NHWC", "AnyLayout"})
.End();
AddOpCompat(OpCompat("depthwise_conv2d"))
.AddInput("Input")
.IsTensor()
.End()
.AddInput("Filter")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.IsOptional()
.End()
.AddInput("ResidualData")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Output")
.IsTensor()
.End()
.AddAttr("strides")
.IsType<std::vector<int>>()
.End()
.AddAttr("paddings")
.IsType<std::vector<int>>()
.End()
.AddAttr("padding_algorithm")
.IsOptional()
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End()
.AddAttr("groups")
.IsNumGE(1)
.End()
.AddAttr("dilations")
.IsType<std::vector<int>>()
.End()
.AddAttr("data_format")
.IsStringIn({"NCHW", "NHWC", "AnyLayout"})
.End();
AddOpCompat(OpCompat("mul"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("x_num_col_dims")
.IsNumGE(1)
.End()
.AddAttr("y_num_col_dims")
.IsNumEQ(1)
.End();
AddOpCompat(OpCompat("matmul_v2"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("trans_x")
.IsBoolEQ(false)
.End()
.AddAttr("trans_y")
.IsBoolEQ(false)
.End();
AddOpCompat(OpCompat("matmul"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("alpha")
.IsNumGE(0.99f)
.IsNumLE(1.01f)
.End()
.AddAttr("transpose_X")
.IsBoolEQ(false)
.End()
.AddAttr("transpose_Y")
.IsBoolEQ(false)
.End();
AddOpCompat(OpCompat("fc"))
.AddInput("Input")
.IsTensor()
.End()
.AddInput("W")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("in_num_col_dims")
.IsNumGE(1)
.End()
.AddAttr("activation_type")
.IsStringIn({"relu", ""})
.End();
AddOpCompat(OpCompat("conv2d_transpose"))
.AddInput("Input")
.IsTensor()
.End()
.AddInput("Filter")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Output")
.IsTensor()
.End()
.AddAttr("output_padding")
.IsType<std::vector<int>>()
.IsOptional()
.End()
.AddAttr("output_size")
.IsType<std::vector<int>>()
.IsOptional()
.End()
.AddAttr("groups")
.IsNumGE(1)
.End()
.AddAttr("dilations")
.IsType<std::vector<int>>()
.End()
.AddAttr("strides")
.IsType<std::vector<int>>()
.End()
.AddAttr("paddings")
.IsType<std::vector<int>>()
.End()
.AddAttr("padding_algorithm")
.IsOptional()
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End()
.AddAttr("data_format")
.IsStringIn({"NCHW", "NHWC", "AnyLayout"})
.End();
}
// Delete dequantize_linear_op, then dequantize weight
void DeleteWeightDequantLinearOpEncoderPass::ApplyImpl(ir::Graph* graph) const {
const std::string pattern_name =
"delete_weight_dequant_linear_op_encoder_pattern";
FusePassBase::Init(pattern_name, graph);
GraphPatternDetector gpd;
auto* scope = param_scope();
PADDLE_ENFORCE_NOT_NULL(scope,
platform::errors::InvalidArgument(
"Scope in DeleteWeightDequantLinearOpEncoderPass "
"should not be null."));
// Create pattern
patterns::DeleteWeightDequantLinearOpEncoderPattern pattern(
gpd.mutable_pattern(), pattern_name);
pattern();
int found_count = 0;
bool is_int8 = false;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_NODES;
/*
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "delete_weight_dequant_linear_op_pass "
"compat check failed.";
return;
}
*/
is_int8 = true;
std::unordered_set<const Node*> nodes2rm = {};
auto* any_op2_desc = any_op2->Op();
// Get weight scale
std::vector<float> weight_scale;
auto* weight_scale_tensor =
scope->GetVar(weight_dequantize_linear_op_scale->Name())
->GetMutable<phi::DenseTensor>();
auto weight_scale_nums = weight_scale_tensor->numel();
if (weight_scale_tensor->dtype() ==
paddle::experimental::DataType::FLOAT32) {
float* weight_scale_data = weight_scale_tensor->data<float>();
for (int i = 0; i < weight_scale_nums; i++) {
weight_scale.push_back(weight_scale_data[i]);
}
} else if (weight_scale_tensor->dtype() ==
paddle::experimental::DataType::FLOAT16) {
phi::dtype::float16* weight_scale_data =
weight_scale_tensor->data<phi::dtype::float16>();
for (int i = 0; i < weight_scale_nums; i++) {
weight_scale.push_back(static_cast<float>(weight_scale_data[i]));
}
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"%d is not supported.", weight_scale_tensor->dtype()));
}
int quant_axis = PADDLE_GET_CONST(
int, weight_dequantize_linear_op->Op()->GetAttr("quant_axis"));
if (quant_axis == -1) { // per_layer quant_dequant: all OP
PADDLE_ENFORCE_EQ(weight_scale_nums,
1,
platform::errors::InvalidArgument(
"When quant_axis == -1 means use per_layer "
"quant_dequant, weight_scale'number should be 1."));
// Add attr to anyop 2
any_op2_desc->SetAttr("weight_scale", weight_scale[0]);
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Delete Weight Dequant Linear Op Encoder Pass is not supported for "
"per-channel quantization"));
}
nodes2rm.insert(weight_dequantize_linear_op_scale);
nodes2rm.insert(weight_dequantize_linear_op);
nodes2rm.insert(weight_dequantize_linear_op_out);
// relink weight to any_op2
any_op2_desc->RenameInput(weight_dequantize_linear_op_out->Var()->Name(),
weight_dequantize_linear_op_x->Var()->Name());
any_op2_desc->Flush();
IR_NODE_LINK_TO(weight_dequantize_linear_op_x, any_op2);
GraphSafeRemoveNodes(graph, nodes2rm);
found_count++;
};
gpd(graph, handler);
graph->Set("enable_int8", new bool(is_int8));
AddStatis(found_count);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(delete_weight_dequant_linear_op_encoder_pass,
paddle::framework::ir::DeleteWeightDequantLinearOpEncoderPass);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
namespace paddle {
namespace framework {
namespace ir {
class DeleteWeightDequantLinearOpEncoderPass : public FusePassBase {
public:
DeleteWeightDequantLinearOpEncoderPass();
virtual ~DeleteWeightDequantLinearOpEncoderPass() {}
protected:
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -118,9 +118,15 @@ int FuseMultiTransformerLayerPass::BuildFusion(Graph* graph,
GraphPatternDetector gpd;
auto* pattern = gpd.mutable_pattern();
// TODO(wufeisheng): Get enable_int8 attr from graph after
// fused_multi_transformer pass with int8 merged
bool enable_int8 = false;
if (graph->Has("enable_int8")) {
enable_int8 = graph->Get<bool>("enable_int8");
}
if (!enable_int8) {
VLOG(4)
<< "fuse_multi_layer_transformer_pass will match float transformer op "
"cause enable_int8 is not been set or set to false";
}
int num_fuse_op = 0;
bool is_decoder = false;
......@@ -209,7 +215,13 @@ int FuseMultiTransformerLayerPass::BuildFusion(Graph* graph,
"OutLinearW",
"QKVBias",
"QKVW"};
if (enable_int8) {
std::vector<std::string> inputs_names_int8_supp = {
"FFN1OutScale", "FFN2OutScale", "OutLinearOutScale", "QKVOutScale"};
inputs_names.insert(inputs_names.end(),
inputs_names_int8_supp.begin(),
inputs_names_int8_supp.end());
}
for (const auto& input_name : inputs_names) {
MergeInput(fuse_op_descs[0], fuse_op_input_var_name_maps, input_name);
}
......@@ -227,6 +239,17 @@ int FuseMultiTransformerLayerPass::BuildFusion(Graph* graph,
}
fuse_op_descs[0]->SetOutput("CacheKVOut", merged_cache_kv_out_names);
if (enable_int8) {
// Merge inputs scale
std::vector<std::string> attr_names = {"qkv_in_scale",
"out_linear_in_scale",
"ffn1_in_scale",
"ffn2_in_scale"};
for (const auto& name : attr_names) {
MergeAttrs<float>(fuse_op_descs, name);
}
}
////////////////
//// ReLink ////
////////////////
......
......@@ -98,6 +98,7 @@ TEST(FuseMultiTransformerLayerPass, encoder_fp) {
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
graph->Set("__param_scope__", CreateParamScope());
graph->Set(kFusedMultiTransformerEncoderFusionCount, new int(num_layers));
graph->Set("enable_int8", new bool(false));
auto pass = PassRegistry::Instance().Get("fuse_multi_transformer_layer_pass");
if (pass.get() == nullptr)
......
......@@ -193,6 +193,7 @@ TEST(FusedMultiTransformerDecoderPass, basic) {
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
graph->Set("__param_scope__", CreateParamScope());
graph->Set("enable_int8", new bool(false));
auto pass =
PassRegistry::Instance().Get("fused_multi_transformer_decoder_pass");
......@@ -344,6 +345,7 @@ TEST(FusedMultiTransformerDecoderFuseQKVPass, basic) {
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
graph->Set("__param_scope__", CreateParamScope());
graph->Set("enable_int8", new bool(false));
auto pass = PassRegistry::Instance().Get(
"fused_multi_transformer_decoder_fuse_qkv_pass");
......@@ -503,6 +505,7 @@ TEST(MultiDevicesFusedMultiTransformerDecoderFuseQKVPass, basic) {
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
graph->Set("__param_scope__", CreateParamScope());
graph->Set("enable_int8", new bool(false));
auto pass = PassRegistry::Instance().Get(
"multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass");
......
......@@ -188,6 +188,7 @@ TEST(FusedMultiTransformerEncoderPass, basic) {
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
graph->Set("__param_scope__", CreateParamScope());
graph->Set("enable_int8", new bool(false));
auto pass =
PassRegistry::Instance().Get("fused_multi_transformer_encoder_pass");
......@@ -334,6 +335,7 @@ TEST(FusedMultiTransformerEncoderFuseQKVPass, basic) {
layers.elementwise_add(attention_out, ffn_eltadd1_out);
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
graph->Set("enable_int8", new bool(false));
graph->Set("__param_scope__", CreateParamScope());
auto pass = PassRegistry::Instance().Get(
......@@ -489,6 +491,7 @@ TEST(MultiDevicesFusedMultiTransformerEncoderFuseQKVPass, basic) {
layers.elementwise_add(attention_out, ffn_eltadd1_out);
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
graph->Set("enable_int8", new bool(false));
graph->Set("__param_scope__", CreateParamScope());
auto pass = PassRegistry::Instance().Get(
......
......@@ -3175,6 +3175,73 @@ void patterns::DeleteWeightQuantDequantLinearOpPattern::operator()() {
any_op2->LinksFrom({weight_dequantize_linear_op_out});
}
void patterns::DeleteWeightDequantLinearOpEncoderPattern::operator()() {
auto weight_dequantize_linear_op_x =
pattern->NewNode(weight_dequantize_linear_op_x_repr())
->AsInput()
->assert_is_op_input("dequantize_linear", "X")
->assert_is_persistable_var();
auto weight_dequantize_linear_op_scale =
pattern->NewNode(weight_dequantize_linear_op_scale_repr())
->AsInput()
->assert_is_op_input("dequantize_linear", "Scale")
->assert_is_persistable_var();
auto weight_dequantize_linear_op =
pattern->NewNode(weight_dequantize_linear_op_repr())
->assert_is_op("dequantize_linear");
auto weight_dequantize_linear_op_out =
pattern->NewNode(weight_dequantize_linear_op_out_repr())
->AsIntermediate()
->assert_is_op_output("dequantize_linear", "Y");
auto any_op2 = pattern->NewNode(any_op2_repr())->assert_is_op()->AsOutput();
// while loop
auto *while0 =
pattern->NewNode(while0_repr())->assert_is_op("while")->AsOutput();
while0->LinksFrom({weight_dequantize_linear_op_out});
weight_dequantize_linear_op
->LinksFrom(
{weight_dequantize_linear_op_x, weight_dequantize_linear_op_scale})
.LinksTo({weight_dequantize_linear_op_out});
any_op2->LinksFrom({weight_dequantize_linear_op_out});
}
void patterns::DeleteWeightDequantLinearOpDecoderPattern::operator()() {
auto weight_dequantize_linear_op_x =
pattern->NewNode(weight_dequantize_linear_op_x_repr())
->AsInput()
->assert_is_op_input("dequantize_linear", "X")
->assert_is_persistable_var();
auto weight_dequantize_linear_op_scale =
pattern->NewNode(weight_dequantize_linear_op_scale_repr())
->AsInput()
->assert_is_op_input("dequantize_linear", "Scale")
->assert_is_persistable_var();
auto weight_dequantize_linear_op =
pattern->NewNode(weight_dequantize_linear_op_repr())
->assert_is_op("dequantize_linear");
auto weight_dequantize_linear_op_out =
pattern->NewNode(weight_dequantize_linear_op_out_repr())
->AsIntermediate()
->assert_is_op_output("dequantize_linear", "Y");
auto any_op2 = pattern->NewNode(any_op2_repr())->assert_is_op()->AsOutput();
weight_dequantize_linear_op
->LinksFrom(
{weight_dequantize_linear_op_x, weight_dequantize_linear_op_scale})
.LinksTo({weight_dequantize_linear_op_out});
any_op2->LinksFrom({weight_dequantize_linear_op_out});
}
void patterns::DeleteQuantDequantLinearOpPattern::operator()() {
auto quantize_linear_op_x = pattern->NewNode(quantize_linear_op_x_repr())
->AsInput()
......
......@@ -1765,6 +1765,39 @@ struct DeleteWeightQuantDequantLinearOpPattern : public PatternBase {
PATTERN_DECL_NODE(any_op2);
};
struct DeleteWeightDequantLinearOpEncoderPattern : public PatternBase {
DeleteWeightDequantLinearOpEncoderPattern(PDPattern* pattern,
const std::string& name_scope)
: PatternBase(pattern,
name_scope,
"delete_weight_quant_dequant_linear_op_pattern") {}
void operator()();
PATTERN_DECL_NODE(weight_dequantize_linear_op_x);
PATTERN_DECL_NODE(weight_dequantize_linear_op_scale);
PATTERN_DECL_NODE(while0);
PATTERN_DECL_NODE(weight_dequantize_linear_op);
PATTERN_DECL_NODE(weight_dequantize_linear_op_out);
PATTERN_DECL_NODE(any_op2);
};
struct DeleteWeightDequantLinearOpDecoderPattern : public PatternBase {
DeleteWeightDequantLinearOpDecoderPattern(PDPattern* pattern,
const std::string& name_scope)
: PatternBase(pattern,
name_scope,
"delete_weight_quant_dequant_linear_op_pattern") {}
void operator()();
PATTERN_DECL_NODE(weight_dequantize_linear_op_x);
PATTERN_DECL_NODE(weight_dequantize_linear_op_scale);
PATTERN_DECL_NODE(weight_dequantize_linear_op);
PATTERN_DECL_NODE(weight_dequantize_linear_op_out);
PATTERN_DECL_NODE(any_op2);
};
struct DeleteQuantDequantLinearOpPattern : public PatternBase {
DeleteQuantDequantLinearOpPattern(PDPattern* pattern,
const std::string& name_scope)
......
......@@ -46,7 +46,10 @@ static const std::vector<std::string> support_subgraph_passes = {
"fused_multi_transformer_decoder_fuse_qkv_pass",
"multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass",
"multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass",
"fuse_multi_transformer_layer_pass"};
"fuse_multi_transformer_layer_pass",
"delete_quant_dequant_linear_op_pass",
"delete_weight_dequant_linear_op_encoder_pass",
"delete_weight_dequant_linear_op_decoder_pass"};
Graph *Pass::Apply(Graph *graph) const {
VLOG(10) << "start to apply pass " << Type() << " to graph";
......
......@@ -165,6 +165,9 @@ const std::vector<std::string> kLiteSubgraphPasses({
// running errors. After fusion operator supports low precision, delete this.
const std::vector<std::string> kGpuLowerPrecisionPasses{
"simplify_with_basic_ops_pass",
"delete_quant_dequant_linear_op_pass",
"delete_weight_dequant_linear_op_encoder_pass",
"delete_weight_dequant_linear_op_decoder_pass",
"map_depthwise_conv_to_conv_pass",
"conv_bn_fuse_pass",
"conv_eltwiseadd_bn_fuse_pass",
......@@ -203,9 +206,12 @@ const std::vector<std::string> kTrtLowerPrecisionPasses{
GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
passes_.assign({
// "identity_scale_op_clean_pass", //
"is_test_pass", //
"simplify_with_basic_ops_pass", //
"map_depthwise_conv_to_conv_pass",
"is_test_pass", //
"simplify_with_basic_ops_pass", //
"delete_quant_dequant_linear_op_pass", //
"delete_weight_dequant_linear_op_encoder_pass", //
"delete_weight_dequant_linear_op_decoder_pass", //
"map_depthwise_conv_to_conv_pass", //
"conv_bn_fuse_pass", //
"conv_eltwiseadd_bn_fuse_pass", //
"embedding_eltwise_layernorm_fuse_pass", //
......
......@@ -20,6 +20,7 @@ limitations under the License. */
#include "paddle/fluid/operators/fused/quant_dequant_kernel.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/kernels/funcs/broadcast_function.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
......@@ -27,6 +28,7 @@ namespace paddle {
namespace operators {
using Tensor = phi::DenseTensor;
using phi::backends::gpu::GpuLaunchConfig;
template <typename T>
class AttnMatmulINT8 {
......@@ -36,6 +38,9 @@ class AttnMatmulINT8 {
: dev_ctx_(dev_ctx), m_(m), n_(n), k_(k), compute_bias_(compute_bias) {
auto helper = std::make_shared<CublasLtHelper>(m, k, n);
helpers_.emplace_back(helper);
gpu_config_ = std::make_unique<GpuLaunchConfig>(
phi::backends::gpu::GetGpuLaunchConfig1D(
dev_ctx, m * n, DequantKernelVecSize));
}
~AttnMatmulINT8() {}
......@@ -50,7 +55,6 @@ class AttnMatmulINT8 {
phi::DenseTensor* bias_out,
const float quant_in_scale,
const phi::DenseTensor* dequant_out_scale,
const int quant_out_scale_offset,
const int quant_round_type = 1,
const float quant_max_bound = 127.0,
const float quant_min_bound = -127.0) {
......@@ -74,9 +78,9 @@ class AttnMatmulINT8 {
m_,
n_,
dev_ctx_.stream(),
gpu_config_.get(),
quant_in_scale,
dequant_out_scale->data<float>(),
quant_out_scale_offset);
dequant_out_scale->data<float>());
if (compute_bias_) {
// bias_out = output + bias
......@@ -99,11 +103,13 @@ class AttnMatmulINT8 {
phi::DenseTensor* input,
const phi::DenseTensor* bias,
phi::DenseTensor* output,
phi::DenseTensor* bias_out) {
phi::DenseTensor* bias_out,
void* workspace = nullptr) {
helpers_[0]->GEMM(input->data<int8_t>(),
weight->data<int8_t>(),
output->data<int32_t>(),
dev_ctx_.stream());
dev_ctx_.stream(),
workspace);
}
// This function is used to execute GEMM, with input and output's types are
......@@ -115,8 +121,7 @@ class AttnMatmulINT8 {
phi::DenseTensor* output,
phi::DenseTensor* output_tmp,
phi::DenseTensor* bias_out,
const phi::DenseTensor* dequant_out_scale,
const int quant_out_scale_offset) {
const phi::DenseTensor* dequant_out_scale) {
helpers_[0]->GEMM(input->data<int8_t>(),
weight->data<int8_t>(),
output_tmp->data<int32_t>(),
......@@ -127,9 +132,9 @@ class AttnMatmulINT8 {
m_,
n_,
dev_ctx_.stream(),
gpu_config_.get(),
quant_in_scale,
dequant_out_scale->data<float>(),
quant_out_scale_offset);
dequant_out_scale->data<float>());
if (compute_bias_) {
// bias_out = output + bias
......@@ -183,6 +188,7 @@ class AttnMatmulINT8 {
int compute_bias_;
std::vector<std::shared_ptr<CublasLtHelper>> helpers_;
std::unique_ptr<GpuLaunchConfig> gpu_config_;
};
} // namespace operators
......
......@@ -24,6 +24,20 @@ namespace dyl = paddle::platform::dynload;
namespace paddle {
namespace operators {
struct CublasLtAlgoParam {
int algoId;
int swizzle;
int customOption;
int tile;
int splitK_val;
int reductionScheme;
int stages;
size_t workspace_size;
};
const std::map<std::tuple<int, int, int>, CublasLtAlgoParam> AlgoParamCache{};
class CublasLtHelper {
public:
CublasLtHelper(int m, int k, int n)
......@@ -99,38 +113,34 @@ class CublasLtHelper {
"cublasLtMatrixLayoutCreate execution error"
"refer https://docs.nvidia.com/cuda/cublas/index.html to get more "
"information"));
}
~CublasLtHelper() {
if (handle_) dyl::cublasLtDestroy(handle_);
if (matmul_desc_) dyl::cublasLtMatmulDescDestroy(matmul_desc_);
if (A_desc_) dyl::cublasLtMatrixLayoutDestroy(A_desc_);
if (B_desc_) dyl::cublasLtMatrixLayoutDestroy(B_desc_);
if (C_desc_) dyl::cublasLtMatrixLayoutDestroy(C_desc_);
}
void GEMM(int8_t* A_dev,
const int8_t* B_dev,
int32_t* C_dev,
cudaStream_t stream) {
cublasStatus_t status;
#if CUDA_VERSION >= 11020
#if __CUDA_ARCH__ >= 800 && CUDA_VERSION >= 11020
cublasLtMatmulAlgo_t algo;
int algoId = 21;
int swizzle = 0;
int customOption = 0;
int tile = 15;
int splitK_val = 0;
int reductionScheme = 0;
#if CUDA_VERSION >= 11000
int stages = 23;
#endif
#if CUBLAS_VER_MAJOR < 11
cudaDataType_t cudaComputeType = CUDA_R_32I;
#else
cublasComputeType_t cudaComputeType = CUBLAS_COMPUTE_32I;
#endif
workspace_size_ = 0;
if (m >= 128) {
tile = 20;
stages = 17;
}
std::tuple<int, int, int> key(m_, k_, n_);
if (AlgoParamCache.count(key) != 0) {
auto value = AlgoParamCache.at(key);
algoId = value.algoId;
swizzle = value.swizzle;
customOption = value.customOption;
tile = value.tile;
splitK_val = value.splitK_val;
reductionScheme = value.reductionScheme;
stages = value.stages;
workspace_size_ = value.workspace_size;
}
dyl::cublasLtMatmulAlgoInit(handle_,
cudaComputeType,
......@@ -140,30 +150,43 @@ class CublasLtHelper {
CUDA_R_32I,
CUDA_R_32I,
algoId,
&algo);
&algo_);
dyl::cublasLtMatmulAlgoConfigSetAttribute(
&algo,
&algo_,
CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION,
&(customOption),
sizeof(customOption));
dyl::cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_TILE_ID, &(tile), sizeof(tile));
dyl::cublasLtMatmulAlgoConfigSetAttribute(&algo,
&algo_, CUBLASLT_ALGO_CONFIG_TILE_ID, &(tile), sizeof(tile));
dyl::cublasLtMatmulAlgoConfigSetAttribute(&algo_,
CUBLASLT_ALGO_CONFIG_SPLITK_NUM,
&(splitK_val),
sizeof(splitK_val));
dyl::cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(swizzle), sizeof(swizzle));
&algo_,
CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING,
&(swizzle),
sizeof(swizzle));
dyl::cublasLtMatmulAlgoConfigSetAttribute(
&algo,
&algo_,
CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME,
&(reductionScheme),
sizeof(int));
#if CUDA_VERSION >= 11000
dyl::cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &(stages), sizeof(stages));
&algo_, CUBLASLT_ALGO_CONFIG_STAGES_ID, &(stages), sizeof(stages));
#endif
#endif
}
~CublasLtHelper() {}
void GEMM(int8_t* A_dev,
const int8_t* B_dev,
int32_t* C_dev,
cudaStream_t stream,
void* workspace = nullptr) {
cublasStatus_t status;
status = dyl::cublasLtMatmul(handle_,
matmul_desc_,
&alpha_,
......@@ -176,13 +199,15 @@ class CublasLtHelper {
C_desc_,
C_dev,
C_desc_,
#if __CUDA_ARCH__ >= 800 && CUDA_VERSION >= 11020
&algo,
#if CUDA_VERSION >= 11020
&algo_,
workspace,
workspace_size_,
#else
nullptr,
#endif
nullptr,
0,
#endif
stream);
PADDLE_ENFORCE_EQ(
status,
......@@ -199,12 +224,17 @@ class CublasLtHelper {
cublasLtMatrixLayout_t A_desc_;
cublasLtMatrixLayout_t B_desc_;
cublasLtMatrixLayout_t C_desc_;
cublasLtMatmulAlgo_t algo_;
int32_t alpha_;
int32_t beta_;
int m_;
int k_;
int n_;
size_t workspace_size_;
};
} // namespace operators
......
......@@ -86,7 +86,6 @@ __global__ void FusedDropoutActBias(
MaskType *mask,
const float quant_last_in_scale = 1.0,
const float *dequant_out_scale_data = nullptr,
const int quant_out_scale_offset = 0,
const float quant_next_in_scale = 1.0,
const int quant_round_type = 1,
const float quant_max_bound = 127.0,
......@@ -127,7 +126,6 @@ __global__ void FusedDropoutActBias(
act,
quant_last_in_scale,
dequant_out_scale_data,
quant_out_scale_offset,
quant_next_in_scale,
quant_round_type,
quant_max_bound,
......@@ -146,7 +144,13 @@ __global__ void FusedActBias(Functor act,
const uint64_t cols,
const InType *__restrict__ src,
const T *__restrict__ bias,
OutType *dst) {
OutType *dst,
const float quant_last_in_scale = 1.0,
const float *dequant_out_scale_data = nullptr,
const float quant_next_in_scale = 1.0,
const int quant_round_type = 1,
const float quant_max_bound = 127.0,
const float quant_min_bound = -127.0) {
const int32_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
using LoadT = phi::AlignedVector<T, VecSize>;
using LoadInType = phi::AlignedVector<InType, VecSize>;
......@@ -156,23 +160,42 @@ __global__ void FusedActBias(Functor act,
LoadInType src_vec;
LoadT bias_vec;
StoreOutType out_vec;
LoadFloat dequant_out_scale_vec;
for (int32_t idx = global_thread_idx * VecSize,
step = blockDim.x * gridDim.x * VecSize;
idx < elem_cnt;
idx += step) {
const int32_t col_idx = idx % cols;
phi::Load<InType, VecSize>(&src[idx], &src_vec);
phi::Load<float, VecSize>(&dequant_out_scale_data[col_idx],
&dequant_out_scale_vec);
if (bias) {
phi::Load<T, VecSize>(&bias[col_idx], &bias_vec);
}
#pragma unroll
for (int32_t unroll_idx = 0; unroll_idx < VecSize; unroll_idx++) {
if (bias) {
out_vec[unroll_idx] = static_cast<OutType>(
act(static_cast<T>(src_vec[unroll_idx]) + bias_vec[unroll_idx]));
T tmp;
if (std::is_same<InType, int32_t>::value) {
tmp = static_cast<T>(static_cast<float>(src_vec[unroll_idx]) *
dequant_out_scale_vec[unroll_idx]);
if (bias) {
tmp = static_cast<T>(act(tmp + bias_vec[unroll_idx]));
} else {
tmp = static_cast<T>(act(tmp));
}
out_vec[unroll_idx] = quant_helper(tmp,
quant_next_in_scale,
quant_round_type,
quant_max_bound,
quant_min_bound);
} else {
out_vec[unroll_idx] =
static_cast<OutType>(act(static_cast<T>(src_vec[unroll_idx])));
if (bias) {
out_vec[unroll_idx] = static_cast<OutType>(
act(static_cast<T>(src_vec[unroll_idx]) + bias_vec[unroll_idx]));
} else {
out_vec[unroll_idx] =
static_cast<OutType>(act(static_cast<T>(src_vec[unroll_idx])));
}
}
}
phi::Store<OutType, VecSize>(out_vec, &dst[idx]);
......@@ -202,7 +225,6 @@ void LaunchDropoutActBias(Functor act_functor,
const phi::GPUContext &ctx,
const float quant_last_in_scale = 1.0,
const float *dequant_out_scale_data = nullptr,
const int quant_out_scale_offset = 0,
const float quant_next_in_scale = 1.0,
const int quant_round_type = 1,
const float quant_max_bound = 127.0,
......@@ -218,7 +240,7 @@ void LaunchDropoutActBias(Functor act_functor,
const int real_vec_size = cols % VecSize == 0 ? VecSize : 1;
const auto config = Get1DBlocksAnd2DGrids(ctx, rows, cols, real_vec_size);
if (cols % VecSize == 0) {
if (is_test && (dequant_out_scale_data == nullptr)) {
if (is_test) {
const int32_t elem_cnt = rows * cols;
const int32_t pack_num = elem_cnt / VecSize;
const int32_t tmp_cols = cols / VecSize;
......@@ -227,8 +249,15 @@ void LaunchDropoutActBias(Functor act_functor,
const int grid_size = std::max(static_cast<int32_t>(1),
(pack_num + block_size - 1) / block_size);
FusedActBias<T, VecSize, Functor, InType, OutType>
<<<grid_size, block_size, 0, ctx.stream()>>>(
act_functor, elem_cnt, cols, src, bias, dst);
<<<grid_size, block_size, 0, ctx.stream()>>>(act_functor,
elem_cnt,
cols,
src,
bias,
dst,
quant_last_in_scale,
dequant_out_scale_data,
quant_next_in_scale);
} else {
FusedDropoutActBias<T, MaskType, VecSize, Functor, InType, OutType>
<<<config.block_per_grid, config.thread_per_block, 0, ctx.stream()>>>(
......@@ -246,7 +275,6 @@ void LaunchDropoutActBias(Functor act_functor,
mask_data,
quant_last_in_scale,
dequant_out_scale_data,
quant_out_scale_offset,
quant_next_in_scale);
}
} else {
......@@ -266,7 +294,6 @@ void LaunchDropoutActBias(Functor act_functor,
mask_data,
quant_last_in_scale,
dequant_out_scale_data,
quant_out_scale_offset,
quant_next_in_scale);
}
}
......
......@@ -154,7 +154,6 @@ class FusedDropoutHelper {
MaskType* mask,
const float quant_last_in_scale = 1.0,
const float* dequant_out_scale_data = nullptr,
const int quant_out_scale_offset = 0,
const float quant_next_in_scale = 1.0) {
auto increment = GetIncrement(ctx);
LaunchResidualDropoutBias<T, MaskType, InType, OutType>(
......@@ -173,7 +172,6 @@ class FusedDropoutHelper {
ctx,
quant_last_in_scale,
dequant_out_scale_data,
quant_out_scale_offset,
quant_next_in_scale);
}
......@@ -212,7 +210,6 @@ class FusedDropoutHelper {
MaskType* mask,
const float quant_last_in_scale = 1.0,
const float* dequant_out_scale_data = nullptr,
const int quant_out_scale_offset = 0,
const float quant_next_in_scale = 1.0,
const int quant_round_type = 1,
const float quant_max_bound = 127.0,
......@@ -237,7 +234,6 @@ class FusedDropoutHelper {
ctx,
quant_last_in_scale,
dequant_out_scale_data,
quant_out_scale_offset,
quant_next_in_scale,
quant_round_type,
quant_max_bound,
......@@ -260,7 +256,6 @@ class FusedDropoutHelper {
ctx,
quant_last_in_scale,
dequant_out_scale_data,
quant_out_scale_offset,
quant_next_in_scale,
quant_round_type,
quant_max_bound,
......@@ -287,7 +282,6 @@ class FusedDropoutHelper {
ctx,
quant_last_in_scale,
dequant_out_scale_data,
quant_out_scale_offset,
quant_next_in_scale,
quant_round_type,
quant_max_bound,
......@@ -454,7 +448,6 @@ class FusedDropoutLayerNormHelper
LayerNormParamType<T>* variance,
const float quant_last_in_scale = 1.0,
const float* dequant_out_scale_data = nullptr,
const int quant_out_scale_offset = 0,
const float quant_next_in_scale = 1.0,
const int quant_round_type = 1,
const float quant_max_bound = 127.0,
......@@ -494,7 +487,6 @@ class FusedDropoutLayerNormHelper
ctx,
quant_last_in_scale,
dequant_out_scale_data,
quant_out_scale_offset,
quant_next_in_scale,
quant_round_type,
quant_max_bound,
......
......@@ -442,7 +442,6 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_fast_ln_fwd_kernel(
OutType *__restrict__ y_ptr,
const float quant_last_in_scale = 1.0,
const float *__restrict__ quant_out_scale_ptr = nullptr,
const int quant_out_scale_offset = 0,
const float quant_next_in_scale = 1.0,
const int quant_round_type = 1,
const float quant_max_bound = 127.0,
......@@ -504,9 +503,8 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_fast_ln_fwd_kernel(
phi::Load<InType, VecSize>(x_ptr + row * ELTS_PER_ROW + col * VecSize,
&x_input[it]);
if (quant_out_scale_ptr != nullptr) {
phi::Load<float, VecSize>(
quant_out_scale_ptr + quant_out_scale_offset + col * VecSize,
&dequant_out_scale[it]);
phi::Load<float, VecSize>(quant_out_scale_ptr + col * VecSize,
&dequant_out_scale[it]);
}
col += THREADS_PER_ROW;
}
......@@ -543,7 +541,6 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_fast_ln_fwd_kernel(
// dropout(x) + residual
if (std::is_same<InType, int32_t>::value) {
T tmp = (static_cast<T>(static_cast<float>(x_input[it][jt]) *
quant_last_in_scale /
dequant_out_scale[it][jt]) +
bias[it][jt]) *
static_cast<T>(mask_vec[it][jt]) * factor +
......@@ -567,7 +564,6 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_fast_ln_fwd_kernel(
if (std::is_same<InType, int32_t>::value) {
// for int32 input, we need to dequantize.
T tmp = static_cast<T>(static_cast<float>(x_input[it][jt]) *
quant_last_in_scale /
dequant_out_scale[it][jt]) *
static_cast<T>(mask_vec[it][jt]) * factor +
residual[it][jt];
......@@ -752,7 +748,6 @@ void LaunchLayernormResidualDropoutBias(
const phi::GPUContext &ctx,
const float quant_last_in_scale = 1.0,
const float *dequant_out_scale_data = nullptr,
const int quant_out_scale_offset = 0,
const float quant_next_in_scale = 1.0,
const int quant_round_type = 1,
const float quant_max_bound = 127.0,
......@@ -844,7 +839,6 @@ void LaunchLayernormResidualDropoutBias(
layernorm_dst, \
quant_last_in_scale, \
dequant_out_scale_data, \
quant_out_scale_offset, \
quant_next_in_scale, \
quant_round_type, \
quant_max_bound, \
......
......@@ -58,6 +58,12 @@ class FusedMultiTransformerINT8Op : public framework::OperatorWithKernel {
CHECK_INPUTS(FFN1Weight);
CHECK_INPUTS(FFN2Weight);
// scale
CHECK_INPUTS(QKVOutScale);
CHECK_INPUTS(OutLinearOutScale);
CHECK_INPUTS(FFN1OutScale);
CHECK_INPUTS(FFN2OutScale);
CHECK_OUTPUT(Out);
// x: qkv's input [batch_size, seq_len, dim_embed]
......@@ -232,20 +238,24 @@ class FusedMultiTransformerINT8OpMaker
"In order to keep consistent with the PTQ/QAT calculation logic,"
"QKVOutScale should be max_bound * max_bound / max_range."
"Here max_range is per-channel weight scale."
"The shape of QKVOutScale is [num_layers, num_channels]")
.AsDispensable();
"The shape of QKVOutScale is [num_channels]")
.AsDispensable()
.AsDuplicable();
AddInput("OutLinearOutScale",
"OutLinearOutScale is used to dequantize out_linear output tensor."
"The definition and shape is the same as QKVOutScale")
.AsDispensable();
.AsDispensable()
.AsDuplicable();
AddInput("FFN1OutScale",
"FFN1OutScale is used to dequantize ffn1 output tensor."
"The definition and shape is the same as QKVOutScale")
.AsDispensable();
.AsDispensable()
.AsDuplicable();
AddInput("FFN2OutScale",
"FFN2OutScale is used to dequantize ffn2 output tensor."
"The definition and shape is the same as QKVOutScale")
.AsDispensable();
.AsDispensable()
.AsDuplicable();
AddOutput("CacheKVOut", "The updated cache KV. Inplace with CacheKV")
.AsDispensable()
......
......@@ -48,16 +48,11 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel<T> {
// dequant output scales, tensor, size = [num_layers, n], n is gemm output
// size
auto *qkv_out_scale = ctx.Input<phi::DenseTensor>("QKVOutScale");
auto *out_linear_out_scale =
ctx.Input<phi::DenseTensor>("OutLinearOutScale");
auto *ffn1_out_scale = ctx.Input<phi::DenseTensor>("FFN1OutScale");
auto *ffn2_out_scale = ctx.Input<phi::DenseTensor>("FFN2OutScale");
int qkv_out_scale_n = qkv_out_scale->dims()[1];
int out_linear_out_scale_n = out_linear_out_scale->dims()[1];
int ffn1_out_scale_n = ffn1_out_scale->dims()[1];
int ffn2_out_scale_n = ffn2_out_scale->dims()[1];
auto qkv_out_scales = ctx.MultiInput<phi::DenseTensor>("QKVOutScale");
auto out_linear_out_scales =
ctx.MultiInput<phi::DenseTensor>("OutLinearOutScale");
auto ffn1_out_scales = ctx.MultiInput<phi::DenseTensor>("FFN1OutScale");
auto ffn2_out_scales = ctx.MultiInput<phi::DenseTensor>("FFN2OutScale");
// 1. layer norm
const auto pre_layer_norm = ctx.Attr<bool>("pre_layer_norm");
......@@ -132,6 +127,7 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel<T> {
transpose_out_2.Resize({{3, bsz, num_head, seq_len, dim_head}});
auto *transpose_out_2_data =
dev_ctx.Alloc<T>(&transpose_out_2, transpose_out_2.numel() * sizeof(T));
qk_out.Resize({{bsz, num_head, seq_len, out_seq_len}});
auto *qk_out_data = dev_ctx.Alloc<T>(&qk_out, qk_out.numel() * sizeof(T));
......@@ -232,19 +228,23 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel<T> {
dev_ctx, bsz_seq, dim_embed, ffn2_dropout_param, epsilon);
// []. init workspace for cublasLt transform
Tensor input_workspace, output_workspace;
Tensor input_workspace, output_workspace, cublaslt_workspace;
// for input and output transform data is CUBLASLT_ORDER_COL32 format,
int m_max = bsz_seq, k_max = std::max(dim_embed, dim_ffn),
n_max = std::max({output_size, dim_embed, dim_ffn});
input_workspace.Resize(
{{32 * ((m_max + 32 - 1) / 32), (k_max + 31) / 32 * 32}});
input_workspace.Resize({{(m_max * k_max + 31) / 32 * 32}});
dev_ctx.Alloc<int8_t>(&input_workspace,
input_workspace.numel() * sizeof(int8_t));
output_workspace.Resize({{n_max * 4, (m_max + 31) / 32 * 32 * 4}});
output_workspace.Resize({{(n_max * m_max + 31) / 32 * 32}});
dev_ctx.Alloc<int32_t>(&output_workspace,
output_workspace.numel() * sizeof(int32_t));
cublaslt_workspace.Resize({{3000000}});
dev_ctx.Alloc<int8_t>(&cublaslt_workspace,
cublaslt_workspace.numel() * sizeof(int8_t));
// calc
auto *out = ctx.Output<phi::DenseTensor>("Out");
auto *from_data = dev_ctx.Alloc<T>(out, out->numel() * sizeof(T));
......@@ -305,8 +305,7 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel<T> {
&output_workspace,
&qkv_out,
qkv_in_scale[i],
qkv_out_scale,
i * qkv_out_scale_n,
qkv_out_scales[i],
quant_round_type,
quant_max_bound,
quant_min_bound);
......@@ -319,8 +318,7 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel<T> {
&output_workspace,
&qkv_out,
qkv_in_scale[i],
qkv_out_scale,
i * qkv_out_scale_n,
qkv_out_scales[i],
quant_round_type,
quant_max_bound,
quant_min_bound);
......@@ -332,8 +330,7 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel<T> {
&qkv_out,
&output_workspace,
&qkv_out,
qkv_out_scale,
i * qkv_out_scale_n);
qkv_out_scales[i]);
}
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG(0) << "step2";
......@@ -441,8 +438,7 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel<T> {
&output_workspace,
nullptr,
out_linear_in_scale[i],
out_linear_out_scale,
i * out_linear_out_scale_n,
out_linear_out_scales[i],
quant_round_type,
quant_max_bound,
quant_min_bound);
......@@ -473,8 +469,7 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel<T> {
ln_mean_data,
ln_var_data,
out_linear_in_scale[i],
out_linear_out_scale->data<float>(),
i * out_linear_out_scale_n,
out_linear_out_scales[i]->data<float>(),
ffn1_in_scale[i],
quant_round_type,
quant_max_bound,
......@@ -504,11 +499,13 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel<T> {
// step6. ffn matmul1
if (pre_layer_norm) {
ffn1_linear_compute.ComputeForwardINT8ToINT8(ffn1_weights[i],
&input_workspace,
nullptr,
&output_workspace,
nullptr);
ffn1_linear_compute.ComputeForwardINT8ToINT8(
ffn1_weights[i],
&input_workspace,
nullptr,
&output_workspace,
nullptr,
cublaslt_workspace.data<int8_t>());
} else {
ffn1_linear_compute.ComputeForward(ffn1_weights[i],
buf1,
......@@ -518,8 +515,7 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel<T> {
&output_workspace,
nullptr,
ffn1_in_scale[i],
ffn1_out_scale,
i * ffn1_out_scale_n,
ffn1_out_scales[i],
quant_round_type,
quant_max_bound,
quant_min_bound);
......@@ -539,8 +535,7 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel<T> {
input_workspace.data<int8_t>(),
ffn1_dropout_mask_data,
ffn1_in_scale[i],
ffn1_out_scale->data<float>(),
i * ffn1_out_scale_n,
ffn1_out_scales[i]->data<float>(),
ffn2_in_scale[i],
quant_round_type,
quant_max_bound,
......@@ -560,11 +555,13 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel<T> {
// step8. ffn matmul2
if (pre_layer_norm) {
ffn2_linear_compute.ComputeForwardINT8ToINT8(ffn2_weights[i],
&input_workspace,
nullptr,
&output_workspace,
nullptr);
ffn2_linear_compute.ComputeForwardINT8ToINT8(
ffn2_weights[i],
&input_workspace,
nullptr,
&output_workspace,
nullptr,
cublaslt_workspace.data<int8_t>());
} else {
ffn2_linear_compute.ComputeForward(ffn2_weights[i],
&ffn1_dropout_out,
......@@ -574,8 +571,7 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel<T> {
&output_workspace,
nullptr,
ffn2_in_scale[i],
ffn2_out_scale,
i * ffn2_out_scale_n,
ffn2_out_scales[i],
quant_round_type,
quant_max_bound,
quant_min_bound);
......@@ -616,8 +612,7 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel<T> {
ln_mean_data,
ln_var_data,
ffn2_in_scale[i],
ffn2_out_scale->data<float>(),
i * ffn2_out_scale_n,
ffn2_out_scales[i]->data<float>(),
qkv_in_scale[i + 1],
quant_round_type,
quant_max_bound,
......@@ -631,8 +626,7 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel<T> {
buf1->data<T>(),
dropout_mask_out_data,
ffn2_in_scale[i],
ffn2_out_scale->data<float>(),
i * ffn2_out_scale_n,
ffn2_out_scales[i]->data<float>(),
1.0);
}
} else {
......
......@@ -49,7 +49,6 @@ __forceinline__ __device__ void FusedResidualDropoutBiasOneThread(
Functor act_func,
const float quant_last_in_scale = 1.0,
const float *dequant_out_scale_data = nullptr,
const int quant_out_scale_offset = 0,
const float quant_next_in_scale = 1.0,
const int quant_round_type = 1,
const float quant_max_bound = 127.0,
......@@ -74,9 +73,8 @@ __forceinline__ __device__ void FusedResidualDropoutBiasOneThread(
}
// vectorize load data from global
phi::Load<InType, VecSize>(&src[row_id * cols + col_id], &src_vec);
phi::Load<float, VecSize>(
&dequant_out_scale_data[quant_out_scale_offset + col_id],
&quant_out_scale_vec);
phi::Load<float, VecSize>(&dequant_out_scale_data[col_id],
&quant_out_scale_vec);
if (residual) {
phi::Load<T, VecSize>(&residual[row_id * cols + col_id], &residual_vec);
}
......@@ -108,7 +106,7 @@ __forceinline__ __device__ void FusedResidualDropoutBiasOneThread(
T tmp;
if (std::is_same<InType, int32_t>::value) {
T tmp0 = static_cast<T>(static_cast<float>(src_vec[ii]) *
quant_last_in_scale / quant_out_scale_vec[ii]);
quant_out_scale_vec[ii]);
tmp = tmp0 + bias_vec[ii];
} else {
tmp = static_cast<T>(src_vec[ii]) + bias_vec[ii];
......@@ -172,7 +170,6 @@ __global__ void FusedResidualDropoutBias(
const bool is_test,
const float quant_last_in_scale = 1.0,
const float *dequant_out_scale_data = nullptr,
const int quant_out_scale_offset = 0,
const float quant_next_in_scale = 1.0) {
int col_id = blockDim.x * blockIdx.x + threadIdx.x;
int row_id = blockIdx.y;
......@@ -208,7 +205,6 @@ __global__ void FusedResidualDropoutBias(
relu,
quant_last_in_scale,
dequant_out_scale_data,
quant_out_scale_offset,
quant_next_in_scale);
}
}
......@@ -236,7 +232,6 @@ void LaunchResidualDropoutBias(const uint32_t rows,
const phi::GPUContext &ctx,
const float quant_last_in_scale = 1.0,
const float *dequant_out_scale_data = nullptr,
const int quant_out_scale_offset = 0,
const float quant_next_in_scale = 1.0) {
// dropout_prob == 1.0f
if (std::abs(dropout_prob - 1.0f) < 1e-5) {
......@@ -278,7 +273,6 @@ void LaunchResidualDropoutBias(const uint32_t rows,
is_test,
quant_last_in_scale,
dequant_out_scale_data,
quant_out_scale_offset,
quant_next_in_scale);
} else {
FusedResidualDropoutBias<T, uint8_t, 1, InType, OutType>
......@@ -297,7 +291,6 @@ void LaunchResidualDropoutBias(const uint32_t rows,
is_test,
quant_last_in_scale,
dequant_out_scale_data,
quant_out_scale_offset,
quant_next_in_scale);
}
}
......
......@@ -18,17 +18,24 @@ limitations under the License. */
#include "paddle/fluid/operators/fake_quantize_op.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/kernels/funcs/aligned_vector.h"
namespace paddle {
namespace operators {
using phi::backends::gpu::GpuLaunchConfig;
constexpr int DequantKernelVecSize = 4;
template <typename T>
__forceinline__ __device__ int8_t quant_helper(const T input,
const float scale,
const int round_type,
const float max_bound,
const float min_bound) {
float quant_value = max_bound * inverse(scale) * static_cast<float>(input);
float quant_value = max_bound * scale * static_cast<float>(input);
if (round_type == 0) {
quant_value = static_cast<float>(roundWithTiesToEven(quant_value));
} else {
......@@ -77,7 +84,7 @@ void quantize_kernel_launcher(const T* input,
const float min_bound,
gpuStream_t stream) {
// TODO(minghaoBD): optimize the kennel launch times when m==1 or n==1
dim3 grid((n + 31) / 32, (m + 31) / 32);
dim3 grid((n >> 2 + 31) / 32, (m + 31) / 32);
dim3 block(32, 32);
quantize_kernel<<<grid, block, 0, stream>>>(input,
......@@ -90,46 +97,48 @@ void quantize_kernel_launcher(const T* input,
min_bound);
}
// dequantize using weight scales and input scales
template <typename T>
template <typename T, int VecSize>
__global__ void dequantize_kernel(T* output,
const int32_t* input,
const int m, // hidden
const int n, // batch size
const int m, // batch size
const int n, // hidden
const float quant_in_scale,
const float* dequant_out_scale_data,
const int quant_out_scale_offset) {
int m_id = blockIdx.x * blockDim.x + threadIdx.x; // hidden
int n_id = blockIdx.y * blockDim.y + threadIdx.y; // batch size
bool check = ((m_id < m) && (n_id < n));
if (check) {
float out_scale = dequant_out_scale_data[quant_out_scale_offset + m_id];
output[n_id * m + m_id] =
static_cast<T>(static_cast<float>(input[n_id * m + m_id]) *
quant_in_scale / out_scale);
const float* dequant_out_scale_data) {
int numel = m * n;
int stride = blockDim.x * gridDim.x * VecSize;
int idx = (blockIdx.x * blockDim.x + threadIdx.x) * VecSize;
int col_id = idx % n;
phi::AlignedVector<int32_t, VecSize> in_vec;
phi::AlignedVector<float, VecSize> out_scale_vec;
phi::AlignedVector<T, VecSize> out_vec;
for (; idx < numel; idx += stride) {
phi::Load<int32_t, VecSize>(input + idx, &in_vec);
phi::Load<float, VecSize>(dequant_out_scale_data + col_id, &out_scale_vec);
#pragma unroll
for (int i = 0; i < VecSize; ++i) {
out_vec[i] =
static_cast<T>(static_cast<float>(in_vec[i]) * out_scale_vec[i]);
}
phi::Store<T, VecSize>(out_vec, output + idx);
}
}
template <typename T>
void dequantize_kernel_launcher(const int32_t* input,
T* output,
const int batch_size, // m
const int hidden_units, // n
const int m, // m
const int n, // n
gpuStream_t stream,
GpuLaunchConfig* gpu_config,
const float quant_in_scale,
const float* dequant_out_scale_data,
const int quant_out_scale_offset) {
dim3 grid((hidden_units + 31) / 32, (batch_size + 31) / 32);
dim3 block(32, 32);
dequantize_kernel<<<grid, block, 0, stream>>>(output,
input,
hidden_units,
batch_size,
quant_in_scale,
dequant_out_scale_data,
quant_out_scale_offset);
const float* dequant_out_scale_data) {
dequantize_kernel<T, DequantKernelVecSize>
<<<gpu_config->block_per_grid, gpu_config->thread_per_block, 0, stream>>>(
output, input, m, n, quant_in_scale, dequant_out_scale_data);
}
} // namespace operators
......
......@@ -307,7 +307,7 @@ class TestFusedMultiTransformerInt8Op(unittest.TestCase):
self.attn_mask = None
def fake_quant(self, input, scale):
quant_value = 127.0 * (1.0 / scale) * paddle.cast(input, 'float32')
quant_value = 127.0 * scale * paddle.cast(input, 'float32')
quant_value = paddle.round(quant_value)
# No need to clip here because scale is the max value
......@@ -333,11 +333,8 @@ class TestFusedMultiTransformerInt8Op(unittest.TestCase):
if self.pre_layer_norm:
ln1_out = self.norm(tensor_query)
max_v = paddle.max(paddle.abs(paddle.cast(ln1_out, 'float32')))[0]
# self.qkv_in_scales.append(127.0 / max_v)
self.qkv_in_scales.append(max_v)
self.qkv_out_scales.append(127.0 * 127.0)
# print('qkv_in_scales ', i, self.qkv_in_scales[i])
# print('qkv_out_scales ', i, self.qkv_out_scales[i])
self.qkv_in_scales.append(1 / max_v)
self.qkv_out_scales.append(max_v / (127.0 * 127.0))
# quant ln1_out
ln1_out = self.fake_quant(ln1_out, self.qkv_in_scales[i])
......@@ -345,9 +342,7 @@ class TestFusedMultiTransformerInt8Op(unittest.TestCase):
q = paddle.nn.functional.linear(ln1_out, self.q_weight_tensor)
# de quant
q = paddle.cast(
paddle.cast(q, 'float32')
* self.qkv_in_scales[i]
/ self.qkv_out_scales[i],
paddle.cast(q, 'float32') * self.qkv_out_scales[i],
self.x_type,
)
......@@ -357,17 +352,13 @@ class TestFusedMultiTransformerInt8Op(unittest.TestCase):
k = paddle.nn.functional.linear(ln1_out, self.k_weight_tensor)
k = paddle.cast(
paddle.cast(k, 'float32')
* self.qkv_in_scales[i]
/ self.qkv_out_scales[i],
paddle.cast(k, 'float32') * self.qkv_out_scales[i],
self.x_type,
)
k = k + self.k_proj_bias_tensor
v = paddle.nn.functional.linear(ln1_out, self.v_weight_tensor)
v = paddle.cast(
paddle.cast(v, 'float32')
* self.qkv_in_scales[i]
/ self.qkv_out_scales[i],
paddle.cast(v, 'float32') * self.qkv_out_scales[i],
self.x_type,
)
v = v + self.v_proj_bias_tensor
......@@ -442,10 +433,10 @@ class TestFusedMultiTransformerInt8Op(unittest.TestCase):
max_v = paddle.max(
paddle.abs(paddle.cast(out_linear_in, 'float32'))
)[0]
# self.out_linear_in_scales.append(127.0 / max_v)
self.out_linear_in_scales.append(max_v)
self.out_linear_out_scales.append((127.0 * 127.0))
self.out_linear_in_scales.append(1 / max_v)
self.out_linear_out_scales.append(max_v / (127.0 * 127.0))
out_linear_in = self.fake_quant(
out_linear_in, self.out_linear_in_scales[i]
)
......@@ -455,9 +446,7 @@ class TestFusedMultiTransformerInt8Op(unittest.TestCase):
)
out = paddle.cast(
paddle.cast(out, 'float32')
* self.out_linear_in_scales[i]
/ self.out_linear_out_scales[i],
paddle.cast(out, 'float32') * self.out_linear_out_scales[i],
self.x_type,
)
......@@ -476,8 +465,8 @@ class TestFusedMultiTransformerInt8Op(unittest.TestCase):
max_v = paddle.max(paddle.abs(paddle.cast(ffn_ln_out, 'float32')))[
0
]
self.ffn1_in_scales.append(max_v)
self.ffn1_out_scales.append((127.0 * 127.0))
self.ffn1_in_scales.append(1 / max_v)
self.ffn1_out_scales.append(max_v / (127.0 * 127.0))
ffn_ln_out = self.fake_quant(ffn_ln_out, self.ffn1_in_scales[i])
ffn1_out = paddle.nn.functional.linear(
......@@ -485,9 +474,7 @@ class TestFusedMultiTransformerInt8Op(unittest.TestCase):
)
ffn1_out = paddle.cast(
paddle.cast(ffn1_out, 'float32')
* self.ffn1_in_scales[i]
/ self.ffn1_out_scales[i],
paddle.cast(ffn1_out, 'float32') * self.ffn1_out_scales[i],
self.x_type,
)
......@@ -495,10 +482,8 @@ class TestFusedMultiTransformerInt8Op(unittest.TestCase):
ffn1_out = self.dropout(self.activation(ffn1_out))
max_v = paddle.max(paddle.abs(paddle.cast(ffn1_out, 'float32')))[0]
# self.ffn2_in_scales.append(127.0 / max_v)
self.ffn2_in_scales.append(max_v)
self.ffn2_out_scales.append((127.0 * 127.0))
# print('ffn2_in_scales ', i, self.ffn2_in_scales[i])
self.ffn2_in_scales.append(1 / max_v)
self.ffn2_out_scales.append(max_v / (127.0 * 127.0))
ffn1_out = self.fake_quant(ffn1_out, self.ffn2_in_scales[i])
ffn2_out = paddle.nn.functional.linear(
......@@ -506,16 +491,12 @@ class TestFusedMultiTransformerInt8Op(unittest.TestCase):
)
ffn2_out = paddle.cast(
paddle.cast(ffn2_out, 'float32')
* self.ffn2_in_scales[i]
/ self.ffn2_out_scales[i],
paddle.cast(ffn2_out, 'float32') * self.ffn2_out_scales[i],
self.x_type,
)
ffn2_out = ffn2_out + self.ffn2_proj_bias_tensor
residual_out = attn_out + self.dropout(ffn2_out)
# print("residual ", attn_out)
# print("residual_out ", residual_out)
final_out = residual_out
if not self.pre_layer_norm:
final_out = self.ffn_norm(residual_out)
......@@ -644,23 +625,18 @@ class TestFusedMultiTransformerInt8Op(unittest.TestCase):
ffn1_weights, ffn1_biases = [], []
ffn2_weights, ffn2_biases = [], []
ffn_ln_scales, ffn_ln_biases = [], []
# Input scales: list of value
qkv_in_scale = []
out_linear_in_scale = []
ffn1_in_scale = []
ffn2_in_scale = []
qkv_out_scales_tensor = paddle.ones(
[self.layers, 3 * self.embed_dim], 'float32'
)
out_linear_out_scales_tensor = paddle.ones(
[self.layers, self.embed_dim], 'float32'
)
ffn1_out_scales_tensor = paddle.ones(
[self.layers, 4 * self.embed_dim], 'float32'
)
ffn2_out_scales_tensor = paddle.ones(
[self.layers, self.embed_dim], 'float32'
)
# Output dequant scales: list of tensor
qkv_out_scales = []
out_linear_out_scales = []
ffn1_out_scales = []
ffn2_out_scales = []
for i in range(self.layers):
qkv_weights.append(qkv_weight_tensor)
......@@ -680,10 +656,30 @@ class TestFusedMultiTransformerInt8Op(unittest.TestCase):
ffn1_in_scale.append(self.ffn1_in_scales[i])
ffn2_in_scale.append(self.ffn2_in_scales[i])
qkv_out_scales_tensor[i, :] *= self.qkv_out_scales[i]
out_linear_out_scales_tensor[i, :] *= self.out_linear_out_scales[i]
ffn1_out_scales_tensor[i, :] *= self.ffn1_out_scales[i]
ffn2_out_scales_tensor[i, :] *= self.ffn2_out_scales[i]
qkv_out_scale = (
paddle.ones([3 * self.embed_dim], 'float32')
* self.qkv_out_scales[i]
)
out_linear_out_scale = (
paddle.ones([self.embed_dim], 'float32')
* self.out_linear_out_scales[i]
)
ffn1_out_scale = (
paddle.ones([4 * self.embed_dim], 'float32')
* self.ffn1_out_scales[i]
)
ffn2_out_scale = (
paddle.ones([self.embed_dim], 'float32')
* self.ffn2_out_scales[i]
)
qkv_out_scales.append(qkv_out_scale)
out_linear_out_scales.append(out_linear_out_scale)
ffn1_out_scales.append(ffn1_out_scale)
ffn2_out_scales.append(ffn2_out_scale)
if self.has_cache_kv:
cache_kvs.append(paddle.to_tensor(cache_kv, stop_gradient=True))
......@@ -713,10 +709,10 @@ class TestFusedMultiTransformerInt8Op(unittest.TestCase):
trans_qkvw=True,
ring_id=-1,
name=None,
qkv_out_scales=qkv_out_scales_tensor,
out_linear_out_scales=out_linear_out_scales_tensor,
ffn1_out_scales=ffn1_out_scales_tensor,
ffn2_out_scales=ffn2_out_scales_tensor,
qkv_out_scales=qkv_out_scales,
out_linear_out_scales=out_linear_out_scales,
ffn1_out_scales=ffn1_out_scales,
ffn2_out_scales=ffn2_out_scales,
num_head=self.num_heads,
dim_head=self.head_dim,
dim_ffn=4 * self.embed_dim,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册