未验证 提交 16cb3ebd 编写于 作者: X Xin Pan 提交者: GitHub

Merge pull request #15268 from xiaolil1/pool-int8

Enhance key generation for Pool INT8 test
......@@ -35,6 +35,7 @@ static std::string gethash(const memory::dims& input_dims,
const std::vector<int>& ksize,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const memory::data_type& dt,
const std::string& suffix) {
auto dims2str = [](const memory::dims& operand_dims) {
std::string dstr = "";
......@@ -44,7 +45,7 @@ static std::string gethash(const memory::dims& input_dims,
return dstr;
};
return dims2str(input_dims) + dims2str(ksize) + dims2str(strides) +
dims2str(paddings) + pooling_type + suffix;
dims2str(paddings) + std::to_string(dt) + pooling_type + suffix;
}
static inline int ComputeCeiledOutput(int input_size, int kernel_size,
......@@ -111,8 +112,10 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto input_format = input->format();
memory::format output_format{memory::format::format_undef};
mkldnn::memory::data_type dt =
paddle::framework::ToMKLDNNDataType(input->type());
const std::string key = gethash(src_tz, pooling_type, ksize, strides,
paddings, ctx.op().Output("Out"));
paddings, dt, ctx.op().Output("Out"));
const std::string key_pool_p = key + "@pool_p";
const std::string key_pool_pd = key + "@pool_pd";
const std::string key_pool_src_mem_p = key + "@pool_src_mem_p";
......@@ -131,9 +134,6 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
padding_right_bottom);
}
mkldnn::memory::data_type dt =
paddle::framework::ToMKLDNNDataType(input->type());
auto src_md = platform::MKLDNNMemDesc(src_tz, dt, input_format);
/* create memory descriptor for pooling without specified format
......@@ -293,8 +293,9 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
// Get an unique name from "argument" name of "Out" variable
// This name will be used as key when referring info from device context
const std::string key = gethash(diff_src_tz, pooling_type, ksize, strides,
paddings, ctx.op().Input("Out"));
const std::string key =
gethash(diff_src_tz, pooling_type, ksize, strides, paddings,
memory::data_type::f32, ctx.op().Input("Out"));
const std::string key_pool_bwd_p = key + "@pool_bwd_p";
const std::string key_pool_diff_src_mem_p = key + "@pool_diff_src_mem_p";
const std::string key_pool_diff_dst_mem_p = key + "@pool_diff_dst_mem_p";
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册