提交 7d7b1ec6 编写于 作者: X xiaolil1

modify pool to enable jit

上级 d99db05a
......@@ -141,9 +141,9 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
*/
auto dst_md = platform::MKLDNNMemDesc(dst_tz, dt,
mkldnn::memory::format::any);
auto propagation = src_md.data.data_type == mkldnn_f32 ? mkldnn::prop_kind::forward_training : mkldnn::prop_kind::forward_scoring;
std::shared_ptr<mkldnn::pooling_forward::primitive_desc> pool_pd =
CreatePrimitiveDesc(src_md, dst_md, strides, padding_left_top,
CreatePrimitiveDesc(src_md, dst_md, propagation, strides, padding_left_top,
padding_right_bottom, ksize, pooling_type,
mkldnn_engine, ceil_mode);
......@@ -164,9 +164,15 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
dev_ctx.SetBlob(key_pool_src_mem_p, src_memory);
dev_ctx.SetBlob(key_pool_dst_mem_p, dst_memory);
pool_p = std::make_shared<pooling_forward>(*pool_pd, *(src_memory.get()),
if (propagation == mkldnn::prop_kind::forward_training) {
pool_p = std::make_shared<pooling_forward>(*pool_pd, *(src_memory.get()),
*(dst_memory.get()),
*workspace_memory);
} else{
pool_p = std::make_shared<pooling_forward>(*pool_pd, *(src_memory.get()),
*(dst_memory.get()));//,
//*workspace_memory);
}
dev_ctx.SetBlob(key_pool_p, pool_p);
......@@ -201,12 +207,13 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
private:
std::unique_ptr<mkldnn::pooling_forward::primitive_desc> CreatePrimitiveDesc(
const mkldnn::memory::desc& src, const mkldnn::memory::desc& dst,
const mkldnn::prop_kind& propagation,
const std::vector<int>& stride, const std::vector<int>& padding_left_top,
const std::vector<int>& padding_right_bot, const std::vector<int>& kernel,
const std::string& pooling_type, const mkldnn::engine& engine,
bool ceil_mode) const {
auto pool_desc = mkldnn::pooling_forward::desc(
mkldnn::prop_kind::forward,
propagation,
pooling_type == "max" ? mkldnn::algorithm::pooling_max
: mkldnn::algorithm::pooling_avg,
src, dst, stride, kernel, padding_left_top, padding_right_bot,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册