未验证 提交 fed0ed34 编写于 作者: S Sylwester Fraczek 提交者: GitHub

add fc-residual quantization (#46917)

* add fc-residual quantization

* revert removal of check for use_mkldnn

* fix bug

* add disable_logs

* review fix

call twice AreScalesPresntForNodes instead of if-else

* rewrite residual input to output

* revert fc mkldnn taking residual data

* format fix

* fix LoDTensor->DenseTensor

* LoDTensor->DenseTensor

* output->input

* revert changes to unsupported script

revert changes to unsupported script

* remove fc residualdata from output blocklist in cpu_bfloat16_pass.cc
上级 41483383
...@@ -1163,21 +1163,12 @@ PDNode *patterns::FCMKLDNN::operator()(bool with_residual_data) { ...@@ -1163,21 +1163,12 @@ PDNode *patterns::FCMKLDNN::operator()(bool with_residual_data) {
if (with_residual_data) { if (with_residual_data) {
auto res_fc_var = pattern->NewNode(residual_data_repr()) auto res_fc_var = pattern->NewNode(residual_data_repr())
->AsInput() ->AsInput()
->assert_is_op_input("fc") ->assert_is_op_input("fc", "ResidualData");
// assert_is_op_input with two arguments doesn't work
// because ResidualData in FC is set as output with
// SetOutput so we do custom assert output
->assert_more([&](Node *x) {
for (auto *op : x->outputs)
if (IsNthOutput(x, op, "ResidualData", 0))
return true;
return false;
});
links_from.push_back(res_fc_var); links_from.push_back(res_fc_var);
} else { } else {
fc_op->assert_more([&](Node *x) { fc_op->assert_more([&](Node *x) {
if (!HasOutput(x, "ResidualData") || if (!HasInput(x, "ResidualData") ||
x->Op()->Output("ResidualData").size() == 0) x->Op()->Input("ResidualData").size() == 0)
return true; return true;
return false; return false;
}); });
......
...@@ -200,7 +200,6 @@ class DeQuantizer final : public Quanter { ...@@ -200,7 +200,6 @@ class DeQuantizer final : public Quanter {
std::unordered_map<std::string, std::vector<std::string>> block_list{ std::unordered_map<std::string, std::vector<std::string>> block_list{
{"layer_norm", {"layer_norm",
{"Mean", "Variance"}}, // not used in inference in MKLDNN {"Mean", "Variance"}}, // not used in inference in MKLDNN
{"fc", {"ResidualData"}}, // artifical output, already dequantized
{"matmul", {"ResidualData"}}, // artifical output, already dequantized {"matmul", {"ResidualData"}}, // artifical output, already dequantized
{"matmul_v2", {"matmul_v2",
{"ResidualData"}}}; // artifical output, already dequantized {"ResidualData"}}}; // artifical output, already dequantized
......
...@@ -515,16 +515,17 @@ void CPUQuantizePass::QuantizeConv(Graph* graph, ...@@ -515,16 +515,17 @@ void CPUQuantizePass::QuantizeConv(Graph* graph,
((with_residual_data) ? "with residual connection" : "")); ((with_residual_data) ? "with residual connection" : ""));
} }
void CPUQuantizePass::QuantizeFc(Graph* graph) const { void CPUQuantizePass::QuantizeFc(Graph* graph, bool with_residual_data) const {
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto pattern = gpd.mutable_pattern(); auto pattern = gpd.mutable_pattern();
patterns::FCMKLDNN fc_pattern{pattern, name_scope_}; patterns::FCMKLDNN fc_pattern{pattern, name_scope_};
fc_pattern(false /* with_residual */); fc_pattern(with_residual_data);
int quantize_fc_count = 0; int quantize_fc_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
VLOG(4) << "Quantize fc op"; VLOG(4) << "Quantize fc op " << (with_residual_data ? "with" : "without")
<< " residual data";
GET_IR_NODE_FROM_SUBGRAPH(fc, fc, fc_pattern); GET_IR_NODE_FROM_SUBGRAPH(fc, fc, fc_pattern);
// skip if should not be quantized // skip if should not be quantized
...@@ -532,6 +533,7 @@ void CPUQuantizePass::QuantizeFc(Graph* graph) const { ...@@ -532,6 +533,7 @@ void CPUQuantizePass::QuantizeFc(Graph* graph) const {
LogQuantizationDisabled(fc); LogQuantizationDisabled(fc);
return; return;
} }
if (!fc->Op()->GetAttrIfExists<bool>("use_mkldnn")) { if (!fc->Op()->GetAttrIfExists<bool>("use_mkldnn")) {
MarkAndLogCannotQuantizeOp(fc, "use_mkldnn attribute set to false"); MarkAndLogCannotQuantizeOp(fc, "use_mkldnn attribute set to false");
return; return;
...@@ -546,6 +548,26 @@ void CPUQuantizePass::QuantizeFc(Graph* graph) const { ...@@ -546,6 +548,26 @@ void CPUQuantizePass::QuantizeFc(Graph* graph) const {
return; return;
} }
if (with_residual_data) {
GET_IR_NODE_FROM_SUBGRAPH(residual_data, residual_data, fc_pattern);
if (!AreScalesPresentForNodes({residual_data})) {
MarkAndLogCannotQuantizeOp(fc, "No scale available for the operator");
return;
}
bool is_residual_unsigned{false};
auto residual_scale =
GetScaleValueForNode(residual_data, &is_residual_unsigned);
QuantizeInput(g,
fc,
residual_data,
"ResidualData",
residual_scale,
is_residual_unsigned,
"Scale_in_eltwise");
}
bool is_input_unsigned{false}; bool is_input_unsigned{false};
auto input_scale = GetScaleValueForNode(input, &is_input_unsigned); auto input_scale = GetScaleValueForNode(input, &is_input_unsigned);
QuantizeInput( QuantizeInput(
...@@ -576,7 +598,9 @@ void CPUQuantizePass::QuantizeFc(Graph* graph) const { ...@@ -576,7 +598,9 @@ void CPUQuantizePass::QuantizeFc(Graph* graph) const {
gpd(graph, handler); gpd(graph, handler);
AddStatis(quantize_fc_count); AddStatis(quantize_fc_count);
LogQuantizedOpsCounter("fc", quantize_fc_count); LogQuantizedOpsCounter("fc",
quantize_fc_count,
with_residual_data ? "with residual connection" : "");
} }
void CPUQuantizePass::QuantizePool(Graph* graph) const { void CPUQuantizePass::QuantizePool(Graph* graph) const {
...@@ -1228,7 +1252,8 @@ void CPUQuantizePass::ApplyImpl(ir::Graph* graph) const { ...@@ -1228,7 +1252,8 @@ void CPUQuantizePass::ApplyImpl(ir::Graph* graph) const {
QuantizePool(graph); QuantizePool(graph);
QuantizeConcat(graph); QuantizeConcat(graph);
QuantizePriorBox(graph); QuantizePriorBox(graph);
QuantizeFc(graph); QuantizeFc(graph, false /* with_residual_data */);
QuantizeFc(graph, true /* with_residual_data */);
QuantizeMatmul(graph, false /* with_residual_data */); QuantizeMatmul(graph, false /* with_residual_data */);
QuantizeMatmul(graph, true /* with_residual_data */); QuantizeMatmul(graph, true /* with_residual_data */);
QuantizeImmutable(graph, "reshape2", "X"); QuantizeImmutable(graph, "reshape2", "X");
......
...@@ -49,8 +49,8 @@ class CPUQuantizePass : public FusePassBase { ...@@ -49,8 +49,8 @@ class CPUQuantizePass : public FusePassBase {
protected: protected:
void ApplyImpl(ir::Graph* graph) const override; void ApplyImpl(ir::Graph* graph) const override;
void QuantizeConv(Graph* graph, bool with_residual_data = false) const; void QuantizeConv(Graph* graph, bool with_residual_data) const;
void QuantizeFc(Graph* graph) const; void QuantizeFc(Graph* graph, bool with_residual_data) const;
void QuantizePool(Graph* graph) const; void QuantizePool(Graph* graph) const;
void QuantizeConcat(Graph* graph) const; void QuantizeConcat(Graph* graph) const;
void QuantizePriorBox(Graph* graph) const; void QuantizePriorBox(Graph* graph) const;
......
...@@ -337,7 +337,8 @@ void CPUQuantizeSquashPass::OpDequantSquash(Graph* graph) const { ...@@ -337,7 +337,8 @@ void CPUQuantizeSquashPass::OpDequantSquash(Graph* graph) const {
if (dequant_in->outputs.size() == 1) { if (dequant_in->outputs.size() == 1) {
if (any_op->Op()->Type() == "conv2d" || if (any_op->Op()->Type() == "conv2d" ||
any_op->Op()->Type() == "conv2d_transpose") { any_op->Op()->Type() == "conv2d_transpose" ||
any_op->Op()->Type() == "fc") {
// do not squash if fuse residual connection is true // do not squash if fuse residual connection is true
// because residual fusion does not support force output with fp32 // because residual fusion does not support force output with fp32
if (any_op->Op()->GetAttrIfExists<bool>("fuse_residual_connection")) if (any_op->Op()->GetAttrIfExists<bool>("fuse_residual_connection"))
...@@ -418,8 +419,8 @@ void CPUQuantizeSquashPass::MultipleQuantizeSquash(Graph* graph) const { ...@@ -418,8 +419,8 @@ void CPUQuantizeSquashPass::MultipleQuantizeSquash(Graph* graph) const {
last_op_names.begin(), last_op_names.end(), quant_out->Name()), last_op_names.begin(), last_op_names.end(), quant_out->Name()),
last_op_names.end()); last_op_names.end());
last_op_names.push_back(first_quant_out->Name()); last_op_names.push_back(first_quant_out->Name());
last_op->Op()->SetInput(last_op_input_name, last_op_op->SetInput(last_op_input_name,
std::vector<std::string>(last_op_names)); std::vector<std::string>(last_op_names));
IR_NODE_LINK_TO(first_quant_out, last_op); IR_NODE_LINK_TO(first_quant_out, last_op);
GraphSafeRemoveNodes(graph, {quant_op, quant_out}); GraphSafeRemoveNodes(graph, {quant_op, quant_out});
......
...@@ -119,7 +119,7 @@ GraphWithStats FCResidualConnectionMKLDNNFusePass::FuseFC( ...@@ -119,7 +119,7 @@ GraphWithStats FCResidualConnectionMKLDNNFusePass::FuseFC(
return; return;
} }
fc_op->Op()->SetOutput("ResidualData", {residual_data->Name()}); fc_op->Op()->SetInput("ResidualData", {residual_data->Name()});
fc_op->Op()->SetOutput("Out", {elementwise_out->Name()}); fc_op->Op()->SetOutput("Out", {elementwise_out->Name()});
fc_op->Op()->SetAttr("fuse_residual_connection", true); fc_op->Op()->SetAttr("fuse_residual_connection", true);
......
...@@ -29,18 +29,16 @@ namespace ir { ...@@ -29,18 +29,16 @@ namespace ir {
class Graph; class Graph;
namespace { void FCMKLDNNPass::ApplyImpl(ir::Graph* graph) const {
void LogEnabledOps(const int counter, const std::string& details) { PADDLE_ENFORCE_NOT_NULL(graph,
std::string msg_ss{"--- enabled FC MKL-DNN for "}; platform::errors::InvalidArgument(
msg_ss += counter + " fc ops " + details; "Pointer to graph argument should not be NULL."));
string::PrettyLogDetail(msg_ss.c_str()); Init("fc_mkldnn_pass", graph);
}
} // namespace
void FCMKLDNNPass::ApplyPass(ir::Graph* graph, bool with_residual) const {
GraphPatternDetector gpd; GraphPatternDetector gpd;
patterns::FCMKLDNN fc_pattern(gpd.mutable_pattern(), "fc_mkldnn_pass"); patterns::FCMKLDNN fc_pattern(gpd.mutable_pattern(), "fc_mkldnn_pass");
fc_pattern(with_residual); // searching for fc+residual doesn't make sense at this stage
fc_pattern(false /*with_residual*/);
int found_fc_count = 0; int found_fc_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
...@@ -79,19 +77,12 @@ void FCMKLDNNPass::ApplyPass(ir::Graph* graph, bool with_residual) const { ...@@ -79,19 +77,12 @@ void FCMKLDNNPass::ApplyPass(ir::Graph* graph, bool with_residual) const {
AddStatis(found_fc_count); AddStatis(found_fc_count);
LogEnabledOps(found_fc_count, if ((!Has("disable_logs") || !Get<bool>("disable_logs")) &&
(with_residual ? "with residual connection" (found_fc_count > 0)) {
: "without residual connection")); std::string msg_ss = "--- enabled FC MKL-DNN for " +
} std::to_string(found_fc_count) + " fc ops ";
string::PrettyLogDetail(msg_ss.c_str());
void FCMKLDNNPass::ApplyImpl(ir::Graph* graph) const { }
PADDLE_ENFORCE_NOT_NULL(graph,
platform::errors::InvalidArgument(
"Pointer to graph argument should not be NULL."));
Init("fc_mkldnn_pass", graph);
ApplyPass(graph, true);
ApplyPass(graph, false);
} }
} // namespace ir } // namespace ir
......
...@@ -34,7 +34,6 @@ class FCMKLDNNPass : public FusePassBase { ...@@ -34,7 +34,6 @@ class FCMKLDNNPass : public FusePassBase {
protected: protected:
void ApplyImpl(ir::Graph* graph) const; void ApplyImpl(ir::Graph* graph) const;
void ApplyPass(ir::Graph* graph, bool with_residual) const;
}; };
} // namespace ir } // namespace ir
......
...@@ -439,6 +439,7 @@ void CpuPassStrategy::EnableMkldnnInt8() { ...@@ -439,6 +439,7 @@ void CpuPassStrategy::EnableMkldnnInt8() {
passes_.push_back("repeated_fc_relu_fuse_pass"); passes_.push_back("repeated_fc_relu_fuse_pass");
passes_.push_back("fc_mkldnn_pass"); passes_.push_back("fc_mkldnn_pass");
passes_.push_back("fc_act_mkldnn_fuse_pass"); passes_.push_back("fc_act_mkldnn_fuse_pass");
passes_.push_back("fc_elementwise_add_mkldnn_fuse_pass");
passes_.push_back("matmul_transpose_reshape_mkldnn_fuse_pass"); passes_.push_back("matmul_transpose_reshape_mkldnn_fuse_pass");
passes_.push_back("batch_norm_act_fuse_pass"); passes_.push_back("batch_norm_act_fuse_pass");
passes_.push_back("softplus_activation_mkldnn_fuse_pass"); passes_.push_back("softplus_activation_mkldnn_fuse_pass");
......
...@@ -103,15 +103,16 @@ class FCMKLDNNHandler ...@@ -103,15 +103,16 @@ class FCMKLDNNHandler
dnnl::primitive_attr attributes; dnnl::primitive_attr attributes;
dnnl::post_ops post_operations; dnnl::post_ops post_operations;
std::vector<float> output_shift_scale; float sum_scale = 1.0f;
float scale = 1.0f; float activation_scale = 1.0f;
if (phi::funcs::is_int8<T_w>()) { if (phi::funcs::is_int8<T_w>()) {
std::tie(output_shift_scale, scale) = ComputeOutputShiftScale(ctx); std::vector<float> output_shift_scale;
std::tie(output_shift_scale, sum_scale, activation_scale) =
GetOutputScales(ctx);
int mask = CreateMask(1, output_shift_scale.size() > 1); int mask = CreateMask(1, output_shift_scale.size() > 1);
attributes.set_output_scales(mask, output_shift_scale); attributes.set_output_scales(mask, output_shift_scale);
} }
float sum_scale = 1.0f;
if (ctx.HasAttr("fuse_residual_connection") && if (ctx.HasAttr("fuse_residual_connection") &&
ctx.Attr<bool>("fuse_residual_connection")) { ctx.Attr<bool>("fuse_residual_connection")) {
post_operations.append_sum(sum_scale); post_operations.append_sum(sum_scale);
...@@ -120,9 +121,9 @@ class FCMKLDNNHandler ...@@ -120,9 +121,9 @@ class FCMKLDNNHandler
// ReLU from "fc_fuse_pass" // ReLU from "fc_fuse_pass"
if (ctx.Attr<std::string>("activation_type") == "relu") { if (ctx.Attr<std::string>("activation_type") == "relu") {
post_operations.append_eltwise( post_operations.append_eltwise(
scale, dnnl::algorithm::eltwise_relu, 0.0f, 0.0f); activation_scale, dnnl::algorithm::eltwise_relu, 0.0f, 0.0f);
} }
platform::AppendActivation(ctx, post_operations, scale); platform::AppendActivation(ctx, post_operations, activation_scale);
if (ctx.HasAttr("fused_output_scale")) { if (ctx.HasAttr("fused_output_scale")) {
float scale_alpha = ctx.Attr<float>("fused_output_scale"); float scale_alpha = ctx.Attr<float>("fused_output_scale");
...@@ -136,18 +137,22 @@ class FCMKLDNNHandler ...@@ -136,18 +137,22 @@ class FCMKLDNNHandler
// Compute the bias scales so that its values correspond to the // Compute the bias scales so that its values correspond to the
// scale of data being an output of weights and input multiplication // scale of data being an output of weights and input multiplication
std::vector<float> ComputeBiasScales( std::vector<float> GetBiasScales(const framework::ExecutionContext& ctx) {
const float scale_in, const std::vector<float>& scale_weights) { if (ctx.HasAttr("Bias_scales")) {
std::vector<float> bias_scales(scale_weights.size()); return ctx.Attr<std::vector<float>>("Bias_scales");
} else {
for (size_t i = 0; i < bias_scales.size(); ++i) { const float scale_in = ctx.Attr<float>("Scale_in");
if (scale_weights[i] == 0.0) const auto& scale_weights = ctx.Attr<std::vector<float>>("Scale_weights");
bias_scales[i] = 1.0f; std::vector<float> bias_scales(scale_weights.size());
else
bias_scales[i] = scale_in * scale_weights[i]; for (size_t i = 0; i < bias_scales.size(); ++i) {
if (scale_weights[i] == 0.0)
bias_scales[i] = 1.0f;
else
bias_scales[i] = scale_in * scale_weights[i];
}
return bias_scales;
} }
return bias_scales;
} }
// Correct output scale, to take into account scaling of input and weights // Correct output scale, to take into account scaling of input and weights
...@@ -155,32 +160,44 @@ class FCMKLDNNHandler ...@@ -155,32 +160,44 @@ class FCMKLDNNHandler
// scaled with its own scales, this data needs to be divided by // scaled with its own scales, this data needs to be divided by
// those scales to normalise them back to what their floating-point range // those scales to normalise them back to what their floating-point range
// was. Then we multiply them by desired output scale we want on the output. // was. Then we multiply them by desired output scale we want on the output.
std::tuple<std::vector<float>, float> ComputeOutputShiftScale( std::tuple<std::vector<float>, float, float> GetOutputScales(
const ExecutionContext& ctx) { const ExecutionContext& ctx) {
auto scale_in_data = ctx.Attr<float>("Scale_in"); if (ctx.HasAttr("Sum_scale")) {
auto scale_weights_data = ctx.Attr<std::vector<float>>("Scale_weights"); return std::make_tuple(ctx.Attr<std::vector<float>>("Output_shift_scale"),
bool has_activation = !ctx.Attr<std::string>("activation_type").empty(); ctx.Attr<float>("Sum_scale"),
bool force_fp32_output = ctx.Attr<bool>("force_fp32_output"); ctx.Attr<float>("Activation_scale"));
} else {
// If the output will be in floats, we don't multiply by scale_out. auto scale_in_data = ctx.Attr<float>("Scale_in");
auto scale_weights_data = ctx.Attr<std::vector<float>>("Scale_weights");
float scale = (!force_fp32_output && has_activation) bool has_activation = !ctx.Attr<std::string>("activation_type").empty();
? ctx.Attr<float>("Scale_out") bool force_fp32_output = ctx.Attr<bool>("force_fp32_output");
: 1.0f; bool fuse_residual_conn = ctx.HasAttr("fuse_residual_connection") &&
float inner_scale = (force_fp32_output || has_activation) ctx.Attr<bool>("fuse_residual_connection");
? 1.0f auto scale_in_eltwise_data = ctx.HasAttr("Scale_in_eltwise")
: ctx.Attr<float>("Scale_out"); ? ctx.Attr<float>("Scale_in_eltwise")
const size_t weight_scales_num = scale_weights_data.size(); : 1.0f;
for (size_t i = 0; i < weight_scales_num; ++i) { // If the output will be in floats, we don't multiply by scale_out.
if (scale_weights_data[i] == 0.0)
scale_weights_data[i] = inner_scale; float activation_scale = (!force_fp32_output && has_activation)
else ? ctx.Attr<float>("Scale_out")
scale_weights_data[i] = : 1.0f;
inner_scale / (scale_in_data * scale_weights_data[i]); float scale_out_data = (force_fp32_output || has_activation)
? 1.0f
: ctx.Attr<float>("Scale_out");
float sum_scale =
fuse_residual_conn ? scale_out_data / scale_in_eltwise_data : 1.0f;
const size_t weight_scales_num = scale_weights_data.size();
for (size_t i = 0; i < weight_scales_num; ++i) {
if (scale_weights_data[i] == 0.0)
scale_weights_data[i] = scale_out_data;
else
scale_weights_data[i] =
scale_out_data / (scale_in_data * scale_weights_data[i]);
}
return std::make_tuple(scale_weights_data, sum_scale, activation_scale);
} }
return make_tuple(scale_weights_data, scale);
} }
// Computing MKL-DNN's scaling mask which determines along which dimension // Computing MKL-DNN's scaling mask which determines along which dimension
...@@ -240,9 +257,7 @@ class FCMKLDNNHandler ...@@ -240,9 +257,7 @@ class FCMKLDNNHandler
} }
std::shared_ptr<dnnl::memory> AcquireBiasMemoryWithReorder( std::shared_ptr<dnnl::memory> AcquireBiasMemoryWithReorder(
const phi::DenseTensor* bias, const framework::ExecutionContext& ctx, const phi::DenseTensor* bias) {
const float scale_in,
const std::vector<float>& scale_weights) {
const float* bias_data = bias->data<float>(); const float* bias_data = bias->data<float>();
if (phi::funcs::is_int8<T_w>() == false) { if (phi::funcs::is_int8<T_w>() == false) {
...@@ -255,7 +270,7 @@ class FCMKLDNNHandler ...@@ -255,7 +270,7 @@ class FCMKLDNNHandler
this->dev_ctx_.GetBlob(bias_key)); this->dev_ctx_.GetBlob(bias_key));
if (!memory_p) { if (!memory_p) {
const auto& scale_data = ComputeBiasScales(scale_in, scale_weights); const auto& scale_data = GetBiasScales(ctx);
dnnl::primitive_attr attrs; dnnl::primitive_attr attrs;
int mask = CreateMask(0, scale_data.size() > 1); int mask = CreateMask(0, scale_data.size() > 1);
...@@ -316,7 +331,7 @@ class FCMKLDNNHandler ...@@ -316,7 +331,7 @@ class FCMKLDNNHandler
const ExecutionContext& ctx, phi::DenseTensor* out) { const ExecutionContext& ctx, phi::DenseTensor* out) {
if (ctx.HasAttr("fuse_residual_connection") && if (ctx.HasAttr("fuse_residual_connection") &&
ctx.Attr<bool>("fuse_residual_connection")) { ctx.Attr<bool>("fuse_residual_connection")) {
auto* residual_param = ctx.Output<phi::DenseTensor>("ResidualData"); auto* residual_param = ctx.Input<phi::DenseTensor>("ResidualData");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
out->dims(), out->dims(),
...@@ -393,7 +408,6 @@ class FCMKLDNNKernel : public framework::OpKernel<T_in> { ...@@ -393,7 +408,6 @@ class FCMKLDNNKernel : public framework::OpKernel<T_in> {
const auto* bias = ctx.Input<phi::DenseTensor>("Bias"); const auto* bias = ctx.Input<phi::DenseTensor>("Bias");
auto out = ctx.Output<LoDTensor>("Out"); auto out = ctx.Output<LoDTensor>("Out");
const float scale_in = ctx.Attr<float>("Scale_in");
const auto& scale_weights = ctx.Attr<std::vector<float>>("Scale_weights"); const auto& scale_weights = ctx.Attr<std::vector<float>>("Scale_weights");
std::shared_ptr<dnnl::inner_product_forward> fc_p; std::shared_ptr<dnnl::inner_product_forward> fc_p;
...@@ -430,7 +444,7 @@ class FCMKLDNNKernel : public framework::OpKernel<T_in> { ...@@ -430,7 +444,7 @@ class FCMKLDNNKernel : public framework::OpKernel<T_in> {
std::make_shared<dnnl::memory>(inner_product_cache->dst_mem); std::make_shared<dnnl::memory>(inner_product_cache->dst_mem);
if (ctx.HasAttr("fuse_residual_connection") && if (ctx.HasAttr("fuse_residual_connection") &&
ctx.Attr<bool>("fuse_residual_connection")) { ctx.Attr<bool>("fuse_residual_connection")) {
auto* residual_param = ctx.Output<phi::DenseTensor>("ResidualData"); auto* residual_param = ctx.Input<phi::DenseTensor>("ResidualData");
out->ShareDataWith(*residual_param); out->ShareDataWith(*residual_param);
} }
auto out_ptr = out->mutable_data<T_out>( auto out_ptr = out->mutable_data<T_out>(
...@@ -460,8 +474,7 @@ class FCMKLDNNKernel : public framework::OpKernel<T_in> { ...@@ -460,8 +474,7 @@ class FCMKLDNNKernel : public framework::OpKernel<T_in> {
dst_memory_p = handler.AcquireCustomDstMemory(ctx, out); dst_memory_p = handler.AcquireCustomDstMemory(ctx, out);
if (bias) { if (bias) {
bias_memory_p = bias_memory_p = handler.AcquireBiasMemoryWithReorder(ctx, bias);
handler.AcquireBiasMemoryWithReorder(bias, scale_in, scale_weights);
} }
fc_p = handler.AcquireForwardPrimitive(); fc_p = handler.AcquireForwardPrimitive();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册