From cb3f426433f608fe0848ba802954afa24d78a2be Mon Sep 17 00:00:00 2001 From: xiaolil1 Date: Fri, 12 Oct 2018 12:55:29 +0800 Subject: [PATCH] fix quantize op register bug --- paddle/fluid/operators/dequantize_op.cc | 11 +++++++---- paddle/fluid/operators/quantize_op.cc | 12 ++++++------ paddle/fluid/operators/requantize_op.cc | 11 +++++++---- 3 files changed, 20 insertions(+), 14 deletions(-) diff --git a/paddle/fluid/operators/dequantize_op.cc b/paddle/fluid/operators/dequantize_op.cc index 99258ac169b..07a20bfae54 100644 --- a/paddle/fluid/operators/dequantize_op.cc +++ b/paddle/fluid/operators/dequantize_op.cc @@ -32,7 +32,7 @@ using mkldnn::stream; using platform::GetMKLDNNFormat; //using MKLDNNDataType = mkldnn::memory::data_type; -template +template class DeQuantOpKernel : public framework::OpKernel { public: @@ -83,13 +83,17 @@ framework::OpKernelType DeQuantOp::GetExpectedKernelType(const framework::Execut framework::LibraryType library_{framework::LibraryType::kPlain}; std::string data_format = ctx.Attr("data_format"); framework::DataLayout layout_ = framework::StringToDataLayout(data_format); + +#ifdef PADDLE_WITH_MKLDNN if (library_ == framework::LibraryType::kPlain && platform::CanMKLDNNBeUsed(ctx)) { library_ = framework::LibraryType::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN; } +#endif + return framework::OpKernelType( - framework::ToDataType(ctx.Input("Input")->type()),ctx.GetPlace(),layout_, library_); + framework::ToDataType(ctx.Input("Input")->type()),ctx.GetPlace(),layout_, library_); } void DeQuantOpMaker::Make() { @@ -108,6 +112,5 @@ namespace ops = paddle::operators; REGISTER_OPERATOR(dequantize, ops::DeQuantOp, ops::DeQuantOpMaker, paddle::framework::DefaultGradOpDescMaker); -REGISTER_OP_CPU_KERNEL(dequantize, ops::DeQuantOpKernel); - +REGISTER_OP_KERNEL(dequantize, MKLDNN, ::paddle::platform::CPUPlace, ops::DeQuantOpKernel); diff --git a/paddle/fluid/operators/quantize_op.cc b/paddle/fluid/operators/quantize_op.cc index fb9285e6173..a18c6f74137 100644 --- a/paddle/fluid/operators/quantize_op.cc +++ b/paddle/fluid/operators/quantize_op.cc @@ -30,7 +30,7 @@ using framework::DataLayout; using mkldnn::stream; using platform::GetMKLDNNFormat; -template +template class QuantOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { @@ -76,14 +76,17 @@ framework::OpKernelType QuantOp::GetExpectedKernelType(const framework::Executio framework::LibraryType library_{framework::LibraryType::kPlain}; std::string data_format = ctx.Attr("data_format"); framework::DataLayout layout_ = framework::StringToDataLayout(data_format); + +#ifdef PADDLE_WITH_MKLDNN if (library_ == framework::LibraryType::kPlain && platform::CanMKLDNNBeUsed(ctx)) { library_ = framework::LibraryType::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN; } +#endif + return framework::OpKernelType( framework::ToDataType(ctx.Input("Input")->type()),ctx.GetPlace(),layout_, library_); - //ctx.device_context()); } @@ -103,10 +106,7 @@ namespace ops = paddle::operators; REGISTER_OPERATOR(quantize, ops::QuantOp, ops::QuantOpMaker, paddle::framework::DefaultGradOpDescMaker); -REGISTER_OP_CPU_KERNEL(quantize, ops::QuantOpKernel); - -//REGISTER_OP_KERNEL(quantization, MKLDNN, paddle::platform::CPUPlace, ops::QuantOpKernel); - +REGISTER_OP_KERNEL(quantize, MKLDNN, ::paddle::platform::CPUPlace, ops::QuantOpKernel); diff --git a/paddle/fluid/operators/requantize_op.cc b/paddle/fluid/operators/requantize_op.cc index 86f5535f90a..6decbded6b0 100644 --- a/paddle/fluid/operators/requantize_op.cc +++ b/paddle/fluid/operators/requantize_op.cc @@ -31,7 +31,7 @@ using framework::DataLayout; using mkldnn::stream; using platform::GetMKLDNNFormat; -template +template class ReQuantOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { @@ -84,13 +84,17 @@ framework::OpKernelType ReQuantOp::GetExpectedKernelType(const framework::Execut framework::LibraryType library_{framework::LibraryType::kPlain}; std::string data_format = ctx.Attr("data_format"); framework::DataLayout layout_ = framework::StringToDataLayout(data_format); + +#ifdef PADDLE_WITH_MKLDNN if (library_ == framework::LibraryType::kPlain && platform::CanMKLDNNBeUsed(ctx)) { library_ = framework::LibraryType::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN; } +#endif + return framework::OpKernelType( - framework::ToDataType(ctx.Input("Input")->type()),ctx.GetPlace(),layout_, library_); + framework::ToDataType(ctx.Input("Input")->type()),ctx.GetPlace(),layout_, library_); } void ReQuantOpMaker::Make() { @@ -109,5 +113,4 @@ namespace ops = paddle::operators; REGISTER_OPERATOR(requantize, ops::ReQuantOp, ops::ReQuantOpMaker, paddle::framework::DefaultGradOpDescMaker); -REGISTER_OP_CPU_KERNEL(requantize, ops::ReQuantOpKernel); - +REGISTER_OP_KERNEL(requantize, MKLDNN, ::paddle::platform::CPUPlace, ops::ReQuantOpKernel); -- GitLab