提交 2e09ab36 编写于 作者: H Haihao Shen

Support force fp32 ouput for int8 conv

上级 3050a73f
...@@ -339,6 +339,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -339,6 +339,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings"); std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations"); std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
bool fuse_relu = ctx.Attr<bool>("fuse_relu"); bool fuse_relu = ctx.Attr<bool>("fuse_relu");
bool force_fp32_output = ctx.Attr<bool>("force_fp32_output");
bool fuse_residual_conn = ctx.Attr<bool>("fuse_residual_connection"); bool fuse_residual_conn = ctx.Attr<bool>("fuse_residual_connection");
int groups = ctx.Attr<int>("groups"); int groups = ctx.Attr<int>("groups");
...@@ -519,6 +520,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -519,6 +520,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
if(dst_dt != residual_dt) if(dst_dt != residual_dt)
dst_dt = residual_dt; dst_dt = residual_dt;
} }
if(force_fp32_output)
dst_dt = fuse_relu? paddle::framework::ToMKLDNNDataType(std::type_index(typeid(float)));
dst_md.reset(new mkldnn::memory::desc(platform::MKLDNNMemDesc(dst_tz, dst_dt, chosen_memory_format))); dst_md.reset(new mkldnn::memory::desc(platform::MKLDNNMemDesc(dst_tz, dst_dt, chosen_memory_format)));
mds[2] = src_md; mds[2] = src_md;
mds[3] = weights_md; mds[3] = weights_md;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册