提交 5f8b86a4 编写于 作者: X xiaolil1

enable s8 convolution for se-resnext

上级 52854e88
......@@ -300,7 +300,6 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
"It must use CPUPlace.");
const bool is_test = ctx.Attr<bool>("is_test");
auto& dev_ctx =
......@@ -433,6 +432,27 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
output->mutable_data<T>(ctx.GetPlace(), ::paddle::memory::Allocator::kDefault, handler->GetDstMemorySize());
dst_memory_p->set_data_handle(to_void_cast<T>(output_data));
}
} else if(is_INT8 && dst_memory_p){
if(fuse_residual_conn) {
auto residual_param = ctx.Input<Tensor>("ResidualData");
auto residual_dt = paddle::framework::ToMKLDNNDataType(residual_param->type());
output->ShareDataWith(*residual_param);
if(residual_dt == mkldnn::memory::data_type::u8){
uint8_t* output_data = output->mutable_data<uint8_t>(ctx.GetPlace());
dst_memory_p->set_data_handle(to_void_cast<uint8_t>(output_data));
} else{
int8_t* output_data = output->mutable_data<int8_t>(ctx.GetPlace());
dst_memory_p->set_data_handle(to_void_cast<int8_t>(output_data));
}
} else {
if(fuse_relu){
uint8_t* output_data = output->mutable_data<uint8_t>(ctx.GetPlace(), ::paddle::memory::Allocator::kDefault, handler->GetDstMemorySize());
dst_memory_p->set_data_handle(to_void_cast<uint8_t>(output_data));
} else{
int8_t* output_data = output->mutable_data<int8_t>(ctx.GetPlace(), ::paddle::memory::Allocator::kDefault, handler->GetDstMemorySize());
dst_memory_p->set_data_handle(to_void_cast<int8_t>(output_data));
}
}
}
if(!is_INT8){
......@@ -552,6 +572,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
output->set_layout(DataLayout::kMKLDNN);
output->set_format(GetMKLDNNFormat(*dst_memory_p));
} else {
std::cout<<"this is init fp32!!!!!!!!!!!!!"<<std::endl;
if(src_memory_reorder_p){
pipeline.push_back(*src_memory_reorder_p);
}
......@@ -773,6 +794,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
output->set_layout(DataLayout::kMKLDNN);
output->set_format(GetMKLDNNFormat(*dst_memory_p));
} else {
std::cout<<"this is init int8!!!!!!!!!!!!!!"<<std::endl;
if(src_memory_reorder_p){
pipeline.push_back(*src_memory_reorder_p);
}
......@@ -1152,7 +1174,8 @@ namespace ops = paddle::operators;
REGISTER_OP_KERNEL(conv2d, MKLDNN, ::paddle::platform::CPUPlace,
ops::ConvMKLDNNOpKernel<float>,
ops::ConvMKLDNNOpKernel<uint8_t>);
ops::ConvMKLDNNOpKernel<uint8_t>,
ops::ConvMKLDNNOpKernel<int8_t>);
REGISTER_OP_KERNEL(conv2d_grad, MKLDNN, ::paddle::platform::CPUPlace,
ops::ConvMKLDNNGradOpKernel<float>);
......@@ -30,7 +30,6 @@ using Tensor = framework::Tensor;
using framework::DataLayout;
using mkldnn::stream;
using platform::GetMKLDNNFormat;
//using MKLDNNDataType = mkldnn::memory::data_type;
template <typename T>
class DeQuantOpKernel : public framework::OpKernel<T> {
......@@ -46,7 +45,6 @@ class DeQuantOpKernel : public framework::OpKernel<T> {
const T* input_data = input->data<T>();
float* output_data = output->mutable_data<float>(ctx.GetPlace());
//T scale_data = *(scale->data<T>());
std::vector<float> scale_data = {*(scale->data<float>())};
std::vector<float> reorder_scale = {1.0f / scale_data[0]};
......@@ -77,7 +75,6 @@ class DeQuantOpKernel : public framework::OpKernel<T> {
pipeline.push_back(*reorder_p);
stream(stream::kind::eager).submit(pipeline).wait();
//output->set_layout(DataLayout::kMKLDNN);
output->set_format(GetMKLDNNFormat(dst_memory));
}
......@@ -114,5 +111,5 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(dequantize, ops::DeQuantOp, ops::DeQuantOpMaker, paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OP_KERNEL(dequantize, MKLDNN, ::paddle::platform::CPUPlace, ops::DeQuantOpKernel<uint8_t>);
REGISTER_OP_KERNEL(dequantize, MKLDNN, ::paddle::platform::CPUPlace, ops::DeQuantOpKernel<uint8_t>, ops::DeQuantOpKernel<int8_t>);
......@@ -46,7 +46,7 @@ class QuantOpKernel : public framework::OpKernel<T> {
std::vector<int> dst_tz = paddle::framework::vectorize2int(output->dims());
const T* input_data = input->data<T>();
uint8_t* output_data = output->mutable_data<uint8_t>(ctx.GetPlace());
std::vector<T> scale_data = {*(scale->data<T>())};
mkldnn::primitive_attr attri;
......@@ -59,20 +59,32 @@ class QuantOpKernel : public framework::OpKernel<T> {
auto src_memory = std::make_shared<mkldnn::memory>(src_pd, to_void_cast<T>(input_data));
std::shared_ptr<primitive::at> src_memory_p = std::shared_ptr<primitive::at>(new primitive::at(*src_memory));
auto dst_md = platform::MKLDNNMemDesc(
{dst_tz}, memory::data_type::u8, memory::format::nhwc);
auto dst_pd = mkldnn::memory::primitive_desc(dst_md, engine);
auto dst_memory = mkldnn::memory(dst_pd, to_void_cast<uint8_t>(output_data));
bool is_negative = ctx.Attr<bool>("is_negative_input");
mkldnn::memory::primitive_desc dst_pd;
std::shared_ptr<mkldnn::memory> dst_memory;
if (is_negative) {
int8_t* output_data = output->mutable_data<int8_t>(ctx.GetPlace());
auto dst_md = platform::MKLDNNMemDesc(
{dst_tz}, memory::data_type::s8, memory::format::nhwc);
dst_pd = mkldnn::memory::primitive_desc(dst_md, engine);
dst_memory.reset(new mkldnn::memory(dst_pd, to_void_cast<int8_t>(output_data)));
} else {
uint8_t* output_data = output->mutable_data<uint8_t>(ctx.GetPlace());
auto dst_md = platform::MKLDNNMemDesc(
{dst_tz}, memory::data_type::u8, memory::format::nhwc);
dst_pd = mkldnn::memory::primitive_desc(dst_md, engine);
dst_memory.reset(new mkldnn::memory(dst_pd, to_void_cast<uint8_t>(output_data)));
}
auto reorder_pd = std::shared_ptr<reorder::primitive_desc>(
new reorder::primitive_desc(src_pd, dst_pd, attri));
auto reorder_p= std::shared_ptr<reorder>(new reorder(*reorder_pd, *src_memory_p, dst_memory));
auto reorder_p= std::shared_ptr<reorder>(new reorder(*reorder_pd, *src_memory_p, *dst_memory));
pipeline.push_back(*reorder_p);
stream(stream::kind::eager).submit(pipeline).wait();
output->set_layout(DataLayout::kMKLDNN);
output->set_format(GetMKLDNNFormat(dst_memory));
output->set_format(GetMKLDNNFormat(*dst_memory));
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册