提交 b6ff1105 编写于 作者: X xiaolil1

enable force fp32 output

上级 2d4321d9
......@@ -334,7 +334,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
bool fuse_relu = ctx.Attr<bool>("fuse_relu");
bool fuse_residual_conn = ctx.Attr<bool>("fuse_residual_connection");
bool force_fp32_output = ctx.Attr<bool>("force_fp32_output");
int groups = ctx.Attr<int>("groups");
//std::cout<<"force_fp32_output = "<<force_fp32_output<<std::endl;
if (fuse_residual_conn) {
PADDLE_ENFORCE(force_fp32_output != true,
"residual fusion does not support force output with fp32");
}
// TODO(tpatejko): add support for dilation
PADDLE_ENFORCE(
......@@ -367,7 +373,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
src_tz, weights_tz, strides, paddings, dilations, groups,
ctx.op().Output("Output"));
const std::string key_conv_pd = key + "@conv_pd";
//std::cout<<key<<std::endl;
bool is_INT8 = ctx.HasInput("Scale_in")? true : false;
bool need_s8_to_u8 = false;
......@@ -444,7 +450,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
int8_t* output_data = output->mutable_data<int8_t>(ctx.GetPlace());
dst_memory_p->set_data_handle(to_void_cast<int8_t>(output_data));
}
} else {
} else if(!force_fp32_output){
if(fuse_relu){
uint8_t* output_data = output->mutable_data<uint8_t>(ctx.GetPlace(), ::paddle::memory::Allocator::kDefault, handler->GetDstMemorySize());
dst_memory_p->set_data_handle(to_void_cast<uint8_t>(output_data));
......@@ -452,6 +458,9 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
int8_t* output_data = output->mutable_data<int8_t>(ctx.GetPlace(), ::paddle::memory::Allocator::kDefault, handler->GetDstMemorySize());
dst_memory_p->set_data_handle(to_void_cast<int8_t>(output_data));
}
} else {
float* output_data = output->mutable_data<float>(ctx.GetPlace(), ::paddle::memory::Allocator::kDefault, handler->GetDstMemorySize());
dst_memory_p->set_data_handle(to_void_cast<float>(output_data));
}
}
......@@ -600,7 +609,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto sum_scale_key = key + "@sum_scale";
auto scale_in_eltwise_key = key + "@scale_in_eltwise";
std::vector<float> scale_in_data;
std::vector<float> scale_out_data;
std::vector<float> scale_out_data = {1.0f};
std::vector<float> scale_weights_data;
std::vector<float> scale_in_eltwise_data;
std::vector<float> output_shift_scale;
......@@ -619,6 +628,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
for(int i=0; i<count; i++){
scale_weights_data[i] =*(scale_weights->data<float>() + i);
}
if(!force_fp32_output)
scale_out_data = {*(scale_out->data<float>())};
output_shift_scale.resize(count);
#pragma omp parallel for if (count > 1)
......@@ -678,6 +688,10 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
paddle::framework::ToMKLDNNDataType(std::type_index(typeid(unsigned char)))
: paddle::framework::ToMKLDNNDataType(std::type_index(typeid(signed char)));
if(force_fp32_output){
dst_dt = paddle::framework::ToMKLDNNDataType(std::type_index(typeid(float)));
}
if(fuse_residual_conn){
auto residual = ctx.Input<Tensor>("ResidualData");
auto residual_dt = paddle::framework::ToMKLDNNDataType(residual->type());
......@@ -738,7 +752,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
dst_memory_p =
handler->AcquireDstMemoryFromPrimitive(to_void_cast<int8_t>(output_data));
}
} else {
} else if(!force_fp32_output){
if(fuse_relu){
uint8_t* output_data = output->mutable_data<uint8_t>(ctx.GetPlace(), ::paddle::memory::Allocator::kDefault, handler->GetDstMemorySize());
dst_memory_p =
......@@ -748,6 +762,10 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
dst_memory_p =
handler->AcquireDstMemoryFromPrimitive(to_void_cast<int8_t>(output_data));
}
} else {
float* output_data = output->mutable_data<float>(ctx.GetPlace(), ::paddle::memory::Allocator::kDefault, handler->GetDstMemorySize());
dst_memory_p =
handler->AcquireDstMemoryFromPrimitive(to_void_cast<float>(output_data));
}
// create convolution op primitive
......@@ -793,6 +811,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
output->set_layout(DataLayout::kMKLDNN);
output->set_format(GetMKLDNNFormat(*dst_memory_p));
} else {
//std::cout<<"this is int8 init"<<std::endl;
if(src_memory_reorder_p){
pipeline.push_back(*src_memory_reorder_p);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册