提交 350a58b0 编写于 作者: X xiaolil1

fix quant op initialization bug

上级 01431825
...@@ -103,18 +103,26 @@ class DeQuantOpKernel : public framework::OpKernel<T> { ...@@ -103,18 +103,26 @@ class DeQuantOpKernel : public framework::OpKernel<T> {
} }
}; };
framework::OpKernelType DeQuantOp::GetExpectedKernelType( framework::OpKernelType DeQuantOp::GetExpectedKernelType(const framework::ExecutionContext& ctx) const {
const framework::ExecutionContext& ctx) const { framework::LibraryType library_{framework::LibraryType::kPlain};
std::string data_format = ctx.Attr<std::string>("data_format");
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN;
}
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("Input")->type()), framework::ToDataType(ctx.Input<framework::LoDTensor>("Input")->type()),ctx.GetPlace(),layout_, library_);
ctx.device_context());
} }
void DeQuantOpMaker::Make() { void DeQuantOpMaker::Make() {
AddInput("Input","input"); AddInput("Input","input");
AddInput("Scale","scale..."); AddInput("Scale","scale...");
AddOutput("Output","output"); AddOutput("Output","output");
AddComment(R"DOC(
This op will quantize data from INT8 to FP32
)DOC");
} }
} // namespace operators } // namespace operators
......
...@@ -73,11 +73,18 @@ class QuantOpKernel : public framework::OpKernel<T> { ...@@ -73,11 +73,18 @@ class QuantOpKernel : public framework::OpKernel<T> {
} }
}; };
framework::OpKernelType QuantOp::GetExpectedKernelType( framework::OpKernelType QuantOp::GetExpectedKernelType(const framework::ExecutionContext& ctx) const {
const framework::ExecutionContext& ctx) const { framework::LibraryType library_{framework::LibraryType::kPlain};
std::string data_format = ctx.Attr<std::string>("data_format");
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN;
}
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("Input")->type()), framework::ToDataType(ctx.Input<framework::LoDTensor>("Input")->type()),ctx.GetPlace(),layout_, library_);
ctx.device_context()); //ctx.device_context());
} }
...@@ -85,6 +92,9 @@ void QuantOpMaker::Make() { ...@@ -85,6 +92,9 @@ void QuantOpMaker::Make() {
AddInput("Input","input"); AddInput("Input","input");
AddInput("Scale","scale..."); AddInput("Scale","scale...");
AddOutput("Output","output"); AddOutput("Output","output");
AddComment(R"DOC(
This op will quantize data from FP32 to INT8
)DOC");
} }
} // namespace operators } // namespace operators
......
...@@ -80,18 +80,26 @@ class ReQuantOpKernel : public framework::OpKernel<T> { ...@@ -80,18 +80,26 @@ class ReQuantOpKernel : public framework::OpKernel<T> {
} }
}; };
framework::OpKernelType ReQuantOp::GetExpectedKernelType( framework::OpKernelType ReQuantOp::GetExpectedKernelType(const framework::ExecutionContext& ctx) const {
const framework::ExecutionContext& ctx) const { framework::LibraryType library_{framework::LibraryType::kPlain};
std::string data_format = ctx.Attr<std::string>("data_format");
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN;
}
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("Input")->type()), framework::ToDataType(ctx.Input<framework::LoDTensor>("Input")->type()),ctx.GetPlace(),layout_, library_);
ctx.device_context());
} }
void ReQuantOpMaker::Make() { void ReQuantOpMaker::Make() {
AddInput("Input","input"); AddInput("Input","input");
AddInput("Scale","scale..."); AddInput("Scale","scale...");
AddOutput("Output","output"); AddOutput("Output","output");
AddComment(R"DOC(
This op will requantize data from INT8 to INT8
)DOC");
} }
} // namespace operators } // namespace operators
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册