未验证 提交 4dc4d5fc 编写于 作者: H HongyuJia 提交者: GitHub

[MKLDNN] Delete mkldnn hard code of fc (#47138)

* remove fc mkldnn hardcode

* remove useless enum of kFCMKLDNN

* fix macro error

* update operators.cmake
上级 420d4bc7
...@@ -511,13 +511,6 @@ function(op_library TARGET) ...@@ -511,13 +511,6 @@ function(op_library TARGET)
# Append first implemented MKLDNN activation operator # Append first implemented MKLDNN activation operator
if(${MKLDNN_FILE} STREQUAL "activation_mkldnn_op") if(${MKLDNN_FILE} STREQUAL "activation_mkldnn_op")
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(softplus, MKLDNN);\n") file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(softplus, MKLDNN);\n")
elseif(${MKLDNN_FILE} STREQUAL "fc_mkldnn_op")
file(APPEND ${pybind_file}
"USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(fc, MKLDNN, FP32);\n")
file(APPEND ${pybind_file}
"USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(fc, MKLDNN, S8);\n")
file(APPEND ${pybind_file}
"USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(fc, MKLDNN, U8);\n")
else() else()
foreach(mkldnn_src ${mkldnn_cc_srcs}) foreach(mkldnn_src ${mkldnn_cc_srcs})
set(op_name "") set(op_name "")
......
...@@ -128,18 +128,6 @@ class FCOp : public framework::OperatorWithKernel { ...@@ -128,18 +128,6 @@ class FCOp : public framework::OperatorWithKernel {
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
auto input_data_type = auto input_data_type =
OperatorWithKernel::IndicateVarDataType(ctx, "Input"); OperatorWithKernel::IndicateVarDataType(ctx, "Input");
if (ctx.Attr<bool>("use_mkldnn")) {
using framework::proto::VarType;
int customized_type_value = (input_data_type == VarType::INT8 ||
input_data_type == VarType::UINT8)
? kFCMKLDNNINT8
: kFCMKLDNNFP32;
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
phi::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN,
customized_type_value);
}
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
}; };
......
...@@ -22,8 +22,6 @@ limitations under the License. */ ...@@ -22,8 +22,6 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
enum { kFCMKLDNNFP32 = 1, kFCMKLDNNINT8 = 2 };
using Tensor = phi::DenseTensor; using Tensor = phi::DenseTensor;
inline void FCOutputSize(const framework::DDim& in_dims, inline void FCOutputSize(const framework::DDim& in_dims,
......
...@@ -320,27 +320,38 @@ class FCMKLDNNHandler ...@@ -320,27 +320,38 @@ class FCMKLDNNHandler
} // namespace operators } // namespace operators
}; // namespace paddle }; // namespace paddle
template <typename T_in, typename T_w> #define IF_CHANGE_FC_TW_TYPENAME(condition, ...) \
if (condition) { \
using T_w = int8_t; \
__VA_ARGS__(); \
} else { \
using T_w = T_in; \
__VA_ARGS__(); \
}
template <typename T_in>
class FCMKLDNNKernel : public framework::OpKernel<T_in> { class FCMKLDNNKernel : public framework::OpKernel<T_in> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
bool force_fp32_output = ctx.Attr<bool>("force_fp32_output"); bool force_fp32_output = ctx.Attr<bool>("force_fp32_output");
bool fuse_relu = ctx.Attr<std::string>("activation_type") == "relu"; bool fuse_relu = ctx.Attr<std::string>("activation_type") == "relu";
if (force_fp32_output) { IF_CHANGE_FC_TW_TYPENAME((std::is_same<T_in, uint8_t>::value), ([&] {
this->RunKernel<float>(ctx); if (force_fp32_output) {
} else if (IsInt8<T_in>()) { this->RunKernel<float, T_w>(ctx);
if (fuse_relu) { } else if (IsInt8<T_in>()) {
this->RunKernel<uint8_t>(ctx); if (fuse_relu) {
} else { this->RunKernel<uint8_t, T_w>(ctx);
this->RunKernel<int8_t>(ctx); } else {
} this->RunKernel<int8_t, T_w>(ctx);
} else { }
this->RunKernel<T_in>(ctx); } else {
} this->RunKernel<T_in, T_w>(ctx);
}
}));
} }
template <typename T_out = T_w> template <typename T_out, typename T_w>
void RunKernel(const framework::ExecutionContext& ctx) const { void RunKernel(const framework::ExecutionContext& ctx) const {
const auto& dev_ctx = const auto& dev_ctx =
ctx.template device_context<platform::MKLDNNDeviceContext>(); ctx.template device_context<platform::MKLDNNDeviceContext>();
...@@ -422,32 +433,11 @@ class FCMKLDNNKernel : public framework::OpKernel<T_in> { ...@@ -422,32 +433,11 @@ class FCMKLDNNKernel : public framework::OpKernel<T_in> {
// data type implies their destination data type. (What's eventually going to // data type implies their destination data type. (What's eventually going to
// be used during computations of kernel). // be used during computations of kernel).
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(fc,
MKLDNN, REGISTER_OP_KERNEL(fc,
::paddle::platform::CPUPlace, MKLDNN,
FP32, ::paddle::platform::CPUPlace,
ops::kFCMKLDNNFP32, ops::FCMKLDNNKernel<float>,
ops::FCMKLDNNKernel<float, float>); ops::FCMKLDNNKernel<paddle::platform::bfloat16>,
ops::FCMKLDNNKernel<uint8_t>,
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE( ops::FCMKLDNNKernel<int8_t>);
fc,
MKLDNN,
::paddle::platform::CPUPlace,
BF16,
ops::kFCMKLDNNFP32,
ops::FCMKLDNNKernel<paddle::platform::bfloat16,
paddle::platform::bfloat16>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(fc,
MKLDNN,
::paddle::platform::CPUPlace,
U8,
ops::kFCMKLDNNINT8,
ops::FCMKLDNNKernel<uint8_t, int8_t>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(fc,
MKLDNN,
::paddle::platform::CPUPlace,
S8,
ops::kFCMKLDNNINT8,
ops::FCMKLDNNKernel<int8_t, int8_t>);
...@@ -70,7 +70,6 @@ static const std::unordered_set<std::string> mkldnn_white_list = { ...@@ -70,7 +70,6 @@ static const std::unordered_set<std::string> mkldnn_white_list = {
// NOTE(jiahongyu): Below ops register kernel with customized_type_value, we // NOTE(jiahongyu): Below ops register kernel with customized_type_value, we
// need to analysis and solve them one-by-one. // need to analysis and solve them one-by-one.
"prior_box", "prior_box",
"fc",
"mul", "mul",
"mul_grad"}; "mul_grad"};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册