提交 786a72cd 编写于 作者: H Haihao Shen

Fix force_fp32_output to correct the dst_memory_pd

上级 1dac27b5
......@@ -413,6 +413,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
scale_weights_data[i] =*(scale_weights->data<float>() + i);
}
scale_out_data = {*(scale_out->data<float>())};
if(force_fp32_output)
scale_out_data[0] = 1.0;
output_shift_scale.resize(count);
#pragma omp parallel for if (count > 1)
for(int i=0; i<count; i++){
......@@ -694,7 +696,6 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
output->ShareDataWith(*residual_param);
if(is_INT8){
if(residual_dt == mkldnn::memory::data_type::u8){
uint8_t* output_data = output->mutable_data<uint8_t>(ctx.GetPlace());
dst_memory_p =
handler.AcquireDstMemoryFromPrimitive(to_void_cast<uint8_t>(output_data));
......@@ -712,7 +713,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
}
}
} else {
if(is_INT8){
if(is_INT8 && !force_fp32_output){
if(fuse_relu){
uint8_t* output_data = output->mutable_data<uint8_t>(ctx.GetPlace(), handler.GetDstMemorySize());
dst_memory_p =
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册