未验证 提交 4dc4d5fc 编写于 作者: H HongyuJia 提交者: GitHub

[MKLDNN] Delete mkldnn hard code of fc (#47138)

* remove fc mkldnn hardcode

* remove useless enum of kFCMKLDNN

* fix macro error

* update operators.cmake
上级 420d4bc7
......@@ -511,13 +511,6 @@ function(op_library TARGET)
# Append first implemented MKLDNN activation operator
if(${MKLDNN_FILE} STREQUAL "activation_mkldnn_op")
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(softplus, MKLDNN);\n")
elseif(${MKLDNN_FILE} STREQUAL "fc_mkldnn_op")
file(APPEND ${pybind_file}
"USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(fc, MKLDNN, FP32);\n")
file(APPEND ${pybind_file}
"USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(fc, MKLDNN, S8);\n")
file(APPEND ${pybind_file}
"USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(fc, MKLDNN, U8);\n")
else()
foreach(mkldnn_src ${mkldnn_cc_srcs})
set(op_name "")
......
......@@ -128,18 +128,6 @@ class FCOp : public framework::OperatorWithKernel {
const framework::ExecutionContext& ctx) const override {
auto input_data_type =
OperatorWithKernel::IndicateVarDataType(ctx, "Input");
if (ctx.Attr<bool>("use_mkldnn")) {
using framework::proto::VarType;
int customized_type_value = (input_data_type == VarType::INT8 ||
input_data_type == VarType::UINT8)
? kFCMKLDNNINT8
: kFCMKLDNNFP32;
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
phi::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN,
customized_type_value);
}
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
......
......@@ -22,8 +22,6 @@ limitations under the License. */
namespace paddle {
namespace operators {
enum { kFCMKLDNNFP32 = 1, kFCMKLDNNINT8 = 2 };
using Tensor = phi::DenseTensor;
inline void FCOutputSize(const framework::DDim& in_dims,
......
......@@ -320,27 +320,38 @@ class FCMKLDNNHandler
} // namespace operators
}; // namespace paddle
template <typename T_in, typename T_w>
#define IF_CHANGE_FC_TW_TYPENAME(condition, ...) \
if (condition) { \
using T_w = int8_t; \
__VA_ARGS__(); \
} else { \
using T_w = T_in; \
__VA_ARGS__(); \
}
template <typename T_in>
class FCMKLDNNKernel : public framework::OpKernel<T_in> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
bool force_fp32_output = ctx.Attr<bool>("force_fp32_output");
bool fuse_relu = ctx.Attr<std::string>("activation_type") == "relu";
if (force_fp32_output) {
this->RunKernel<float>(ctx);
} else if (IsInt8<T_in>()) {
if (fuse_relu) {
this->RunKernel<uint8_t>(ctx);
} else {
this->RunKernel<int8_t>(ctx);
}
} else {
this->RunKernel<T_in>(ctx);
}
IF_CHANGE_FC_TW_TYPENAME((std::is_same<T_in, uint8_t>::value), ([&] {
if (force_fp32_output) {
this->RunKernel<float, T_w>(ctx);
} else if (IsInt8<T_in>()) {
if (fuse_relu) {
this->RunKernel<uint8_t, T_w>(ctx);
} else {
this->RunKernel<int8_t, T_w>(ctx);
}
} else {
this->RunKernel<T_in, T_w>(ctx);
}
}));
}
template <typename T_out = T_w>
template <typename T_out, typename T_w>
void RunKernel(const framework::ExecutionContext& ctx) const {
const auto& dev_ctx =
ctx.template device_context<platform::MKLDNNDeviceContext>();
......@@ -422,32 +433,11 @@ class FCMKLDNNKernel : public framework::OpKernel<T_in> {
// data type implies their destination data type. (What's eventually going to
// be used during computations of kernel).
namespace ops = paddle::operators;
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(fc,
MKLDNN,
::paddle::platform::CPUPlace,
FP32,
ops::kFCMKLDNNFP32,
ops::FCMKLDNNKernel<float, float>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(
fc,
MKLDNN,
::paddle::platform::CPUPlace,
BF16,
ops::kFCMKLDNNFP32,
ops::FCMKLDNNKernel<paddle::platform::bfloat16,
paddle::platform::bfloat16>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(fc,
MKLDNN,
::paddle::platform::CPUPlace,
U8,
ops::kFCMKLDNNINT8,
ops::FCMKLDNNKernel<uint8_t, int8_t>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(fc,
MKLDNN,
::paddle::platform::CPUPlace,
S8,
ops::kFCMKLDNNINT8,
ops::FCMKLDNNKernel<int8_t, int8_t>);
REGISTER_OP_KERNEL(fc,
MKLDNN,
::paddle::platform::CPUPlace,
ops::FCMKLDNNKernel<float>,
ops::FCMKLDNNKernel<paddle::platform::bfloat16>,
ops::FCMKLDNNKernel<uint8_t>,
ops::FCMKLDNNKernel<int8_t>);
......@@ -70,7 +70,6 @@ static const std::unordered_set<std::string> mkldnn_white_list = {
// NOTE(jiahongyu): Below ops register kernel with customized_type_value, we
// need to analysis and solve them one-by-one.
"prior_box",
"fc",
"mul",
"mul_grad"};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册