diff --git a/paddle/fluid/operators/dequantize_op.cc b/paddle/fluid/operators/dequantize_op.cc index 99258ac169ba4cfc8c26fc97f1b3abb1243654da..07a20bfae54745da8f4bd5eed6d036dc33dc3a72 100644 --- a/paddle/fluid/operators/dequantize_op.cc +++ b/paddle/fluid/operators/dequantize_op.cc @@ -32,7 +32,7 @@ using mkldnn::stream; using platform::GetMKLDNNFormat; //using MKLDNNDataType = mkldnn::memory::data_type; -template +template class DeQuantOpKernel : public framework::OpKernel { public: @@ -83,13 +83,17 @@ framework::OpKernelType DeQuantOp::GetExpectedKernelType(const framework::Execut framework::LibraryType library_{framework::LibraryType::kPlain}; std::string data_format = ctx.Attr("data_format"); framework::DataLayout layout_ = framework::StringToDataLayout(data_format); + +#ifdef PADDLE_WITH_MKLDNN if (library_ == framework::LibraryType::kPlain && platform::CanMKLDNNBeUsed(ctx)) { library_ = framework::LibraryType::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN; } +#endif + return framework::OpKernelType( - framework::ToDataType(ctx.Input("Input")->type()),ctx.GetPlace(),layout_, library_); + framework::ToDataType(ctx.Input("Input")->type()),ctx.GetPlace(),layout_, library_); } void DeQuantOpMaker::Make() { @@ -108,6 +112,5 @@ namespace ops = paddle::operators; REGISTER_OPERATOR(dequantize, ops::DeQuantOp, ops::DeQuantOpMaker, paddle::framework::DefaultGradOpDescMaker); -REGISTER_OP_CPU_KERNEL(dequantize, ops::DeQuantOpKernel); - +REGISTER_OP_KERNEL(dequantize, MKLDNN, ::paddle::platform::CPUPlace, ops::DeQuantOpKernel); diff --git a/paddle/fluid/operators/quantize_op.cc b/paddle/fluid/operators/quantize_op.cc index fb9285e61730cca8fbb40ab70a9305bd69e8d121..a18c6f74137bee726fe04cd069d4a293bedfaa60 100644 --- a/paddle/fluid/operators/quantize_op.cc +++ b/paddle/fluid/operators/quantize_op.cc @@ -30,7 +30,7 @@ using framework::DataLayout; using mkldnn::stream; using platform::GetMKLDNNFormat; -template +template class QuantOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { @@ -76,14 +76,17 @@ framework::OpKernelType QuantOp::GetExpectedKernelType(const framework::Executio framework::LibraryType library_{framework::LibraryType::kPlain}; std::string data_format = ctx.Attr("data_format"); framework::DataLayout layout_ = framework::StringToDataLayout(data_format); + +#ifdef PADDLE_WITH_MKLDNN if (library_ == framework::LibraryType::kPlain && platform::CanMKLDNNBeUsed(ctx)) { library_ = framework::LibraryType::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN; } +#endif + return framework::OpKernelType( framework::ToDataType(ctx.Input("Input")->type()),ctx.GetPlace(),layout_, library_); - //ctx.device_context()); } @@ -103,10 +106,7 @@ namespace ops = paddle::operators; REGISTER_OPERATOR(quantize, ops::QuantOp, ops::QuantOpMaker, paddle::framework::DefaultGradOpDescMaker); -REGISTER_OP_CPU_KERNEL(quantize, ops::QuantOpKernel); - -//REGISTER_OP_KERNEL(quantization, MKLDNN, paddle::platform::CPUPlace, ops::QuantOpKernel); - +REGISTER_OP_KERNEL(quantize, MKLDNN, ::paddle::platform::CPUPlace, ops::QuantOpKernel); diff --git a/paddle/fluid/operators/requantize_op.cc b/paddle/fluid/operators/requantize_op.cc index 86f5535f90a225160ebfdaf478a3cd7b5292ab4c..6decbded6b028ee9d788b6bfacee12528ee447ce 100644 --- a/paddle/fluid/operators/requantize_op.cc +++ b/paddle/fluid/operators/requantize_op.cc @@ -31,7 +31,7 @@ using framework::DataLayout; using mkldnn::stream; using platform::GetMKLDNNFormat; -template +template class ReQuantOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { @@ -84,13 +84,17 @@ framework::OpKernelType ReQuantOp::GetExpectedKernelType(const framework::Execut framework::LibraryType library_{framework::LibraryType::kPlain}; std::string data_format = ctx.Attr("data_format"); framework::DataLayout layout_ = framework::StringToDataLayout(data_format); + +#ifdef PADDLE_WITH_MKLDNN if (library_ == framework::LibraryType::kPlain && platform::CanMKLDNNBeUsed(ctx)) { library_ = framework::LibraryType::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN; } +#endif + return framework::OpKernelType( - framework::ToDataType(ctx.Input("Input")->type()),ctx.GetPlace(),layout_, library_); + framework::ToDataType(ctx.Input("Input")->type()),ctx.GetPlace(),layout_, library_); } void ReQuantOpMaker::Make() { @@ -109,5 +113,4 @@ namespace ops = paddle::operators; REGISTER_OPERATOR(requantize, ops::ReQuantOp, ops::ReQuantOpMaker, paddle::framework::DefaultGradOpDescMaker); -REGISTER_OP_CPU_KERNEL(requantize, ops::ReQuantOpKernel); - +REGISTER_OP_KERNEL(requantize, MKLDNN, ::paddle::platform::CPUPlace, ops::ReQuantOpKernel);