提交 cb3f4264 编写于 作者: X xiaolil1

fix quantize op register bug

上级 f80ab8fe
......@@ -32,7 +32,7 @@ using mkldnn::stream;
using platform::GetMKLDNNFormat;
//using MKLDNNDataType = mkldnn::memory::data_type;
template <typename DeviceContext, typename T>
template <typename T>
class DeQuantOpKernel : public framework::OpKernel<T> {
public:
......@@ -83,13 +83,17 @@ framework::OpKernelType DeQuantOp::GetExpectedKernelType(const framework::Execut
framework::LibraryType library_{framework::LibraryType::kPlain};
std::string data_format = ctx.Attr<std::string>("data_format");
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN;
}
#endif
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("Input")->type()),ctx.GetPlace(),layout_, library_);
framework::ToDataType(ctx.Input<framework::Tensor>("Input")->type()),ctx.GetPlace(),layout_, library_);
}
void DeQuantOpMaker::Make() {
......@@ -108,6 +112,5 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(dequantize, ops::DeQuantOp, ops::DeQuantOpMaker, paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OP_CPU_KERNEL(dequantize, ops::DeQuantOpKernel<paddle::platform::CPUDeviceContext, float>);
REGISTER_OP_KERNEL(dequantize, MKLDNN, ::paddle::platform::CPUPlace, ops::DeQuantOpKernel<float>);
......@@ -30,7 +30,7 @@ using framework::DataLayout;
using mkldnn::stream;
using platform::GetMKLDNNFormat;
template <typename DeviceContext, typename T>
template <typename T>
class QuantOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......@@ -76,14 +76,17 @@ framework::OpKernelType QuantOp::GetExpectedKernelType(const framework::Executio
framework::LibraryType library_{framework::LibraryType::kPlain};
std::string data_format = ctx.Attr<std::string>("data_format");
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN;
}
#endif
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("Input")->type()),ctx.GetPlace(),layout_, library_);
//ctx.device_context());
}
......@@ -103,10 +106,7 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(quantize, ops::QuantOp, ops::QuantOpMaker, paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OP_CPU_KERNEL(quantize, ops::QuantOpKernel<paddle::platform::CPUDeviceContext, float>);
//REGISTER_OP_KERNEL(quantization, MKLDNN, paddle::platform::CPUPlace, ops::QuantOpKernel<paddle::platform::CPUDeviceContext, float>);
REGISTER_OP_KERNEL(quantize, MKLDNN, ::paddle::platform::CPUPlace, ops::QuantOpKernel<float>);
......
......@@ -31,7 +31,7 @@ using framework::DataLayout;
using mkldnn::stream;
using platform::GetMKLDNNFormat;
template <typename DeviceContext, typename T>
template <typename T>
class ReQuantOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......@@ -84,13 +84,17 @@ framework::OpKernelType ReQuantOp::GetExpectedKernelType(const framework::Execut
framework::LibraryType library_{framework::LibraryType::kPlain};
std::string data_format = ctx.Attr<std::string>("data_format");
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN;
}
#endif
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("Input")->type()),ctx.GetPlace(),layout_, library_);
framework::ToDataType(ctx.Input<framework::Tensor>("Input")->type()),ctx.GetPlace(),layout_, library_);
}
void ReQuantOpMaker::Make() {
......@@ -109,5 +113,4 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(requantize, ops::ReQuantOp, ops::ReQuantOpMaker, paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OP_CPU_KERNEL(requantize, ops::ReQuantOpKernel<paddle::platform::CPUDeviceContext, float>);
REGISTER_OP_KERNEL(requantize, MKLDNN, ::paddle::platform::CPUPlace, ops::ReQuantOpKernel<float>);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册