diff --git a/paddle/fluid/operators/pool_mkldnn_op.cc b/paddle/fluid/operators/pool_mkldnn_op.cc index 0a9a29956affedb8605ab9949070943fbbb54145..f6f40b1daf4b6e5502190aaaab6b976fc960bcda 100644 --- a/paddle/fluid/operators/pool_mkldnn_op.cc +++ b/paddle/fluid/operators/pool_mkldnn_op.cc @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "paddle/fluid/framework/data_layout_transform.h" #include "paddle/fluid/operators/pool_op.h" #include "paddle/fluid/platform/mkldnn_helper.h" @@ -71,7 +72,6 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel { void Compute(const paddle::framework::ExecutionContext& ctx) const override { PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()), "It must use CPUPlace."); - auto& dev_ctx = ctx.template device_context(); const auto& mkldnn_engine = dev_ctx.GetEngine(); @@ -130,20 +130,25 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel { CorrectOutputSize(src_tz, dst_tz, ksize, paddings, strides, padding_right_bottom); } - auto src_md = platform::MKLDNNMemDesc( - src_tz, platform::MKLDNNGetDataType(), input_format); + + mkldnn::memory::data_type dt = + paddle::framework::ToMKLDNNDataType(input->type()); + + auto src_md = platform::MKLDNNMemDesc(src_tz, dt, input_format); /* create memory descriptor for pooling without specified format * ('any') which lets a primitive (pooling in this case) choose * the memory format preferred for best performance */ - auto dst_md = platform::MKLDNNMemDesc(dst_tz, mkldnn::memory::f32, - mkldnn::memory::format::any); - + auto dst_md = + platform::MKLDNNMemDesc(dst_tz, dt, mkldnn::memory::format::any); + auto propagation = src_md.data.data_type == mkldnn_f32 + ? mkldnn::prop_kind::forward_training + : mkldnn::prop_kind::forward_scoring; std::shared_ptr pool_pd = - CreatePrimitiveDesc(src_md, dst_md, strides, padding_left_top, - padding_right_bottom, ksize, pooling_type, - mkldnn_engine, ceil_mode, is_test); + CreatePrimitiveDesc(src_md, dst_md, propagation, strides, + padding_left_top, padding_right_bottom, ksize, + pooling_type, mkldnn_engine, ceil_mode, is_test); // save pool_pd into global device context to be referred in backward path if (!is_test) dev_ctx.SetBlob(key_pool_pd, pool_pd); @@ -203,7 +208,8 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel { private: std::unique_ptr CreatePrimitiveDesc( const mkldnn::memory::desc& src, const mkldnn::memory::desc& dst, - const std::vector& stride, const std::vector& padding_left_top, + const mkldnn::prop_kind& propagation, const std::vector& stride, + const std::vector& padding_left_top, const std::vector& padding_right_bot, const std::vector& kernel, const std::string& pooling_type, const mkldnn::engine& engine, bool ceil_mode, bool is_test) const { @@ -411,6 +417,9 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel { namespace ops = paddle::operators; REGISTER_OP_KERNEL(pool2d, MKLDNN, ::paddle::platform::CPUPlace, - ops::PoolMKLDNNOpKernel); + ops::PoolMKLDNNOpKernel, + ops::PoolMKLDNNOpKernel, + ops::PoolMKLDNNOpKernel); + REGISTER_OP_KERNEL(pool2d_grad, MKLDNN, ::paddle::platform::CPUPlace, ops::PoolMKLDNNGradOpKernel); diff --git a/python/paddle/fluid/tests/unittests/test_pool2d_int8_mkldnn_op.py b/python/paddle/fluid/tests/unittests/test_pool2d_int8_mkldnn_op.py new file mode 100644 index 0000000000000000000000000000000000000000..954d9993b29622f6a4806c8c8dd45f21abafc4fa --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_pool2d_int8_mkldnn_op.py @@ -0,0 +1,236 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +from __future__ import division + +import unittest +import numpy as np + +import paddle.fluid.core as core +from op_test import OpTest + + +def adaptive_start_index(index, input_size, output_size): + return int(np.floor(index * input_size / output_size)) + + +def adaptive_end_index(index, input_size, output_size): + return int(np.ceil((index + 1) * input_size / output_size)) + + +def max_pool2D_forward_naive(x, + ksize, + strides, + paddings, + global_pool=0, + ceil_mode=False, + exclusive=True, + adaptive=False): + N, C, H, W = x.shape + if global_pool == 1: + ksize = [H, W] + if adaptive: + H_out, W_out = ksize + else: + H_out = (H - ksize[0] + 2 * paddings[0] + strides[0] - 1 + ) // strides[0] + 1 if ceil_mode else ( + H - ksize[0] + 2 * paddings[0]) // strides[0] + 1 + W_out = (W - ksize[1] + 2 * paddings[1] + strides[1] - 1 + ) // strides[1] + 1 if ceil_mode else ( + W - ksize[1] + 2 * paddings[1]) // strides[1] + 1 + out = np.zeros((N, C, H_out, W_out)) + for i in range(H_out): + for j in range(W_out): + if adaptive: + r_start = adaptive_start_index(i, H, ksize[0]) + r_end = adaptive_end_index(i, H, ksize[0]) + c_start = adaptive_start_index(j, W, ksize[1]) + c_end = adaptive_end_index(j, W, ksize[1]) + else: + r_start = np.max((i * strides[0] - paddings[0], 0)) + r_end = np.min((i * strides[0] + ksize[0] - paddings[0], H)) + c_start = np.max((j * strides[1] - paddings[1], 0)) + c_end = np.min((j * strides[1] + ksize[1] - paddings[1], W)) + x_masked = x[:, :, r_start:r_end, c_start:c_end] + + out[:, :, i, j] = np.max(x_masked, axis=(2, 3)) + return out + + +def avg_pool2D_forward_naive(x, + ksize, + strides, + paddings, + global_pool=0, + ceil_mode=False, + exclusive=True, + adaptive=False): + N, C, H, W = x.shape + if global_pool == 1: + ksize = [H, W] + if adaptive: + H_out, W_out = ksize + else: + H_out = (H - ksize[0] + 2 * paddings[0] + strides[0] - 1 + ) // strides[0] + 1 if ceil_mode else ( + H - ksize[0] + 2 * paddings[0]) // strides[0] + 1 + W_out = (W - ksize[1] + 2 * paddings[1] + strides[1] - 1 + ) // strides[1] + 1 if ceil_mode else ( + W - ksize[1] + 2 * paddings[1]) // strides[1] + 1 + out = np.zeros((N, C, H_out, W_out)) + for i in range(H_out): + for j in range(W_out): + if adaptive: + r_start = adaptive_start_index(i, H, ksize[0]) + r_end = adaptive_end_index(i, H, ksize[0]) + c_start = adaptive_start_index(j, W, ksize[1]) + c_end = adaptive_end_index(j, W, ksize[1]) + else: + r_start = np.max((i * strides[0] - paddings[0], 0)) + r_end = np.min((i * strides[0] + ksize[0] - paddings[0], H)) + c_start = np.max((j * strides[1] - paddings[1], 0)) + c_end = np.min((j * strides[1] + ksize[1] - paddings[1], W)) + x_masked = x[:, :, r_start:r_end, c_start:c_end] + + field_size = ((r_end - r_start) * (c_end - c_start)) \ + if (exclusive or adaptive) else (ksize[0] * ksize[1]) + out[:, :, i, j] = np.sum(x_masked, axis=(2, 3)) / field_size + return out + + +class TestPool2D_Op(OpTest): + def setUp(self): + self.op_type = "pool2d" + self.use_cudnn = False + self.use_mkldnn = True + self.dtype = np.int8 + self.init_test_case() + self.init_global_pool() + self.init_pool_type() + self.init_ceil_mode() + self.init_exclusive() + self.init_adaptive() + if self.global_pool: + self.paddings = [0 for _ in range(len(self.paddings))] + input = np.random.random(self.shape).astype(self.dtype) + output = self.pool2D_forward_naive( + input, self.ksize, self.strides, self.paddings, self.global_pool, + self.ceil_mode, self.exclusive, self.adaptive).astype(self.dtype) + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(input)} + + self.attrs = { + 'strides': self.strides, + 'paddings': self.paddings, + 'ksize': self.ksize, + 'pooling_type': self.pool_type, + 'global_pooling': self.global_pool, + 'use_cudnn': self.use_cudnn, + 'use_mkldnn': self.use_mkldnn, + 'ceil_mode': self.ceil_mode, + 'data_format': + 'AnyLayout', # TODO(dzhwinter) : should be fix latter + 'exclusive': self.exclusive, + 'adaptive': self.adaptive + } + + self.outputs = {'Out': output} + + def test_check_output(self): + self.check_output() + + def init_test_case(self): + self.shape = [2, 3, 5, 5] + self.ksize = [3, 3] + self.strides = [1, 1] + self.paddings = [0, 0] + self.dtype = np.int8 + + def init_pool_type(self): + self.pool_type = "avg" + self.pool2D_forward_naive = avg_pool2D_forward_naive + + def init_global_pool(self): + self.global_pool = True + + def init_ceil_mode(self): + self.ceil_mode = False + + def init_exclusive(self): + self.exclusive = True + + def init_adaptive(self): + self.adaptive = False + + +class TestCase1(TestPool2D_Op): + def init_test_case(self): + self.shape = [2, 3, 7, 7] + self.ksize = [3, 3] + self.strides = [1, 1] + self.paddings = [0, 0] + self.dtype = np.int8 + + def init_pool_type(self): + self.pool_type = "avg" + self.pool2D_forward_naive = avg_pool2D_forward_naive + + def init_global_pool(self): + self.global_pool = False + + +class TestCase2(TestPool2D_Op): + def init_test_case(self): + self.shape = [2, 3, 7, 7] + self.ksize = [3, 3] + self.strides = [1, 1] + self.paddings = [1, 1] + self.dtype = np.uint8 + + def init_pool_type(self): + self.pool_type = "avg" + self.pool2D_forward_naive = avg_pool2D_forward_naive + + def init_global_pool(self): + self.global_pool = False + + +class TestCase3(TestPool2D_Op): + def init_test_case(self): + self.shape = [2, 3, 7, 7] + self.ksize = [3, 3] + self.strides = [1, 1] + self.paddings = [0, 0] + self.dtype = np.int8 + + def init_pool_type(self): + self.pool_type = "max" + self.pool2D_forward_naive = max_pool2D_forward_naive + + +class TestCase4(TestCase1): + def init_test_case(self): + self.shape = [2, 3, 7, 7] + self.ksize = [3, 3] + self.strides = [1, 1] + self.paddings = [1, 1] + self.dtype = np.uint8 + + def init_pool_type(self): + self.pool_type = "max" + self.pool2D_forward_naive = max_pool2D_forward_naive + + +if __name__ == '__main__': + unittest.main()