未验证 提交 cfa513f7 编写于 作者: Y YangQun 提交者: GitHub

[ONEDNN] Upgrade oneDNN version to v3.1 (#52463)

* squash pick the poc code
* fix build after rebase
* fix int8 conv and fc uts
* Fix and clean-up Get_SRC_Scale_Memory
* fix floating point fc uts
* fix test_analyzer_int8_googlenet
* test_analyzer_int8_mobilenetv1
* fix int8 mobilenet v2 and v3
* fix build error after rebase
* [oneDNN] rename library version
* fix conv bias datatype
* try to fix import error
* fix rebase error
* [oneDNN] pack library into python wheel
* add MKLDNN_SHARED_LIB_3 to env_dict
* fix test_analyzer_bert
* fix fill_constant op kernel
* fix ernie and matmul op ut
* fix softplus ut
* fix conv+relu6 fusion ut
* fix hardswish fusion
* fix quant+transpose fusion ut
* fixsgd ut
* fix int8 matmul with flatten
* fix fc+scale fusion
* fix conv/matmul+gelu fusion uts
* fix rebase error
* Revert "fix conv/matmul+gelu fusion uts"
This reverts commit 47eb5e49972bd8f7271a233def9bfb3e98ce78e1.
* upgrade to onednn v3.1
* remove older version onednn
* use densetensor::data() for achieving mean and var in layernorm impl
* comments for atol of integer tests
* fix clang-format
* Revert "remove older version onednn"
This reverts commit 783e57ddfd4401254596eae7d47adb9b03590c09.
* improve binary handle
* fix expand kernel
* Revert "use densetensor::data() for achieving mean and var in layernorm impl"
* always use forward_inference for conv
* remove activation scales
* rollback changes to mkldnn.cmake
* address comments
* port changes to dequantize kernel
* fix merge error
* fix fused_elementwise_kernel
* upgrade onednn version to v3.1.1
* fix some approval error
* fix error msg format
* remove old onednn libs
* try to fix symbolic link issue
* fix cinn test case segfault
* do not explicit link test with onednn
* remove unnecessary changes
* integrate CINN with onednn v3
* link with mkldnn project
* fix cinn build file

---------
Co-authored-by: NTomasz Socha <tomasz.socha@intel.com>
Co-authored-by: NChen, Xinyu1 <xinyu1.chen@intel.com>
Co-authored-by: Ntianshuo78520a <707759223@qq.com>
上级 be3a6fa7
...@@ -174,8 +174,8 @@ if(WITH_MKL) ...@@ -174,8 +174,8 @@ if(WITH_MKL)
target_link_libraries(cinnapi cinn_mklml) target_link_libraries(cinnapi cinn_mklml)
add_dependencies(cinnapi cinn_mklml) add_dependencies(cinnapi cinn_mklml)
if(WITH_MKLDNN) if(WITH_MKLDNN)
target_link_libraries(cinnapi mkldnn) target_link_libraries(cinnapi ${MKLDNN_LIB})
add_dependencies(cinnapi mkldnn) add_dependencies(cinnapi ${MKLDNN_PROJECT})
endif() endif()
endif() endif()
...@@ -224,8 +224,8 @@ function(gen_cinncore LINKTYPE) ...@@ -224,8 +224,8 @@ function(gen_cinncore LINKTYPE)
target_link_libraries(${CINNCORE_TARGET} cinn_mklml) target_link_libraries(${CINNCORE_TARGET} cinn_mklml)
add_dependencies(${CINNCORE_TARGET} cinn_mklml) add_dependencies(${CINNCORE_TARGET} cinn_mklml)
if(WITH_MKLDNN) if(WITH_MKLDNN)
target_link_libraries(${CINNCORE_TARGET} mkldnn) target_link_libraries(${CINNCORE_TARGET} ${MKLDNN_LIB})
add_dependencies(${CINNCORE_TARGET} mkldnn) add_dependencies(${CINNCORE_TARGET} ${MKLDNN_PROJECT})
endif() endif()
endif() endif()
......
# Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2017-2023 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -21,7 +21,6 @@ set(MKLDNN_INC_DIR ...@@ -21,7 +21,6 @@ set(MKLDNN_INC_DIR
"${MKLDNN_INSTALL_DIR}/include" "${MKLDNN_INSTALL_DIR}/include"
CACHE PATH "mkldnn include directory." FORCE) CACHE PATH "mkldnn include directory." FORCE)
set(SOURCE_DIR ${PADDLE_SOURCE_DIR}/third_party/mkldnn) set(SOURCE_DIR ${PADDLE_SOURCE_DIR}/third_party/mkldnn)
set(MKLDNN_TAG 2089770c4818be8933c5e9d1dd3cbaeba1457667)
# Introduce variables: # Introduce variables:
# * CMAKE_INSTALL_LIBDIR # * CMAKE_INSTALL_LIBDIR
...@@ -128,16 +127,12 @@ if(WIN32) ...@@ -128,16 +127,12 @@ if(WIN32)
VERBATIM) VERBATIM)
add_custom_target(mkldnn_cmd ALL DEPENDS ${MKLDNN_LIB}) add_custom_target(mkldnn_cmd ALL DEPENDS ${MKLDNN_LIB})
else() else()
set(MKLDNN_SHARED_LIB ${MKLDNN_INSTALL_DIR}/libmkldnn.so.0) set(MKLDNN_SHARED_LIB ${MKLDNN_INSTALL_DIR}/libdnnl.so.3)
set(MKLDNN_SHARED_LIB_1 ${MKLDNN_INSTALL_DIR}/libdnnl.so.1)
set(MKLDNN_SHARED_LIB_2 ${MKLDNN_INSTALL_DIR}/libdnnl.so.2)
add_custom_command( add_custom_command(
OUTPUT ${MKLDNN_SHARED_LIB_2} OUTPUT ${MKLDNN_SHARED_LIB}
COMMAND ${CMAKE_COMMAND} -E copy ${MKLDNN_LIB} ${MKLDNN_SHARED_LIB} COMMAND ${CMAKE_COMMAND} -E copy ${MKLDNN_LIB} ${MKLDNN_SHARED_LIB}
COMMAND ${CMAKE_COMMAND} -E copy ${MKLDNN_LIB} ${MKLDNN_SHARED_LIB_1}
COMMAND ${CMAKE_COMMAND} -E copy ${MKLDNN_LIB} ${MKLDNN_SHARED_LIB_2}
DEPENDS ${MKLDNN_PROJECT}) DEPENDS ${MKLDNN_PROJECT})
add_custom_target(mkldnn_cmd ALL DEPENDS ${MKLDNN_SHARED_LIB_2}) add_custom_target(mkldnn_cmd ALL DEPENDS ${MKLDNN_SHARED_LIB})
endif() endif()
# generate a static dummy target to track mkldnn dependencies # generate a static dummy target to track mkldnn dependencies
......
...@@ -130,17 +130,9 @@ function(copy_part_of_thrid_party TARGET DST) ...@@ -130,17 +130,9 @@ function(copy_part_of_thrid_party TARGET DST)
add_custom_command( add_custom_command(
TARGET ${TARGET} TARGET ${TARGET}
POST_BUILD POST_BUILD
COMMAND strip -s ${dst_dir}/lib/libmkldnn.so.0 COMMAND strip -s ${dst_dir}/lib/libdnnl.so.3
COMMENT "striping libmkldnn.so.0") COMMENT "striping libdnnl.so.3")
endif() endif()
add_custom_command(
TARGET ${TARGET}
POST_BUILD
COMMAND ${CMAKE_COMMAND} -E create_symlink libmkldnn.so.0
${dst_dir}/lib/libdnnl.so.1
COMMAND ${CMAKE_COMMAND} -E create_symlink libmkldnn.so.0
${dst_dir}/lib/libdnnl.so.2
COMMENT "Make a symbol link of libmkldnn.so.0")
endif() endif()
endif() endif()
......
...@@ -58,9 +58,13 @@ void cinn_cpu_mkldnn_softmax_fp32(int batch, ...@@ -58,9 +58,13 @@ void cinn_cpu_mkldnn_softmax_fp32(int batch,
auto src_mem = auto src_mem =
memory(src_md, engine, reinterpret_cast<float*>(inputs->memory)); memory(src_md, engine, reinterpret_cast<float*>(inputs->memory));
auto dst_mem = memory(src_md, engine, reinterpret_cast<float*>(out->memory)); auto dst_mem = memory(src_md, engine, reinterpret_cast<float*>(out->memory));
auto softmax_d = dnnl::softmax_forward::desc( auto softmax_pd =
dnnl::prop_kind::forward_inference, src_md, axis); dnnl::softmax_forward::primitive_desc(engine,
auto softmax_pd = dnnl::softmax_forward::primitive_desc(softmax_d, engine); dnnl::prop_kind::forward_inference,
dnnl::algorithm::softmax_accurate,
src_md,
src_md,
axis);
auto softmax_prim = dnnl::softmax_forward(softmax_pd); auto softmax_prim = dnnl::softmax_forward(softmax_pd);
softmax_prim.execute(engine_stream, softmax_prim.execute(engine_stream,
...@@ -117,19 +121,17 @@ void cinn_cpu_mkldnn_conv2d_nchw_fp32(int batch_size, ...@@ -117,19 +121,17 @@ void cinn_cpu_mkldnn_conv2d_nchw_fp32(int batch_size,
auto conv_weights_md = memory::desc({conv_weights_tz}, dt::f32, tag::any); auto conv_weights_md = memory::desc({conv_weights_tz}, dt::f32, tag::any);
auto conv_dst_md = memory::desc({conv_dst_tz}, dt::f32, tag::nchw); auto conv_dst_md = memory::desc({conv_dst_tz}, dt::f32, tag::nchw);
auto conv_desc = auto conv_prim_desc = dnnl::convolution_forward::primitive_desc(
dnnl::convolution_forward::desc(dnnl::prop_kind::forward_inference, cpu_engine,
dnnl::algorithm::convolution_direct, dnnl::prop_kind::forward_inference,
conv_src_md, dnnl::algorithm::convolution_direct,
conv_weights_md, conv_src_md,
conv_dst_md, conv_weights_md,
conv_strides, conv_dst_md,
conv_dilations, conv_strides,
conv_paddings, conv_dilations,
conv_paddings); conv_paddings,
conv_paddings);
auto conv_prim_desc =
dnnl::convolution_forward::primitive_desc(conv_desc, cpu_engine);
auto conv_src_memory = conv_user_src_memory; auto conv_src_memory = conv_user_src_memory;
auto conv_weights_memory = conv_user_weights_memory; auto conv_weights_memory = conv_user_weights_memory;
......
...@@ -185,7 +185,6 @@ if(WITH_MKLDNN) ...@@ -185,7 +185,6 @@ if(WITH_MKLDNN)
pass_library(elementwise_act_onednn_fuse_pass inference DIR mkldnn) pass_library(elementwise_act_onednn_fuse_pass inference DIR mkldnn)
pass_library(matmul_elementwise_add_mkldnn_fuse_pass inference DIR mkldnn) pass_library(matmul_elementwise_add_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(matmul_activation_mkldnn_fuse_pass inference DIR mkldnn) pass_library(matmul_activation_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(layer_norm_onednn_optimization_pass inference DIR mkldnn)
pass_library(operator_scale_onednn_fuse_pass inference DIR mkldnn) pass_library(operator_scale_onednn_fuse_pass inference DIR mkldnn)
pass_library(quant_transpose2_dequant_onednn_fuse_pass inference DIR mkldnn) pass_library(quant_transpose2_dequant_onednn_fuse_pass inference DIR mkldnn)
pass_library(squeeze2_transpose2_onednn_fuse_pass inference DIR mkldnn) pass_library(squeeze2_transpose2_onednn_fuse_pass inference DIR mkldnn)
......
...@@ -942,29 +942,6 @@ PDNode *patterns::ConvBN::operator()(paddle::framework::ir::PDNode *conv_input, ...@@ -942,29 +942,6 @@ PDNode *patterns::ConvBN::operator()(paddle::framework::ir::PDNode *conv_input,
return bn_out_var; return bn_out_var;
} }
PDNode *patterns::LayerNormShiftScale::operator()() {
auto layer_norm_in = pattern->NewNode(layer_norm_in_repr())
->AsInput()
->assert_is_op_input("layer_norm", "X");
auto layer_norm_bias = pattern->NewNode(layer_norm_bias_repr())
->AsInput()
->assert_is_op_input("layer_norm", "Bias");
auto layer_norm_scale = pattern->NewNode(layer_norm_scale_repr())
->AsInput()
->assert_is_op_input("layer_norm", "Scale");
auto layer_norm_op =
pattern->NewNode(layer_norm_op_repr())->assert_is_op("layer_norm");
auto layer_norm_out = pattern->NewNode(layer_norm_out_repr())
->assert_is_op_output("layer_norm", "Y")
->AsOutput();
layer_norm_op->LinksFrom({layer_norm_in, layer_norm_bias, layer_norm_scale})
.LinksTo({layer_norm_out});
return layer_norm_out;
}
PDNode *patterns::OperatorActivation::operator()( PDNode *patterns::OperatorActivation::operator()(
const std::string &operator_type, const std::string &activation_type) { const std::string &operator_type, const std::string &activation_type) {
auto *preceding_op = auto *preceding_op =
......
...@@ -526,19 +526,6 @@ struct ConvBN : public PatternBase { ...@@ -526,19 +526,6 @@ struct ConvBN : public PatternBase {
PATTERN_DECL_NODE(bn_saved_variance); PATTERN_DECL_NODE(bn_saved_variance);
}; };
struct LayerNormShiftScale : public PatternBase {
LayerNormShiftScale(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "layer_norm_shift_scale") {}
PDNode* operator()();
PATTERN_DECL_NODE(layer_norm_in);
PATTERN_DECL_NODE(layer_norm_op);
PATTERN_DECL_NODE(layer_norm_bias);
PATTERN_DECL_NODE(layer_norm_scale);
PATTERN_DECL_NODE(layer_norm_out);
};
struct OperatorActivation : public PatternBase { struct OperatorActivation : public PatternBase {
OperatorActivation(PDPattern* pattern, const std::string& name_scope) OperatorActivation(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "operator_activation") {} : PatternBase(pattern, name_scope, "operator_activation") {}
......
...@@ -42,7 +42,7 @@ inline std::unordered_map<std::string, std::string> GetAttributeMap( ...@@ -42,7 +42,7 @@ inline std::unordered_map<std::string, std::string> GetAttributeMap(
if (act_type == "swish") { if (act_type == "swish") {
attr_map.emplace("beta", "fuse_alpha"); attr_map.emplace("beta", "fuse_alpha");
} else if (act_type == "relu6") { } else if (act_type == "relu6") {
attr_map.emplace("threshold", "fuse_alpha"); attr_map.emplace("threshold", "fuse_beta");
} else if (act_type == "hard_sigmoid") { } else if (act_type == "hard_sigmoid") {
attr_map.emplace("slope", "fuse_alpha"); attr_map.emplace("slope", "fuse_alpha");
attr_map.emplace("offset", "fuse_beta"); attr_map.emplace("offset", "fuse_beta");
...@@ -73,6 +73,11 @@ inline void SetActivationAttrs(paddle::framework::OpDesc* fused_op, ...@@ -73,6 +73,11 @@ inline void SetActivationAttrs(paddle::framework::OpDesc* fused_op,
} }
} }
if (act_type == "hard_swish") {
fused_op->SetAttr("fuse_alpha", 1.f / 6.f);
fused_op->SetAttr("fuse_beta", 1.f / 2.f);
}
if (act_type == "gelu" && act_op->HasAttr("approximate")) { if (act_type == "gelu" && act_op->HasAttr("approximate")) {
std::string gelu_act_type = std::string gelu_act_type =
PADDLE_GET_CONST(bool, act_op->GetAttr("approximate")) ? "gelu_tanh" PADDLE_GET_CONST(bool, act_op->GetAttr("approximate")) ? "gelu_tanh"
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/layer_norm_onednn_optimization_pass.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/backends/onednn/onednn_reuse.h"
#include "paddle/utils/string/pretty_log.h"
namespace paddle {
namespace framework {
namespace ir {
using string::PrettyLogDetail;
void LayerNormOneDNNOptimizationPass::ApplyImpl(Graph *graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init("layer_norm_onednn_optimization_pass", graph);
GraphPatternDetector gpd;
patterns::LayerNormShiftScale layer_norm_shift_scale_pattern(
gpd.mutable_pattern(), "layer_norm_onednn_optimization_pass");
layer_norm_shift_scale_pattern();
int found_layer_norm = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
Graph *g) {
GET_IR_NODE_FROM_SUBGRAPH(
layer_norm_op, layer_norm_op, layer_norm_shift_scale_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
layer_norm_bias, layer_norm_bias, layer_norm_shift_scale_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
layer_norm_scale, layer_norm_scale, layer_norm_shift_scale_pattern);
if (layer_norm_op->Op()->HasAttr("use_mkldnn") &&
!(PADDLE_GET_CONST(bool, layer_norm_op->Op()->GetAttr("use_mkldnn")))) {
VLOG(4) << "Only oneDNN version of layer_norm can be optimized to "
"include Bias and Shift in a single tensor.";
return;
}
auto *scope = param_scope();
auto ln_bias_name = layer_norm_op->Op()->Input("Bias");
auto ln_scale_name = layer_norm_op->Op()->Input("Scale");
auto *ln_bias_tensor =
scope->FindVar(ln_bias_name[0])->GetMutable<phi::DenseTensor>();
auto *ln_scale_tensor =
scope->FindVar(ln_scale_name[0])->GetMutable<phi::DenseTensor>();
const int channels = ln_bias_tensor->dims()[0];
VarDesc scale_shift_desc(patterns::PDNodeName(
"layer_norm_onednn_optimization_pass", "ScaleShift"));
scale_shift_desc.SetShape({channels * 2});
scale_shift_desc.SetDataType(
framework::TransToProtoVarType(ln_bias_tensor->dtype()));
scale_shift_desc.SetPersistable(true);
auto scale_shift_node = g->CreateVarNode(&scale_shift_desc);
auto *scale_shift_tensor =
scope->Var(scale_shift_node->Name())->GetMutable<phi::DenseTensor>();
scale_shift_tensor->Resize(phi::make_ddim({channels * 2}));
memcpy(scale_shift_tensor->mutable_data<float>(phi::CPUPlace()),
ln_scale_tensor->data<float>(),
channels * sizeof(float));
memcpy(scale_shift_tensor->data<float>() + channels,
ln_bias_tensor->data<float>(),
channels * sizeof(float));
layer_norm_op->Op()->SetInput("ScaleShift", {scale_shift_node->Name()});
IR_NODE_LINK_TO(scale_shift_node, layer_norm_op);
found_layer_norm++;
};
gpd(graph, handler);
AddStatis(found_layer_norm);
if ((!Has("disable_logs") || !Get<bool>("disable_logs")) &&
found_layer_norm > 0)
PrettyLogDetail("--- optimized %d layer_norms by merging Scale and Bias",
found_layer_norm);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(layer_norm_onednn_optimization_pass,
paddle::framework::ir::LayerNormOneDNNOptimizationPass);
REGISTER_PASS_CAPABILITY(layer_norm_onednn_optimization_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination().GE(
"layer_norm", 0));
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
namespace paddle {
namespace framework {
namespace ir {
class LayerNormOneDNNOptimizationPass : public FusePassBase {
public:
virtual ~LayerNormOneDNNOptimizationPass() {}
protected:
void ApplyImpl(Graph *graph) const override;
};
} // namespace ir
} // namespace framework
} // namespace paddle
...@@ -158,7 +158,7 @@ if(WITH_MKL) ...@@ -158,7 +158,7 @@ if(WITH_MKL)
if(WIN32) if(WIN32)
set(MKLDNN_LIB ${MKLDNN_PATH}/lib/mkldnn.lib) set(MKLDNN_LIB ${MKLDNN_PATH}/lib/mkldnn.lib)
else() else()
set(MKLDNN_LIB ${MKLDNN_PATH}/lib/libmkldnn.so.0) set(MKLDNN_LIB ${MKLDNN_PATH}/lib/libdnnl.so.3)
endif() endif()
endif() endif()
else() else()
......
...@@ -601,10 +601,7 @@ void AnalysisPredictor::MkldnnQuantizer::PrepareArgument() const { ...@@ -601,10 +601,7 @@ void AnalysisPredictor::MkldnnQuantizer::PrepareArgument() const {
arg->main_graph().SetNotOwned(framework::ir::kParamScopeAttr, scope_ptr); arg->main_graph().SetNotOwned(framework::ir::kParamScopeAttr, scope_ptr);
auto* builder = predictor_.config_.pass_builder(); auto* builder = predictor_.config_.pass_builder();
builder->SetPasses({"cpu_quantize_pass", builder->SetPasses({"cpu_quantize_pass", "cpu_quantize_squash_pass"});
"cpu_quantize_squash_pass",
"int8_scale_calculation_mkldnn_pass",
"params_quantization_mkldnn_pass"});
if (predictor_.config_.ir_debug_) builder->TurnOnDebug(); if (predictor_.config_.ir_debug_) builder->TurnOnDebug();
auto passes = builder->AllPasses(); auto passes = builder->AllPasses();
predictor_.argument_->SetIrAnalysisPasses(passes); predictor_.argument_->SetIrAnalysisPasses(passes);
......
...@@ -371,7 +371,6 @@ void CpuPassStrategy::EnableMKLDNN() { ...@@ -371,7 +371,6 @@ void CpuPassStrategy::EnableMKLDNN() {
"softplus_activation_onednn_fuse_pass", // "softplus_activation_onednn_fuse_pass", //
"shuffle_channel_mkldnn_detect_pass", // "shuffle_channel_mkldnn_detect_pass", //
"elementwise_act_onednn_fuse_pass", // "elementwise_act_onednn_fuse_pass", //
"layer_norm_onednn_optimization_pass", //
"operator_scale_onednn_fuse_pass", // "operator_scale_onednn_fuse_pass", //
"operator_unsqueeze2_onednn_fuse_pass", // "operator_unsqueeze2_onednn_fuse_pass", //
"operator_reshape2_onednn_fuse_pass", // "operator_reshape2_onednn_fuse_pass", //
...@@ -465,7 +464,6 @@ void CpuPassStrategy::EnableMkldnnInt8() { ...@@ -465,7 +464,6 @@ void CpuPassStrategy::EnableMkldnnInt8() {
passes_.push_back("scale_matmul_fuse_pass"); passes_.push_back("scale_matmul_fuse_pass");
passes_.push_back("reshape_transpose_matmul_mkldnn_fuse_pass"); passes_.push_back("reshape_transpose_matmul_mkldnn_fuse_pass");
passes_.push_back("matmul_elementwise_add_mkldnn_fuse_pass"); passes_.push_back("matmul_elementwise_add_mkldnn_fuse_pass");
passes_.push_back("layer_norm_onednn_optimization_pass");
passes_.push_back("operator_scale_onednn_fuse_pass"); passes_.push_back("operator_scale_onednn_fuse_pass");
passes_.push_back("operator_unsqueeze2_onednn_fuse_pass"); passes_.push_back("operator_unsqueeze2_onednn_fuse_pass");
passes_.push_back("operator_reshape2_onednn_fuse_pass"); passes_.push_back("operator_reshape2_onednn_fuse_pass");
...@@ -473,8 +471,6 @@ void CpuPassStrategy::EnableMkldnnInt8() { ...@@ -473,8 +471,6 @@ void CpuPassStrategy::EnableMkldnnInt8() {
passes_.push_back("cpu_quantize_pass"); passes_.push_back("cpu_quantize_pass");
passes_.push_back("cpu_quantize_squash_pass"); passes_.push_back("cpu_quantize_squash_pass");
passes_.push_back("quant_transpose2_dequant_onednn_fuse_pass"); passes_.push_back("quant_transpose2_dequant_onednn_fuse_pass");
passes_.push_back("int8_scale_calculation_mkldnn_pass");
passes_.push_back("params_quantization_mkldnn_pass");
} }
use_mkldnn_int8_ = true; use_mkldnn_int8_ = true;
#else #else
......
...@@ -81,7 +81,7 @@ class RNNMKLDNNHandler : public phi::funcs::OneDNNHandlerT<T, T_alg> { ...@@ -81,7 +81,7 @@ class RNNMKLDNNHandler : public phi::funcs::OneDNNHandlerT<T, T_alg> {
bool is_NTC(const dnnl::memory::desc& md) { bool is_NTC(const dnnl::memory::desc& md) {
auto ntc_md = dnnl::memory::desc( auto ntc_md = dnnl::memory::desc(
md.dims(), md.data_type(), dnnl::memory::format_tag::ntc); md.get_dims(), md.get_data_type(), dnnl::memory::format_tag::ntc);
return md == ntc_md; return md == ntc_md;
} }
......
...@@ -198,7 +198,8 @@ class MultiGRUHandler { ...@@ -198,7 +198,8 @@ class MultiGRUHandler {
: OneDNNGetDataType<T>(), : OneDNNGetDataType<T>(),
OneDNNMemoryFormat::ntc); OneDNNMemoryFormat::ntc);
auto desc = std::make_shared<dnnl::gru_forward::desc>( pd = std::make_shared<dnnl::gru_forward::primitive_desc>(
engine_,
dnnl::prop_kind::forward_inference, dnnl::prop_kind::forward_inference,
dir, dir,
x_md, x_md,
...@@ -207,9 +208,8 @@ class MultiGRUHandler { ...@@ -207,9 +208,8 @@ class MultiGRUHandler {
wh_md, wh_md,
b_md, b_md,
h_md, h_md,
dnnl::memory::desc()); dnnl::memory::desc(),
pd = std::make_shared<dnnl::gru_forward::primitive_desc>( attrs_[2 * layer + (dir == R2L)]);
*desc, attrs_[2 * layer + (dir == R2L)], engine_);
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
pd, pd,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
...@@ -234,7 +234,7 @@ class MultiGRUHandler { ...@@ -234,7 +234,7 @@ class MultiGRUHandler {
std::vector<dnnl::memory::desc> src_mds{in_md, in_md}; std::vector<dnnl::memory::desc> src_mds{in_md, in_md};
pd = std::make_shared<dnnl::concat::primitive_desc>( pd = std::make_shared<dnnl::concat::primitive_desc>(
axis, src_mds, engine_); engine_, axis, src_mds);
dev_ctx_.SetBlob(pd_key, pd); dev_ctx_.SetBlob(pd_key, pd);
} }
concat_pds_[layer] = pd; concat_pds_[layer] = pd;
...@@ -612,7 +612,7 @@ class MultiGRUHandler { ...@@ -612,7 +612,7 @@ class MultiGRUHandler {
bool isNTC(const dnnl::memory::desc& md) { bool isNTC(const dnnl::memory::desc& md) {
auto ntc_md = dnnl::memory::desc( auto ntc_md = dnnl::memory::desc(
md.dims(), md.data_type(), dnnl::memory::format_tag::ntc); md.get_dims(), md.get_data_type(), dnnl::memory::format_tag::ntc);
return md == ntc_md; return md == ntc_md;
} }
......
...@@ -323,9 +323,7 @@ phi::KernelKey GetPad3dExpectedKernelType( ...@@ -323,9 +323,7 @@ phi::KernelKey GetPad3dExpectedKernelType(
// only constant mode and non-blocked layouts are supported for oneDNN // only constant mode and non-blocked layouts are supported for oneDNN
if (op_ptr->CanMKLDNNBeUsed(ctx, input_data_type) && if (op_ptr->CanMKLDNNBeUsed(ctx, input_data_type) &&
ctx.Attr<std::string>("mode") == "constant" && ctx.Attr<std::string>("mode") == "constant" &&
ctx.Input<phi::DenseTensor>("X") ctx.Input<phi::DenseTensor>("X")->mem_desc().get_inner_nblks() == 0) {
->mem_desc()
.data.format_desc.blocking.inner_nblks == 0) {
return phi::KernelKey(phi::Backend::ONEDNN, return phi::KernelKey(phi::Backend::ONEDNN,
phi::DataLayout::ONEDNN, phi::DataLayout::ONEDNN,
phi::TransToPhiDataType(input_data_type)); phi::TransToPhiDataType(input_data_type));
......
...@@ -33,7 +33,37 @@ struct InnerProductCache { ...@@ -33,7 +33,37 @@ struct InnerProductCache {
dnnl::memory weights_mem; dnnl::memory weights_mem;
dnnl::memory bias_mem; dnnl::memory bias_mem;
dnnl::memory dst_mem; dnnl::memory dst_mem;
dnnl::memory src_scales_mem;
dnnl::memory wei_scales_mem;
dnnl::memory dst_scales_mem;
}; };
std::tuple<std::vector<float>,
std::vector<float>,
std::vector<float>,
std::vector<float>>
GetDNNLScales(const ExecutionContext& ctx) {
auto scale_in_data = ctx.Attr<float>("Scale_in");
auto scale_out = ctx.Attr<float>("Scale_out");
auto scale_weights_data = ctx.Attr<std::vector<float>>("Scale_weights");
auto scale_in_eltwise_data = ctx.HasAttr("Scale_in_eltwise")
? ctx.Attr<float>("Scale_in_eltwise")
: 1.0f;
std::vector<float> dnnl_src_scales = {1.f / scale_in_data};
size_t count = scale_weights_data.size();
std::vector<float> dnnl_wei_scales(count);
#pragma omp parallel for if (count > 50)
for (size_t i = 0; i < count; i++) {
dnnl_wei_scales[i] = 1.f / scale_weights_data[i];
}
std::vector<float> dnnl_psum_scales = {1.f / scale_in_eltwise_data};
std::vector<float> dnnl_dst_scales = {1.f / scale_out};
return std::make_tuple(
dnnl_src_scales, dnnl_wei_scales, dnnl_psum_scales, dnnl_dst_scales);
}
template <typename T_in, typename T_w, typename T_out> template <typename T_in, typename T_w, typename T_out>
class FCMKLDNNHandler class FCMKLDNNHandler
: public phi::funcs::OneDNNHandlerNoCachingT<T_in, : public phi::funcs::OneDNNHandlerNoCachingT<T_in,
...@@ -100,11 +130,46 @@ class FCMKLDNNHandler ...@@ -100,11 +130,46 @@ class FCMKLDNNHandler
float sum_scale = 1.0f; float sum_scale = 1.0f;
float activation_scale = 1.0f; float activation_scale = 1.0f;
if (phi::funcs::is_int8<T_w>()) { if (phi::funcs::is_int8<T_w>()) {
std::vector<float> output_shift_scale; std::vector<float> src_scales, wei_scales, psum_scales, dst_scales;
std::tie(output_shift_scale, sum_scale, activation_scale) = std::tie(src_scales, wei_scales, psum_scales, dst_scales) =
GetOutputScales(ctx); GetDNNLScales(ctx);
int mask = CreateMask(1, output_shift_scale.size() > 1);
attributes.set_output_scales(mask, output_shift_scale); bool force_fp32_output = ctx.HasAttr("force_fp32_output") &&
ctx.Attr<bool>("force_fp32_output");
attributes.set_scales_mask(DNNL_ARG_SRC, 0);
dnnl::memory::desc src_scales_md(
{1}, dnnl::memory::data_type::f32, dnnl::memory::format_tag::x);
src_scales_mem_ = dnnl::memory(src_scales_md, this->engine_);
memcpy(src_scales_mem_.get_data_handle(),
src_scales.data(),
src_scales.size() * sizeof(float));
int mask = wei_scales.size() > 1 ? 1 : 0;
attributes.set_scales_mask(DNNL_ARG_WEIGHTS, mask);
dnnl::memory::desc wei_scales_md(
{static_cast<int64_t>(wei_scales.size())},
dnnl::memory::data_type::f32,
dnnl::memory::format_tag::x);
wei_scales_mem_ = dnnl::memory(wei_scales_md, this->engine_);
memcpy(wei_scales_mem_.get_data_handle(),
wei_scales.data(),
wei_scales.size() * sizeof(float));
if (!force_fp32_output) {
attributes.set_scales_mask(DNNL_ARG_DST, 0);
dnnl::memory::desc dst_scales_md(
{1}, dnnl::memory::data_type::f32, dnnl::memory::format_tag::x);
dst_scales_mem_ = dnnl::memory(dst_scales_md, this->engine_);
memcpy(dst_scales_mem_.get_data_handle(),
dst_scales.data(),
dst_scales.size() * sizeof(float));
}
sum_scale = psum_scales[0];
} }
if (ctx.HasAttr("fuse_residual_connection") && if (ctx.HasAttr("fuse_residual_connection") &&
...@@ -114,41 +179,20 @@ class FCMKLDNNHandler ...@@ -114,41 +179,20 @@ class FCMKLDNNHandler
// ReLU from "fc_fuse_pass" // ReLU from "fc_fuse_pass"
if (ctx.Attr<std::string>("activation_type") == "relu") { if (ctx.Attr<std::string>("activation_type") == "relu") {
post_operations.append_eltwise( post_operations.append_eltwise(dnnl::algorithm::eltwise_relu, 0.0f, 0.0f);
activation_scale, dnnl::algorithm::eltwise_relu, 0.0f, 0.0f);
} }
AppendActivation(ctx, post_operations, activation_scale); AppendActivation(ctx, post_operations, activation_scale);
if (ctx.HasAttr("fused_output_scale")) { if (ctx.HasAttr("fused_output_scale")) {
float scale_alpha = ctx.Attr<float>("fused_output_scale"); float scale_alpha = ctx.Attr<float>("fused_output_scale");
post_operations.append_eltwise( post_operations.append_eltwise(
1.0, dnnl::algorithm::eltwise_linear, scale_alpha, 0.0f); dnnl::algorithm::eltwise_linear, scale_alpha, 0.0f);
} }
attributes.set_post_ops(post_operations); attributes.set_post_ops(post_operations);
return attributes; return attributes;
} }
// Compute the bias scales so that its values correspond to the
// scale of data being an output of weights and input multiplication
std::vector<float> GetBiasScales(const ExecutionContext& ctx) {
if (ctx.HasAttr("Bias_scales")) {
return ctx.Attr<std::vector<float>>("Bias_scales");
} else {
const float scale_in = ctx.Attr<float>("Scale_in");
const auto& scale_weights = ctx.Attr<std::vector<float>>("Scale_weights");
std::vector<float> bias_scales(scale_weights.size());
for (size_t i = 0; i < bias_scales.size(); ++i) {
if (scale_weights[i] == 0.0)
bias_scales[i] = 1.0f;
else
bias_scales[i] = scale_in * scale_weights[i];
}
return bias_scales;
}
}
void AppendActivation(const ExecutionContext& ctx, void AppendActivation(const ExecutionContext& ctx,
dnnl::post_ops& post_ops, // NOLINT dnnl::post_ops& post_ops, // NOLINT
float activation_scale = 1.0f) { float activation_scale = 1.0f) {
...@@ -174,55 +218,9 @@ class FCMKLDNNHandler ...@@ -174,55 +218,9 @@ class FCMKLDNNHandler
"Activation '%s' not found in oneDNN algorithms mapper", "Activation '%s' not found in oneDNN algorithms mapper",
fuse_activation)); fuse_activation));
post_ops.append_eltwise(activation_type->second, fuse_alpha, fuse_beta);
post_ops.append_eltwise( post_ops.append_eltwise(
activation_scale, activation_type->second, fuse_alpha, fuse_beta); dnnl::algorithm::eltwise_linear, activation_scale, 0.0f);
}
// Correct output scale, to take into account scaling of input and weights
// Since the data that comes out of input and weight multiplication is
// scaled with its own scales, this data needs to be divided by
// those scales to normalise them back to what their floating-point range
// was. Then we multiply them by desired output scale we want on the output.
std::tuple<std::vector<float>, float, float> GetOutputScales(
const ExecutionContext& ctx) {
if (ctx.HasAttr("Sum_scale")) {
return std::make_tuple(ctx.Attr<std::vector<float>>("Output_shift_scale"),
ctx.Attr<float>("Sum_scale"),
ctx.Attr<float>("Activation_scale"));
} else {
auto scale_in_data = ctx.Attr<float>("Scale_in");
auto scale_weights_data = ctx.Attr<std::vector<float>>("Scale_weights");
bool has_activation = !ctx.Attr<std::string>("activation_type").empty() ||
(ctx.HasAttr("fuse_activation") &&
!ctx.Attr<std::string>("fuse_activation").empty());
bool force_fp32_output = ctx.Attr<bool>("force_fp32_output");
bool fuse_residual_conn = ctx.HasAttr("fuse_residual_connection") &&
ctx.Attr<bool>("fuse_residual_connection");
auto scale_in_eltwise_data = ctx.HasAttr("Scale_in_eltwise")
? ctx.Attr<float>("Scale_in_eltwise")
: 1.0f;
// If the output will be in floats, we don't multiply by scale_out.
float activation_scale = (!force_fp32_output && has_activation)
? ctx.Attr<float>("Scale_out")
: 1.0f;
float scale_out_data = (force_fp32_output || has_activation)
? 1.0f
: ctx.Attr<float>("Scale_out");
float sum_scale =
fuse_residual_conn ? scale_out_data / scale_in_eltwise_data : 1.0f;
const size_t weight_scales_num = scale_weights_data.size();
for (size_t i = 0; i < weight_scales_num; ++i) {
if (scale_weights_data[i] == 0.0)
scale_weights_data[i] = scale_out_data;
else
scale_weights_data[i] =
scale_out_data / (scale_in_data * scale_weights_data[i]);
}
return std::make_tuple(scale_weights_data, sum_scale, activation_scale);
}
} }
// Computing oneDNN's scaling mask which determines along which dimension // Computing oneDNN's scaling mask which determines along which dimension
...@@ -235,7 +233,8 @@ class FCMKLDNNHandler ...@@ -235,7 +233,8 @@ class FCMKLDNNHandler
const dnnl::memory::desc& user_md, const dnnl::memory::desc& user_md,
const dnnl::memory::desc& target_md, const dnnl::memory::desc& target_md,
void* ptr, void* ptr,
const dnnl::primitive_attr& attrs) { const dnnl::primitive_attr& attrs,
const std::vector<float>& scale_data) {
std::shared_ptr<dnnl::memory> target_memory_p; std::shared_ptr<dnnl::memory> target_memory_p;
auto user_memory_p = auto user_memory_p =
...@@ -244,16 +243,21 @@ class FCMKLDNNHandler ...@@ -244,16 +243,21 @@ class FCMKLDNNHandler
auto reorder_p = std::make_shared<dnnl::reorder>( auto reorder_p = std::make_shared<dnnl::reorder>(
*user_memory_p, *target_memory_p, attrs); *user_memory_p, *target_memory_p, attrs);
auto scales_md =
dnnl::memory::desc({static_cast<int64_t>(scale_data.size())},
dnnl::memory::data_type::f32,
dnnl::memory::format_tag::x);
auto scale_mem =
dnnl::memory(scales_md,
this->engine_,
phi::funcs::to_void_cast<float>(scale_data.data()));
auto& astream = OneDNNContext::tls().get_stream(); auto& astream = OneDNNContext::tls().get_stream();
{ {
platform::RecordEvent record_reorder( reorder_p->execute(astream,
"int_reorder", {{DNNL_ARG_FROM, *user_memory_p},
platform::TracerEventType::UserDefined, {DNNL_ARG_TO, *target_memory_p},
1, {DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, scale_mem}});
platform::EventRole::kUniqueOp);
reorder_p->execute(
astream,
{{DNNL_ARG_FROM, *user_memory_p}, {DNNL_ARG_TO, *target_memory_p}});
astream.wait(); astream.wait();
} }
...@@ -262,6 +266,9 @@ class FCMKLDNNHandler ...@@ -262,6 +266,9 @@ class FCMKLDNNHandler
std::string memory_key_; std::string memory_key_;
const OneDNNContext& dev_ctx_; const OneDNNContext& dev_ctx_;
dnnl::memory src_scales_mem_;
dnnl::memory wei_scales_mem_;
dnnl::memory dst_scales_mem_;
public: public:
std::shared_ptr<dnnl::memory> AcquireSrcMemoryWithReorder( std::shared_ptr<dnnl::memory> AcquireSrcMemoryWithReorder(
...@@ -272,7 +279,7 @@ class FCMKLDNNHandler ...@@ -272,7 +279,7 @@ class FCMKLDNNHandler
if (x->dims().size() != 2) { if (x->dims().size() != 2) {
// reshape restrictions are always satisfied because in case of 3 or 4 dim // reshape restrictions are always satisfied because in case of 3 or 4 dim
// input, plain layout is enforced // input, plain layout is enforced
user_md = user_md.reshape(this->fwd_pd_->src_desc().dims()); user_md = user_md.reshape(this->fwd_pd_->src_desc().get_dims());
} }
return this->AcquireMemoryWithReorder( return this->AcquireMemoryWithReorder(
...@@ -282,36 +289,8 @@ class FCMKLDNNHandler ...@@ -282,36 +289,8 @@ class FCMKLDNNHandler
std::shared_ptr<dnnl::memory> AcquireBiasMemoryWithReorder( std::shared_ptr<dnnl::memory> AcquireBiasMemoryWithReorder(
const ExecutionContext& ctx, const phi::DenseTensor* bias) { const ExecutionContext& ctx, const phi::DenseTensor* bias) {
const float* bias_data = bias->data<float>(); const float* bias_data = bias->data<float>();
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->bias_desc(),
if (phi::funcs::is_int8<T_w>() == false) { to_void_cast<float>(bias_data));
// for BF16/FP32 bias is 1D and has no scales, so reorder is not needed
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->bias_desc(),
to_void_cast<float>(bias_data));
} else {
const std::string bias_key = this->memory_key_ + "@bias";
auto memory_p = std::static_pointer_cast<dnnl::memory>(
this->dev_ctx_.GetBlob(bias_key));
if (!memory_p) {
const auto& scale_data = GetBiasScales(ctx);
dnnl::primitive_attr attrs;
int mask = CreateMask(0, scale_data.size() > 1);
attrs.set_output_scales(mask, scale_data);
auto user_md = dnnl::memory::desc({bias->dims()[0]},
OneDNNGetDataType<float>(),
dnnl::memory::format_tag::a);
memory_p = this->AcquireMemoryWithReorderAndAttrs(
user_md,
this->fwd_pd_->bias_desc(),
to_void_cast<float>(bias_data),
attrs);
this->dev_ctx_.SetBlob(bias_key, memory_p);
}
return memory_p;
}
} }
std::shared_ptr<dnnl::memory> AcquireWeightsMemoryWithReorder( std::shared_ptr<dnnl::memory> AcquireWeightsMemoryWithReorder(
...@@ -322,7 +301,7 @@ class FCMKLDNNHandler ...@@ -322,7 +301,7 @@ class FCMKLDNNHandler
if (!memory_p) { if (!memory_p) {
const float* weights_data = weights->data<float>(); const float* weights_data = weights->data<float>();
auto weights_dims = this->fwd_pd_->weights_desc().dims(); auto weights_dims = this->fwd_pd_->weights_desc().get_dims();
auto user_md = dnnl::memory::desc(weights_dims, auto user_md = dnnl::memory::desc(weights_dims,
OneDNNGetDataType<float>(), OneDNNGetDataType<float>(),
...@@ -331,13 +310,14 @@ class FCMKLDNNHandler ...@@ -331,13 +310,14 @@ class FCMKLDNNHandler
if (phi::funcs::is_int8<T_w>()) { if (phi::funcs::is_int8<T_w>()) {
dnnl::primitive_attr attrs; dnnl::primitive_attr attrs;
int mask = CreateMask(0, scale_data.size() > 1); int mask = CreateMask(0, scale_data.size() > 1);
attrs.set_output_scales(mask, scale_data); attrs.set_scales_mask(DNNL_ARG_SRC, mask);
memory_p = this->AcquireMemoryWithReorderAndAttrs( memory_p = this->AcquireMemoryWithReorderAndAttrs(
user_md, user_md,
this->fwd_pd_->weights_desc(), this->fwd_pd_->weights_desc(),
to_void_cast<float>(weights_data), to_void_cast<float>(weights_data),
attrs); attrs,
scale_data);
} else { } else {
memory_p = memory_p =
this->AcquireMemoryWithReorder(user_md, this->AcquireMemoryWithReorder(user_md,
...@@ -370,7 +350,18 @@ class FCMKLDNNHandler ...@@ -370,7 +350,18 @@ class FCMKLDNNHandler
} }
return this->template AcquireDstMemory<T_out>(out); return this->template AcquireDstMemory<T_out>(out);
} // namespace operators } // namespace operators
}; // namespace paddle
void SetScalesIfNeeded(std::unordered_map<int, dnnl::memory>* args) {
if (src_scales_mem_.get_desc().is_zero() != true) {
args->insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, src_scales_mem_});
args->insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, wei_scales_mem_});
}
// dst scales may be empty when force fp32 output
if (dst_scales_mem_.get(true)) {
args->insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, dst_scales_mem_});
}
}
}; // namespace paddle
#define IF_CHANGE_FC_TW_TYPENAME(condition, ...) \ #define IF_CHANGE_FC_TW_TYPENAME(condition, ...) \
if (condition) { \ if (condition) { \
...@@ -408,7 +399,7 @@ class FCMKLDNNKernel : public framework::OpKernel<T_in> { ...@@ -408,7 +399,7 @@ class FCMKLDNNKernel : public framework::OpKernel<T_in> {
const std::shared_ptr<dnnl::memory>& src_mem, const std::shared_ptr<dnnl::memory>& src_mem,
const phi::DenseTensor* x, const phi::DenseTensor* x,
const dnnl::engine& engine) const { const dnnl::engine& engine) const {
auto x_md = x->mem_desc().reshape(src_mem->get_desc().dims()); auto x_md = x->mem_desc().reshape(src_mem->get_desc().get_dims());
if (x_md != src_mem->get_desc()) { if (x_md != src_mem->get_desc()) {
dnnl::memory x_mem(x_md, engine, to_void_cast<T_in>(x->data<T_in>())); dnnl::memory x_mem(x_md, engine, to_void_cast<T_in>(x->data<T_in>()));
auto reorder_p = dnnl::reorder(x_mem, *src_mem); auto reorder_p = dnnl::reorder(x_mem, *src_mem);
...@@ -453,6 +444,8 @@ class FCMKLDNNKernel : public framework::OpKernel<T_in> { ...@@ -453,6 +444,8 @@ class FCMKLDNNKernel : public framework::OpKernel<T_in> {
RecomputeOutputDims(ctx, x, weights, out); RecomputeOutputDims(ctx, x, weights, out);
std::unordered_map<int, dnnl::memory> fc_args;
if (inner_product_cache) { if (inner_product_cache) {
fc_p = std::make_shared<dnnl::inner_product_forward>( fc_p = std::make_shared<dnnl::inner_product_forward>(
inner_product_cache->inner_product_p); inner_product_cache->inner_product_p);
...@@ -474,9 +467,25 @@ class FCMKLDNNKernel : public framework::OpKernel<T_in> { ...@@ -474,9 +467,25 @@ class FCMKLDNNKernel : public framework::OpKernel<T_in> {
ctx.GetPlace(), dst_memory_p->get_desc().get_size()); ctx.GetPlace(), dst_memory_p->get_desc().get_size());
dst_memory_p->set_data_handle(out_ptr); dst_memory_p->set_data_handle(out_ptr);
fc_args.insert({DNNL_ARG_SRC, *src_memory_p});
fc_args.insert({DNNL_ARG_WEIGHTS, *weights_memory_p});
fc_args.insert({DNNL_ARG_DST, *dst_memory_p});
if (bias) { if (bias) {
bias_memory_p = bias_memory_p =
std::make_shared<dnnl::memory>(inner_product_cache->bias_mem); std::make_shared<dnnl::memory>(inner_product_cache->bias_mem);
fc_args.insert({DNNL_ARG_BIAS, *bias_memory_p});
}
if (inner_product_cache->src_scales_mem.get(true)) {
fc_args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC,
inner_product_cache->src_scales_mem});
fc_args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS,
inner_product_cache->wei_scales_mem});
}
if (inner_product_cache->dst_scales_mem.get(true)) {
fc_args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST,
inner_product_cache->dst_scales_mem});
} }
} else { } else {
auto in_col_dims = ctx.Attr<int>("in_num_col_dims"); auto in_col_dims = ctx.Attr<int>("in_num_col_dims");
...@@ -495,25 +504,23 @@ class FCMKLDNNKernel : public framework::OpKernel<T_in> { ...@@ -495,25 +504,23 @@ class FCMKLDNNKernel : public framework::OpKernel<T_in> {
weights_memory_p = weights_memory_p =
handler.AcquireWeightsMemoryWithReorder(weights, scale_weights); handler.AcquireWeightsMemoryWithReorder(weights, scale_weights);
dst_memory_p = handler.AcquireCustomDstMemory(ctx, out); dst_memory_p = handler.AcquireCustomDstMemory(ctx, out);
fc_args.insert({DNNL_ARG_SRC, *src_memory_p});
fc_args.insert({DNNL_ARG_WEIGHTS, *weights_memory_p});
fc_args.insert({DNNL_ARG_DST, *dst_memory_p});
if (bias) { if (bias) {
bias_memory_p = handler.AcquireBiasMemoryWithReorder(ctx, bias); bias_memory_p = handler.AcquireBiasMemoryWithReorder(ctx, bias);
fc_args.insert({DNNL_ARG_BIAS, *bias_memory_p});
}
if (phi::funcs::is_int8<T_in>()) {
handler.SetScalesIfNeeded(&fc_args);
} }
fc_p = handler.AcquireForwardPrimitive(); fc_p = handler.AcquireForwardPrimitive();
} }
auto& astream = OneDNNContext::tls().get_stream(); auto& astream = OneDNNContext::tls().get_stream();
std::unordered_map<int, dnnl::memory> fc_args = {
{DNNL_ARG_SRC, *src_memory_p},
{DNNL_ARG_WEIGHTS, *weights_memory_p},
{DNNL_ARG_DST, *dst_memory_p}};
if (bias) {
fc_args.insert({DNNL_ARG_BIAS, *bias_memory_p});
}
fc_p->execute(astream, fc_args); fc_p->execute(astream, fc_args);
astream.wait(); astream.wait();
...@@ -526,6 +533,18 @@ class FCMKLDNNKernel : public framework::OpKernel<T_in> { ...@@ -526,6 +533,18 @@ class FCMKLDNNKernel : public framework::OpKernel<T_in> {
if (bias) { if (bias) {
ip_cache->bias_mem = *bias_memory_p; ip_cache->bias_mem = *bias_memory_p;
} }
if (fc_args.count(DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC)) {
ip_cache->src_scales_mem =
fc_args.at(DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC);
ip_cache->wei_scales_mem =
fc_args.at(DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS);
}
if (fc_args.count(DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST)) {
ip_cache->dst_scales_mem =
fc_args.at(DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST);
}
dev_ctx.SetBlob(cache_key, ip_cache); dev_ctx.SetBlob(cache_key, ip_cache);
} }
......
...@@ -37,49 +37,34 @@ class LayerNormOneDNNHandler ...@@ -37,49 +37,34 @@ class LayerNormOneDNNHandler
engine, cpu_place) { engine, cpu_place) {
const auto fwd_prop_kind = is_test ? dnnl::prop_kind::forward_inference const auto fwd_prop_kind = is_test ? dnnl::prop_kind::forward_inference
: dnnl::prop_kind::forward_training; : dnnl::prop_kind::forward_training;
this->AcquireForwardPrimitiveDescriptor( this->AcquireForwardPrimitiveDescriptor(
fwd_prop_kind, x->mem_desc(), epsilon, flags); fwd_prop_kind, x->mem_desc(), x->mem_desc(), epsilon, flags);
} }
std::shared_ptr<dnnl::memory> AcquireScaleShiftMemory( std::tuple<std::shared_ptr<dnnl::memory>, std::shared_ptr<dnnl::memory>>
const phi::DenseTensor* scale, AcquireScaleShiftMemory(const phi::DenseTensor* scale,
const phi::DenseTensor* shift, const phi::DenseTensor* shift) {
const framework::ExecutionContext& ctx) { auto scale_memory = this->AcquireMemoryFromPrimitive(
// OneDNN requires a single piece of memory for scale and shift data. During this->fwd_pd_->weights_desc(),
// inference both pieces of memory are merged inside phi::funcs::to_void_cast<float>(scale->data<float>()));
// layer_norm_onednn_optimization_pass, but during training we have to auto shift_memory = this->AcquireMemoryFromPrimitive(
// manually copy them into new memory buffer this->fwd_pd_->weights_desc(),
auto* scaleshift = ctx.Input<phi::DenseTensor>("ScaleShift"); phi::funcs::to_void_cast<float>(shift->data<float>()));
if (scaleshift) {
return this->AcquireMemoryFromPrimitive( return std::make_tuple(scale_memory, shift_memory);
this->fwd_pd_->weights_desc(),
phi::funcs::to_void_cast(scaleshift->data<float>()));
} else {
const unsigned int C = phi::vectorize(scale->dims())[0];
auto scaleshift_memory =
this->AcquireMemoryFromPrimitive(this->fwd_pd_->weights_desc());
auto mem_data_handle =
reinterpret_cast<float*>(scaleshift_memory->get_data_handle());
std::copy(
scale->data<float>(), scale->data<float>() + C, mem_data_handle);
std::copy(
shift->data<float>(), shift->data<float>() + C, mem_data_handle + C);
return scaleshift_memory;
}
} }
std::shared_ptr<dnnl::memory> AcquireMeanMemory(phi::DenseTensor* mean) { std::shared_ptr<dnnl::memory> AcquireMeanMemory(phi::DenseTensor* mean) {
T* mean_data = mean->mutable_data<T>(this->place_, float* mean_data = mean->mutable_data<float>(
this->fwd_pd_->mean_desc().get_size()); this->place_, this->fwd_pd_->mean_desc().get_size());
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->mean_desc(), return this->AcquireMemoryFromPrimitive(this->fwd_pd_->mean_desc(),
mean_data); mean_data);
} }
std::shared_ptr<dnnl::memory> AcquireVarianceMemory( std::shared_ptr<dnnl::memory> AcquireVarianceMemory(
phi::DenseTensor* variance) { phi::DenseTensor* variance) {
T* variance_data = variance->mutable_data<T>( float* variance_data = variance->mutable_data<float>(
this->place_, this->fwd_pd_->variance_desc().get_size()); this->place_, this->fwd_pd_->variance_desc().get_size());
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->variance_desc(), return this->AcquireMemoryFromPrimitive(this->fwd_pd_->variance_desc(),
variance_data); variance_data);
...@@ -114,7 +99,8 @@ class LayerNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -114,7 +99,8 @@ class LayerNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
dnnl::normalization_flags flags{}; dnnl::normalization_flags flags{};
if (with_scaleshift) { if (with_scaleshift) {
flags |= dnnl::normalization_flags::use_scale_shift; flags |= dnnl::normalization_flags::use_scale |
dnnl::normalization_flags::use_shift;
} }
LayerNormOneDNNHandler<T> handler( LayerNormOneDNNHandler<T> handler(
...@@ -141,9 +127,9 @@ class LayerNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -141,9 +127,9 @@ class LayerNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
} }
if (with_scaleshift) { if (with_scaleshift) {
std::shared_ptr<dnnl::memory> scaleshift_memory = auto scaleshift_mems = handler.AcquireScaleShiftMemory(scale, bias);
handler.AcquireScaleShiftMemory(scale, bias, ctx); args.insert({DNNL_ARG_SCALE, *(std::get<0>(scaleshift_mems))});
args.insert({DNNL_ARG_SCALE_SHIFT, *scaleshift_memory}); args.insert({DNNL_ARG_SHIFT, *(std::get<1>(scaleshift_mems))});
} }
layer_norm_p->execute(astream, args); layer_norm_p->execute(astream, args);
......
...@@ -50,6 +50,7 @@ class LRNOneDNNHandler ...@@ -50,6 +50,7 @@ class LRNOneDNNHandler
: dnnl::prop_kind::forward_training, : dnnl::prop_kind::forward_training,
dnnl::algorithm::lrn_across_channels, dnnl::algorithm::lrn_across_channels,
input->mem_desc(), input->mem_desc(),
input->mem_desc(),
n, n,
alpha, alpha,
beta, beta,
...@@ -80,6 +81,7 @@ class LRNOneDNNHandler ...@@ -80,6 +81,7 @@ class LRNOneDNNHandler
dnnl::prop_kind::forward_training, dnnl::prop_kind::forward_training,
dnnl::algorithm::lrn_across_channels, dnnl::algorithm::lrn_across_channels,
in_x->mem_desc(), in_x->mem_desc(),
in_x->mem_desc(),
n, n,
alpha, alpha,
beta, beta,
...@@ -87,8 +89,9 @@ class LRNOneDNNHandler ...@@ -87,8 +89,9 @@ class LRNOneDNNHandler
this->AcquireBackwardPrimitiveDescriptor( this->AcquireBackwardPrimitiveDescriptor(
dnnl::algorithm::lrn_across_channels, dnnl::algorithm::lrn_across_channels,
in_x->mem_desc(),
out_grad->mem_desc(), out_grad->mem_desc(),
out_grad->mem_desc(),
in_x->mem_desc(),
n, n,
alpha, alpha,
beta, beta,
......
...@@ -160,7 +160,7 @@ class MatMulV1OneDNNHandler ...@@ -160,7 +160,7 @@ class MatMulV1OneDNNHandler
dnnl::primitive_attr matmul_attrs; dnnl::primitive_attr matmul_attrs;
float scale_out = ComputeOutputScale(ctx); float scale_out = ComputeOutputScale(ctx);
if (scale_out != 1.0f) { if (scale_out != 1.0f) {
matmul_attrs.set_output_scales(0, {scale_out}); matmul_attrs.set_scales_mask(DNNL_ARG_SRC, 0);
} }
return matmul_attrs; return matmul_attrs;
} }
...@@ -226,7 +226,9 @@ class MatMulOneDNNHandler ...@@ -226,7 +226,9 @@ class MatMulOneDNNHandler
auto out_md = memory::desc(out_dims, OneDNNGetDataType<OT>(), out_strides); auto out_md = memory::desc(out_dims, OneDNNGetDataType<OT>(), out_strides);
dnnl::primitive_attr attrs; dnnl::primitive_attr attrs;
if (scale != 1.0f) attrs.set_output_scales(0, {scale}); if (scale != 1.0f) {
attrs.set_scales_mask(DNNL_ARG_SRC, 0);
}
this->AcquireForwardPrimitiveDescriptor(attrs, x_md, y_md, out_md); this->AcquireForwardPrimitiveDescriptor(attrs, x_md, y_md, out_md);
} }
...@@ -330,6 +332,15 @@ void ExecuteMatMulV1(const ExecutionContext &ctx, ...@@ -330,6 +332,15 @@ void ExecuteMatMulV1(const ExecutionContext &ctx,
{DNNL_ARG_WEIGHTS, *weights_memory_p}, {DNNL_ARG_WEIGHTS, *weights_memory_p},
{DNNL_ARG_DST, *dst_memory_p}}; {DNNL_ARG_DST, *dst_memory_p}};
float computed_scale_x = handler.ComputeOutputScale(ctx);
if (std::fabs(computed_scale_x - 1.f) > 1e-6f) {
auto scale_x_md = dnnl::memory::desc(
{1}, dnnl::memory::data_type::f32, dnnl::memory::format_tag::x);
auto scale_x_mem =
dnnl::memory(scale_x_md, onednn_engine, &computed_scale_x);
matmul_args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, scale_x_mem});
}
auto &astream = OneDNNContext::tls().get_stream(); auto &astream = OneDNNContext::tls().get_stream();
matmul_p->execute(astream, matmul_args); matmul_p->execute(astream, matmul_args);
astream.wait(); astream.wait();
...@@ -602,6 +613,16 @@ class MatMulGradMKLDNNKernel : public paddle::framework::OpKernel<T> { ...@@ -602,6 +613,16 @@ class MatMulGradMKLDNNKernel : public paddle::framework::OpKernel<T> {
y_combined = *y; y_combined = *y;
} }
float alpha = ctx.HasAttr("alpha") ? ctx.Attr<float>("alpha") : 1.0f;
auto alpha_md = dnnl::memory::desc(
{1}, dnnl::memory::data_type::f32, dnnl::memory::format_tag::x);
auto scale_mem =
alpha != 1.0f
? dnnl::memory(
alpha_md, engine, phi::funcs::to_void_cast<float>(&alpha))
: dnnl::memory();
MatMulOneDNNHandler<T, T, T> handler(engine, MatMulOneDNNHandler<T, T, T> handler(engine,
ctx.GetPlace(), ctx.GetPlace(),
&x_combined, &x_combined,
...@@ -621,6 +642,9 @@ class MatMulGradMKLDNNKernel : public paddle::framework::OpKernel<T> { ...@@ -621,6 +642,9 @@ class MatMulGradMKLDNNKernel : public paddle::framework::OpKernel<T> {
{DNNL_ARG_SRC, *src_memory_p}, {DNNL_ARG_SRC, *src_memory_p},
{DNNL_ARG_WEIGHTS, *weights_memory_p}, {DNNL_ARG_WEIGHTS, *weights_memory_p},
{DNNL_ARG_DST, *dst_memory_p}}; {DNNL_ARG_DST, *dst_memory_p}};
if (alpha != 1.0f) {
matmul_args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, scale_mem});
}
auto &astream = OneDNNContext::tls().get_stream(); auto &astream = OneDNNContext::tls().get_stream();
matmul_p->execute(astream, matmul_args); matmul_p->execute(astream, matmul_args);
......
...@@ -36,7 +36,8 @@ class QuantOpKernel : public framework::OpKernel<T> { ...@@ -36,7 +36,8 @@ class QuantOpKernel : public framework::OpKernel<T> {
auto* out = ctx.Output<phi::DenseTensor>("Output"); auto* out = ctx.Output<phi::DenseTensor>("Output");
const auto quantization_scale = ctx.Attr<float>("Scale"); const auto quantization_scale = ctx.Attr<float>("Scale");
const auto quantization_shift = ctx.Attr<float>("Shift"); const auto quantization_shift =
static_cast<int32_t>(ctx.Attr<float>("Shift"));
const bool with_scale = quantization_scale != 1.0f; const bool with_scale = quantization_scale != 1.0f;
const bool with_shift = quantization_shift != 0.0f; const bool with_shift = quantization_shift != 0.0f;
...@@ -61,12 +62,11 @@ class QuantOpKernel : public framework::OpKernel<T> { ...@@ -61,12 +62,11 @@ class QuantOpKernel : public framework::OpKernel<T> {
static constexpr int32_t mask = 0; static constexpr int32_t mask = 0;
if (with_scale) { if (with_scale) {
attrs.set_output_scales(mask, {quantization_scale}); attrs.set_scales_mask(DNNL_ARG_SRC, mask);
} }
if (with_shift) { if (with_shift) {
attrs.set_zero_points( attrs.set_zero_points_mask(DNNL_ARG_DST, mask);
DNNL_ARG_DST, mask, {static_cast<int32_t>(quantization_shift)});
} }
auto x_type = phi::funcs::ToOneDNNDataType(x->dtype()); auto x_type = phi::funcs::ToOneDNNDataType(x->dtype());
...@@ -94,7 +94,32 @@ class QuantOpKernel : public framework::OpKernel<T> { ...@@ -94,7 +94,32 @@ class QuantOpKernel : public framework::OpKernel<T> {
reorder_dst_memory_p, reorder_src_memory_p, attrs); reorder_dst_memory_p, reorder_src_memory_p, attrs);
auto& astream = phi::OneDNNContext::tls().get_stream(); auto& astream = phi::OneDNNContext::tls().get_stream();
reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p);
auto scales_md = dnnl::memory::desc(
{1}, dnnl::memory::data_type::f32, dnnl::memory::format_tag::x);
auto scales_mem =
dnnl::memory(scales_md,
dev_ctx.GetEngine(),
phi::funcs::to_void_cast<float>(&quantization_scale));
auto zero_points_md = dnnl::memory::desc(
{1}, dnnl::memory::data_type::s32, dnnl::memory::format_tag::x);
auto zero_points_mem =
dnnl::memory(zero_points_md,
dev_ctx.GetEngine(),
phi::funcs::to_void_cast<int32_t>(&quantization_shift));
std::unordered_map<int, dnnl::memory> reorder_args;
reorder_args.insert({DNNL_ARG_SRC, *reorder_src_memory_p});
reorder_args.insert({DNNL_ARG_DST, *reorder_dst_memory_p});
if (with_scale) {
reorder_args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, scales_mem});
}
if (with_shift) {
reorder_args.insert(
{DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_DST, zero_points_mem});
}
reorder_p->execute(astream, reorder_args);
astream.wait(); astream.wait();
out->set_mem_desc(reorder_dst_memory_p->get_desc()); out->set_mem_desc(reorder_dst_memory_p->get_desc());
......
...@@ -39,10 +39,10 @@ class ReQuantOpKernel : public framework::OpKernel<T> { ...@@ -39,10 +39,10 @@ class ReQuantOpKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<phi::DenseTensor>("Input"); auto* input = ctx.Input<phi::DenseTensor>("Input");
auto scale_in = ctx.Attr<float>("Scale_in"); auto scale_in = ctx.Attr<float>("Scale_in");
auto shift_in = ctx.Attr<float>("Shift_in"); auto shift_in = static_cast<int32_t>(ctx.Attr<float>("Shift_in"));
auto scale_out = ctx.Attr<float>("Scale_out"); auto scale_out = ctx.Attr<float>("Scale_out");
auto shift_out = ctx.Attr<float>("Shift_out"); auto shift_out = static_cast<int32_t>(ctx.Attr<float>("Shift_out"));
bool with_shift = shift_in != 0.0f || shift_out != 0.0f; bool with_shift = shift_in != 0 || shift_out != 0;
auto* output = ctx.Output<phi::DenseTensor>("Output"); auto* output = ctx.Output<phi::DenseTensor>("Output");
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
...@@ -53,7 +53,7 @@ class ReQuantOpKernel : public framework::OpKernel<T> { ...@@ -53,7 +53,7 @@ class ReQuantOpKernel : public framework::OpKernel<T> {
scale_out, scale_out,
0.0f, 0.0f,
platform::errors::InvalidArgument("Scale of output cannot be 0.0")); platform::errors::InvalidArgument("Scale of output cannot be 0.0"));
if (shift_in != 0.0f) { if (shift_in != 0) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
input->dtype(), input->dtype(),
DataType::UINT8, DataType::UINT8,
...@@ -68,19 +68,26 @@ class ReQuantOpKernel : public framework::OpKernel<T> { ...@@ -68,19 +68,26 @@ class ReQuantOpKernel : public framework::OpKernel<T> {
auto src_paddle_dt = input->dtype(); auto src_paddle_dt = input->dtype();
auto dst_paddle_dt = with_shift ? DataType::UINT8 : src_paddle_dt; auto dst_paddle_dt = with_shift ? DataType::UINT8 : src_paddle_dt;
auto xstrides = input->mem_desc().data.format_desc.blocking.strides; auto xstrides = input->mem_desc().get_strides();
std::vector<dnnl_dim_t> vstrides(xstrides,
xstrides + input->mem_desc().data.ndims);
dnnl::primitive_attr attrs; dnnl::primitive_attr attrs;
int mask = 0; int mask = 0;
float reorder_scale = scale_out / scale_in; float reorder_scale = scale_in / scale_out;
attrs.set_output_scales(mask, {reorder_scale}); attrs.set_scales_mask(DNNL_ARG_DST, mask);
auto scales_md = dnnl::memory::desc(
{1}, dnnl::memory::data_type::f32, dnnl::memory::format_tag::x);
auto scales_mem =
dnnl::memory(scales_md,
dev_ctx.GetEngine(),
phi::funcs::to_void_cast<float>(&reorder_scale));
uint32_t reorder_shift =
with_shift
? clip_to_uint8(shift_out - (1.0f / reorder_scale) * shift_in)
: 0;
if (with_shift) { if (with_shift) {
uint8_t reorder_shift = attrs.set_zero_points_mask(DNNL_ARG_DST, mask);
clip_to_uint8(shift_out - reorder_scale * shift_in);
attrs.set_zero_points(
DNNL_ARG_DST, mask, {static_cast<int32_t>(reorder_shift)});
} }
phi::funcs::ReorderOneDNNHandler reorder_handler( phi::funcs::ReorderOneDNNHandler reorder_handler(
...@@ -94,13 +101,29 @@ class ReQuantOpKernel : public framework::OpKernel<T> { ...@@ -94,13 +101,29 @@ class ReQuantOpKernel : public framework::OpKernel<T> {
auto src_memory_p = reorder_handler.AcquireSrcMemory( auto src_memory_p = reorder_handler.AcquireSrcMemory(
input->mem_desc(), phi::funcs::to_void_cast(input->data<T>())); input->mem_desc(), phi::funcs::to_void_cast(input->data<T>()));
auto dst_memory_p = reorder_handler.AcquireDstMemory( auto dst_memory_p = reorder_handler.AcquireDstMemory(
output, src_tz, vstrides, dev_ctx.GetPlace()); output, src_tz, xstrides, dev_ctx.GetPlace());
auto reorder_p = auto reorder_p =
reorder_handler.AcquireReorder(dst_memory_p, src_memory_p, attrs); reorder_handler.AcquireReorder(dst_memory_p, src_memory_p, attrs);
auto& astream = phi::OneDNNContext::tls().get_stream(); auto& astream = phi::OneDNNContext::tls().get_stream();
reorder_p->execute(astream, *src_memory_p, *dst_memory_p);
auto zero_points_md = dnnl::memory::desc(
{1}, dnnl::memory::data_type::s32, dnnl::memory::format_tag::x);
auto zero_points_out_mem =
dnnl::memory(zero_points_md, dev_ctx.GetEngine(), &reorder_shift);
std::unordered_map<int, dnnl::memory> reorder_args;
reorder_args.insert({DNNL_ARG_SRC, *src_memory_p});
reorder_args.insert({DNNL_ARG_DST, *dst_memory_p});
reorder_args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, scales_mem});
// shift for DST
if (with_shift) {
reorder_args.insert(
{DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_DST, zero_points_out_mem});
}
reorder_p->execute(astream, reorder_args);
astream.wait(); astream.wait();
output->set_mem_desc(dst_memory_p->get_desc()); output->set_mem_desc(dst_memory_p->get_desc());
......
...@@ -29,8 +29,11 @@ class ShuffleChannelMKLDNNHandler ...@@ -29,8 +29,11 @@ class ShuffleChannelMKLDNNHandler
: phi::funcs::OneDNNHandlerNoCachingT<T, dnnl::shuffle_forward>( : phi::funcs::OneDNNHandlerNoCachingT<T, dnnl::shuffle_forward>(
engine, cpu_place) { engine, cpu_place) {
static constexpr int channel_axis = 1; static constexpr int channel_axis = 1;
this->AcquireForwardPrimitiveDescriptor( this->AcquireForwardPrimitiveDescriptor(dnnl::prop_kind::forward_training,
dnnl::prop_kind::forward_training, x->mem_desc(), channel_axis, group); x->mem_desc(),
x->mem_desc(),
channel_axis,
group);
} }
}; };
......
...@@ -55,13 +55,13 @@ class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -55,13 +55,13 @@ class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto dst_md = auto dst_md =
dnnl::memory::desc(x_vec_dims, dnnl::memory::desc(x_vec_dims,
x->mem_desc().data_type(), x->mem_desc().get_data_type(),
phi::funcs::GetPlainOneDNNFormat(x_vec_dims.size())); phi::funcs::GetPlainOneDNNFormat(x_vec_dims.size()));
auto dst_strides = auto dst_strides =
phi::funcs::FakeTransposeStrides(dst_md.dims(), transpose_axis); phi::funcs::FakeTransposeStrides(dst_md.get_dims(), transpose_axis);
dst_md = dst_md = dnnl::memory::desc(
dnnl::memory::desc(x_vec_dims, x->mem_desc().data_type(), dst_strides); x_vec_dims, x->mem_desc().get_data_type(), dst_strides);
auto dst_data = auto dst_data =
out->mutable_data(ctx.GetPlace(), x->type(), dst_md.get_size()); out->mutable_data(ctx.GetPlace(), x->type(), dst_md.get_size());
......
...@@ -703,9 +703,7 @@ class Pad2dOp : public framework::OperatorWithKernel { ...@@ -703,9 +703,7 @@ class Pad2dOp : public framework::OperatorWithKernel {
// only constant mode and non-blocked layouts are supported for oneDNN // only constant mode and non-blocked layouts are supported for oneDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type) && if (this->CanMKLDNNBeUsed(ctx, input_data_type) &&
ctx.Attr<std::string>("mode") == "constant" && ctx.Attr<std::string>("mode") == "constant" &&
ctx.Input<phi::DenseTensor>("X") ctx.Input<phi::DenseTensor>("X")->mem_desc().get_inner_nblks() == 0) {
->mem_desc()
.data.format_desc.blocking.inner_nblks == 0) {
return phi::KernelKey(phi::Backend::ONEDNN, return phi::KernelKey(phi::Backend::ONEDNN,
phi::DataLayout::ONEDNN, phi::DataLayout::ONEDNN,
phi::TransToPhiDataType(input_data_type)); phi::TransToPhiDataType(input_data_type));
......
...@@ -164,7 +164,7 @@ class SliceOp : public framework::OperatorWithKernel { ...@@ -164,7 +164,7 @@ class SliceOp : public framework::OperatorWithKernel {
// created, so in that scenario a fallback is needed // created, so in that scenario a fallback is needed
if (ctx.Input<phi::DenseTensor>("Input") if (ctx.Input<phi::DenseTensor>("Input")
->mem_desc() ->mem_desc()
.data.format_desc.blocking.inner_nblks == 0) { .get_inner_nblks() == 0) {
return phi::KernelKey(phi::Backend::ONEDNN, return phi::KernelKey(phi::Backend::ONEDNN,
phi::DataLayout::ONEDNN, phi::DataLayout::ONEDNN,
phi::TransToPhiDataType(input_data_type)); phi::TransToPhiDataType(input_data_type));
...@@ -341,7 +341,7 @@ class SliceOpGrad : public framework::OperatorWithKernel { ...@@ -341,7 +341,7 @@ class SliceOpGrad : public framework::OperatorWithKernel {
// created, so in that scenario a fallback is needed // created, so in that scenario a fallback is needed
if (ctx.Input<phi::DenseTensor>(framework::GradVarName("Out")) if (ctx.Input<phi::DenseTensor>(framework::GradVarName("Out"))
->mem_desc() ->mem_desc()
.data.format_desc.blocking.inner_nblks == 0) { .get_inner_nblks() == 0) {
return phi::KernelKey(phi::Backend::ONEDNN, return phi::KernelKey(phi::Backend::ONEDNN,
phi::DataLayout::ONEDNN, phi::DataLayout::ONEDNN,
phi::TransToPhiDataType(input_data_type)); phi::TransToPhiDataType(input_data_type));
......
...@@ -124,7 +124,7 @@ class SplitOp : public framework::OperatorWithKernel { ...@@ -124,7 +124,7 @@ class SplitOp : public framework::OperatorWithKernel {
// 16(depending on which blocking format is used) submemory cannot be // 16(depending on which blocking format is used) submemory cannot be
// created, so in that scenario a fallback is needed // created, so in that scenario a fallback is needed
const auto x_md = ctx.Input<phi::DenseTensor>("X")->mem_desc(); const auto x_md = ctx.Input<phi::DenseTensor>("X")->mem_desc();
if (x_md.data.format_desc.blocking.inner_nblks == 0) { if (x_md.get_inner_nblks() == 0) {
return phi::KernelKey(phi::Backend::ONEDNN, return phi::KernelKey(phi::Backend::ONEDNN,
phi::DataLayout::ONEDNN, phi::DataLayout::ONEDNN,
phi::TransToPhiDataType(input_data_type)); phi::TransToPhiDataType(input_data_type));
......
...@@ -33,14 +33,19 @@ class AXPYHandler { ...@@ -33,14 +33,19 @@ class AXPYHandler {
{n}, OneDNNGetDataType<T>(), dnnl::memory::format_tag::x); {n}, OneDNNGetDataType<T>(), dnnl::memory::format_tag::x);
src_mem_ = dnnl::memory(md, onednn_engine, DNNL_MEMORY_NONE); src_mem_ = dnnl::memory(md, onednn_engine, DNNL_MEMORY_NONE);
dst_mem_ = dnnl::memory(md, onednn_engine, DNNL_MEMORY_NONE); dst_mem_ = dnnl::memory(md, onednn_engine, DNNL_MEMORY_NONE);
dnnl::primitive_attr reorder_attr; dnnl::primitive_attr reorder_attr;
dnnl::post_ops post_operations;
if (alpha != 1.f) { if (alpha != 1.f) {
std::vector<float> scales(1, alpha); reorder_attr.set_scales_mask(DNNL_ARG_FROM, 0); // Ax + b
reorder_attr.set_output_scales(0, scales); auto scales_md = dnnl::memory::desc(
{n}, dnnl::memory::data_type::f32, dnnl::memory::format_tag::x);
src_scales_mem_ = dnnl::memory(scales_md, onednn_engine);
*reinterpret_cast<float *>(src_scales_mem_.get_data_handle()) = alpha;
} }
post_operations.append_sum(1.0f);
dnnl::post_ops post_operations;
post_operations.append_sum(1.0f);
reorder_attr.set_post_ops(post_operations); reorder_attr.set_post_ops(post_operations);
reorder_p_ = dnnl::reorder(src_mem_, dst_mem_, reorder_attr); reorder_p_ = dnnl::reorder(src_mem_, dst_mem_, reorder_attr);
} }
...@@ -50,6 +55,8 @@ class AXPYHandler { ...@@ -50,6 +55,8 @@ class AXPYHandler {
return src_mem_; return src_mem_;
} }
dnnl::memory &AcquireAlphaMemory() { return this->src_scales_mem_; }
dnnl::memory &AcquireDstMemory(T *y) { dnnl::memory &AcquireDstMemory(T *y) {
dst_mem_.set_data_handle(y); dst_mem_.set_data_handle(y);
return dst_mem_; return dst_mem_;
...@@ -59,6 +66,7 @@ class AXPYHandler { ...@@ -59,6 +66,7 @@ class AXPYHandler {
private: private:
dnnl::memory src_mem_; dnnl::memory src_mem_;
dnnl::memory src_scales_mem_;
dnnl::memory dst_mem_; dnnl::memory dst_mem_;
dnnl::reorder reorder_p_; dnnl::reorder reorder_p_;
}; };
...@@ -107,7 +115,16 @@ void OneDNNAXPYHandler<T>::Impl::operator()(const T *x, T *y) { ...@@ -107,7 +115,16 @@ void OneDNNAXPYHandler<T>::Impl::operator()(const T *x, T *y) {
auto &reorder_dst_mem_p = handler_->AcquireDstMemory(y); auto &reorder_dst_mem_p = handler_->AcquireDstMemory(y);
auto reorder_p = handler_->AcquireReorder(); auto reorder_p = handler_->AcquireReorder();
auto &astream = OneDNNContext::tls().get_stream(); auto &astream = OneDNNContext::tls().get_stream();
reorder_p.execute(astream, reorder_src_mem_p, reorder_dst_mem_p);
std::unordered_map<int, dnnl::memory> reorder_args;
reorder_args.insert({DNNL_ARG_SRC, reorder_src_mem_p});
reorder_args.insert({DNNL_ARG_DST, reorder_dst_mem_p});
if (static_cast<float>(this->alpha_) != 1.f) {
reorder_args.insert(
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, handler_->AcquireAlphaMemory()});
}
reorder_p.execute(astream, reorder_args);
astream.wait(); astream.wait();
} }
......
...@@ -36,10 +36,12 @@ void FusedElementwiseKernel(const OneDNNContext& dev_ctx, ...@@ -36,10 +36,12 @@ void FusedElementwiseKernel(const OneDNNContext& dev_ctx,
dnnl::post_ops post_operations; dnnl::post_ops post_operations;
funcs::AppendActivation( funcs::AppendActivation(
dev_ctx, post_operations, 1.0f, fuse_activation, fuse_alpha, fuse_beta); dev_ctx, post_operations, fuse_activation, fuse_alpha, fuse_beta);
if (fused_output_scale != 1.0) { if (fused_output_scale != 1.0) {
// linear post op's formula is `alpha * dst + beta`. Here we only want to
// scale the output not shift it, so the beta is set to 0.0f.
post_operations.append_eltwise( post_operations.append_eltwise(
1.0, dnnl::algorithm::eltwise_linear, fused_output_scale, 0.0f); dnnl::algorithm::eltwise_linear, fused_output_scale, 0.0f);
} }
auto* non_const_x = &x; auto* non_const_x = &x;
...@@ -96,10 +98,19 @@ void FusedElementwiseKernel(const OneDNNContext& dev_ctx, ...@@ -96,10 +98,19 @@ void FusedElementwiseKernel(const OneDNNContext& dev_ctx,
auto& astream = OneDNNContext::tls().get_stream(); auto& astream = OneDNNContext::tls().get_stream();
const std::unordered_map<int, dnnl::memory> args = { std::unordered_map<int, dnnl::memory> args = {{DNNL_ARG_SRC_0, *src_x_memory},
{DNNL_ARG_SRC_0, *src_x_memory}, {DNNL_ARG_SRC_1, *src_y_memory},
{DNNL_ARG_SRC_1, *src_y_memory}, {DNNL_ARG_DST, *dst_memory}};
{DNNL_ARG_DST, *dst_memory}};
if (handler.Has_SRC_0_Scale()) {
args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC_0,
handler.Get_SRC_0_Scale_Memory()});
}
if (handler.Has_SRC_1_Scale()) {
args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC_1,
handler.Get_SRC_1_Scale_Memory()});
}
binary_prim->execute(astream, args); binary_prim->execute(astream, args);
astream.wait(); astream.wait();
...@@ -107,7 +118,7 @@ void FusedElementwiseKernel(const OneDNNContext& dev_ctx, ...@@ -107,7 +118,7 @@ void FusedElementwiseKernel(const OneDNNContext& dev_ctx,
auto out_md = dst_memory->get_desc(); auto out_md = dst_memory->get_desc();
if (handler.use_broadcasting_hack) { if (handler.use_broadcasting_hack) {
auto dims = out_md.dims(); auto dims = out_md.get_dims();
dims.insert(dims.begin(), non_const_x->dims()[0]); dims.insert(dims.begin(), non_const_x->dims()[0]);
dims[1] /= dims[0]; dims[1] /= dims[0];
out_md = out_md.reshape(dims); out_md = out_md.reshape(dims);
......
...@@ -139,17 +139,6 @@ class FusedMatmulOneDNNHandler ...@@ -139,17 +139,6 @@ class FusedMatmulOneDNNHandler
this->AcquireForwardPrimitiveDescriptor(matmul_attrs, x_md, y_md, out_md); this->AcquireForwardPrimitiveDescriptor(matmul_attrs, x_md, y_md, out_md);
} }
float ComputeOutputScale(float matmul_alpha,
const float scale_x,
const float scale_y,
const float scale_in_eltwise UNUSED,
const float scale_out,
const bool force_fp32_output) {
float f_scale_out = force_fp32_output ? 1.0f : scale_out;
matmul_alpha *= f_scale_out / (scale_x * scale_y);
return matmul_alpha;
}
dnnl::primitive_attr CreateMatmulAttrs(const OneDNNContext &dev_ctx, dnnl::primitive_attr CreateMatmulAttrs(const OneDNNContext &dev_ctx,
const DenseTensor *residual_data, const DenseTensor *residual_data,
const float matmul_alpha, const float matmul_alpha,
...@@ -165,14 +154,17 @@ class FusedMatmulOneDNNHandler ...@@ -165,14 +154,17 @@ class FusedMatmulOneDNNHandler
dnnl::primitive_attr matmul_attrs; dnnl::primitive_attr matmul_attrs;
dnnl::post_ops post_operations; dnnl::post_ops post_operations;
float computed_scale_out = ComputeOutputScale(matmul_alpha, if (scale_x != 1.0f) {
scale_x, matmul_attrs.set_scales_mask(DNNL_ARG_SRC, 0);
scale_y, }
scale_in_eltwise,
scale_out, // alpha can be folded to weight scale
force_fp32_output); if (scale_y != 1.0f || matmul_alpha != 1.0f) {
if (computed_scale_out != 1.0f) { matmul_attrs.set_scales_mask(DNNL_ARG_WEIGHTS, 0);
matmul_attrs.set_output_scales(0, {computed_scale_out}); }
if (!force_fp32_output && scale_out != 1.0f) {
matmul_attrs.set_scales_mask(DNNL_ARG_DST, 0);
} }
if (residual_data) { if (residual_data) {
...@@ -183,17 +175,17 @@ class FusedMatmulOneDNNHandler ...@@ -183,17 +175,17 @@ class FusedMatmulOneDNNHandler
post_operations.append_binary(dnnl::algorithm::binary_add, post_operations.append_binary(dnnl::algorithm::binary_add,
residual_data_md); residual_data_md);
if (scale_in_eltwise != 0.0f) { if (scale_in_eltwise != 0.0f) {
float sum_scale = scale_out / scale_in_eltwise; float sum_scale = 1.f / scale_in_eltwise;
post_operations.append_sum(sum_scale); post_operations.append_sum(sum_scale);
} }
} }
funcs::AppendActivation( funcs::AppendActivation(
dev_ctx, post_operations, 1.0f, fuse_activation, fuse_alpha, fuse_beta); dev_ctx, post_operations, fuse_activation, fuse_alpha, fuse_beta);
if (fused_output_scale != 1.0f) { if (fused_output_scale != 1.0f) {
post_operations.append_eltwise( post_operations.append_eltwise(
1.0, dnnl::algorithm::eltwise_linear, fused_output_scale, 0.0f); dnnl::algorithm::eltwise_linear, fused_output_scale, 0.0f);
} }
matmul_attrs.set_post_ops(post_operations); matmul_attrs.set_post_ops(post_operations);
...@@ -281,6 +273,37 @@ void ExecuteFusedMatmul(const OneDNNContext &dev_ctx, ...@@ -281,6 +273,37 @@ void ExecuteFusedMatmul(const OneDNNContext &dev_ctx,
*residual_data_memory_p}); *residual_data_memory_p});
} }
if (scale_x != 1.0f) {
dnnl::memory::desc src_scales_md(
{1}, dnnl::memory::data_type::f32, dnnl::memory::format_tag::x);
auto src_scales_mem =
std::make_shared<dnnl::memory>(src_scales_md, dev_ctx.GetEngine());
*reinterpret_cast<float *>(src_scales_mem->get_data_handle()) =
1.f / scale_x;
matmul_args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, *src_scales_mem});
}
if (scale_y != 1.0f || matmul_alpha != 1.0f) {
dnnl::memory::desc wei_scales_md(
{1}, dnnl::memory::data_type::f32, dnnl::memory::format_tag::x);
auto wei_scales_mem =
std::make_shared<dnnl::memory>(wei_scales_md, dev_ctx.GetEngine());
*reinterpret_cast<float *>(wei_scales_mem->get_data_handle()) =
matmul_alpha / scale_y;
matmul_args.insert(
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, *wei_scales_mem});
}
if (!force_fp32_output && scale_out != 1.0f) {
dnnl::memory::desc dst_scales_md(
{1}, dnnl::memory::data_type::f32, dnnl::memory::format_tag::x);
auto dst_scales_mem =
std::make_shared<dnnl::memory>(dst_scales_md, dev_ctx.GetEngine());
*reinterpret_cast<float *>(dst_scales_mem->get_data_handle()) =
1.f / scale_out;
matmul_args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, *dst_scales_mem});
}
auto &astream = OneDNNContext::tls().get_stream(); auto &astream = OneDNNContext::tls().get_stream();
matmul_p->execute(astream, matmul_args); matmul_p->execute(astream, matmul_args);
astream.wait(); astream.wait();
......
...@@ -27,7 +27,7 @@ void SetInMemDescWithSqueeze2FuseSupport( ...@@ -27,7 +27,7 @@ void SetInMemDescWithSqueeze2FuseSupport(
const dnnl::memory::desc& in_md) { const dnnl::memory::desc& in_md) {
const std::set<int64_t> squeeze2_axes_set(fused_squeeze2_axes.begin(), const std::set<int64_t> squeeze2_axes_set(fused_squeeze2_axes.begin(),
fused_squeeze2_axes.end()); fused_squeeze2_axes.end());
const std::vector<int64_t>& x_vec_dims = in_md.dims(); const std::vector<int64_t>& x_vec_dims = in_md.get_dims();
std::vector<int64_t> squeezed_op_tz( std::vector<int64_t> squeezed_op_tz(
x_vec_dims.size() - fused_squeeze2_axes.size(), 0); x_vec_dims.size() - fused_squeeze2_axes.size(), 0);
...@@ -113,12 +113,12 @@ void FusedTransposeKernel(const Context& dev_ctx, ...@@ -113,12 +113,12 @@ void FusedTransposeKernel(const Context& dev_ctx,
const int32_t mask = 0; const int32_t mask = 0;
if (scale != 1.0f) { if (scale != 1.0f) {
attrs.set_output_scales(mask, {scale}); attrs.set_scales_mask(DNNL_ARG_SRC, mask);
} }
if (shift != 0.0f) { if (shift != 0.0f) {
auto dst = output_data_type == "fp32" ? DNNL_ARG_SRC : DNNL_ARG_DST; auto arg = output_data_type == "fp32" ? DNNL_ARG_SRC : DNNL_ARG_DST;
attrs.set_zero_points(dst, mask, {static_cast<int32_t>(shift)}); attrs.set_zero_points_mask(arg, mask);
} }
DataType out_dtype; DataType out_dtype;
...@@ -149,8 +149,31 @@ void FusedTransposeKernel(const Context& dev_ctx, ...@@ -149,8 +149,31 @@ void FusedTransposeKernel(const Context& dev_ctx,
auto reorder_p = reorder_handler.AcquireReorder( auto reorder_p = reorder_handler.AcquireReorder(
reorder_dst_memory_p, reorder_src_memory_p, attrs); reorder_dst_memory_p, reorder_src_memory_p, attrs);
std::unordered_map<int, dnnl::memory> args = {
{DNNL_ARG_SRC, *reorder_src_memory_p},
{DNNL_ARG_DST, *reorder_dst_memory_p},
};
if (scale != 1.0f) {
auto scales_md = dnnl::memory::desc(
{1}, dnnl::memory::data_type::f32, dnnl::memory::format_tag::x);
auto scales = dnnl::memory(
scales_md, dev_ctx.GetEngine(), const_cast<float*>(&scale));
args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, scales});
}
if (shift != 0.0f) {
auto zps_md = dnnl::memory::desc(
{1}, dnnl::memory::data_type::s32, dnnl::memory::format_tag::x);
auto zps = dnnl::memory(zps_md, dev_ctx.GetEngine());
*reinterpret_cast<int32_t*>(zps.get_data_handle()) =
static_cast<int32_t>(shift);
auto arg = output_data_type == "fp32" ? DNNL_ARG_SRC : DNNL_ARG_DST;
args.insert({DNNL_ARG_ATTR_ZERO_POINTS | arg, zps});
}
auto& astream = OneDNNContext::tls().get_stream(); auto& astream = OneDNNContext::tls().get_stream();
reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p); reorder_p->execute(astream, args);
astream.wait(); astream.wait();
auto out_md = reorder_dst_memory_p->get_desc().permute_axes( auto out_md = reorder_dst_memory_p->get_desc().permute_axes(
...@@ -164,7 +187,7 @@ void FusedTransposeKernel(const Context& dev_ctx, ...@@ -164,7 +187,7 @@ void FusedTransposeKernel(const Context& dev_ctx,
fused_reshape2_shape, out, out_md); fused_reshape2_shape, out, out_md);
} else if (!fused_squeeze2_axes.empty()) { } else if (!fused_squeeze2_axes.empty()) {
out->set_mem_desc(out_md); out->set_mem_desc(out_md);
out->Resize(make_ddim(out_md.dims())); out->Resize(make_ddim(out_md.get_dims()));
} else { } else {
out->set_mem_desc(out_md); out->set_mem_desc(out_md);
} }
......
...@@ -247,7 +247,11 @@ void HardSwishGradKernel(const Context& dev_ctx, ...@@ -247,7 +247,11 @@ void HardSwishGradKernel(const Context& dev_ctx,
const DenseTensor& dout, const DenseTensor& dout,
DenseTensor* dx) { DenseTensor* dx) {
HardSwishOneDNNGradFunctor<T> functor; HardSwishOneDNNGradFunctor<T> functor;
functor(dev_ctx, x, dout, 0, 0, dx); // the formula of oneDNN hardswish primitive is:
// d=s*max(0,min(1,alpha*s+beta)). here, we set alpha=1/6, beta=1/2, to make
// the formula equal to the hardswish definition in Paddle:
// https://www.paddlepaddle.org.cn/documentation/docs/en/api/paddle/nn/functional/hardswish_en.html
functor(dev_ctx, x, dout, 1.0 / 6.0, 1.0 / 2.0, dx);
} }
template <typename T, typename Context> template <typename T, typename Context>
......
...@@ -160,7 +160,7 @@ void HardSwishKernel(const Context& dev_ctx, ...@@ -160,7 +160,7 @@ void HardSwishKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
DenseTensor* out) { DenseTensor* out) {
HardSwishOneDNNFunctor<T> functor; HardSwishOneDNNFunctor<T> functor;
functor(dev_ctx, x, 6, 0, out); functor(dev_ctx, x, 1.0 / 6.0, 1.0 / 2.0, out);
} }
template <typename T, typename Context> template <typename T, typename Context>
......
...@@ -49,16 +49,6 @@ class SumOneDNNHandler : public OneDNNHandlerNoCachingT<T, dnnl::sum> { ...@@ -49,16 +49,6 @@ class SumOneDNNHandler : public OneDNNHandlerNoCachingT<T, dnnl::sum> {
this->AcquireForwardPrimitiveDescriptor(dst_md, scales, srcs_md); this->AcquireForwardPrimitiveDescriptor(dst_md, scales, srcs_md);
} }
// (jczaja) sum oneDNN prim is not having .desc attribute so
// we cannot use base AcquireForwardPrimitiveDescriptor
void AcquireForwardPrimitiveDescriptor(
const dnnl::memory::desc& dst_md,
const std::vector<float>& scales,
const std::vector<dnnl::memory::desc>& srcs_md) {
this->fwd_pd_.reset(
new dnnl::sum::primitive_desc(dst_md, scales, srcs_md, this->engine_));
}
std::shared_ptr<dnnl::memory> AcquireSrcMemory(const DenseTensor* input, std::shared_ptr<dnnl::memory> AcquireSrcMemory(const DenseTensor* input,
int i) { int i) {
const T* input_data = input->data<T>(); const T* input_data = input->data<T>();
......
...@@ -43,19 +43,17 @@ void BatchNormGradRawKernel(const Context& dev_ctx, ...@@ -43,19 +43,17 @@ void BatchNormGradRawKernel(const Context& dev_ctx,
funcs::BatchNormOneDNNHandler<T> handler( funcs::BatchNormOneDNNHandler<T> handler(
dev_ctx.GetEngine(), dev_ctx.GetPlace(), epsilon, &x, &scale, &y_grad); dev_ctx.GetEngine(), dev_ctx.GetPlace(), epsilon, &x, &scale, &y_grad);
const unsigned int C = vectorize(scale.dims())[0]; T* diff_scale_data = dev_ctx.template Alloc<T>(scale_grad);
const size_t scaleshift_size = 2 * C; T* diff_shift_data = dev_ctx.template Alloc<T>(bias_grad);
std::vector<T> diff_scaleshift_data;
diff_scaleshift_data.reserve(scaleshift_size);
auto src_memory = handler.AcquireSrcMemory(&x); auto src_memory = handler.AcquireSrcMemory(&x);
auto mean_memory = handler.AcquireMeanMemory(&saved_mean); auto mean_memory = handler.AcquireMeanMemory(&saved_mean);
auto variance_memory = handler.AcquireVarianceMemory(&saved_variance); auto variance_memory = handler.AcquireVarianceMemory(&saved_variance);
auto diff_dst_memory = handler.AcquireDiffDstMemory(&y_grad); auto diff_dst_memory = handler.AcquireDiffDstMemory(&y_grad);
auto scaleshift_memory = handler.AcquireScaleShiftMemory(&scale, &bias); auto scaleshift_mems = handler.AcquireScaleShiftMemory(&scale, &bias);
auto diff_src_memory = handler.AcquireDiffSrcMemory(x_grad); auto diff_src_memory = handler.AcquireDiffSrcMemory(x_grad);
auto diff_scaleshift_memory = auto diff_scaleshift_mems =
handler.AcquireDiffScaleShiftMemory(diff_scaleshift_data.data()); handler.AcquireDiffScaleShiftMemory(diff_scale_data, diff_shift_data);
auto batch_norm_bwd_p = handler.AcquireBackwardPrimitive(); auto batch_norm_bwd_p = handler.AcquireBackwardPrimitive();
...@@ -66,20 +64,12 @@ void BatchNormGradRawKernel(const Context& dev_ctx, ...@@ -66,20 +64,12 @@ void BatchNormGradRawKernel(const Context& dev_ctx,
{DNNL_ARG_MEAN, *mean_memory}, {DNNL_ARG_MEAN, *mean_memory},
{DNNL_ARG_VARIANCE, *variance_memory}, {DNNL_ARG_VARIANCE, *variance_memory},
{DNNL_ARG_DIFF_DST, *diff_dst_memory}, {DNNL_ARG_DIFF_DST, *diff_dst_memory},
{DNNL_ARG_SCALE_SHIFT, *scaleshift_memory}, {DNNL_ARG_SCALE, *(std::get<0>(scaleshift_mems))},
{DNNL_ARG_DIFF_SRC, *diff_src_memory}, {DNNL_ARG_DIFF_SRC, *diff_src_memory},
{DNNL_ARG_DIFF_SCALE_SHIFT, *diff_scaleshift_memory}}); {DNNL_ARG_DIFF_SCALE, *(std::get<0>(diff_scaleshift_mems))},
{DNNL_ARG_DIFF_SHIFT, *(std::get<1>(diff_scaleshift_mems))}});
astream.wait(); astream.wait();
T* diff_scale_data = dev_ctx.template Alloc<T>(scale_grad);
T* diff_shift_data = dev_ctx.template Alloc<T>(bias_grad);
// copy back diff scale/shift to output tensors (diff scale/shift)
diff_scaleshift_data.resize(scaleshift_size);
auto it = std::begin(diff_scaleshift_data);
std::copy(it, std::next(it, C), diff_scale_data);
std::copy(std::next(it, C), std::end(diff_scaleshift_data), diff_shift_data);
// set memory descriptor of out tensor // set memory descriptor of out tensor
x_grad->set_mem_desc(diff_src_memory->get_desc()); x_grad->set_mem_desc(diff_src_memory->get_desc());
} }
......
...@@ -58,7 +58,7 @@ void BatchNormKernel(const Context &dev_ctx, ...@@ -58,7 +58,7 @@ void BatchNormKernel(const Context &dev_ctx,
test_mode); test_mode);
auto src_memory = handler.AcquireSrcMemory(&x); auto src_memory = handler.AcquireSrcMemory(&x);
auto scaleshift_memory = handler.AcquireScaleShiftMemory(&scale, &bias); auto scaleshift_mems = handler.AcquireScaleShiftMemory(&scale, &bias);
auto dst_memory = handler.AcquireDstMemory(y); auto dst_memory = handler.AcquireDstMemory(y);
auto batch_norm_p = handler.AcquireForwardPrimitive(); auto batch_norm_p = handler.AcquireForwardPrimitive();
...@@ -79,7 +79,8 @@ void BatchNormKernel(const Context &dev_ctx, ...@@ -79,7 +79,8 @@ void BatchNormKernel(const Context &dev_ctx,
auto &astream = OneDNNContext::tls().get_stream(); auto &astream = OneDNNContext::tls().get_stream();
batch_norm_p->execute(astream, batch_norm_p->execute(astream,
{{DNNL_ARG_SRC, *src_memory}, {{DNNL_ARG_SRC, *src_memory},
{DNNL_ARG_SCALE_SHIFT, *scaleshift_memory}, {DNNL_ARG_SCALE, *(std::get<0>(scaleshift_mems))},
{DNNL_ARG_SHIFT, *(std::get<1>(scaleshift_mems))},
{DNNL_ARG_MEAN, *mean_memory}, {DNNL_ARG_MEAN, *mean_memory},
{DNNL_ARG_VARIANCE, *variance_memory}, {DNNL_ARG_VARIANCE, *variance_memory},
{DNNL_ARG_DST, *dst_memory}}); {DNNL_ARG_DST, *dst_memory}});
......
...@@ -63,16 +63,6 @@ class ConcatOneDNNHandler : public OneDNNHandlerNoCachingT<T, dnnl::concat> { ...@@ -63,16 +63,6 @@ class ConcatOneDNNHandler : public OneDNNHandlerNoCachingT<T, dnnl::concat> {
this->AcquireForwardPrimitiveDescriptor(dst_md, concat_axis, srcs_md); this->AcquireForwardPrimitiveDescriptor(dst_md, concat_axis, srcs_md);
} }
// (jczaja) concat oneDNN prim is not having .desc attribute so
// we cannot use base AcquireForwardPrimitiveDescriptor
void AcquireForwardPrimitiveDescriptor(
const memory::desc& dst_md,
const int concat_axis,
const std::vector<memory::desc>& srcs_md) {
this->fwd_pd_.reset(new dnnl::concat::primitive_desc(
dst_md, concat_axis, srcs_md, this->engine_));
}
std::shared_ptr<dnnl::memory> AcquireSrcMemory(const DenseTensor& input, std::shared_ptr<dnnl::memory> AcquireSrcMemory(const DenseTensor& input,
int i) { int i) {
const T* input_data = input.data<T>(); const T* input_data = input.data<T>();
......
...@@ -241,30 +241,23 @@ void ComputeINT8(const OneDNNContext& dev_ctx, ...@@ -241,30 +241,23 @@ void ComputeINT8(const OneDNNContext& dev_ctx,
{DNNL_ARG_DST, *dst_memory_p}}; {DNNL_ARG_DST, *dst_memory_p}};
if (bias) { if (bias) {
std::vector<float> bias_scales; auto bias_memory_p = handler.AcquireBiasMemoryWithReorder(bias, true);
auto p_scales_tuple =
std::make_shared<std::tuple<float, std::vector<float>>>(
std::make_tuple(static_cast<float>(mask_reorder),
bias_scales));
if (dev_ctx.HasDnnAttr("Bias_scales")) {
bias_scales = PADDLE_GET_CONST(std::vector<float>,
dev_ctx.GetDnnAttr("Bias_scales"));
p_scales_tuple =
std::make_shared<std::tuple<float, std::vector<float>>>(
std::make_tuple(static_cast<float>(mask_reorder),
bias_scales));
} else {
p_scales_tuple = handler.get_int8_bias_scales(
filter, groups, scale_weights_data);
}
auto bias_memory_p = handler.AcquireBiasMemoryWithReorder(
bias,
true,
std::get<1>(*p_scales_tuple),
std::get<0>(*p_scales_tuple));
args.insert({DNNL_ARG_BIAS, *bias_memory_p}); args.insert({DNNL_ARG_BIAS, *bias_memory_p});
} }
auto src_scales_memory = handler.AcquireScalesMemory(DNNL_ARG_SRC);
args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, *src_scales_memory});
auto wei_scales_memory = handler.AcquireScalesMemory(DNNL_ARG_WEIGHTS);
args.insert(
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, *wei_scales_memory});
if (!force_fp32_output) {
auto dst_scales_memory = handler.AcquireScalesMemory(DNNL_ARG_DST);
args.insert(
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, *dst_scales_memory});
}
auto& astream = OneDNNContext::tls().get_stream(); auto& astream = OneDNNContext::tls().get_stream();
conv_p->execute(astream, args); conv_p->execute(astream, args);
astream.wait(); astream.wait();
......
...@@ -127,7 +127,7 @@ void ConvGradKernel(const Context& dev_ctx, ...@@ -127,7 +127,7 @@ void ConvGradKernel(const Context& dev_ctx,
// goidhw) for 2d conv with groups (five dimensional data reorder // goidhw) for 2d conv with groups (five dimensional data reorder
// to goihw) auto weights_tz = phi::vectorize(filter->dims()); // to goihw) auto weights_tz = phi::vectorize(filter->dims());
auto weights_tz = diff_weights_memory_p->get_desc().dims(); auto weights_tz = diff_weights_memory_p->get_desc().get_dims();
dnnl::memory::format_tag out_format = dnnl::memory::format_tag out_format =
weights_tz.size() == 6 ? dnnl::memory::format_tag::goidhw weights_tz.size() == 6 ? dnnl::memory::format_tag::goidhw
: dnnl::memory::format_tag::goihw; : dnnl::memory::format_tag::goihw;
......
...@@ -183,8 +183,7 @@ class ConvOneDNNHandlerT ...@@ -183,8 +183,7 @@ class ConvOneDNNHandlerT
const auto dst_md = funcs::OneDNNMemDesc( const auto dst_md = funcs::OneDNNMemDesc(
dst_tz, funcs::OneDNNGetDataType<T_out>(), chosen_memory_format); dst_tz, funcs::OneDNNGetDataType<T_out>(), chosen_memory_format);
const auto fwd_prop_kind = is_test ? dnnl::prop_kind::forward_inference const auto fwd_prop_kind = dnnl::prop_kind::forward_inference;
: dnnl::prop_kind::forward_training;
const dnnl::primitive_attr conv_attr = CreateConvAttrs(filter, const dnnl::primitive_attr conv_attr = CreateConvAttrs(filter,
groups, groups,
force_fp32_output, force_fp32_output,
...@@ -193,15 +192,10 @@ class ConvOneDNNHandlerT ...@@ -193,15 +192,10 @@ class ConvOneDNNHandlerT
if (bias) { if (bias) {
auto bias_tz = phi::vectorize(bias->dims()); auto bias_tz = phi::vectorize(bias->dims());
dnnl::memory::desc bias_md; dnnl::memory::desc bias_md =
if (funcs::is_int8<T>()) { funcs::OneDNNMemDesc(bias_tz,
bias_md = funcs::OneDNNMemDesc(bias_tz, dnnl::memory::data_type::f32,
dnnl::memory::data_type::s32, funcs::OneDNNMemoryFormat::x);
funcs::OneDNNMemoryFormat::x);
} else {
bias_md = funcs::OneDNNMemDesc(
bias_tz, data_type, funcs::OneDNNMemoryFormat::x);
}
this->AcquireForwardPrimitiveDescriptor( this->AcquireForwardPrimitiveDescriptor(
conv_attr, conv_attr,
...@@ -340,20 +334,14 @@ class ConvOneDNNHandlerT ...@@ -340,20 +334,14 @@ class ConvOneDNNHandlerT
dnnl::primitive_attr conv_attr; dnnl::primitive_attr conv_attr;
if (bias) { if (bias) {
auto bias_tz = phi::vectorize(bias->dims()); auto bias_tz = phi::vectorize(bias->dims());
dnnl::memory::desc bias_md; dnnl::memory::desc bias_md =
if (funcs::is_int8<T>()) { funcs::OneDNNMemDesc(bias_tz,
bias_md = funcs::OneDNNMemDesc(bias_tz, dnnl::memory::data_type::f32,
dnnl::memory::data_type::s32, funcs::OneDNNMemoryFormat::x);
funcs::OneDNNMemoryFormat::x);
} else {
bias_md = funcs::OneDNNMemDesc(bias_tz,
dnnl::memory::data_type::f32,
funcs::OneDNNMemoryFormat::x);
}
this->AcquireForwardPrimitiveDescriptor( this->AcquireForwardPrimitiveDescriptor(
conv_attr, conv_attr,
dnnl::prop_kind::forward_training, dnnl::prop_kind::forward_inference,
dnnl::algorithm::convolution_direct, dnnl::algorithm::convolution_direct,
src_md, src_md,
weights_md, weights_md,
...@@ -366,7 +354,7 @@ class ConvOneDNNHandlerT ...@@ -366,7 +354,7 @@ class ConvOneDNNHandlerT
} else { } else {
this->AcquireForwardPrimitiveDescriptor( this->AcquireForwardPrimitiveDescriptor(
conv_attr, conv_attr,
dnnl::prop_kind::forward_training, dnnl::prop_kind::forward_inference,
dnnl::algorithm::convolution_direct, dnnl::algorithm::convolution_direct,
src_md, src_md,
weights_md, weights_md,
...@@ -399,110 +387,6 @@ class ConvOneDNNHandlerT ...@@ -399,110 +387,6 @@ class ConvOneDNNHandlerT
} }
} }
std::shared_ptr<std::tuple<float, std::vector<float>>> get_int8_bias_scales(
const DenseTensor* filter,
int groups,
const std::vector<float>& scale_weights_data) {
// Get scales int8 bias key
const std::string key_bs = this->key_ + "@bs";
// Scales for int8 bias are to be cached to avoid
// computing them each iteration
groups = std::max(groups, 1);
auto bias_scale_tuple =
std::static_pointer_cast<std::tuple<float, std::vector<float>>>(
this->dev_ctx_.GetBlob(key_bs));
if (bias_scale_tuple) return bias_scale_tuple;
const auto& weights_tz = phi::vectorize(filter->dims());
const auto& scale_in_data =
this->dev_ctx_.HasDnnAttr("Scale_in")
? PADDLE_GET_CONST(float, this->dev_ctx_.GetDnnAttr("Scale_in"))
: 1.0f;
bool is_multi_channel = scale_weights_data.size() > 1;
int mask_reorder = is_multi_channel ? 1 << 0 : 1;
int count = 1;
if (is_multi_channel) {
count *= weights_tz[0];
if (groups > 1) {
count *= weights_tz[1];
}
}
bias_scale_tuple =
std::make_shared<std::tuple<float, std::vector<float>>>(std::make_tuple(
static_cast<float>(mask_reorder), std::vector<float>(count)));
for (int i = 0; i < count; i++) {
std::get<1>(*bias_scale_tuple)[i] = scale_in_data * scale_weights_data[i];
}
this->dev_ctx_.SetBlob(key_bs, bias_scale_tuple);
return bias_scale_tuple;
}
std::tuple<float, std::vector<float>, float> get_int8_scales(
const DenseTensor* filter,
int groups,
bool force_fp32_output,
bool fuse_residual_conn,
const std::string& fuse_activation) const {
const auto& weights_tz = phi::vectorize(filter->dims());
groups = std::max(groups, 1);
const auto& scale_weights_data =
this->dev_ctx_.HasDnnAttr("Scale_weights")
? PADDLE_GET_CONST(std::vector<float>,
this->dev_ctx_.GetDnnAttr("Scale_weights"))
: std::vector<float>{1.0f};
const auto& scale_in_data =
this->dev_ctx_.HasDnnAttr("Scale_in")
? PADDLE_GET_CONST(float, this->dev_ctx_.GetDnnAttr("Scale_in"))
: 1.0f;
const auto& scale_in_eltwise_data =
this->dev_ctx_.HasDnnAttr("Scale_in_eltwise")
? PADDLE_GET_CONST(float,
this->dev_ctx_.GetDnnAttr("Scale_in_eltwise"))
: 1.0f;
bool is_multi_channel = scale_weights_data.size() > 1;
bool has_activation = !fuse_activation.empty();
const auto& scale_out =
this->dev_ctx_.HasDnnAttr("Scale_out")
? PADDLE_GET_CONST(float, this->dev_ctx_.GetDnnAttr("Scale_out"))
: 1.0f;
float activation_scale =
(!force_fp32_output && has_activation) ? scale_out : 1.0f;
float scale_out_data =
(force_fp32_output || has_activation) ? 1.0f : scale_out;
float sum_scale =
fuse_residual_conn ? scale_out_data / scale_in_eltwise_data : 1.0f;
int count =
is_multi_channel
? (groups > 1 ? (weights_tz)[1] * (weights_tz)[0] : (weights_tz)[0])
: 1;
std::vector<float> output_shift_scale(count);
#pragma omp parallel for if (count > 50)
for (int i = 0; i < count; i++) {
if (scale_weights_data[i] == 0.0)
// weights data will contain 0 in some models, then weights
// scale couldn't be calculated
output_shift_scale[i] = scale_out_data;
else
output_shift_scale[i] =
static_cast<float>(static_cast<double>(scale_out_data) /
(static_cast<double>(scale_in_data) *
static_cast<double>(scale_weights_data[i])));
}
return std::make_tuple(sum_scale, output_shift_scale, activation_scale);
}
dnnl::primitive_attr CreateConvAttrs(const DenseTensor* filter, dnnl::primitive_attr CreateConvAttrs(const DenseTensor* filter,
int groups, int groups,
bool force_fp32_output, bool force_fp32_output,
...@@ -512,36 +396,30 @@ class ConvOneDNNHandlerT ...@@ -512,36 +396,30 @@ class ConvOneDNNHandlerT
dnnl::post_ops post_operations; dnnl::post_ops post_operations;
float sum_scale = 1.0f; float sum_scale = 1.0f;
float activation_scale = 1.0f;
std::vector<float> output_shift_scale; std::vector<float> output_shift_scale;
if (funcs::is_int8<T>()) { if (funcs::is_int8<T>()) {
if (this->dev_ctx_.HasDnnAttr("Sum_scale")) { conv_attr.set_scales_mask(DNNL_ARG_SRC, 0);
sum_scale =
PADDLE_GET_CONST(float, this->dev_ctx_.GetDnnAttr("Sum_scale")); auto wei_scales = ConvertToDNNLScales("Scale_weights");
activation_scale = // By oneDNN API definition:
this->dev_ctx_.HasDnnAttr("Activation_scale") // - For per-tensor quantization: the mask should be 0
? PADDLE_GET_CONST( // - For per-dimension quantization: the mask should be 1 <<
float, this->dev_ctx_.GetDnnAttr("Activation_scale")) // dimension_index Here, wei_scales.size() != 1 means per-channel
: activation_scale; // quantization, the channel index in oneDNN is always 0, so we use mask =
output_shift_scale = // 1 << 0. If the conv is group, the weights shape will be [g, oc/g, ic,
this->dev_ctx_.HasDnnAttr("Output_shift_scale") // h, w], we need to do scaling along both group dim and oc dim, so the
? PADDLE_GET_CONST( // mask = (1 << 0) + (1 << 1).
std::vector<float>, int mask = wei_scales.size() == 1
this->dev_ctx_.GetDnnAttr("Output_shift_scale")) ? 0
: output_shift_scale; : (groups > 1 ? ((1 << 0) + (1 << 1)) : 1 << 0);
} else { conv_attr.set_scales_mask(DNNL_ARG_WEIGHTS, mask);
std::tie(sum_scale, output_shift_scale, activation_scale) =
get_int8_scales(filter, if (!force_fp32_output) {
groups, conv_attr.set_scales_mask(DNNL_ARG_DST, 0);
force_fp32_output,
fuse_residual_conn,
fuse_activation);
} }
if (output_shift_scale.size() > 0) { auto psum_scales = ConvertToDNNLScales("Scale_in_eltwise");
int mask = output_shift_scale.size() > 1 ? 1 << 1 : 0; sum_scale = psum_scales[0];
conv_attr.set_output_scales(mask, output_shift_scale);
}
} }
// Fusion with Elementwise layer relies on adding a sum post-operation with // Fusion with Elementwise layer relies on adding a sum post-operation with
...@@ -553,7 +431,7 @@ class ConvOneDNNHandlerT ...@@ -553,7 +431,7 @@ class ConvOneDNNHandlerT
post_operations.append_sum(sum_scale); post_operations.append_sum(sum_scale);
} }
funcs::AppendActivation(this->dev_ctx_, post_operations, activation_scale); funcs::AppendActivation(this->dev_ctx_, post_operations);
conv_attr.set_post_ops(post_operations); conv_attr.set_post_ops(post_operations);
return conv_attr; return conv_attr;
...@@ -750,6 +628,69 @@ class ConvOneDNNHandlerT ...@@ -750,6 +628,69 @@ class ConvOneDNNHandlerT
this->AcquireReorder(residual_memory_p, dst_memory_p); this->AcquireReorder(residual_memory_p, dst_memory_p);
return dst_memory_p; return dst_memory_p;
} }
// Currently, 4 kind of onednn scales are supported: src scales, weight
// scales, post-sum scales and dst scales. This function is used to convert
// paddle scales to onednn scales
std::vector<float> ConvertToDNNLScales(const std::string& attr_name) {
std::vector<float> paddle_scales;
// weight scales is vector but other scales are scalar
if (attr_name == "Scale_weights") {
paddle_scales =
this->dev_ctx_.HasDnnAttr(attr_name)
? PADDLE_GET_CONST(std::vector<float>,
this->dev_ctx_.GetDnnAttr(attr_name))
: std::vector<float>{1.0f};
} else {
float scale =
this->dev_ctx_.HasDnnAttr(attr_name)
? PADDLE_GET_CONST(float, this->dev_ctx_.GetDnnAttr(attr_name))
: 1.0f;
paddle_scales = std::vector<float>{scale};
}
size_t count = paddle_scales.size();
std::vector<float> dnnl_scales(count);
#pragma omp parallel for if (count > 50)
for (size_t i = 0; i < count; i++) {
dnnl_scales[i] = 1.f / paddle_scales[i];
}
return dnnl_scales;
}
std::shared_ptr<dnnl::memory> AcquireScalesMemory(int dnnl_arg) {
// <dnnl_arg, {cache_key_suffix, attr_name}>
std::unordered_map<int, std::pair<std::string, std::string>> map = {
{DNNL_ARG_SRC, {"@src_scales", "Scale_in"}},
{DNNL_ARG_WEIGHTS, {"@wei_scales", "Scale_weights"}},
{DNNL_ARG_DST, {"@dst_scales", "Scale_out"}},
};
std::string cache_key_suffix, attr_name;
std::tie(cache_key_suffix, attr_name) = map.at(dnnl_arg);
// first look up the cache
auto dnnl_scales_mem = this->AcquireMemory(cache_key_suffix);
if (!dnnl_scales_mem) {
// cache miss, so construct scales memory from the paddle scales
// attributes
auto dnnl_scales = ConvertToDNNLScales(attr_name);
dnnl::memory::desc dnnl_scales_md(
{static_cast<int64_t>(dnnl_scales.size())},
dnnl::memory::data_type::f32,
dnnl::memory::format_tag::x);
dnnl_scales_mem =
std::make_shared<dnnl::memory>(dnnl_scales_md, this->engine_);
memcpy(dnnl_scales_mem->get_data_handle(),
dnnl_scales.data(),
dnnl_scales.size() * sizeof(float));
// cache the constructed memory
this->CacheMemory(cache_key_suffix, dnnl_scales_mem);
}
return dnnl_scales_mem;
}
}; };
} // namespace onednn } // namespace onednn
......
...@@ -253,7 +253,7 @@ class ConvTransposeOneDNNHandlerT ...@@ -253,7 +253,7 @@ class ConvTransposeOneDNNHandlerT
dnnl::reorder::primitive_desc reorder_pdesc; dnnl::reorder::primitive_desc reorder_pdesc;
if (funcs::is_int8<T>()) { if (funcs::is_int8<T>()) {
dnnl::primitive_attr attr; dnnl::primitive_attr attr;
attr.set_output_scales(mask, scale_data); attr.set_scales_mask(DNNL_ARG_DST, mask);
reorder_pdesc = dnnl::reorder::primitive_desc( reorder_pdesc = dnnl::reorder::primitive_desc(
*user_memory_p, *target_memory_p, attr); *user_memory_p, *target_memory_p, attr);
} else { } else {
...@@ -264,9 +264,22 @@ class ConvTransposeOneDNNHandlerT ...@@ -264,9 +264,22 @@ class ConvTransposeOneDNNHandlerT
dev_ctx.SetBlob(key_reorder_p, reorder_p); dev_ctx.SetBlob(key_reorder_p, reorder_p);
auto& astream = OneDNNContext::tls().get_stream(); auto& astream = OneDNNContext::tls().get_stream();
reorder_p->execute(
astream, std::unordered_map<int, dnnl::memory> reorder_args;
{{DNNL_ARG_FROM, *user_memory_p}, {DNNL_ARG_TO, *target_memory_p}}); reorder_args.insert({DNNL_ARG_SRC, *user_memory_p});
reorder_args.insert({DNNL_ARG_DST, *target_memory_p});
if (funcs::is_int8<T>()) {
auto scale_md =
dnnl::memory::desc({static_cast<int64_t>(scale_data.size())},
dnnl::memory::data_type::f32,
dnnl::memory::format_tag::x);
auto scale_data_mem = dnnl::memory(scale_md, this->engine_);
scale_data_mem.set_data_handle(
phi::funcs::to_void_cast(scale_data.data()));
reorder_args.insert(
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, scale_data_mem});
}
reorder_p->execute(astream, reorder_args);
astream.wait(); astream.wait();
} else { } else {
target_memory_p = user_memory_p; target_memory_p = user_memory_p;
......
...@@ -28,17 +28,21 @@ void DeQuantKernel(const Context& dev_ctx, ...@@ -28,17 +28,21 @@ void DeQuantKernel(const Context& dev_ctx,
const float quantization_scale, const float quantization_scale,
const float quantization_shift, const float quantization_shift,
DenseTensor* out) { DenseTensor* out) {
const bool with_shift = quantization_shift != 0.0f;
PADDLE_ENFORCE(quantization_scale != 0.0f, PADDLE_ENFORCE(quantization_scale != 0.0f,
phi::errors::InvalidArgument( phi::errors::InvalidArgument(
"Dequantization scale must be different than 0.0f")); "Dequantization scale must be different than 0.0f"));
PADDLE_ENFORCE(quantization_shift <= 255 && quantization_shift >= 0, const auto q_shift = static_cast<int32_t>(quantization_shift);
phi::errors::InvalidArgument( PADDLE_ENFORCE_GE(q_shift,
"Dequantization shift must be lower or equal to ", 0,
"255 and greater or equal to 0, but got %f", phi::errors::InvalidArgument(
quantization_shift)); "Dequantization shift must be greater or equal to 0"));
PADDLE_ENFORCE_LE(q_shift,
255,
phi::errors::InvalidArgument(
"Dequantization shift must be lower or equal to 255"));
const bool with_shift = q_shift != 0;
auto x_tz = phi::vectorize<int64_t>(x.dims()); auto x_tz = phi::vectorize<int64_t>(x.dims());
auto x_type = phi::funcs::ToOneDNNDataType(x.dtype()); auto x_type = phi::funcs::ToOneDNNDataType(x.dtype());
...@@ -47,12 +51,10 @@ void DeQuantKernel(const Context& dev_ctx, ...@@ -47,12 +51,10 @@ void DeQuantKernel(const Context& dev_ctx,
dnnl::primitive_attr attrs; dnnl::primitive_attr attrs;
static constexpr int32_t mask = 0; // same shift and scale for whole tensor static constexpr int32_t mask = 0; // same shift and scale for whole tensor
const float reorder_scale = 1. / quantization_scale; attrs.set_scales_mask(DNNL_ARG_DST, mask);
attrs.set_output_scales(mask, {reorder_scale});
if (with_shift) { if (with_shift) {
attrs.set_zero_points( attrs.set_zero_points_mask(DNNL_ARG_SRC, mask);
DNNL_ARG_SRC, mask, {static_cast<int32_t>(quantization_shift)});
} }
phi::funcs::ReorderOneDNNHandler reorder_handler( phi::funcs::ReorderOneDNNHandler reorder_handler(
...@@ -67,7 +69,29 @@ void DeQuantKernel(const Context& dev_ctx, ...@@ -67,7 +69,29 @@ void DeQuantKernel(const Context& dev_ctx,
reorder_dst_memory_p, reorder_src_memory_p, attrs); reorder_dst_memory_p, reorder_src_memory_p, attrs);
auto& astream = phi::OneDNNContext::tls().get_stream(); auto& astream = phi::OneDNNContext::tls().get_stream();
reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p);
auto scales_md = dnnl::memory::desc(
{1}, dnnl::memory::data_type::f32, dnnl::memory::format_tag::x);
auto scales_mem =
dnnl::memory(scales_md,
dev_ctx.GetEngine(),
phi::funcs::to_void_cast<float>(&quantization_scale));
auto zero_points_md = dnnl::memory::desc(
{1}, dnnl::memory::data_type::s32, dnnl::memory::format_tag::x);
auto zero_points_mem =
dnnl::memory(zero_points_md,
dev_ctx.GetEngine(),
phi::funcs::to_void_cast<int32_t>(&q_shift));
std::unordered_map<int, dnnl::memory> reorder_args;
reorder_args.insert({DNNL_ARG_SRC, *reorder_src_memory_p});
reorder_args.insert({DNNL_ARG_DST, *reorder_dst_memory_p});
reorder_args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, scales_mem});
if (with_shift) {
reorder_args.insert(
{DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC, zero_points_mem});
}
reorder_p->execute(astream, reorder_args);
astream.wait(); astream.wait();
out->set_mem_desc(reorder_dst_memory_p->get_desc()); out->set_mem_desc(reorder_dst_memory_p->get_desc());
......
...@@ -49,14 +49,18 @@ inline void AddSubNonBroadcast(ReorderOneDNNHandler* reorder_handler, ...@@ -49,14 +49,18 @@ inline void AddSubNonBroadcast(ReorderOneDNNHandler* reorder_handler,
phi::DenseTensor* grad_tensor, phi::DenseTensor* grad_tensor,
const std::shared_ptr<dnnl::memory>& src_memory, const std::shared_ptr<dnnl::memory>& src_memory,
const std::shared_ptr<dnnl::memory>& dst_memory, const std::shared_ptr<dnnl::memory>& dst_memory,
const std::vector<float>& scales) { const dnnl::memory& scales_memory) {
dnnl::primitive_attr reorder_attr; dnnl::primitive_attr reorder_attr;
reorder_attr.set_output_scales(0, scales); reorder_attr.set_scales_mask(DNNL_ARG_DST, 0);
auto reorder_p = auto reorder_p =
reorder_handler->AcquireReorder(dst_memory, src_memory, reorder_attr); reorder_handler->AcquireReorder(dst_memory, src_memory, reorder_attr);
reorder_p->execute( std::unordered_map<int, dnnl::memory> args = {
OneDNNContext::tls().get_stream(), *src_memory, *dst_memory); {DNNL_ARG_SRC, *src_memory},
{DNNL_ARG_DST, *dst_memory},
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, scales_memory}};
auto& astream = OneDNNContext::tls().get_stream();
reorder_p->execute(astream, args);
} }
template <typename T> template <typename T>
...@@ -73,7 +77,7 @@ inline void BroadcastReduction(const Place& place, ...@@ -73,7 +77,7 @@ inline void BroadcastReduction(const Place& place,
// Broadcasting // Broadcasting
if (is_sub) { if (is_sub) {
dnnl::post_ops po; dnnl::post_ops po;
po.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, scales[0], 0); po.append_eltwise(dnnl::algorithm::eltwise_linear, scales[0], 0);
broadcast_reduction_attr.set_post_ops(po); broadcast_reduction_attr.set_post_ops(po);
} }
...@@ -126,9 +130,9 @@ void ElementwiseGradKernel(const OneDNNContext& dev_ctx, ...@@ -126,9 +130,9 @@ void ElementwiseGradKernel(const OneDNNContext& dev_ctx,
swap_x_y = true; swap_x_y = true;
} }
std::vector<float> scales{1.0}; float scale{1.0};
if (swap_x_y) { if (swap_x_y) {
scales[0] = (BINARY_OP == dnnl::algorithm::binary_add) ? 1 : -1; scale = (BINARY_OP == dnnl::algorithm::binary_add) ? 1 : -1;
} }
auto tz = phi::vectorize<int64_t>(dout.dims()); auto tz = phi::vectorize<int64_t>(dout.dims());
...@@ -143,6 +147,11 @@ void ElementwiseGradKernel(const OneDNNContext& dev_ctx, ...@@ -143,6 +147,11 @@ void ElementwiseGradKernel(const OneDNNContext& dev_ctx,
std::shared_ptr<dnnl::memory> broadcast_src_memory = reorder_src_memory; std::shared_ptr<dnnl::memory> broadcast_src_memory = reorder_src_memory;
auto& astream = OneDNNContext::tls().get_stream(); auto& astream = OneDNNContext::tls().get_stream();
auto scales_md = dnnl::memory::desc(
{1}, dnnl::memory::data_type::f32, dnnl::memory::format_tag::x);
auto scales_mem = dnnl::memory(scales_md, onednn_engine);
auto scale_memory_buf = static_cast<float*>(scales_mem.get_data_handle());
*scale_memory_buf = scale;
if (dx) { if (dx) {
// elementwise_add & elementwise_sub // elementwise_add & elementwise_sub
if (BINARY_OP == dnnl::algorithm::binary_add || if (BINARY_OP == dnnl::algorithm::binary_add ||
...@@ -151,7 +160,7 @@ void ElementwiseGradKernel(const OneDNNContext& dev_ctx, ...@@ -151,7 +160,7 @@ void ElementwiseGradKernel(const OneDNNContext& dev_ctx,
dst_memory = reorder_handler.AcquireDstMemory( dst_memory = reorder_handler.AcquireDstMemory(
dx, dout.mem_desc(), dev_ctx.GetPlace()); dx, dout.mem_desc(), dev_ctx.GetPlace());
AddSubNonBroadcast( AddSubNonBroadcast(
&reorder_handler, dx, reorder_src_memory, dst_memory, scales); &reorder_handler, dx, reorder_src_memory, dst_memory, scales_mem);
} }
} else { // elementwise_mul & elementwise_div } else { // elementwise_mul & elementwise_div
funcs::BinaryOneDNNHandler<T> binary_handler(BINARY_OP, funcs::BinaryOneDNNHandler<T> binary_handler(BINARY_OP,
...@@ -176,7 +185,9 @@ void ElementwiseGradKernel(const OneDNNContext& dev_ctx, ...@@ -176,7 +185,9 @@ void ElementwiseGradKernel(const OneDNNContext& dev_ctx,
const std::unordered_map<int, dnnl::memory> args = { const std::unordered_map<int, dnnl::memory> args = {
{DNNL_ARG_SRC_0, *src_dout_memory}, {DNNL_ARG_SRC_0, *src_dout_memory},
{DNNL_ARG_SRC_1, *src_y_memory}, {DNNL_ARG_SRC_1, *src_y_memory},
{DNNL_ARG_DST, *dst_memory}}; {DNNL_ARG_DST, *dst_memory},
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC_0, scales_mem},
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC_1, scales_mem}};
binary_prim->execute(astream, args); binary_prim->execute(astream, args);
} }
...@@ -189,7 +200,7 @@ void ElementwiseGradKernel(const OneDNNContext& dev_ctx, ...@@ -189,7 +200,7 @@ void ElementwiseGradKernel(const OneDNNContext& dev_ctx,
&dout, &dout,
broadcast_src_memory, broadcast_src_memory,
dst_memory, dst_memory,
scales, {scale},
BINARY_OP == dnnl::algorithm::binary_sub); BINARY_OP == dnnl::algorithm::binary_sub);
} else { } else {
dx->set_mem_desc(dst_memory->get_desc()); dx->set_mem_desc(dst_memory->get_desc());
...@@ -204,7 +215,7 @@ void ElementwiseGradKernel(const OneDNNContext& dev_ctx, ...@@ -204,7 +215,7 @@ void ElementwiseGradKernel(const OneDNNContext& dev_ctx,
dst_memory = reorder_handler.AcquireDstMemory( dst_memory = reorder_handler.AcquireDstMemory(
dy, dout.mem_desc(), dev_ctx.GetPlace()); dy, dout.mem_desc(), dev_ctx.GetPlace());
AddSubNonBroadcast( AddSubNonBroadcast(
&reorder_handler, dy, reorder_src_memory, dst_memory, scales); &reorder_handler, dy, reorder_src_memory, dst_memory, scales_mem);
} }
} else { // elementwise_mul & elementwise_div } else { // elementwise_mul & elementwise_div
std::unordered_map<int, dnnl::memory> args; std::unordered_map<int, dnnl::memory> args;
...@@ -273,7 +284,9 @@ void ElementwiseGradKernel(const OneDNNContext& dev_ctx, ...@@ -273,7 +284,9 @@ void ElementwiseGradKernel(const OneDNNContext& dev_ctx,
binary_prim = binary_handler.AcquireForwardPrimitive(); binary_prim = binary_handler.AcquireForwardPrimitive();
args = {{DNNL_ARG_SRC_0, *src_0_memory}, args = {{DNNL_ARG_SRC_0, *src_0_memory},
{DNNL_ARG_SRC_1, *src_1_memory}, {DNNL_ARG_SRC_1, *src_1_memory},
{DNNL_ARG_DST, *dst_dy_memory}}; {DNNL_ARG_DST, *dst_dy_memory},
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC_0, scales_mem},
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC_1, scales_mem}};
if (BINARY_OP == dnnl::algorithm::binary_div) if (BINARY_OP == dnnl::algorithm::binary_div)
args.insert({DNNL_ARG_ATTR_MULTIPLE_POST_OP(0) | DNNL_ARG_SRC_1, args.insert({DNNL_ARG_ATTR_MULTIPLE_POST_OP(0) | DNNL_ARG_SRC_1,
...@@ -292,7 +305,7 @@ void ElementwiseGradKernel(const OneDNNContext& dev_ctx, ...@@ -292,7 +305,7 @@ void ElementwiseGradKernel(const OneDNNContext& dev_ctx,
&dout, &dout,
broadcast_src_memory, broadcast_src_memory,
dst_memory, dst_memory,
scales, {scale},
BINARY_OP == dnnl::algorithm::binary_sub); BINARY_OP == dnnl::algorithm::binary_sub);
} else { } else {
dy->set_mem_desc(dst_memory->get_desc()); dy->set_mem_desc(dst_memory->get_desc());
......
...@@ -110,10 +110,19 @@ void ElementwiseKernel(const OneDNNContext& dev_ctx, ...@@ -110,10 +110,19 @@ void ElementwiseKernel(const OneDNNContext& dev_ctx,
auto& astream = OneDNNContext::tls().get_stream(); auto& astream = OneDNNContext::tls().get_stream();
const std::unordered_map<int, dnnl::memory> args = { std::unordered_map<int, dnnl::memory> args = {{DNNL_ARG_SRC_0, *src_x_memory},
{DNNL_ARG_SRC_0, *src_x_memory}, {DNNL_ARG_SRC_1, *src_y_memory},
{DNNL_ARG_SRC_1, *src_y_memory}, {DNNL_ARG_DST, *dst_memory}};
{DNNL_ARG_DST, *dst_memory}};
if (handler.Has_SRC_0_Scale()) {
args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC_0,
handler.Get_SRC_0_Scale_Memory()});
}
if (handler.Has_SRC_1_Scale()) {
args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC_1,
handler.Get_SRC_1_Scale_Memory()});
}
binary_prim->execute(astream, args); binary_prim->execute(astream, args);
astream.wait(); astream.wait();
...@@ -121,7 +130,7 @@ void ElementwiseKernel(const OneDNNContext& dev_ctx, ...@@ -121,7 +130,7 @@ void ElementwiseKernel(const OneDNNContext& dev_ctx,
auto out_md = dst_memory->get_desc(); auto out_md = dst_memory->get_desc();
if (handler.use_broadcasting_hack) { if (handler.use_broadcasting_hack) {
auto dims = out_md.dims(); auto dims = out_md.get_dims();
dims.insert(dims.begin(), non_const_x->dims()[0]); dims.insert(dims.begin(), non_const_x->dims()[0]);
dims[1] /= dims[0]; dims[1] /= dims[0];
out_md = out_md.reshape(dims); out_md = out_md.reshape(dims);
......
...@@ -65,7 +65,9 @@ void ExpandKernel(const Context& dev_ctx, ...@@ -65,7 +65,9 @@ void ExpandKernel(const Context& dev_ctx,
const std::unordered_map<int, dnnl::memory> args = { const std::unordered_map<int, dnnl::memory> args = {
{DNNL_ARG_SRC_0, *dst_memory_p}, {DNNL_ARG_SRC_0, *dst_memory_p},
{DNNL_ARG_SRC_1, *src_memory_p}, {DNNL_ARG_SRC_1, *src_memory_p},
{DNNL_ARG_DST, *dst_memory_p}}; {DNNL_ARG_DST, *dst_memory_p},
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC_0, handler.Get_Scale_Memory(0.0f)},
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC_1, handler.Get_Scale_Memory(1.0f)}};
auto& astream = OneDNNContext::tls().get_stream(); auto& astream = OneDNNContext::tls().get_stream();
binary_p->execute(astream, args); binary_p->execute(astream, args);
......
...@@ -34,18 +34,22 @@ class FillConstantOneDNNHandler ...@@ -34,18 +34,22 @@ class FillConstantOneDNNHandler
dnnl::memory::format_tag::ab); dnnl::memory::format_tag::ab);
dnnl::primitive_attr attrs; dnnl::primitive_attr attrs;
attrs.set_scales(DNNL_ARG_SRC_0, /* mask = */ 0, {0.0f}); attrs.set_scales_mask(DNNL_ARG_SRC_0, /* mask = */ 0);
src1_md_ = dnnl::memory::desc({1, sizeof(T)},
OneDNNGetDataType<uint8_t>(),
dnnl::memory::format_tag::ab);
this->AcquireForwardPrimitiveDescriptor( this->AcquireForwardPrimitiveDescriptor(
attrs, dnnl::algorithm::binary_add, src0_md, src1_md, src0_md); dnnl::algorithm::binary_add, src0_md, src1_md_, src0_md, attrs);
} }
static const dnnl::memory::desc src1_md; const dnnl::memory::desc& get_src1_md() const { return src1_md_; }
private:
dnnl::memory::desc src1_md_;
}; };
template <typename T>
const dnnl::memory::desc FillConstantOneDNNHandler<T>::src1_md(
{1, sizeof(T)}, OneDNNGetDataType<uint8_t>(), dnnl::memory::format_tag::ab);
} // namespace funcs } // namespace funcs
template <typename T, typename Context> template <typename T, typename Context>
...@@ -63,7 +67,7 @@ void FullKernel(const Context& dev_ctx, ...@@ -63,7 +67,7 @@ void FullKernel(const Context& dev_ctx,
out, onednn_engine, dev_ctx.GetPlace()); out, onednn_engine, dev_ctx.GetPlace());
dnnl::memory constant_value_memory = dnnl::memory constant_value_memory =
dnnl::memory(funcs::FillConstantOneDNNHandler<T>::src1_md, dnnl::memory(handler.get_src1_md(),
onednn_engine, onednn_engine,
reinterpret_cast<uint8_t*>(&fill_value)); reinterpret_cast<uint8_t*>(&fill_value));
...@@ -71,10 +75,19 @@ void FullKernel(const Context& dev_ctx, ...@@ -71,10 +75,19 @@ void FullKernel(const Context& dev_ctx,
auto fill_constant_p = handler.AcquireForwardPrimitive(); auto fill_constant_p = handler.AcquireForwardPrimitive();
auto& astream = OneDNNContext::tls().get_stream(); auto& astream = OneDNNContext::tls().get_stream();
fill_constant_p->execute(astream,
{{DNNL_ARG_SRC_0, *src0_memory_p}, std::vector<float> zero(1, 0);
{DNNL_ARG_SRC_1, constant_value_memory}, auto scales_md = dnnl::memory::desc(
{DNNL_ARG_DST, *src0_memory_p}}); {1}, dnnl::memory::data_type::f32, dnnl::memory::format_tag::x);
auto scales = dnnl::memory(scales_md, onednn_engine, zero.data());
std::unordered_map<int, dnnl::memory> args;
args.insert({DNNL_ARG_SRC_0, *src0_memory_p});
args.insert({DNNL_ARG_SRC_1, constant_value_memory});
args.insert({DNNL_ARG_DST, *src0_memory_p});
args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC_0, scales});
fill_constant_p->execute(astream, args);
astream.wait(); astream.wait();
// src0_memory_p's md was just to allow the usage of a binary // src0_memory_p's md was just to allow the usage of a binary
......
...@@ -24,18 +24,21 @@ namespace phi { ...@@ -24,18 +24,21 @@ namespace phi {
template <typename T> template <typename T>
class LogSoftmaxOneDNNHandler class LogSoftmaxOneDNNHandler
: public funcs::OneDNNHandlerNoCachingT<T, dnnl::logsoftmax_forward> { : public funcs::OneDNNHandlerNoCachingT<T, dnnl::softmax_forward> {
public: public:
LogSoftmaxOneDNNHandler(const dnnl::engine onednn_engine, LogSoftmaxOneDNNHandler(const dnnl::engine onednn_engine,
Place cpu_place, Place cpu_place,
const DenseTensor& x, const DenseTensor& x,
const int axis) const int axis)
: funcs::OneDNNHandlerNoCachingT<T, dnnl::logsoftmax_forward>( : funcs::OneDNNHandlerNoCachingT<T, dnnl::softmax_forward>(onednn_engine,
onednn_engine, cpu_place) { cpu_place) {
const int rank = x.dims().size() != 0 ? x.dims().size() : 1; const int rank = x.dims().size() != 0 ? x.dims().size() : 1;
const int canonical_axis = funcs::CanonicalAxis(axis, rank); const int canonical_axis = funcs::CanonicalAxis(axis, rank);
this->AcquireForwardPrimitiveDescriptor( this->AcquireForwardPrimitiveDescriptor(dnnl::prop_kind::forward_inference,
dnnl::prop_kind::forward_inference, x.mem_desc(), canonical_axis); dnnl::algorithm::softmax_log,
x.mem_desc(),
x.mem_desc(),
canonical_axis);
} }
}; };
......
...@@ -211,10 +211,14 @@ class MulPrimitiveFactory { ...@@ -211,10 +211,14 @@ class MulPrimitiveFactory {
const std::vector<float> &scale) { const std::vector<float> &scale) {
auto mask = scale.size() > 1 ? 1 : 0; auto mask = scale.size() > 1 ? 1 : 0;
dnnl::primitive_attr attr; dnnl::primitive_attr attr;
attr.set_output_scales(mask, scale); attr.set_scales_mask(DNNL_ARG_SRC, mask);
auto src_mem = memory(src_desc, engine_, src_data); auto src_mem = memory(src_desc, engine_, src_data);
auto dst_mem = memory(dst_desc, engine_); auto dst_mem = memory(dst_desc, engine_);
auto scales_md = dnnl::memory::desc(
{1}, dnnl::memory::data_type::f32, dnnl::memory::format_tag::x);
auto scales_mem = dnnl::memory(
scales_md, engine_, funcs::to_void_cast<float>(scale.data()));
auto reorder_pd = dnnl::reorder::primitive_desc(src_mem, dst_mem, attr); auto reorder_pd = dnnl::reorder::primitive_desc(src_mem, dst_mem, attr);
...@@ -222,7 +226,11 @@ class MulPrimitiveFactory { ...@@ -222,7 +226,11 @@ class MulPrimitiveFactory {
auto &astream = OneDNNContext::tls().get_stream(); auto &astream = OneDNNContext::tls().get_stream();
{ {
reorder.execute(astream, src_mem, dst_mem); std::unordered_map<int, dnnl::memory> reorder_args;
reorder_args.insert({DNNL_ARG_SRC, src_mem});
reorder_args.insert({DNNL_ARG_DST, dst_mem});
reorder_args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, scales_mem});
reorder.execute(astream, reorder_args);
astream.wait(); astream.wait();
} }
...@@ -230,9 +238,7 @@ class MulPrimitiveFactory { ...@@ -230,9 +238,7 @@ class MulPrimitiveFactory {
} }
memory QuantInputY(memory input_y, const std::vector<float> &scale_y) { memory QuantInputY(memory input_y, const std::vector<float> &scale_y) {
const auto &dims = input_y.get_desc().data.dims; auto y_dims = input_y.get_desc().get_dims();
auto ndims = input_y.get_desc().data.ndims;
auto y_dims = std::vector<int64_t>(dims, dims + ndims);
auto user_y_desc = auto user_y_desc =
CreateMemDescriptor<YT>(y_dims, funcs::OneDNNMemoryFormat::oi); CreateMemDescriptor<YT>(y_dims, funcs::OneDNNMemoryFormat::oi);
...@@ -272,7 +278,13 @@ class MulPrimitiveFactory { ...@@ -272,7 +278,13 @@ class MulPrimitiveFactory {
scale_out_data / (scale_x_data * scale_y_data[i]); scale_out_data / (scale_x_data * scale_y_data[i]);
} }
int mul_mask = is_multi_channel ? 1 : 0; int mul_mask = is_multi_channel ? 1 : 0;
mul_attr.set_output_scales(mul_mask, output_shift_scale); mul_attr.set_scales_mask(DNNL_ARG_WEIGHTS, mul_mask);
auto scales_md = dnnl::memory::desc(
{count}, dnnl::memory::data_type::f32, dnnl::memory::format_tag::x);
scales_mem_ = dnnl::memory(scales_md, engine_);
auto mem_buf = scales_mem_.get_data_handle();
memcpy(mem_buf, output_shift_scale.data(), count * sizeof(float));
return mul_attr; return mul_attr;
} }
...@@ -286,19 +298,17 @@ class MulPrimitiveFactory { ...@@ -286,19 +298,17 @@ class MulPrimitiveFactory {
const auto y_desc = y_memory.get_desc(); const auto y_desc = y_memory.get_desc();
inner_product_forward::primitive_desc mul_prim_desc; inner_product_forward::primitive_desc mul_prim_desc;
const auto &mul_desc = inner_product_forward::desc(
prop_kind::forward, x_desc, y_desc, dst_desc);
if (is_int8_) { if (is_int8_) {
bool force_fp32_output = bool force_fp32_output =
dev_ctx.HasDnnAttr("force_fp32_output") dev_ctx.HasDnnAttr("force_fp32_output")
? PADDLE_GET_CONST(bool, dev_ctx.GetDnnAttr("force_fp32_output")) ? PADDLE_GET_CONST(bool, dev_ctx.GetDnnAttr("force_fp32_output"))
: false; : false;
auto mul_attr = CreateMulAttr(dev_ctx, force_fp32_output); auto mul_attr = CreateMulAttr(dev_ctx, force_fp32_output);
mul_prim_desc = mul_prim_desc = inner_product_forward::primitive_desc(
inner_product_forward::primitive_desc(mul_desc, mul_attr, engine_); engine_, prop_kind::forward, x_desc, y_desc, dst_desc, mul_attr);
} else { } else {
mul_prim_desc = inner_product_forward::primitive_desc(mul_desc, engine_); mul_prim_desc = inner_product_forward::primitive_desc(
engine_, prop_kind::forward, x_desc, y_desc, dst_desc);
} }
output_ = CreateDstMemory(mul_prim_desc, dev_ctx, output); output_ = CreateDstMemory(mul_prim_desc, dev_ctx, output);
...@@ -308,10 +318,12 @@ class MulPrimitiveFactory { ...@@ -308,10 +318,12 @@ class MulPrimitiveFactory {
void Execute() { void Execute() {
auto &astream = OneDNNContext::tls().get_stream(); auto &astream = OneDNNContext::tls().get_stream();
(*mul_).execute(astream, (*mul_).execute(astream,
{{DNNL_ARG_SRC, *x_input_}, {{DNNL_ARG_SRC, *x_input_},
{DNNL_ARG_WEIGHTS, *y_input_}, {DNNL_ARG_WEIGHTS, *y_input_},
{DNNL_ARG_DST, *output_}}); {DNNL_ARG_DST, *output_},
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, scales_mem_}});
astream.wait(); astream.wait();
} }
...@@ -428,6 +440,7 @@ class MulPrimitiveFactory { ...@@ -428,6 +440,7 @@ class MulPrimitiveFactory {
paddle::optional<memory> output_; paddle::optional<memory> output_;
paddle::optional<inner_product_forward> mul_; paddle::optional<inner_product_forward> mul_;
static constexpr bool is_int8_ = funcs::is_int8<XT>(); static constexpr bool is_int8_ = funcs::is_int8<XT>();
dnnl::memory scales_mem_;
}; };
/* OT: output data type */ /* OT: output data type */
...@@ -511,9 +524,12 @@ void MatmulWithFlattenKernelINT8(const Context &dev_ctx, ...@@ -511,9 +524,12 @@ void MatmulWithFlattenKernelINT8(const Context &dev_ctx,
out->Resize(out_dims); out->Resize(out_dims);
} }
auto in_md = memory::desc(*dnnl_primitive_desc_query_md( auto in_md = dnnl_primitive_desc_query_md(
mul.get_primitive_desc(), dnnl_query_dst_md, 0)); mul.get_primitive_desc(), dnnl_query_dst_md, 0);
out->set_mem_desc(in_md.reshape(vectorize<int64_t>(out->dims()))); dnnl_memory_desc_t cloned_in_md = nullptr;
dnnl_memory_desc_clone(&cloned_in_md, in_md);
out->set_mem_desc(
memory::desc(cloned_in_md).reshape(vectorize<int64_t>(out->dims())));
} }
template <typename T, typename Context> template <typename T, typename Context>
......
...@@ -143,7 +143,11 @@ void ReduceGradKernel(const Context& dev_ctx, ...@@ -143,7 +143,11 @@ void ReduceGradKernel(const Context& dev_ctx,
const std::unordered_map<int, dnnl::memory> args = { const std::unordered_map<int, dnnl::memory> args = {
{DNNL_ARG_SRC_0, *dst_memory_p}, {DNNL_ARG_SRC_0, *dst_memory_p},
{DNNL_ARG_SRC_1, *src_memory_p}, {DNNL_ARG_SRC_1, *src_memory_p},
{DNNL_ARG_DST, *dst_memory_p}}; {DNNL_ARG_DST, *dst_memory_p},
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC_0,
handler.Get_Scale_Memory(scale_x)},
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC_1,
handler.Get_Scale_Memory(scale_y)}};
auto& astream = OneDNNContext::tls().get_stream(); auto& astream = OneDNNContext::tls().get_stream();
binary_prim->execute(astream, args); binary_prim->execute(astream, args);
......
...@@ -120,7 +120,7 @@ void ExecuteReshape(const Context& dev_ctx, ...@@ -120,7 +120,7 @@ void ExecuteReshape(const Context& dev_ctx,
const DDim& x_dims, const DDim& x_dims,
DenseTensor* out) { DenseTensor* out) {
auto out_dims = ValidateShape(shape.GetData(), x_dims); auto out_dims = ValidateShape(shape.GetData(), x_dims);
auto x_vec_dims = x.mem_desc().dims(); auto x_vec_dims = x.mem_desc().get_dims();
funcs::ReorderOneDNNHandler reorder_handler( funcs::ReorderOneDNNHandler reorder_handler(
x_vec_dims, x_vec_dims,
......
...@@ -23,12 +23,12 @@ const std::vector<int64_t> get_slice_strides( ...@@ -23,12 +23,12 @@ const std::vector<int64_t> get_slice_strides(
const std::vector<int64_t>& out_vec_dims, const std::vector<int64_t>& out_vec_dims,
const dnnl::memory::desc& full_md, const dnnl::memory::desc& full_md,
int axis) { int axis) {
auto strides = full_md.data.format_desc.blocking.strides; auto strides = full_md.get_strides();
auto ndims = full_md.data.ndims; auto ndims = full_md.get_dims().size();
auto full_dims = full_md.data.dims; auto full_dims = full_md.get_dims();
auto splitted_stride = strides[axis]; auto splitted_stride = strides[axis];
std::vector<int64_t> slice_strides(ndims, splitted_stride); std::vector<int64_t> slice_strides(ndims, splitted_stride);
for (int16_t i = 0; i < ndims; ++i) { for (size_t i = 0; i < ndims; ++i) {
slice_strides[i] = strides[i] > splitted_stride slice_strides[i] = strides[i] > splitted_stride
? (strides[i] / full_dims[axis]) * out_vec_dims[axis] ? (strides[i] / full_dims[axis]) * out_vec_dims[axis]
: strides[i]; : strides[i];
......
...@@ -73,16 +73,6 @@ class StackOneDNNHandler : public OneDNNHandlerNoCachingT<T, dnnl::concat> { ...@@ -73,16 +73,6 @@ class StackOneDNNHandler : public OneDNNHandlerNoCachingT<T, dnnl::concat> {
this->AcquireForwardPrimitiveDescriptor(dst_md, stack_axis, srcs_md); this->AcquireForwardPrimitiveDescriptor(dst_md, stack_axis, srcs_md);
} }
// concat oneDNN prim is not having .desc attribute so we cannot use default
// AcquireForwardPrimitiveDescriptor
void AcquireForwardPrimitiveDescriptor(
const memory::desc& dst_md,
const int stack_axis,
const std::vector<memory::desc>& srcs_md) {
this->fwd_pd_.reset(new dnnl::concat::primitive_desc(
dst_md, stack_axis, srcs_md, this->engine_));
}
std::shared_ptr<dnnl::memory> AcquireSrcMemory(const DenseTensor& input, std::shared_ptr<dnnl::memory> AcquireSrcMemory(const DenseTensor& input,
int i) { int i) {
const T* input_data = input.data<T>(); const T* input_data = input.data<T>();
......
...@@ -73,8 +73,8 @@ void TransposeKernel(const Context& dev_ctx, ...@@ -73,8 +73,8 @@ void TransposeKernel(const Context& dev_ctx,
x.mem_desc(), funcs::to_void_cast(x.data<T>())); x.mem_desc(), funcs::to_void_cast(x.data<T>()));
auto fake_strides = funcs::FakeTransposeStrides(x_vec_dims, axis); auto fake_strides = funcs::FakeTransposeStrides(x_vec_dims, axis);
auto dst_md = auto dst_md = dnnl::memory::desc(
dnnl::memory::desc(x_vec_dims, x.mem_desc().data_type(), fake_strides); x_vec_dims, x.mem_desc().get_data_type(), fake_strides);
auto reorder_dst_memory_p = auto reorder_dst_memory_p =
reorder_handler.AcquireDstMemory(out, dst_md, dev_ctx.GetPlace()); reorder_handler.AcquireDstMemory(out, dst_md, dev_ctx.GetPlace());
auto reorder_p = reorder_handler.AcquireReorder(reorder_dst_memory_p, auto reorder_p = reorder_handler.AcquireReorder(reorder_dst_memory_p,
......
...@@ -47,8 +47,6 @@ env_dict={ ...@@ -47,8 +47,6 @@ env_dict={
'PSLIB_VERSION_PY':'@PSLIB_VERSION_PY@', 'PSLIB_VERSION_PY':'@PSLIB_VERSION_PY@',
'WITH_MKLDNN':'@WITH_MKLDNN@', 'WITH_MKLDNN':'@WITH_MKLDNN@',
'MKLDNN_SHARED_LIB':'@MKLDNN_SHARED_LIB@', 'MKLDNN_SHARED_LIB':'@MKLDNN_SHARED_LIB@',
'MKLDNN_SHARED_LIB_1':'@MKLDNN_SHARED_LIB_1@',
'MKLDNN_SHARED_LIB_2':'@MKLDNN_SHARED_LIB_2@',
'MKLDNN_INSTALL_DIR':'@MKLDNN_INSTALL_DIR@', 'MKLDNN_INSTALL_DIR':'@MKLDNN_INSTALL_DIR@',
'WITH_ONNXRUNTIME':'@WITH_ONNXRUNTIME@', 'WITH_ONNXRUNTIME':'@WITH_ONNXRUNTIME@',
'ONNXRUNTIME_SHARED_LIB':'@ONNXRUNTIME_SHARED_LIB@', 'ONNXRUNTIME_SHARED_LIB':'@ONNXRUNTIME_SHARED_LIB@',
......
...@@ -654,9 +654,7 @@ if '${WITH_MKLDNN}' == 'ON': ...@@ -654,9 +654,7 @@ if '${WITH_MKLDNN}' == 'ON':
raise Exception("patch libdnnl.so failed, command: %s" % command) raise Exception("patch libdnnl.so failed, command: %s" % command)
shutil.copy('${MKLDNN_SHARED_LIB}', libs_path) shutil.copy('${MKLDNN_SHARED_LIB}', libs_path)
if os.name != 'nt': if os.name != 'nt':
shutil.copy('${MKLDNN_SHARED_LIB_1}', libs_path) package_data['paddle.libs']+=['libdnnl.so.3']
shutil.copy('${MKLDNN_SHARED_LIB_2}', libs_path)
package_data['paddle.libs']+=['libmkldnn.so.0', 'libdnnl.so.1', 'libdnnl.so.2']
else: else:
package_data['paddle.libs']+=['mkldnn.dll'] package_data['paddle.libs']+=['mkldnn.dll']
......
...@@ -137,7 +137,7 @@ if '${WITH_MKL}' == 'ON': ...@@ -137,7 +137,7 @@ if '${WITH_MKL}' == 'ON':
cinnlibs.append('${MKLML_IOMP_LIB}') cinnlibs.append('${MKLML_IOMP_LIB}')
if '${WITH_MKLDNN}' == 'ON': if '${WITH_MKLDNN}' == 'ON':
cinnlibs.append('${MKLDNN_SHARED_LIB_2}') cinnlibs.append('${MKLDNN_SHARED_LIB}')
if '${WITH_GPU}' == 'ON': if '${WITH_GPU}' == 'ON':
cinnlibs.append('${CMAKE_BINARY_DIR}/dist/cinn/include/paddle/cinn/runtime/cuda/cinn_cuda_runtime_source.cuh') cinnlibs.append('${CMAKE_BINARY_DIR}/dist/cinn/include/paddle/cinn/runtime/cuda/cinn_cuda_runtime_source.cuh')
......
...@@ -1088,13 +1088,7 @@ def get_package_data_and_package_dir(): ...@@ -1088,13 +1088,7 @@ def get_package_data_and_package_dir():
) )
shutil.copy(env_dict.get("MKLDNN_SHARED_LIB"), libs_path) shutil.copy(env_dict.get("MKLDNN_SHARED_LIB"), libs_path)
if os.name != 'nt': if os.name != 'nt':
shutil.copy(env_dict.get("MKLDNN_SHARED_LIB_1"), libs_path) package_data['paddle.libs'] += ['libdnnl.so.3']
shutil.copy(env_dict.get("MKLDNN_SHARED_LIB_2"), libs_path)
package_data['paddle.libs'] += [
'libmkldnn.so.0',
'libdnnl.so.1',
'libdnnl.so.2',
]
else: else:
package_data['paddle.libs'] += ['mkldnn.dll'] package_data['paddle.libs'] += ['mkldnn.dll']
......
...@@ -183,7 +183,7 @@ if(WITH_MKL) ...@@ -183,7 +183,7 @@ if(WITH_MKL)
if(WIN32) if(WIN32)
set(MKLDNN_LIB ${MKLDNN_PATH}/lib/mkldnn.lib) set(MKLDNN_LIB ${MKLDNN_PATH}/lib/mkldnn.lib)
else() else()
set(MKLDNN_LIB ${MKLDNN_PATH}/lib/libmkldnn.so.0) set(MKLDNN_LIB ${MKLDNN_PATH}/lib/libdnnl.so.3)
endif() endif()
endif() endif()
else() else()
......
...@@ -41,6 +41,7 @@ class TestConv2DInt8Op(TestConv2DOp): ...@@ -41,6 +41,7 @@ class TestConv2DInt8Op(TestConv2DOp):
self.mkldnn_data_type = "int8" self.mkldnn_data_type = "int8"
self.weighttype = np.float32 self.weighttype = np.float32
self.use_mkldnn = True self.use_mkldnn = True
self.init_weight_quantization_type()
self.init_group() self.init_group()
self.init_dilation() self.init_dilation()
self.init_test_case() self.init_test_case()
...@@ -181,8 +182,9 @@ class TestConv2DInt8Op(TestConv2DOp): ...@@ -181,8 +182,9 @@ class TestConv2DInt8Op(TestConv2DOp):
def test_check_output(self): def test_check_output(self):
# TODO(wangzhongpu): support mkldnn op in dygraph mode # TODO(wangzhongpu): support mkldnn op in dygraph mode
# the atol for integer tests should be 1
self.check_output_with_place( self.check_output_with_place(
core.CPUPlace(), atol=0, check_dygraph=False core.CPUPlace(), atol=1, check_dygraph=False
) )
def test_check_grad(self): def test_check_grad(self):
...@@ -202,9 +204,16 @@ class TestConv2DInt8Op(TestConv2DOp): ...@@ -202,9 +204,16 @@ class TestConv2DInt8Op(TestConv2DOp):
self.filter_size = [2, f_c, 3, 3] self.filter_size = [2, f_c, 3, 3]
self.scale_in = 0.95 self.scale_in = 0.95
self.scale_out = 0.5 self.scale_out = 0.5
self.scale_weights = [10.0] self.scale_weights = (
[10.0] * self.filter_size[0]
if self.per_channel_quantize_weight
else [10.0]
)
self.scale_in_eltwise = 0.6 self.scale_in_eltwise = 0.6
def init_weight_quantization_type(self):
self.per_channel_quantize_weight = False
def init_data_type(self): def init_data_type(self):
self.srctype = np.uint8 self.srctype = np.uint8
self.dsttype = np.int8 self.dsttype = np.int8
...@@ -239,15 +248,15 @@ class TestConv2D(TestConv2DInt8Op): ...@@ -239,15 +248,15 @@ class TestConv2D(TestConv2DInt8Op):
class TestWithHardSwish(TestConv2D): class TestWithHardSwish(TestConv2D):
def init_fuse_activation(self): def init_fuse_activation(self):
self.fuse_activation = "hard_swish" self.fuse_activation = "hard_swish"
self.fuse_alpha = 0 self.fuse_alpha = 1.0 / 6.0
self.fuse_beta = 0 self.fuse_beta = 1.0 / 2.0
class TestWithRelu6(TestConv2D): class TestWithRelu6(TestConv2D):
def init_fuse_activation(self): def init_fuse_activation(self):
self.fuse_activation = "relu6" self.fuse_activation = "relu6"
self.fuse_alpha = 6 self.fuse_alpha = 0
self.fuse_beta = 0 self.fuse_beta = 6
class TestWithSwish(TestConv2D): class TestWithSwish(TestConv2D):
...@@ -350,6 +359,34 @@ def init_data_type_with_fusion(self, input_dt, fuse_activation, fuse_residual): ...@@ -350,6 +359,34 @@ def init_data_type_with_fusion(self, input_dt, fuse_activation, fuse_residual):
self.fuse_residual = fuse_residual self.fuse_residual = fuse_residual
class TestDepthwiseConv2d(TestConv2D):
def init_test_case(self):
self.pad = [1, 1]
self.stride = [1, 1]
self.input_size = [1, 32, 112, 112]
self.input_residual_size = [1, 32, 112, 112]
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [32, f_c, 3, 3]
self.scale_in = 0.95
self.scale_out = 0.5
self.scale_weights = (
[10.0] * self.filter_size[0]
if self.per_channel_quantize_weight
else [10.0]
)
self.scale_in_eltwise = 0.8
def init_group(self):
self.groups = 32
def init_weight_quantization_type(self):
self.per_channel_quantize_weight = True
def init_fuse_residual(self):
self.fuse_residual = False
def create_test_int8_class(parent): def create_test_int8_class(parent):
# --------------------test conv2d s8 in and u8 out-------------------- # --------------------test conv2d s8 in and u8 out--------------------
class TestS8U8Case(parent): class TestS8U8Case(parent):
......
...@@ -92,7 +92,7 @@ class TestConv2DMKLDNNOp(TestConv2DOp): ...@@ -92,7 +92,7 @@ class TestConv2DMKLDNNOp(TestConv2DOp):
output = np.maximum(output, 0).astype(self.dsttype) output = np.maximum(output, 0).astype(self.dsttype)
if self.fuse_activation == "relu6": if self.fuse_activation == "relu6":
output = np.minimum(np.maximum(output, 0), self.fuse_alpha).astype( output = np.minimum(np.maximum(output, 0), self.fuse_beta).astype(
self.dsttype self.dsttype
) )
if ( if (
...@@ -120,7 +120,7 @@ class TestWithbreluFusion(TestConv2DMKLDNNOp): ...@@ -120,7 +120,7 @@ class TestWithbreluFusion(TestConv2DMKLDNNOp):
def init_test_case(self): def init_test_case(self):
TestConv2DMKLDNNOp.init_test_case(self) TestConv2DMKLDNNOp.init_test_case(self)
self.fuse_activation = "relu6" self.fuse_activation = "relu6"
self.fuse_alpha = 6.0 self.fuse_beta = 6.0
self.dsttype = np.float32 self.dsttype = np.float32
......
...@@ -24,15 +24,21 @@ class TestFCINT8OneDNNOp(OpTest): ...@@ -24,15 +24,21 @@ class TestFCINT8OneDNNOp(OpTest):
self.op_type = "fc" self.op_type = "fc"
self._cpu_only = True self._cpu_only = True
self.configure() self.configure()
self.set_shape()
self.generate_data() self.generate_data()
self.set_inputs() self.set_inputs()
y_scales_size = (
self.bias_shape if self.per_channel_quantize_weight else 1
)
self.attrs = { self.attrs = {
'use_mkldnn': True, 'use_mkldnn': True,
'Scale_in': self.x_scale, 'Scale_in': self.x_scale,
'Scale_weights': [self.y_scale], 'Scale_weights': [self.y_scale] * y_scales_size,
'Scale_out': self.out_scale, 'Scale_out': self.out_scale,
'force_fp32_output': self.force_fp32_output, 'force_fp32_output': self.force_fp32_output,
'in_num_col_dims': self.in_num_col_dims,
} }
if self.force_fp32_output: if self.force_fp32_output:
...@@ -45,6 +51,13 @@ class TestFCINT8OneDNNOp(OpTest): ...@@ -45,6 +51,13 @@ class TestFCINT8OneDNNOp(OpTest):
def configure(self): def configure(self):
self.use_bias = True self.use_bias = True
self.force_fp32_output = False self.force_fp32_output = False
self.in_num_col_dims = 1
self.per_channel_quantize_weight = False
def set_shape(self):
self.input_shape = (10, 5)
self.weight_shape = (5, 10)
self.bias_shape = 10
def set_inputs(self): def set_inputs(self):
self.inputs = {'Input': self.x, 'W': self.y_float, 'Bias': self.bias} self.inputs = {'Input': self.x, 'W': self.y_float, 'Bias': self.bias}
...@@ -55,15 +68,26 @@ class TestFCINT8OneDNNOp(OpTest): ...@@ -55,15 +68,26 @@ class TestFCINT8OneDNNOp(OpTest):
return scale, quantized return scale, quantized
def generate_data(self): def generate_data(self):
self.x_float = np.random.random((10, 5)).astype("float32") * 10 self.x_float = np.random.random(self.input_shape).astype("float32") * 10
self.x_scale, self.x = self.quantize(self.x_float) self.x_scale, self.x = self.quantize(self.x_float)
self.y_float = np.random.random((5, 10)).astype("float32") * 10 self.y_float = (
np.random.random(self.weight_shape).astype("float32") * 10
)
self.y_scale, self.y = self.quantize(self.y_float) self.y_scale, self.y = self.quantize(self.y_float)
self.out_float = np.dot(self.x_float, self.y_float) flatten_shape = [1, 1]
for i in range(len(self.input_shape)):
if i < self.in_num_col_dims:
flatten_shape[0] *= self.input_shape[i]
else:
flatten_shape[1] *= self.input_shape[i]
self.out_float = np.dot(
self.x_float.reshape(flatten_shape), self.y_float
)
if self.use_bias: if self.use_bias:
self.bias = np.random.random(10).astype("float32") * 10 self.bias = np.random.random(self.bias_shape).astype("float32") * 10
self.out_float += self.bias self.out_float += self.bias
self.out_scale, self.out = self.quantize(self.out_float) self.out_scale, self.out = self.quantize(self.out_float)
...@@ -77,6 +101,8 @@ class TestFCINT8NoBiasOneDNNOp(TestFCINT8OneDNNOp): ...@@ -77,6 +101,8 @@ class TestFCINT8NoBiasOneDNNOp(TestFCINT8OneDNNOp):
def configure(self): def configure(self):
self.use_bias = False self.use_bias = False
self.force_fp32_output = False self.force_fp32_output = False
self.in_num_col_dims = 1
self.per_channel_quantize_weight = False
def set_inputs(self): def set_inputs(self):
self.inputs = { self.inputs = {
...@@ -89,6 +115,21 @@ class TestFCINT8ForceFP32OutputOneDNNOp(TestFCINT8NoBiasOneDNNOp): ...@@ -89,6 +115,21 @@ class TestFCINT8ForceFP32OutputOneDNNOp(TestFCINT8NoBiasOneDNNOp):
def configure(self): def configure(self):
self.use_bias = False self.use_bias = False
self.force_fp32_output = True self.force_fp32_output = True
self.in_num_col_dims = 1
self.per_channel_quantize_weight = False
class TestFCINT8ForceFP32OutputPerChannelWeightOneDNNOp(TestFCINT8OneDNNOp):
def configure(self):
self.use_bias = True
self.force_fp32_output = True
self.in_num_col_dims = 1
self.per_channel_quantize_weight = True
def set_shape(self):
self.input_shape = (1, 8, 1, 1)
self.weight_shape = (8, 10)
self.bias_shape = 10
if __name__ == "__main__": if __name__ == "__main__":
......
Subproject commit 2089770c4818be8933c5e9d1dd3cbaeba1457667 Subproject commit 64f6bcbcbab628e96f33a62c3e975f8535a7bde4
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册