提交 b6ff1105 编写于 作者: X xiaolil1

enable force fp32 output

上级 2d4321d9
...@@ -334,7 +334,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -334,7 +334,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations"); std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
bool fuse_relu = ctx.Attr<bool>("fuse_relu"); bool fuse_relu = ctx.Attr<bool>("fuse_relu");
bool fuse_residual_conn = ctx.Attr<bool>("fuse_residual_connection"); bool fuse_residual_conn = ctx.Attr<bool>("fuse_residual_connection");
bool force_fp32_output = ctx.Attr<bool>("force_fp32_output");
int groups = ctx.Attr<int>("groups"); int groups = ctx.Attr<int>("groups");
//std::cout<<"force_fp32_output = "<<force_fp32_output<<std::endl;
if (fuse_residual_conn) {
PADDLE_ENFORCE(force_fp32_output != true,
"residual fusion does not support force output with fp32");
}
// TODO(tpatejko): add support for dilation // TODO(tpatejko): add support for dilation
PADDLE_ENFORCE( PADDLE_ENFORCE(
...@@ -367,7 +373,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -367,7 +373,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
src_tz, weights_tz, strides, paddings, dilations, groups, src_tz, weights_tz, strides, paddings, dilations, groups,
ctx.op().Output("Output")); ctx.op().Output("Output"));
const std::string key_conv_pd = key + "@conv_pd"; const std::string key_conv_pd = key + "@conv_pd";
//std::cout<<key<<std::endl;
bool is_INT8 = ctx.HasInput("Scale_in")? true : false; bool is_INT8 = ctx.HasInput("Scale_in")? true : false;
bool need_s8_to_u8 = false; bool need_s8_to_u8 = false;
...@@ -444,7 +450,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -444,7 +450,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
int8_t* output_data = output->mutable_data<int8_t>(ctx.GetPlace()); int8_t* output_data = output->mutable_data<int8_t>(ctx.GetPlace());
dst_memory_p->set_data_handle(to_void_cast<int8_t>(output_data)); dst_memory_p->set_data_handle(to_void_cast<int8_t>(output_data));
} }
} else { } else if(!force_fp32_output){
if(fuse_relu){ if(fuse_relu){
uint8_t* output_data = output->mutable_data<uint8_t>(ctx.GetPlace(), ::paddle::memory::Allocator::kDefault, handler->GetDstMemorySize()); uint8_t* output_data = output->mutable_data<uint8_t>(ctx.GetPlace(), ::paddle::memory::Allocator::kDefault, handler->GetDstMemorySize());
dst_memory_p->set_data_handle(to_void_cast<uint8_t>(output_data)); dst_memory_p->set_data_handle(to_void_cast<uint8_t>(output_data));
...@@ -452,6 +458,9 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -452,6 +458,9 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
int8_t* output_data = output->mutable_data<int8_t>(ctx.GetPlace(), ::paddle::memory::Allocator::kDefault, handler->GetDstMemorySize()); int8_t* output_data = output->mutable_data<int8_t>(ctx.GetPlace(), ::paddle::memory::Allocator::kDefault, handler->GetDstMemorySize());
dst_memory_p->set_data_handle(to_void_cast<int8_t>(output_data)); dst_memory_p->set_data_handle(to_void_cast<int8_t>(output_data));
} }
} else {
float* output_data = output->mutable_data<float>(ctx.GetPlace(), ::paddle::memory::Allocator::kDefault, handler->GetDstMemorySize());
dst_memory_p->set_data_handle(to_void_cast<float>(output_data));
} }
} }
...@@ -600,7 +609,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -600,7 +609,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto sum_scale_key = key + "@sum_scale"; auto sum_scale_key = key + "@sum_scale";
auto scale_in_eltwise_key = key + "@scale_in_eltwise"; auto scale_in_eltwise_key = key + "@scale_in_eltwise";
std::vector<float> scale_in_data; std::vector<float> scale_in_data;
std::vector<float> scale_out_data; std::vector<float> scale_out_data = {1.0f};
std::vector<float> scale_weights_data; std::vector<float> scale_weights_data;
std::vector<float> scale_in_eltwise_data; std::vector<float> scale_in_eltwise_data;
std::vector<float> output_shift_scale; std::vector<float> output_shift_scale;
...@@ -619,7 +628,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -619,7 +628,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
for(int i=0; i<count; i++){ for(int i=0; i<count; i++){
scale_weights_data[i] =*(scale_weights->data<float>() + i); scale_weights_data[i] =*(scale_weights->data<float>() + i);
} }
scale_out_data = {*(scale_out->data<float>())}; if(!force_fp32_output)
scale_out_data = {*(scale_out->data<float>())};
output_shift_scale.resize(count); output_shift_scale.resize(count);
#pragma omp parallel for if (count > 1) #pragma omp parallel for if (count > 1)
for(int i=0; i<count; i++){ for(int i=0; i<count; i++){
...@@ -678,6 +688,10 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -678,6 +688,10 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
paddle::framework::ToMKLDNNDataType(std::type_index(typeid(unsigned char))) paddle::framework::ToMKLDNNDataType(std::type_index(typeid(unsigned char)))
: paddle::framework::ToMKLDNNDataType(std::type_index(typeid(signed char))); : paddle::framework::ToMKLDNNDataType(std::type_index(typeid(signed char)));
if(force_fp32_output){
dst_dt = paddle::framework::ToMKLDNNDataType(std::type_index(typeid(float)));
}
if(fuse_residual_conn){ if(fuse_residual_conn){
auto residual = ctx.Input<Tensor>("ResidualData"); auto residual = ctx.Input<Tensor>("ResidualData");
auto residual_dt = paddle::framework::ToMKLDNNDataType(residual->type()); auto residual_dt = paddle::framework::ToMKLDNNDataType(residual->type());
...@@ -738,7 +752,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -738,7 +752,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
dst_memory_p = dst_memory_p =
handler->AcquireDstMemoryFromPrimitive(to_void_cast<int8_t>(output_data)); handler->AcquireDstMemoryFromPrimitive(to_void_cast<int8_t>(output_data));
} }
} else { } else if(!force_fp32_output){
if(fuse_relu){ if(fuse_relu){
uint8_t* output_data = output->mutable_data<uint8_t>(ctx.GetPlace(), ::paddle::memory::Allocator::kDefault, handler->GetDstMemorySize()); uint8_t* output_data = output->mutable_data<uint8_t>(ctx.GetPlace(), ::paddle::memory::Allocator::kDefault, handler->GetDstMemorySize());
dst_memory_p = dst_memory_p =
...@@ -748,6 +762,10 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -748,6 +762,10 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
dst_memory_p = dst_memory_p =
handler->AcquireDstMemoryFromPrimitive(to_void_cast<int8_t>(output_data)); handler->AcquireDstMemoryFromPrimitive(to_void_cast<int8_t>(output_data));
} }
} else {
float* output_data = output->mutable_data<float>(ctx.GetPlace(), ::paddle::memory::Allocator::kDefault, handler->GetDstMemorySize());
dst_memory_p =
handler->AcquireDstMemoryFromPrimitive(to_void_cast<float>(output_data));
} }
// create convolution op primitive // create convolution op primitive
...@@ -793,6 +811,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -793,6 +811,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
output->set_layout(DataLayout::kMKLDNN); output->set_layout(DataLayout::kMKLDNN);
output->set_format(GetMKLDNNFormat(*dst_memory_p)); output->set_format(GetMKLDNNFormat(*dst_memory_p));
} else { } else {
//std::cout<<"this is int8 init"<<std::endl;
if(src_memory_reorder_p){ if(src_memory_reorder_p){
pipeline.push_back(*src_memory_reorder_p); pipeline.push_back(*src_memory_reorder_p);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册