提交 350a58b0 编写于 作者: X xiaolil1

fix quant op initialization bug

上级 01431825
......@@ -103,18 +103,26 @@ class DeQuantOpKernel : public framework::OpKernel<T> {
}
};
framework::OpKernelType DeQuantOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
framework::OpKernelType DeQuantOp::GetExpectedKernelType(const framework::ExecutionContext& ctx) const {
framework::LibraryType library_{framework::LibraryType::kPlain};
std::string data_format = ctx.Attr<std::string>("data_format");
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN;
}
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("Input")->type()),
ctx.device_context());
framework::ToDataType(ctx.Input<framework::LoDTensor>("Input")->type()),ctx.GetPlace(),layout_, library_);
}
void DeQuantOpMaker::Make() {
AddInput("Input","input");
AddInput("Scale","scale...");
AddOutput("Output","output");
AddComment(R"DOC(
This op will quantize data from INT8 to FP32
)DOC");
}
} // namespace operators
......
......@@ -73,11 +73,18 @@ class QuantOpKernel : public framework::OpKernel<T> {
}
};
framework::OpKernelType QuantOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
framework::OpKernelType QuantOp::GetExpectedKernelType(const framework::ExecutionContext& ctx) const {
framework::LibraryType library_{framework::LibraryType::kPlain};
std::string data_format = ctx.Attr<std::string>("data_format");
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN;
}
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("Input")->type()),
ctx.device_context());
framework::ToDataType(ctx.Input<framework::LoDTensor>("Input")->type()),ctx.GetPlace(),layout_, library_);
//ctx.device_context());
}
......@@ -85,6 +92,9 @@ void QuantOpMaker::Make() {
AddInput("Input","input");
AddInput("Scale","scale...");
AddOutput("Output","output");
AddComment(R"DOC(
This op will quantize data from FP32 to INT8
)DOC");
}
} // namespace operators
......
......@@ -80,18 +80,26 @@ class ReQuantOpKernel : public framework::OpKernel<T> {
}
};
framework::OpKernelType ReQuantOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
framework::OpKernelType ReQuantOp::GetExpectedKernelType(const framework::ExecutionContext& ctx) const {
framework::LibraryType library_{framework::LibraryType::kPlain};
std::string data_format = ctx.Attr<std::string>("data_format");
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN;
}
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("Input")->type()),
ctx.device_context());
framework::ToDataType(ctx.Input<framework::LoDTensor>("Input")->type()),ctx.GetPlace(),layout_, library_);
}
void ReQuantOpMaker::Make() {
AddInput("Input","input");
AddInput("Scale","scale...");
AddOutput("Output","output");
AddComment(R"DOC(
This op will requantize data from INT8 to INT8
)DOC");
}
} // namespace operators
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册