diff --git a/paddle/fluid/operators/pool_mkldnn_op.cc b/paddle/fluid/operators/pool_mkldnn_op.cc index f6f40b1daf4b6e5502190aaaab6b976fc960bcda..f4bad7b712b2b078ed68f0a3d0e751d9ae2d6191 100644 --- a/paddle/fluid/operators/pool_mkldnn_op.cc +++ b/paddle/fluid/operators/pool_mkldnn_op.cc @@ -35,6 +35,7 @@ static std::string gethash(const memory::dims& input_dims, const std::vector& ksize, const std::vector& strides, const std::vector& paddings, + const memory::data_type& dt, const std::string& suffix) { auto dims2str = [](const memory::dims& operand_dims) { std::string dstr = ""; @@ -44,7 +45,7 @@ static std::string gethash(const memory::dims& input_dims, return dstr; }; return dims2str(input_dims) + dims2str(ksize) + dims2str(strides) + - dims2str(paddings) + pooling_type + suffix; + dims2str(paddings) + std::to_string(dt) + pooling_type + suffix; } static inline int ComputeCeiledOutput(int input_size, int kernel_size, @@ -111,8 +112,10 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel { auto input_format = input->format(); memory::format output_format{memory::format::format_undef}; + mkldnn::memory::data_type dt = + paddle::framework::ToMKLDNNDataType(input->type()); const std::string key = gethash(src_tz, pooling_type, ksize, strides, - paddings, ctx.op().Output("Out")); + paddings, dt, ctx.op().Output("Out")); const std::string key_pool_p = key + "@pool_p"; const std::string key_pool_pd = key + "@pool_pd"; const std::string key_pool_src_mem_p = key + "@pool_src_mem_p"; @@ -131,9 +134,6 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel { padding_right_bottom); } - mkldnn::memory::data_type dt = - paddle::framework::ToMKLDNNDataType(input->type()); - auto src_md = platform::MKLDNNMemDesc(src_tz, dt, input_format); /* create memory descriptor for pooling without specified format @@ -293,8 +293,9 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel { // Get an unique name from "argument" name of "Out" variable // This name will be used as key when referring info from device context - const std::string key = gethash(diff_src_tz, pooling_type, ksize, strides, - paddings, ctx.op().Input("Out")); + const std::string key = + gethash(diff_src_tz, pooling_type, ksize, strides, paddings, + memory::data_type::f32, ctx.op().Input("Out")); const std::string key_pool_bwd_p = key + "@pool_bwd_p"; const std::string key_pool_diff_src_mem_p = key + "@pool_diff_src_mem_p"; const std::string key_pool_diff_dst_mem_p = key + "@pool_diff_dst_mem_p";