提交 93c4ee01 编写于 作者: X xiaolil1

integrate residual different format fix to INT8

上级 51e5950e
...@@ -535,84 +535,95 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -535,84 +535,95 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std::shared_ptr<mkldnn::memory> dst_memory_p; std::shared_ptr<mkldnn::memory> dst_memory_p;
bool need_s8_to_u8 = false; bool need_s8_to_u8 = false;
if(is_INT8){ if(fuse_residual_conn) {
if (fuse_residual_conn) {
auto residual_param = ctx.Input<Tensor>("ResidualData"); auto residual_param = ctx.Input<Tensor>("ResidualData");
PADDLE_ENFORCE_EQ(output->dims(), residual_param->dims(), PADDLE_ENFORCE_EQ(output->dims(), residual_param->dims(),
"Output and elementwise parameter need to have the " "Output and elementwise parameter need to have the "
"same dimension sizes"); "same dimension sizes");
output->ShareDataWith(*residual_param);
auto residual_dt = paddle::framework::ToMKLDNNDataType(residual_param->type()); auto residual_dt = paddle::framework::ToMKLDNNDataType(residual_param->type());
if(residual_param->format() != handler.GetDstFormat()) {
auto residual_data_tz =
paddle::framework::vectorize2int(residual_param->dims());
auto residual_data_type =
paddle::framework::ToMKLDNNDataType(residual_param->type());
auto user_residual_md = platform::MKLDNNMemDesc(
residual_data_tz, residual_data_type, residual_param->format());
if(is_INT8){
if(residual_dt == mkldnn::memory::data_type::u8){ if(residual_dt == mkldnn::memory::data_type::u8){
auto residual_param_data = residual_param->data<uint8_t>();
auto user_residual_memory_p = handler.AcquireResidualDataMemory(
user_residual_md, to_void_cast<uint8_t>(residual_param_data));
PADDLE_ENFORCE(
residual_param_data != nullptr,
"Provide data if you want MKLDNN conv+elementwise_add fusion");
uint8_t* output_data = output->mutable_data<uint8_t>(ctx.GetPlace()); uint8_t* output_data = output->mutable_data<uint8_t>(ctx.GetPlace());
dst_memory_p = dst_memory_p =
handler.AcquireDstMemoryFromPrimitive(to_void_cast<uint8_t>(output_data)); handler.AcquireDstMemoryFromResidualDataMemory(
user_residual_memory_p, to_void_cast<uint8_t>(output_data), pipeline);
} else{ } else{
auto residual_param_data = residual_param->data<int8_t>();
auto user_residual_memory_p = handler.AcquireResidualDataMemory(
user_residual_md, to_void_cast<int8_t>(residual_param_data));
PADDLE_ENFORCE(
residual_param_data != nullptr,
"Provide data if you want MKLDNN conv+elementwise_add fusion");
int8_t* output_data = output->mutable_data<int8_t>(ctx.GetPlace()); int8_t* output_data = output->mutable_data<int8_t>(ctx.GetPlace());
dst_memory_p = dst_memory_p =
handler.AcquireDstMemoryFromPrimitive(to_void_cast<int8_t>(output_data)); handler.AcquireDstMemoryFromResidualDataMemory(
user_residual_memory_p, to_void_cast<int8_t>(output_data), pipeline);
if(fuse_relu) if(fuse_relu)
need_s8_to_u8 = true; need_s8_to_u8 = true;
} }
} else {
if(fuse_relu){
uint8_t* output_data = output->mutable_data<uint8_t>(ctx.GetPlace(), handler.GetDstMemorySize());
dst_memory_p =
handler.AcquireDstMemoryFromPrimitive(to_void_cast<uint8_t>(output_data));
} else{
int8_t* output_data = output->mutable_data<int8_t>(ctx.GetPlace(), handler.GetDstMemorySize());
dst_memory_p =
handler.AcquireDstMemoryFromPrimitive(to_void_cast<int8_t>(output_data));
}
}
} else{ } else{
// create reorder primitive if the input format is not the preferred one
// auto src_memory_p =
// handler.AcquireSrcMemoryFromPrimitive(user_src_memory_p, pipeline);
// auto weights_memory_p = handler.AcquireWeightsMemoryFromPrimitive(
// user_weights_memory_p, pipeline, is_test);
// std::shared_ptr<mkldnn::memory> dst_memory_p;
if (fuse_residual_conn) {
auto residual_param = ctx.Input<Tensor>("ResidualData");
auto residual_param_data = residual_param->data<T>(); auto residual_param_data = residual_param->data<T>();
auto user_residual_memory_p = handler.AcquireResidualDataMemory(
user_residual_md, to_void_cast<T>(residual_param_data));
PADDLE_ENFORCE( PADDLE_ENFORCE(
residual_param_data != nullptr, residual_param_data != nullptr,
"Provide data if you want MKLDNN conv+elementwise_add fusion"); "Provide data if you want MKLDNN conv+elementwise_add fusion");
PADDLE_ENFORCE_EQ(output->dims(), residual_param->dims(),
"Output and elementwise parameter need to have the "
"same dimension sizes");
if (residual_param->format() != handler.GetDstFormat()) {
auto output_data = auto output_data =
output->mutable_data<T>(ctx.GetPlace(), handler.GetDstMemorySize()); output->mutable_data<T>(ctx.GetPlace(), handler.GetDstMemorySize());
auto residual_data_tz =
paddle::framework::vectorize2int(residual_param->dims());
auto residual_data_type =
paddle::framework::ToMKLDNNDataType(residual_param->type());
auto user_residual_md = platform::MKLDNNMemDesc(
residual_data_tz, residual_data_type, residual_param->format());
auto user_residual_memory_p = handler.AcquireResidualDataMemory(
user_residual_md, to_void_cast<T>(residual_param_data));
dst_memory_p = handler.AcquireDstMemoryFromResidualDataMemory( dst_memory_p = handler.AcquireDstMemoryFromResidualDataMemory(
user_residual_memory_p, to_void_cast<T>(output_data), pipeline); user_residual_memory_p, to_void_cast<T>(output_data), pipeline);
}
} else { } else {
output->ShareDataWith(*residual_param); output->ShareDataWith(*residual_param);
if(is_INT8){
if(residual_dt == mkldnn::memory::data_type::u8){
uint8_t* output_data = output->mutable_data<uint8_t>(ctx.GetPlace());
dst_memory_p =
handler.AcquireDstMemoryFromPrimitive(to_void_cast<uint8_t>(output_data));
} else{
int8_t* output_data = output->mutable_data<int8_t>(ctx.GetPlace());
dst_memory_p =
handler.AcquireDstMemoryFromPrimitive(to_void_cast<int8_t>(output_data));
if(fuse_relu)
need_s8_to_u8 = true;
}
} else{
auto output_data = output->mutable_data<T>(ctx.GetPlace()); auto output_data = output->mutable_data<T>(ctx.GetPlace());
dst_memory_p = dst_memory_p =
handler.AcquireDstMemoryFromPrimitive(to_void_cast<T>(output_data)); handler.AcquireDstMemoryFromPrimitive(to_void_cast<T>(output_data));
} }
}
} else { } else {
if(is_INT8){
if(fuse_relu){
uint8_t* output_data = output->mutable_data<uint8_t>(ctx.GetPlace(), handler.GetDstMemorySize());
dst_memory_p =
handler.AcquireDstMemoryFromPrimitive(to_void_cast<uint8_t>(output_data));
} else{
int8_t* output_data = output->mutable_data<int8_t>(ctx.GetPlace(), handler.GetDstMemorySize());
dst_memory_p =
handler.AcquireDstMemoryFromPrimitive(to_void_cast<int8_t>(output_data));
}
} else{
auto output_data = auto output_data =
output->mutable_data<T>(ctx.GetPlace(), handler.GetDstMemorySize()); output->mutable_data<T>(ctx.GetPlace(), handler.GetDstMemorySize());
dst_memory_p = dst_memory_p =
handler.AcquireDstMemoryFromPrimitive(to_void_cast<T>(output_data)); handler.AcquireDstMemoryFromPrimitive(to_void_cast<T>(output_data));
} }
// dst_memory_p =
// handler.AcquireDstMemoryFromPrimitive(to_void_cast<T>(output_data));
} }
// create convolution op primitive // create convolution op primitive
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册