未验证 提交 16cb3ebd 编写于 作者: X Xin Pan 提交者: GitHub

Merge pull request #15268 from xiaolil1/pool-int8

Enhance key generation for Pool INT8 test
...@@ -35,6 +35,7 @@ static std::string gethash(const memory::dims& input_dims, ...@@ -35,6 +35,7 @@ static std::string gethash(const memory::dims& input_dims,
const std::vector<int>& ksize, const std::vector<int>& ksize,
const std::vector<int>& strides, const std::vector<int>& strides,
const std::vector<int>& paddings, const std::vector<int>& paddings,
const memory::data_type& dt,
const std::string& suffix) { const std::string& suffix) {
auto dims2str = [](const memory::dims& operand_dims) { auto dims2str = [](const memory::dims& operand_dims) {
std::string dstr = ""; std::string dstr = "";
...@@ -44,7 +45,7 @@ static std::string gethash(const memory::dims& input_dims, ...@@ -44,7 +45,7 @@ static std::string gethash(const memory::dims& input_dims,
return dstr; return dstr;
}; };
return dims2str(input_dims) + dims2str(ksize) + dims2str(strides) + return dims2str(input_dims) + dims2str(ksize) + dims2str(strides) +
dims2str(paddings) + pooling_type + suffix; dims2str(paddings) + std::to_string(dt) + pooling_type + suffix;
} }
static inline int ComputeCeiledOutput(int input_size, int kernel_size, static inline int ComputeCeiledOutput(int input_size, int kernel_size,
...@@ -111,8 +112,10 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -111,8 +112,10 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto input_format = input->format(); auto input_format = input->format();
memory::format output_format{memory::format::format_undef}; memory::format output_format{memory::format::format_undef};
mkldnn::memory::data_type dt =
paddle::framework::ToMKLDNNDataType(input->type());
const std::string key = gethash(src_tz, pooling_type, ksize, strides, const std::string key = gethash(src_tz, pooling_type, ksize, strides,
paddings, ctx.op().Output("Out")); paddings, dt, ctx.op().Output("Out"));
const std::string key_pool_p = key + "@pool_p"; const std::string key_pool_p = key + "@pool_p";
const std::string key_pool_pd = key + "@pool_pd"; const std::string key_pool_pd = key + "@pool_pd";
const std::string key_pool_src_mem_p = key + "@pool_src_mem_p"; const std::string key_pool_src_mem_p = key + "@pool_src_mem_p";
...@@ -131,9 +134,6 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -131,9 +134,6 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
padding_right_bottom); padding_right_bottom);
} }
mkldnn::memory::data_type dt =
paddle::framework::ToMKLDNNDataType(input->type());
auto src_md = platform::MKLDNNMemDesc(src_tz, dt, input_format); auto src_md = platform::MKLDNNMemDesc(src_tz, dt, input_format);
/* create memory descriptor for pooling without specified format /* create memory descriptor for pooling without specified format
...@@ -293,8 +293,9 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -293,8 +293,9 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
// Get an unique name from "argument" name of "Out" variable // Get an unique name from "argument" name of "Out" variable
// This name will be used as key when referring info from device context // This name will be used as key when referring info from device context
const std::string key = gethash(diff_src_tz, pooling_type, ksize, strides, const std::string key =
paddings, ctx.op().Input("Out")); gethash(diff_src_tz, pooling_type, ksize, strides, paddings,
memory::data_type::f32, ctx.op().Input("Out"));
const std::string key_pool_bwd_p = key + "@pool_bwd_p"; const std::string key_pool_bwd_p = key + "@pool_bwd_p";
const std::string key_pool_diff_src_mem_p = key + "@pool_diff_src_mem_p"; const std::string key_pool_diff_src_mem_p = key + "@pool_diff_src_mem_p";
const std::string key_pool_diff_dst_mem_p = key + "@pool_diff_dst_mem_p"; const std::string key_pool_diff_dst_mem_p = key + "@pool_diff_dst_mem_p";
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册