提交 d5167b48 编写于 作者: X xiaolil1

modification for int8 kernel PR, remove requantization op temporarily

上级 a042d86b
......@@ -54,7 +54,7 @@ class DeQuantOpKernel : public framework::OpKernel<T> {
std::vector<int> src_tz = paddle::framework::vectorize2int(input->dims());
std::vector<int> dst_tz = paddle::framework::vectorize2int(output->dims());
mkldnn::memory::data_type src_dt = paddle::framework::ToMKLDNNDataType(input->type());
mkldnn::memory::format src_fmt = memory::format::nhwc;//input->format();
mkldnn::memory::format src_fmt = input->format();
mkldnn::primitive_attr attri;
int mask = 0;
......@@ -101,12 +101,10 @@ framework::OpKernelType DeQuantOp::GetExpectedKernelType(const framework::Execut
}
void DeQuantOpMaker::Make() {
AddInput("Input","input");
AddInput("Scale","scale...");
AddOutput("Output","output");
AddComment(R"DOC(
This op will quantize data from INT8 to FP32
)DOC");
AddInput("Input","input data");
AddInput("Scale","scale data");
AddOutput("Output","output data");
AddComment(R"DOC(This op will quantize data from INT8 to FP32)DOC");
}
} // namespace operators
......
......@@ -95,18 +95,17 @@ framework::OpKernelType QuantOp::GetExpectedKernelType(const framework::Executio
void QuantOpMaker::Make() {
AddInput("Input","input");
AddInput("Scale","scale...");
AddOutput("Output","output");
AddComment(R"DOC(
This op will quantize data from FP32 to INT8
)DOC");
AddInput("Input","input data");
AddInput("Scale","scale data");
AddOutput("Output","output data");
AddComment(R"DOC(This op will quantize data from FP32 to INT8)DOC");
}
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
//TODO Support FP32->S8 quantization.
REGISTER_OPERATOR(quantize, ops::QuantOp, ops::QuantOpMaker, paddle::framework::DefaultGradOpDescMaker<true>);
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "mkldnn.hpp"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/mkldnn_helper.h"
#include "paddle/fluid/operators/requantize_op.h"
#include "paddle/fluid/framework/data_layout_transform.h"
namespace paddle {
namespace operators {
using mkldnn::memory;
using mkldnn::primitive;
using mkldnn::reorder;
using platform::to_void_cast;
using Tensor = framework::Tensor;
using framework::DataLayout;
using mkldnn::stream;
using platform::GetMKLDNNFormat;
template <typename T>
class ReQuantOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
//std::cout<<"this is requant op!!!!!"<<std::endl;
auto* input = ctx.Input<Tensor>("Input");
//auto* scale = ctx.Input<Tensor>("Scale");
auto* output = ctx.Output<Tensor>("Output");
auto& dev_ctx =
ctx.template device_context<platform::MKLDNNDeviceContext>();
const auto& engine = dev_ctx.GetEngine();
std::vector<primitive> pipeline;
std::vector<int> src_tz = paddle::framework::vectorize2int(input->dims());
std::vector<int> dst_tz = paddle::framework::vectorize2int(output->dims());
mkldnn::memory::data_type src_dt = paddle::framework::ToMKLDNNDataType(input->type());
mkldnn::memory::data_type dst_dt = mkldnn::memory::data_type::u8;//paddle::framework::ToMKLDNNDataType(output->type());
mkldnn::memory::format src_fmt = memory::format::nhwc;//input->format();
mkldnn::memory::format dst_fmt = memory::format::nhwc;//output->format();
const T* input_data = input->data<T>();
uint8_t* output_data = output->mutable_data<uint8_t>(ctx.GetPlace());
//T scale_data = *(scale->data<T>());
std::vector<float> scale_data = {0.9999999}; //{*(scale->data<float>())};
mkldnn::primitive_attr attri;
int mask = 0;
attri.set_output_scales(mask,scale_data);// scale_data);
//attri.set_int_output_round_mode(round_nearest); //FIX ME
auto src_md = platform::MKLDNNMemDesc(
{src_tz}, src_dt, src_fmt); //FIX ME WITH S8
auto src_pd = mkldnn::memory::primitive_desc(src_md, engine);
auto src_memory = std::make_shared<mkldnn::memory>(src_pd, to_void_cast<T>(input_data));
std::shared_ptr<primitive::at> src_memory_p = std::shared_ptr<primitive::at>(new primitive::at(*src_memory));
auto dst_md = platform::MKLDNNMemDesc(
{dst_tz}, dst_dt, dst_fmt);
auto dst_pd = mkldnn::memory::primitive_desc(dst_md, engine);
auto dst_memory = mkldnn::memory(dst_pd, to_void_cast<uint8_t>(output_data));
auto reorder_pd = std::shared_ptr<reorder::primitive_desc>(
new reorder::primitive_desc(src_pd, dst_pd, attri));
int is_sum = ctx.Attr<int>("is_sum");
if(is_sum){
//std::cout<<"is_sum == true"<<std::endl;
memcpy(output_data, input_data, sizeof(uint8_t) * input->numel());
} else{
auto reorder_p= std::shared_ptr<reorder>(new reorder(*reorder_pd, *src_memory_p, dst_memory));
pipeline.push_back(*reorder_p);
stream(stream::kind::eager).submit(pipeline).wait();
}
output->set_layout(DataLayout::kMKLDNN);
output->set_format(GetMKLDNNFormat(dst_memory));
//std::cout<<"requant op end!!!!!"<<std::endl;
}
};
framework::OpKernelType ReQuantOp::GetExpectedKernelType(const framework::ExecutionContext& ctx) const {
framework::LibraryType library_{framework::LibraryType::kPlain};
std::string data_format = ctx.Attr<std::string>("data_format");
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN;
}
#endif
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("Input")->type()),ctx.GetPlace(),layout_, library_);
}
void ReQuantOpMaker::Make() {
AddInput("Input","input");
AddInput("Scale","scale...");
AddOutput("Output","output");
AddComment(R"DOC(
This op will requantize data from INT8 to INT8
)DOC");
}
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(requantize, ops::ReQuantOp, ops::ReQuantOpMaker, paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OP_KERNEL(requantize, MKLDNN, ::paddle::platform::CPUPlace, ops::ReQuantOpKernel<int8_t>);
......@@ -153,7 +153,6 @@ class MKLDNNHandler {
std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key));
PADDLE_ENFORCE((mem_p != nullptr) || (is_reusing_ == false),
"Fail to find mem primitive in device context");
//mem_p = nullptr;
if (mem_p == nullptr) {
mem_p = std::make_shared<mkldnn::memory>(mdp, ptr);
dev_ctx_.SetBlob(local_key, mem_p);
......@@ -234,10 +233,9 @@ class MKLDNNHandler {
std::shared_ptr<mkldnn::primitive> reorder_p;
if (mpd != user_mpd) {
target_memory_p = std::make_shared<mkldnn::memory>(mpd);
std::shared_ptr<mkldnn::reorder> reorder_p;// =
//std::make_shared<mkldnn::reorder>(*user_memory_p, *target_memory_p);
std::shared_ptr<mkldnn::reorder> reorder_p;
if(is_INT8){
mkldnn::primitive_attr attri;
mkldnn::primitive_attr attri; //attribute for int8 weights and bias data reorder.
attri.set_output_scales(mask, scale_data);
auto reorder_pd = std::shared_ptr<mkldnn::reorder::primitive_desc>(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册