提交 cb3f4264 编写于 作者: X xiaolil1

fix quantize op register bug

上级 f80ab8fe
...@@ -32,7 +32,7 @@ using mkldnn::stream; ...@@ -32,7 +32,7 @@ using mkldnn::stream;
using platform::GetMKLDNNFormat; using platform::GetMKLDNNFormat;
//using MKLDNNDataType = mkldnn::memory::data_type; //using MKLDNNDataType = mkldnn::memory::data_type;
template <typename DeviceContext, typename T> template <typename T>
class DeQuantOpKernel : public framework::OpKernel<T> { class DeQuantOpKernel : public framework::OpKernel<T> {
public: public:
...@@ -83,13 +83,17 @@ framework::OpKernelType DeQuantOp::GetExpectedKernelType(const framework::Execut ...@@ -83,13 +83,17 @@ framework::OpKernelType DeQuantOp::GetExpectedKernelType(const framework::Execut
framework::LibraryType library_{framework::LibraryType::kPlain}; framework::LibraryType library_{framework::LibraryType::kPlain};
std::string data_format = ctx.Attr<std::string>("data_format"); std::string data_format = ctx.Attr<std::string>("data_format");
framework::DataLayout layout_ = framework::StringToDataLayout(data_format); framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain && if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) { platform::CanMKLDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kMKLDNN; library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN;
} }
#endif
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("Input")->type()),ctx.GetPlace(),layout_, library_); framework::ToDataType(ctx.Input<framework::Tensor>("Input")->type()),ctx.GetPlace(),layout_, library_);
} }
void DeQuantOpMaker::Make() { void DeQuantOpMaker::Make() {
...@@ -108,6 +112,5 @@ namespace ops = paddle::operators; ...@@ -108,6 +112,5 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(dequantize, ops::DeQuantOp, ops::DeQuantOpMaker, paddle::framework::DefaultGradOpDescMaker<true>); REGISTER_OPERATOR(dequantize, ops::DeQuantOp, ops::DeQuantOpMaker, paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OP_CPU_KERNEL(dequantize, ops::DeQuantOpKernel<paddle::platform::CPUDeviceContext, float>); REGISTER_OP_KERNEL(dequantize, MKLDNN, ::paddle::platform::CPUPlace, ops::DeQuantOpKernel<float>);
...@@ -30,7 +30,7 @@ using framework::DataLayout; ...@@ -30,7 +30,7 @@ using framework::DataLayout;
using mkldnn::stream; using mkldnn::stream;
using platform::GetMKLDNNFormat; using platform::GetMKLDNNFormat;
template <typename DeviceContext, typename T> template <typename T>
class QuantOpKernel : public framework::OpKernel<T> { class QuantOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
...@@ -76,14 +76,17 @@ framework::OpKernelType QuantOp::GetExpectedKernelType(const framework::Executio ...@@ -76,14 +76,17 @@ framework::OpKernelType QuantOp::GetExpectedKernelType(const framework::Executio
framework::LibraryType library_{framework::LibraryType::kPlain}; framework::LibraryType library_{framework::LibraryType::kPlain};
std::string data_format = ctx.Attr<std::string>("data_format"); std::string data_format = ctx.Attr<std::string>("data_format");
framework::DataLayout layout_ = framework::StringToDataLayout(data_format); framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain && if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) { platform::CanMKLDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kMKLDNN; library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN;
} }
#endif
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("Input")->type()),ctx.GetPlace(),layout_, library_); framework::ToDataType(ctx.Input<framework::Tensor>("Input")->type()),ctx.GetPlace(),layout_, library_);
//ctx.device_context());
} }
...@@ -103,10 +106,7 @@ namespace ops = paddle::operators; ...@@ -103,10 +106,7 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(quantize, ops::QuantOp, ops::QuantOpMaker, paddle::framework::DefaultGradOpDescMaker<true>); REGISTER_OPERATOR(quantize, ops::QuantOp, ops::QuantOpMaker, paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OP_CPU_KERNEL(quantize, ops::QuantOpKernel<paddle::platform::CPUDeviceContext, float>); REGISTER_OP_KERNEL(quantize, MKLDNN, ::paddle::platform::CPUPlace, ops::QuantOpKernel<float>);
//REGISTER_OP_KERNEL(quantization, MKLDNN, paddle::platform::CPUPlace, ops::QuantOpKernel<paddle::platform::CPUDeviceContext, float>);
......
...@@ -31,7 +31,7 @@ using framework::DataLayout; ...@@ -31,7 +31,7 @@ using framework::DataLayout;
using mkldnn::stream; using mkldnn::stream;
using platform::GetMKLDNNFormat; using platform::GetMKLDNNFormat;
template <typename DeviceContext, typename T> template <typename T>
class ReQuantOpKernel : public framework::OpKernel<T> { class ReQuantOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
...@@ -84,13 +84,17 @@ framework::OpKernelType ReQuantOp::GetExpectedKernelType(const framework::Execut ...@@ -84,13 +84,17 @@ framework::OpKernelType ReQuantOp::GetExpectedKernelType(const framework::Execut
framework::LibraryType library_{framework::LibraryType::kPlain}; framework::LibraryType library_{framework::LibraryType::kPlain};
std::string data_format = ctx.Attr<std::string>("data_format"); std::string data_format = ctx.Attr<std::string>("data_format");
framework::DataLayout layout_ = framework::StringToDataLayout(data_format); framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain && if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) { platform::CanMKLDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kMKLDNN; library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN;
} }
#endif
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("Input")->type()),ctx.GetPlace(),layout_, library_); framework::ToDataType(ctx.Input<framework::Tensor>("Input")->type()),ctx.GetPlace(),layout_, library_);
} }
void ReQuantOpMaker::Make() { void ReQuantOpMaker::Make() {
...@@ -109,5 +113,4 @@ namespace ops = paddle::operators; ...@@ -109,5 +113,4 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(requantize, ops::ReQuantOp, ops::ReQuantOpMaker, paddle::framework::DefaultGradOpDescMaker<true>); REGISTER_OPERATOR(requantize, ops::ReQuantOp, ops::ReQuantOpMaker, paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OP_CPU_KERNEL(requantize, ops::ReQuantOpKernel<paddle::platform::CPUDeviceContext, float>); REGISTER_OP_KERNEL(requantize, MKLDNN, ::paddle::platform::CPUPlace, ops::ReQuantOpKernel<float>);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册