提交 5d7d5482 编写于 作者: M Michał Gallus 提交者: Tao Luo

INT8 Fully-connected (#17641)

* Implement Int8 FC

* Integrate FC into INT8v2

test=develop

* int8 FC: transpose weights before computing scales

test=develop

* Add support for activation_type string in FC

test=develop

* Disable MKL-DNN's FC in VGG16 and 19

test=develop

* Disable FC quantization when mkldnn FC is disabled

test=develop

* Solve PADDLE_ENFORCES in FC int8

* Fix Paddle enforces and remove const cast

test=develop

* Fix style changes

test=develop

* Fix quantizer_tester test and add fc quantization

test=develop

* Fix FC test fail on CUDA

* Remove unnecessary log from quantize placement pass

test=develop

* Add Thread ID to FC hash key

test=develop

* Add comments to MKL-DNN FC Kernel

test=develop

* Refactor quantizer

test=develop

* Fix linter issues

test=develop

* Fix crash in slim googlenet

test=develop

* Fix PADDLE_ENFORCE messages

test=develop
上级 b639a882
......@@ -186,10 +186,14 @@ function(op_library TARGET)
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, FP32);\n")
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, S8);\n")
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, U8);\n")
elseif(${MKLDNN_FILE} STREQUAL "transpose_mkldnn_op")
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(transpose2, MKLDNN, FP32);\n")
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(transpose2, MKLDNN, S8);\n")
elseif(${MKLDNN_FILE} STREQUAL "transpose_mkldnn_op")
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(transpose2, MKLDNN, FP32);\n")
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(transpose2, MKLDNN, S8);\n")
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(transpose2, MKLDNN, U8);\n")
elseif(${MKLDNN_FILE} STREQUAL "fc_mkldnn_op")
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(fc, MKLDNN, FP32);\n")
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(fc, MKLDNN, S8);\n")
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(fc, MKLDNN, U8);\n")
else()
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, MKLDNN);\n")
endif()
......
......@@ -905,15 +905,17 @@ PDNode *patterns::FCMKLDNN::operator()(paddle::framework::ir::PDNode *x,
auto *fc_op = pattern->NewNode(fc_repr())->assert_is_op("fc");
// Create variables
// Input
auto *input_var = pattern->NewNode(input_repr())
->AsInput()
->assert_is_op_input("fc", "Input");
// Filter
auto *fc_weight_var = pattern->NewNode(weights_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("fc", "W");
// Bias
auto *fc_bias_var = pattern->NewNode(bias_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("fc", "Bias");
// Output
auto *fc_out_var = pattern->NewNode(output_repr())
......@@ -921,7 +923,8 @@ PDNode *patterns::FCMKLDNN::operator()(paddle::framework::ir::PDNode *x,
->assert_is_op_output("fc", "Out")
->assert_is_only_output_of_op("fc");
fc_op->LinksFrom({x, fc_weight_var, fc_bias_var}).LinksTo({fc_out_var});
fc_op->LinksFrom({input_var, fc_weight_var, fc_bias_var})
.LinksTo({fc_out_var});
return fc_out_var;
}
......
......@@ -517,6 +517,7 @@ struct FCMKLDNN : public PatternBase {
// declare operator node's name
PATTERN_DECL_NODE(fc);
// declare variable node's name
PATTERN_DECL_NODE(input);
PATTERN_DECL_NODE(weights);
PATTERN_DECL_NODE(bias);
PATTERN_DECL_NODE(output);
......
......@@ -17,6 +17,7 @@
#include <utility>
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/platform/errors.h"
#include "paddle/fluid/string/pretty_log.h"
namespace paddle {
......@@ -43,6 +44,13 @@ void CPUQuantizePass::QuantizeInput(Graph* g, Node* op, Node* input,
std::string input_name, double scale_to_one,
bool is_unsigned,
std::string scale_attr_name) const {
auto inputs = op->Op()->InputNames();
bool name_found =
std::find(inputs.begin(), inputs.end(), input_name) != inputs.end();
PADDLE_ENFORCE_EQ(
name_found, true,
platform::errors::InvalidArgument("%s isn't the input of the %s operator",
input_name, op->Op()->Type()));
unsigned max = is_unsigned ? U8_MAX : S8_MAX;
float scale = scale_to_one * max;
......@@ -122,6 +130,13 @@ void CPUQuantizePass::DequantizeOutput(Graph* g, Node* op, Node* output,
std::string output_name,
double scale_to_one, bool is_unsigned,
std::string scale_attr_name) const {
auto outputs = op->Op()->OutputNames();
bool name_found =
std::find(outputs.begin(), outputs.end(), output_name) != outputs.end();
PADDLE_ENFORCE_EQ(name_found, true,
platform::errors::InvalidArgument(
"%s isn't the output of the %s operator", output_name,
op->Op()->Type()));
unsigned max = is_unsigned ? U8_MAX : S8_MAX;
float scale = scale_to_one * max;
......@@ -228,6 +243,66 @@ void CPUQuantizePass::QuantizeConv(Graph* graph,
PrettyLogDetail(msg_ss.str().c_str());
}
void CPUQuantizePass::QuantizeFc(Graph* graph) const {
GraphPatternDetector gpd;
auto pattern = gpd.mutable_pattern();
patterns::FCMKLDNN fc_pattern{pattern, name_scope_};
auto* fc_input = gpd.mutable_pattern()
->NewNode("fc_quantizer/input")
->AsInput()
->assert_is_op_input("fc", "Input");
fc_pattern(fc_input, false);
int quantize_fc_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "Quantize fc op";
GET_IR_NODE_FROM_SUBGRAPH(fc, fc, fc_pattern);
auto* fc_op_desc = fc->Op();
// skip if should not be quantized
if (fc_op_desc->GetAttrIfExists<bool>("use_quantizer") != true ||
fc_op_desc->GetAttrIfExists<bool>("use_mkldnn") != true)
return;
GET_IR_NODE_FROM_SUBGRAPH(weights, weights, fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(input, input, fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(output, output, fc_pattern);
// get scales calculated after warmup, they scale variables to MAX=1.0
auto scales = Get<VarQuantScale>("quant_var_scales");
auto input_scale = scales[input->Name()].second.data<double>()[0];
bool is_input_unsigned = scales[input->Name()].first;
QuantizeInput(g, fc, input, "Input", input_scale, is_input_unsigned,
"Scale_in");
auto weight_scale_tensor = scales[weights->Name()].second;
EigenVectorArrayMap eigen_tensor{weight_scale_tensor.data<double>(),
weight_scale_tensor.numel(), 1};
eigen_tensor *= static_cast<double>(S8_MAX);
std::vector<float> filter_scale{
weight_scale_tensor.data<double>(),
weight_scale_tensor.data<double>() + weight_scale_tensor.numel()};
fc->Op()->SetAttr("Scale_weights", filter_scale);
auto output_scale = scales[output->Name()].second.data<double>()[0];
bool is_output_unsigned = scales[output->Name()].first;
DequantizeOutput(g, fc, output, "Out", output_scale, is_output_unsigned,
"Scale_out");
++quantize_fc_count;
};
gpd(graph, handler);
AddStatis(quantize_fc_count);
std::stringstream msg_ss;
msg_ss << "--- quantized " << quantize_fc_count << " fc ops";
PrettyLogDetail(msg_ss.str().c_str());
}
void CPUQuantizePass::QuantizePool(Graph* graph) const {
GraphPatternDetector gpd;
auto pattern = gpd.mutable_pattern();
......@@ -418,6 +493,7 @@ void CPUQuantizePass::ApplyImpl(ir::Graph* graph) const {
QuantizeConcat(graph);
QuantizePriorBox(graph);
QuantizeTranspose(graph);
QuantizeFc(graph);
}
} // namespace ir
......
......@@ -46,6 +46,8 @@ class CPUQuantizePass : public FusePassBase {
void QuantizeConv(Graph* graph, bool with_residual_data = false) const;
void QuantizeFc(Graph* graph) const;
void QuantizePool(Graph* graph) const;
void QuantizeConcat(Graph* graph) const;
......
......@@ -62,6 +62,10 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
if (inputs.size() > 1) op->SetInput("W", {inputs[1]});
if (inputs.size() > 2) op->SetInput("Bias", {inputs[2]});
op->SetOutput("Out", {outputs[0]});
op->SetAttr("use_quantizer", use_quantizer);
op->SetAttr("Scale_in", 1.0f);
op->SetAttr("Scale_out", 1.0f);
op->SetAttr("Scale_weights", std::vector<float>{1.0f});
} else if (type == "concat") {
op->SetInput("X", inputs);
op->SetOutput("Out", outputs);
......@@ -71,13 +75,13 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
namespace {
static const std::initializer_list<std::string> variable_names{
"a", "w1", "c", "d", "w2", "e", "f", "g",
"h", "w3", "b1", "i", "j", "w4", "b2"};
"a", "w1", "c", "d", "w2", "e", "f", "g", "h",
"w3", "b1", "i", "j", "w4", "b2", "w5", "b3"};
// (a,w1)->Conv1->c and c->Pool1->d
//
// (d,w2)->Conv2->e and e->Pool2->f
//
// d->Dropout1->g and g->Fc1->h and (h,w3,b1,i)->Conv3->j
// d->Dropout1->g and (g, w5, b3)->Fc1->h and (h,w3,b1,i)->Conv3->j
//
// (d,w4, b2)->Conv4->i
ProgramDesc BuildProgramDesc(bool use_mkldnn, bool use_quantizer) {
......@@ -98,7 +102,8 @@ ProgramDesc BuildProgramDesc(bool use_mkldnn, bool use_quantizer) {
SetOp(&prog, "pool2d", "Pool2", {"e"}, {"f"}, use_mkldnn, use_quantizer);
SetOp(&prog, "dropout", "Dropout1", {"d"}, {"g"}, use_mkldnn);
SetOp(&prog, "fc", "Fc1", {"g"}, {"h"}, use_mkldnn);
SetOp(&prog, "fc", "Fc1", {"g", "w5", "b3"}, {"h"}, use_mkldnn,
use_quantizer);
SetOp(&prog, "conv2d", "Conv3", {"h", "w3", "b1", "i"}, {"j"}, use_mkldnn,
use_quantizer);
......@@ -194,13 +199,13 @@ TEST(CpuQuantizePass, quantize) {
// (d->QUANT3->IN3,w2)->Conv2->OUT3->DEQUANT3->e and
// e->QUANT4->IN4->Pool2->OUT4->DEQUANT4->f
//
// d->Dropout1->g and g->Fc1->h and
// d->Dropout1->g and (g->QUANT8->IN8,w5,b3)->Fc1->OUT7->DEQUANT7->h and
// (h->QUANT5->IN5,w3,b1,i->QUANT6->IN6)->Conv3->OUT5->DEQUANT5->j
//
// (d->QUANT7->IN7,w4, b2)->Conv4->DEQUANT6->OUT6->i
// Insert nodes: 7 Quant + 7 IN + 6 OUT + 6 DEQUANT
int added_nodes = 7 + 7 + 6 + 6;
MainTest(BuildProgramDesc(use_mkldnn, use_quantizer), 4, 2, 7, 6, added_nodes,
// Insert nodes: 8 Quant + 8 IN + 7 OUT + 7 DEQUANT
int added_nodes = 8 + 8 + 7 + 7;
MainTest(BuildProgramDesc(use_mkldnn, use_quantizer), 4, 2, 8, 7, added_nodes,
2.0f * 127);
}
......
......@@ -26,12 +26,11 @@ namespace framework {
namespace ir {
void FCMKLDNNPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE(graph);
PADDLE_ENFORCE_NOT_NULL(graph,
platform::errors::InvalidArgument(
"Pointer to graph argument should not be NULL."));
Init("fc_mkldnn_pass", graph);
auto* scope = param_scope();
PADDLE_ENFORCE(scope);
GraphPatternDetector gpd;
auto* x = gpd.mutable_pattern()
->NewNode("fc_mkldnn_pass/x")
......@@ -49,18 +48,25 @@ void FCMKLDNNPass::ApplyImpl(ir::Graph* graph) const {
return;
}
GET_IR_NODE_FROM_SUBGRAPH(fc, fc, fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(input, input, fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(weights, weights, fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(bias, bias, fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(output, output, fc_pattern);
OpDesc* desc = fc->Op();
auto in_size = fc->inputs[0]->Var()->GetShape().size();
if (in_size != 2 && in_size != 4) {
auto dims = fc->inputs[0]->Var()->GetShape();
auto dim_num = dims.size();
bool are_dims_supported = dim_num == 2 || dim_num == 4;
constexpr size_t height_axis = 2;
constexpr size_t width_axis = 3;
bool is_size_supported =
dim_num == 4 ? (dims[width_axis] == 1 && dims[height_axis] == 1) : true;
if (!are_dims_supported || !is_size_supported) {
VLOG(3) << "Do not enable FC MKL-DNN for dimensions different than 2 & 4";
VLOG(3) << "Or when width and height are different than one";
return;
}
desc->SetAttr("use_mkldnn", true);
PADDLE_ENFORCE(subgraph.count(x));
found_fc_count++;
};
......
......@@ -276,7 +276,7 @@ class MkldnnQuantizerTest : public testing::Test {
std::pair<bool, framework::LoDTensor> GetMaxChScalingFactor(
const framework::LoDTensor& var_tensor, bool is_unsigned) const {
return mkldnn_quantizer->GetMaxChScalingFactor(var_tensor, is_unsigned);
return mkldnn_quantizer->GetMaxChScalingFactor(var_tensor, is_unsigned, 0);
}
std::pair<bool, framework::LoDTensor> GetKLScalingFactor(
......
......@@ -37,6 +37,11 @@ using framework::LoDTensor;
using framework::ir::Graph;
using ConstEigenVectorArrayMap =
Eigen::Map<const Eigen::Array<float, Eigen::Dynamic, 1>>;
using EigenMatrixDoubleArray =
Eigen::Array<double, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
using EigenMatrixArray =
Eigen::Array<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
using ConstEigenMatrixArrayMap = Eigen::Map<const EigenMatrixArray>;
using string::PrettyLogH1;
static LoDTensor CreateScaleTensor(int64_t channels_num = 1);
......@@ -66,7 +71,7 @@ bool AnalysisPredictor::MkldnnQuantizer::CalculateScales() {
bool is_unsigned = false;
bool compute_scale = true;
if (is_output) {
if (op->Type() == "conv2d") {
if (op->Type() == "conv2d" || op->Type() == "fc") {
// output of conv2d with relu must be unsigned
std::string fuse_activation =
op->GetAttrIfExists<std::string>("fuse_activation");
......@@ -138,7 +143,12 @@ void AnalysisPredictor::MkldnnQuantizer::CalculateSingleScale(
scales_[var_name] = GetMaxScalingFactor(var_tensor, is_unsigned);
break;
case ScaleAlgo::MAX_CH:
scales_[var_name] = GetMaxChScalingFactor(var_tensor, is_unsigned);
scales_[var_name] = GetMaxChScalingFactor(var_tensor, is_unsigned,
/*is_transposed*/ false);
break;
case ScaleAlgo::MAX_CH_T:
scales_[var_name] = GetMaxChScalingFactor(var_tensor, is_unsigned,
/*is_transposed*/ true);
break;
case ScaleAlgo::KL:
scales_[var_name] = GetKLScalingFactor(var_tensor, is_unsigned);
......@@ -319,7 +329,7 @@ AnalysisPredictor::MkldnnQuantizer::GetMaxScalingFactor(
std::pair<bool, LoDTensor>
AnalysisPredictor::MkldnnQuantizer::GetMaxChScalingFactor(
const LoDTensor& var_tensor, bool is_unsigned) const {
const LoDTensor& var_tensor, bool is_unsigned, bool is_transposed) const {
PADDLE_ENFORCE(var_tensor.dims().size() > 0, "Tensor dimension is empty.");
ConstEigenVectorArrayMap eigen_tensor{var_tensor.data<float>(),
......@@ -331,18 +341,23 @@ AnalysisPredictor::MkldnnQuantizer::GetMaxChScalingFactor(
"Tensor is claimed to be unsigned, but its min value (%f) is < 0.0",
min_val);
int channels = var_tensor.dims()[0];
LoDTensor scale_tensor = CreateScaleTensor(channels);
auto* scale_ptr = scale_tensor.mutable_data<double>(CPUPlace());
for (int i = 0; i < channels; ++i) {
const auto tensor = var_tensor.Slice(i, i + 1);
auto dims = var_tensor.dims();
constexpr int num_col_dims = 1;
auto flattened_dims = framework::flatten_to_2d(dims, num_col_dims);
ConstEigenMatrixArrayMap eigen_tensor_mat{
var_tensor.data<float>(), flattened_dims[0], flattened_dims[1]};
ConstEigenVectorArrayMap eigen_tensor{tensor.data<float>(), tensor.numel(),
1};
float max_abs = eigen_tensor.abs().maxCoeff();
scale_ptr[i] = 1.0 / max_abs;
EigenMatrixDoubleArray scales;
if (is_transposed) {
scales = 1.0 / eigen_tensor_mat.cast<double>().abs().colwise().maxCoeff();
} else {
scales = 1.0 / eigen_tensor_mat.cast<double>().abs().rowwise().maxCoeff();
}
int output_channel_axis = is_transposed;
int channels = dims[output_channel_axis];
LoDTensor scale_tensor = CreateScaleTensor(channels);
auto* scale_ptr = scale_tensor.mutable_data<double>(CPUPlace());
std::copy(scales.data(), scales.data() + scales.size(), scale_ptr);
return std::make_pair(is_unsigned, scale_tensor);
}
......
......@@ -79,7 +79,8 @@ class AnalysisPredictor::MkldnnQuantizer {
const framework::LoDTensor& var_tensor, bool is_unsigned) const;
std::pair<bool, framework::LoDTensor> GetMaxChScalingFactor(
const framework::LoDTensor& var_tensor, bool is_unsigned) const;
const framework::LoDTensor& var_tensor, bool is_unsigned,
bool is_transposed) const;
std::pair<bool, framework::LoDTensor> GetMaxScalingFactor(
const framework::LoDTensor& var_tensor, bool is_unsigned) const;
......
......@@ -37,6 +37,11 @@ MkldnnQuantizerConfig::MkldnnQuantizerConfig() {
rules_["transpose2"]["X"] = ScaleAlgo::KL;
rules_["transpose2"]["Out"] = ScaleAlgo::NONE;
rules_["fc"]["Input"] = ScaleAlgo::KL;
rules_["fc"]["W"] = ScaleAlgo::MAX_CH_T;
rules_["fc"]["Bias"] = ScaleAlgo::NONE;
rules_["fc"]["Out"] = ScaleAlgo::KL;
}
ScaleAlgo MkldnnQuantizerConfig::scale_algo(
......
......@@ -26,10 +26,12 @@ namespace paddle {
// Algorithms for finding scale of quantized Tensors.
enum class ScaleAlgo {
NONE, // Do not compute scale
MAX, // Find scale based on the maximum absolute value
MAX_CH, // Find scale based on the maximum absolute value per channel
KL, // Find scale based on KL Divergence
NONE, // Do not compute scale
MAX, // Find scale based on the max absolute value
MAX_CH, // Find scale based on the max absolute value per output channel
MAX_CH_T, // Find scale based on the max absolute value per output channel
// of a transposed tensor
KL, // Find scale based on KL Divergence
};
struct MkldnnQuantizerConfig {
......
......@@ -93,13 +93,21 @@ class FCOp : public framework::OperatorWithKernel {
const framework::ExecutionContext& ctx) const override {
framework::LibraryType library = framework::LibraryType::kPlain;
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
int customized_type_value =
framework::OpKernelType::kDefaultCustomizedTypeValue;
auto input_data_type =
OperatorWithKernel::IndicateVarDataType(ctx, "Input");
if (ctx.Attr<bool>("use_mkldnn")) {
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
using framework::proto::VarType;
customized_type_value = (input_data_type == VarType::INT8 ||
input_data_type == VarType::UINT8)
? kFCMKLDNNINT8
: kFCMKLDNNFP32;
}
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace(),
layout, library);
return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout,
library, customized_type_value);
}
};
......@@ -132,6 +140,27 @@ class FCOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<bool>(framework::kAllKernelsMustComputeRuntimeShape,
"Skip calling InferShape() function in the runtime.")
.SetDefault(true);
/* int8 parameters */
AddAttr<bool>("use_quantizer",
"(bool, default false) "
"Set to true for operators that should be quantized and use "
"int8 kernel. "
"Only used on CPU.")
.SetDefault(false);
AddAttr<float>("Scale_in",
"(float, default 1.0f), The quantize scale of input data")
.SetDefault(1.0f);
AddAttr<std::vector<float>>("Scale_weights",
"(std::vector<float>, default {1.0f}), The "
"quantize scale of weights data")
.SetDefault({1.0f});
AddAttr<float>("Scale_out",
"(float, default 1.0f), The quantize scale of output data")
.SetDefault(1.0f);
AddAttr<bool>("force_fp32_output",
"(bool, default false) Force INT8 kernel output FP32, only "
"used in MKL-DNN INT8")
.SetDefault(false);
AddComment(R"DOC(
Fully Connected Operator.
......
......@@ -21,6 +21,7 @@ limitations under the License. */
namespace paddle {
namespace operators {
enum { kFCMKLDNNFP32 = 1, kFCMKLDNNINT8 = 2 };
using Tensor = framework::Tensor;
......
......@@ -78,7 +78,6 @@ class DeQuantOpKernel : public framework::OpKernel<T> {
auto dst_md = platform::MKLDNNMemDesc(
{dst_tz}, memory::data_type::f32,
platform::MKLDNNFormatForSize(dst_tz.size(), memory::format::nchw));
auto dst_pd = mkldnn::memory::primitive_desc(dst_md, engine);
dst_memory = std::make_shared<mkldnn::memory>(
dst_pd, to_void_cast<float>(output_data));
......
......@@ -37,7 +37,7 @@ using mkldnn::primitive;
using mkldnn::stream;
using mkldnn::prop_kind;
template <typename T>
template <typename T_in, typename T_w, typename T_out>
class FCPrimitiveFactory {
public:
explicit FCPrimitiveFactory(const mkldnn::engine& engine) : engine_(engine) {}
......@@ -47,19 +47,29 @@ class FCPrimitiveFactory {
const Tensor* bias, LoDTensor* output,
const ExecutionContext& ctx) {
RecomputeOutputDims(ctx, input, weights, output);
// If primitive has already been created and cached, don't create new one,
// but update input and output data pointers and return it.
if (fc_) {
UpdateDataPointers(ctx, output, input);
return *fc_;
}
auto src_desc = CreateMemDescriptor(input, input->format());
input_ = CreateMemory(src_desc, input);
auto src_desc = CreateMemDescriptor<T_in>(input, input->format());
input_ = CreateMemory<T_in>(src_desc, input);
// Since MKL-DNN doesn't support 4D column-major data formats in
// inner_product
// primitive, transpose the weights to be in row-major format
weights_ = TransposeWeights(weights);
if (src_desc.data.ndims == 4) {
weights_ = CreateFourDimWeightsMemory(input, weights);
}
// If int8 data type is desired, weights are quantized to signed int8
QuantizeWeights(ctx);
auto dst_desc = CreateMemDescriptor(output, MKLDNNMemoryFormat::any);
// Choose MKLDNNMemoryFormat::any so that MKL-DNN can determine itself what
// is the best format for output during the creation of inner product
// primitive descriptor
auto dst_desc = CreateMemDescriptor<T_out>(output, MKLDNNMemoryFormat::any);
fc_ = CreateFcPrimitive(*input_, *weights_, dst_desc, bias, output, ctx);
return *fc_;
......@@ -68,14 +78,18 @@ class FCPrimitiveFactory {
private:
void UpdateDataPointers(const ExecutionContext& ctx, Tensor* out,
const Tensor* in) {
input_->set_data_handle(const_cast<T*>(in->data<T>()));
output_->set_data_handle(out->mutable_data<T>(ctx.GetPlace()));
input_->set_data_handle(to_void_cast(in->data<T_in>()));
output_->set_data_handle(out->mutable_data<T_out>(ctx.GetPlace()));
// If the primitive exists, but the output tensor has changed its
// variable, update its format to what has been determined in first
// call to CreateFcPrimitive method.
if (out->format() == MKLDNNMemoryFormat::format_undef) {
auto output_format = platform::GetMKLDNNFormat(*output_);
out->set_format((MKLDNNMemoryFormat)output_format);
}
}
// Choose weight memory format based on input memory format
MKLDNNMemoryFormat MatchWeightFormat(MKLDNNMemoryFormat fmt) {
using format = MKLDNNMemoryFormat;
switch (fmt) {
......@@ -85,11 +99,14 @@ class FCPrimitiveFactory {
return format::oIhw8i;
case format::nchw:
return format::oihw;
case format::nhwc:
return format::hwio;
default:
return format::format_undef;
}
}
// Convert data from one data format to another
mkldnn::memory Reorder(const memory::desc& src_desc,
const memory::desc& dst_desc, const void* src_data) {
auto src_mem = memory({src_desc, engine_}, const_cast<void*>(src_data));
......@@ -101,18 +118,46 @@ class FCPrimitiveFactory {
return dst_mem;
}
// Convert data from one data format to another and rescale it.
// If the desired data type is (un)signed int8, quantization occurs here.
mkldnn::memory Reorder(const memory& src_mem,
const memory::primitive_desc& dst_pd,
const std::vector<float>& scale_data) {
mkldnn::memory dst_mem = mkldnn::memory(dst_pd);
mkldnn::primitive_attr attributes;
// According to MKL-DNN's documentation mask determines along which
// dimensions should the scale be applied.
// 0 - Single scale applied to whole tensor
// 1 - Apply Scale along a slice of each dimension which index is 1.
// In case of weights quantization, that dimension is output,
// becuase we perform per-output-channel quantization
int mask = CreateMask(0, scale_data.size() > 1);
attributes.set_output_scales(mask, scale_data);
auto reorder =
mkldnn::reorder(mkldnn::reorder::primitive_desc(
src_mem.get_primitive_desc(), dst_pd, attributes),
src_mem, dst_mem);
stream(stream::kind::eager).submit({reorder}).wait();
return dst_mem;
}
template <typename T>
static mkldnn::memory::desc CreateMemDescriptor(const std::vector<int>& dims,
MKLDNNMemoryFormat format) {
return platform::MKLDNNMemDesc(dims, platform::MKLDNNGetDataType<T>(),
format);
}
template <typename T>
static mkldnn::memory::desc CreateMemDescriptor(const Tensor* tensor,
MKLDNNMemoryFormat format) {
auto dims = framework::vectorize<int>(tensor->dims());
return CreateMemDescriptor(dims, format);
return CreateMemDescriptor<T>(dims, format);
}
template <typename T>
mkldnn::memory CreateMemory(const mkldnn::memory::desc& desc,
const Tensor* tensor) {
return CreateMemory(desc, tensor->data<T>());
......@@ -123,12 +168,102 @@ class FCPrimitiveFactory {
return memory({desc, engine_}, const_cast<void*>(data));
}
// Transpose weights through MKL-DNN's reorder from io to oi format.
mkldnn::memory TransposeWeights(const Tensor* weights) {
auto dims = framework::vectorize<int>(weights->dims());
std::swap(dims[0], dims[1]); // Correct output dimensions
auto src_desc = CreateMemDescriptor(dims, MKLDNNMemoryFormat::io);
auto dst_desc = CreateMemDescriptor(dims, MKLDNNMemoryFormat::oi);
return Reorder(src_desc, dst_desc, weights->data<T>());
auto src_desc = CreateMemDescriptor<float>(dims, MKLDNNMemoryFormat::io);
auto dst_desc = CreateMemDescriptor<float>(dims, MKLDNNMemoryFormat::oi);
return Reorder(src_desc, dst_desc, weights->data<float>());
}
// Compute the bias scales so that its values correspond to the
// scale of data being an output of weights and input multiplication
std::vector<float> ComputeBiasScales(const ExecutionContext& ctx) {
auto scale_in_data = ctx.Attr<float>("Scale_in");
auto scale_weights_data = ctx.Attr<std::vector<float>>("Scale_weights");
const size_t weight_scales_num = scale_weights_data.size();
std::vector<float> bias_scales(weight_scales_num);
#pragma omp parallel for
for (size_t i = 0; i < weight_scales_num; i++) {
if (scale_weights_data[i] == 0.0)
bias_scales[i] = 1.0f;
else
bias_scales[i] = scale_in_data * scale_weights_data[i];
}
return bias_scales;
}
// Correct output scale, to take into account scaling of input and weights
// Since the data that comes out of input and weight multiplication is
// scaled with its own scales, this data needs to be divided by
// those scales to normalise them back to what their floating-point range
// was. Then we multiply them by desired output scale we want on the output.
std::vector<float> ComputeOutputShiftScale(const ExecutionContext& ctx) {
auto scale_in_data = ctx.Attr<float>("Scale_in");
auto scale_weights_data = ctx.Attr<std::vector<float>>("Scale_weights");
// If the output will be in floats, we don't multiply by scale_out.
auto scale_out_data = ctx.Attr<bool>("force_fp32_output")
? 1.0f
: ctx.Attr<float>("Scale_out");
const size_t weight_scales_num = scale_weights_data.size();
std::vector<float> output_shift_scale(weight_scales_num);
#pragma omp parallel for
for (size_t i = 0; i < weight_scales_num; i++) {
if (scale_weights_data[i] == 0.0)
output_shift_scale[i] = scale_out_data;
else
output_shift_scale[i] =
scale_out_data / (scale_in_data * scale_weights_data[i]);
}
return output_shift_scale;
}
// Computing MKL-DNN's scaling mask which determines along which dimension
// slice should the scaling be applied. For more data plase refer to:
// https://intel.github.io/mkl-dnn/group__c__api__attributes.html
// Section dnnl_status_t DNNL_API dnnl_primitive_attr_set_output_scales
int CreateMask(int slice_dimension, bool is_multi_channel_quantizied) {
return is_multi_channel_quantizied ? 1 << slice_dimension : 0;
}
void QuantizeWeights(const ExecutionContext& ctx) {
auto quantized_desc = weights_->get_primitive_desc().desc();
quantized_desc.data.data_type =
(mkldnn_data_type_t)platform::MKLDNNGetDataType<T_w>();
weights_ = Reorder(*weights_, {quantized_desc, engine_},
ctx.Attr<std::vector<float>>("Scale_weights"));
}
void QuantizeBias(const inner_product_forward::primitive_desc& fc_prim_desc,
const ExecutionContext& ctx) {
auto bias_scales = ComputeBiasScales(ctx);
bias_ = Reorder(*bias_, fc_prim_desc.bias_primitive_desc(), bias_scales);
}
// Fuse relu into FC with activation type attribute has been set to 'relu'
mkldnn::primitive_attr CreatePostOps(const ExecutionContext& ctx) {
mkldnn::primitive_attr attributes;
mkldnn::post_ops post_operations;
auto output_shift_scale = ComputeOutputShiftScale(ctx);
int mask = CreateMask(1, output_shift_scale.size() > 1);
attributes.set_output_scales(mask, output_shift_scale);
if (ctx.Attr<std::string>("activation_type") == "relu") {
constexpr float scale = 1.0f;
constexpr float negative_slope = 0.0f;
constexpr float placeholder = 1.0f; // beta
post_operations.append_eltwise(scale, mkldnn::algorithm::eltwise_relu,
negative_slope, placeholder);
}
attributes.set_post_ops(post_operations);
return attributes;
}
inner_product_forward CreateFcPrimitive(const memory& src_memory,
......@@ -136,21 +271,34 @@ class FCPrimitiveFactory {
const memory::desc& dst_desc,
const Tensor* bias, Tensor* output,
const ExecutionContext& ctx) {
// Acquire descriptors needed for creation of inner_product primitive
// descriptor
const auto weights_desc = weights_memory.get_primitive_desc().desc();
const auto src_desc = src_memory.get_primitive_desc().desc();
// Based on provided attributes, create attributes used by MKL-DNN to
// enable fused post-op activations such as 'relu'
const auto attrs = CreatePostOps(ctx);
// If bias exists, create inner_product primitive with or without bias
if (bias) {
auto bias_desc = CreateMemDescriptor(bias, bias->format());
bias_ = CreateMemory(bias_desc, bias);
auto bias_desc = CreateMemDescriptor<float>(bias, bias->format());
bias_ = CreateMemory<float>(bias_desc, bias);
// Create inner_product descriptor. At this point the format of output
// is determined.
auto fc_prim_desc =
CreateFcPrimDesc(src_desc, weights_desc, bias_desc, dst_desc);
CreateFcPrimDesc(src_desc, weights_desc, bias_desc, dst_desc, attrs);
// If int8 is desired, quantize bias into 32-bit signed int
QuantizeBias(fc_prim_desc, ctx);
// Based on format determined by inner_product, create output in desired
// memory format
output_ = CreateDstMemory(fc_prim_desc, ctx, output);
// Return MKL-DNN primitive ready to be fed into pipeline and executed
return inner_product_forward(fc_prim_desc, src_memory, weights_memory,
*bias_, *output_);
} else {
auto fc_prim_desc = CreateFcPrimDesc(src_desc, weights_desc, dst_desc);
auto fc_prim_desc =
CreateFcPrimDesc(src_desc, weights_desc, dst_desc, attrs);
output_ = CreateDstMemory(fc_prim_desc, ctx, output);
return inner_product_forward(fc_prim_desc, src_memory, weights_memory,
......@@ -162,24 +310,39 @@ class FCPrimitiveFactory {
const mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& weights_desc,
const mkldnn::memory::desc& bias_desc,
const mkldnn::memory::desc& dst_desc) {
const mkldnn::memory::desc& dst_desc,
const mkldnn::primitive_attr& attrs) {
auto fc_desc =
inner_product_forward::desc(prop_kind::forward_scoring, input_desc,
weights_desc, bias_desc, dst_desc);
return inner_product_forward::primitive_desc(fc_desc, engine_);
return inner_product_forward::primitive_desc(fc_desc, attrs, engine_);
}
mkldnn::inner_product_forward::primitive_desc CreateFcPrimDesc(
const mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& weights_desc,
const mkldnn::memory::desc& dst_desc) {
const mkldnn::memory::desc& dst_desc,
const mkldnn::primitive_attr& attrs) {
auto fc_desc = inner_product_forward::desc(prop_kind::forward, input_desc,
weights_desc, dst_desc);
return inner_product_forward::primitive_desc(fc_desc, engine_);
return inner_product_forward::primitive_desc(fc_desc, attrs, engine_);
}
// Since MKL-DNN requires the number of input dimensions to be
// equal to the number of weight dimensions, we have to convert
// weights to 4D memory if input is 4D. It also requires that
// all dimensions of weights and inputs agree, with an exception
// for the batch size and number of output channels (the first dim).
// In order to perform that we have to prepare the memory descriptor
// by hand, as MKL-DNN's reorder does not support conversion
// from one dimensionality to another. Hence, we set
// the first dimension of weights to resemble number of outputs
// and then we use the sizes of number of input channels as well
// as image width and height for latter dimensions. Then we create
// memories, find a format corresponding with input format and
// perform a converion.
mkldnn::memory CreateFourDimWeightsMemory(const Tensor* input,
const Tensor* weights) {
auto input_dims = framework::vectorize<int>(input->dims());
......@@ -187,19 +350,22 @@ class FCPrimitiveFactory {
auto dims = {weight_dims[1], input_dims[1], input_dims[2], input_dims[3]};
auto dst_format = MatchWeightFormat(input->format());
auto src_desc = CreateMemDescriptor(dims, MKLDNNMemoryFormat::oihw);
auto dst_desc = CreateMemDescriptor(dims, dst_format);
auto src_desc = CreateMemDescriptor<float>(dims, MKLDNNMemoryFormat::oihw);
auto dst_desc = CreateMemDescriptor<float>(dims, dst_format);
return Reorder(src_desc, dst_desc, weights_->get_data_handle());
}
// Create output memory based on output tensor and inner_product
// primitive descriptor format chosen for output
mkldnn::memory CreateDstMemory(
const mkldnn::inner_product_forward::primitive_desc& fc_prim_desc,
const ExecutionContext& ctx, Tensor* output) {
auto dst_prim_desc = fc_prim_desc.dst_primitive_desc();
auto buffer_size = dst_prim_desc.get_size();
T* output_data = output->mutable_data<T>(ctx.GetPlace(), buffer_size);
memory dst_mem(dst_prim_desc, to_void_cast<T>(output_data));
T_out* output_data =
output->mutable_data<T_out>(ctx.GetPlace(), buffer_size);
memory dst_mem(dst_prim_desc, to_void_cast<T_out>(output_data));
output->set_format(platform::GetMKLDNNFormat(dst_mem));
return dst_mem;
}
......@@ -227,30 +393,63 @@ class FCPrimitiveFactory {
boost::optional<inner_product_forward> fc_;
};
template <typename T>
std::shared_ptr<FCPrimitiveFactory<T>> GetPrimitiveFactory(
const MKLDNNDeviceContext& dev_ctx, const ExecutionContext& ctx,
const Tensor* input, const Tensor* weights,
const mkldnn::engine& mkldnn_engine) {
// Attempt to fetch cached primitive factory based on provided parameters
// of input format, weight dimensions and output name.
// If not cached, create a new one.
template <typename T_in, typename T_w, typename T_out>
static std::shared_ptr<FCPrimitiveFactory<T_in, T_w, T_out>>
GetPrimitiveFactory(const MKLDNNDeviceContext& dev_ctx,
const ExecutionContext& ctx, const Tensor* input,
const Tensor* weights,
const mkldnn::engine& mkldnn_engine) {
const std::string key = platform::CreateKey(
platform::ThreadIDasStr(), input->format(),
framework::vectorize<int>(weights->dims()), ctx.op().Output("Out"));
auto prim_creator =
std::static_pointer_cast<FCPrimitiveFactory<T>>(dev_ctx.GetBlob(key));
std::static_pointer_cast<FCPrimitiveFactory<T_in, T_w, T_out>>(
dev_ctx.GetBlob(key));
if (prim_creator == nullptr) {
prim_creator = std::make_shared<FCPrimitiveFactory<T>>(mkldnn_engine);
prim_creator =
std::make_shared<FCPrimitiveFactory<T_in, T_w, T_out>>(mkldnn_engine);
dev_ctx.SetBlob(key, prim_creator);
}
return prim_creator;
}
template <typename T>
class FCMKLDNNOpKernel : public framework::OpKernel<T> {
// Choose appropriate primitive factory implementation based on inferred
// output type (uint8, int8 or float).
template <typename T_in, typename T_w>
static inner_product_forward GetFcPrimitive(
const MKLDNNDeviceContext& dev_ctx, const ExecutionContext& ctx,
const LoDTensor* input, const Tensor* w, const Tensor* bias,
LoDTensor* output, const mkldnn::engine& mkldnn_engine, bool fuse_relu,
bool force_fp32_output) {
constexpr bool is_int8 =
std::is_same<T_in, int8_t>::value || std::is_same<T_in, uint8_t>::value;
if (!is_int8 || force_fp32_output) {
return GetPrimitiveFactory<T_in, T_w, float>(dev_ctx, ctx, input, w,
mkldnn_engine)
->CreateFcPrimitive(input, w, bias, output, ctx);
} else if (fuse_relu) {
return GetPrimitiveFactory<T_in, T_w, uint8_t>(dev_ctx, ctx, input, w,
mkldnn_engine)
->CreateFcPrimitive(input, w, bias, output, ctx);
} else {
return GetPrimitiveFactory<T_in, T_w, int8_t>(dev_ctx, ctx, input, w,
mkldnn_engine)
->CreateFcPrimitive(input, w, bias, output, ctx);
}
}
template <typename T_in, typename T_w>
class FCMKLDNNOpKernel : public framework::OpKernel<T_in> {
public:
void Compute(const paddle::framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
"It must use CPUPlace.");
PADDLE_ENFORCE_EQ(
platform::is_cpu_place(ctx.GetPlace()), true,
platform::errors::PreconditionNotMet("FC MKL-DNN must use CPUPlace."));
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine();
......@@ -259,9 +458,12 @@ class FCMKLDNNOpKernel : public framework::OpKernel<T> {
auto bias = ctx.Input<Tensor>("Bias");
auto output = ctx.Output<LoDTensor>("Out");
auto prim_creator =
GetPrimitiveFactory<T>(dev_ctx, ctx, input, w, mkldnn_engine);
auto fc = prim_creator->CreateFcPrimitive(input, w, bias, output, ctx);
bool fuse_relu = ctx.Attr<std::string>("activation_type") == "relu";
bool force_fp32_output = ctx.Attr<bool>("force_fp32_output");
auto fc =
GetFcPrimitive<T_in, T_w>(dev_ctx, ctx, input, w, bias, output,
mkldnn_engine, fuse_relu, force_fp32_output);
stream(stream::kind::eager).submit({fc}).wait();
output->set_layout(DataLayout::kMKLDNN);
......@@ -270,5 +472,18 @@ class FCMKLDNNOpKernel : public framework::OpKernel<T> {
} // namespace operators
} // namespace paddle
REGISTER_OP_KERNEL(fc, MKLDNN, ::paddle::platform::CPUPlace,
paddle::operators::FCMKLDNNOpKernel<float>);
// Weights of FC are by default stored using fp32, template argument of weight
// data type implies their destination data type. (What's eventually going to
// be used during computations of kernel).
namespace ops = paddle::operators;
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(fc, MKLDNN, ::paddle::platform::CPUPlace,
FP32, ops::kFCMKLDNNFP32,
ops::FCMKLDNNOpKernel<float, float>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(fc, MKLDNN, ::paddle::platform::CPUPlace,
U8, ops::kFCMKLDNNINT8,
ops::FCMKLDNNOpKernel<uint8_t, int8_t>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(fc, MKLDNN, ::paddle::platform::CPUPlace,
S8, ops::kFCMKLDNNINT8,
ops::FCMKLDNNOpKernel<int8_t, int8_t>);
......@@ -42,6 +42,7 @@ class TestFCMKLDNNOp(OpTest):
def setUp(self):
self.op_type = "fc"
self._cpu_only = True
self.use_mkldnn = True
self.create_data()
self.inputs = {'Input': self.matrix.input, 'W': self.matrix.weights}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册