提交 59da41e8 编写于 作者: X xiaolil1

fix weight fmt bug for dwc conv (mobilenet)

上级 59a90d08
...@@ -511,8 +511,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -511,8 +511,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
src_md.reset(new mkldnn::memory::desc(platform::MKLDNNMemDesc( src_md.reset(new mkldnn::memory::desc(platform::MKLDNNMemDesc(
src_tz, memory::data_type::u8, chosen_memory_format))); src_tz, memory::data_type::u8, chosen_memory_format)));
weights_md.reset(new mkldnn::memory::desc(platform::MKLDNNMemDesc( weights_md.reset(new mkldnn::memory::desc(platform::MKLDNNMemDesc(
weights_tz, memory::data_type::s8, weights_tz, memory::data_type::s8, chosen_memory_format)));
(g == 1) ? chosen_memory_format : mkldnn::memory::format::goihw)));
auto dst_dt = fuse_relu? paddle::framework::ToMKLDNNDataType(std::type_index(typeid(unsigned char))) : paddle::framework::ToMKLDNNDataType(std::type_index(typeid(signed char))); auto dst_dt = fuse_relu? paddle::framework::ToMKLDNNDataType(std::type_index(typeid(unsigned char))) : paddle::framework::ToMKLDNNDataType(std::type_index(typeid(signed char)));
if(fuse_residual_conn){ if(fuse_residual_conn){
auto residual = ctx.Input<Tensor>("ResidualData"); auto residual = ctx.Input<Tensor>("ResidualData");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册