提交 8f051b36 编写于 作者: X xiaoli.liu@intel.com

Enable INT8 pool OP

test=develop
上级 aba1f9b0
...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/data_layout_transform.h"
#include "paddle/fluid/operators/pool_op.h" #include "paddle/fluid/operators/pool_op.h"
#include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/platform/mkldnn_helper.h"
...@@ -71,7 +72,6 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -71,7 +72,6 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
void Compute(const paddle::framework::ExecutionContext& ctx) const override { void Compute(const paddle::framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()), PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
"It must use CPUPlace."); "It must use CPUPlace.");
auto& dev_ctx = auto& dev_ctx =
ctx.template device_context<platform::MKLDNNDeviceContext>(); ctx.template device_context<platform::MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine(); const auto& mkldnn_engine = dev_ctx.GetEngine();
...@@ -130,20 +130,25 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -130,20 +130,25 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
CorrectOutputSize(src_tz, dst_tz, ksize, paddings, strides, CorrectOutputSize(src_tz, dst_tz, ksize, paddings, strides,
padding_right_bottom); padding_right_bottom);
} }
auto src_md = platform::MKLDNNMemDesc(
src_tz, platform::MKLDNNGetDataType<T>(), input_format); mkldnn::memory::data_type dt =
paddle::framework::ToMKLDNNDataType(input->type());
auto src_md = platform::MKLDNNMemDesc(src_tz, dt, input_format);
/* create memory descriptor for pooling without specified format /* create memory descriptor for pooling without specified format
* ('any') which lets a primitive (pooling in this case) choose * ('any') which lets a primitive (pooling in this case) choose
* the memory format preferred for best performance * the memory format preferred for best performance
*/ */
auto dst_md = platform::MKLDNNMemDesc(dst_tz, mkldnn::memory::f32, auto dst_md =
mkldnn::memory::format::any); platform::MKLDNNMemDesc(dst_tz, dt, mkldnn::memory::format::any);
auto propagation = src_md.data.data_type == mkldnn_f32
? mkldnn::prop_kind::forward_training
: mkldnn::prop_kind::forward_scoring;
std::shared_ptr<mkldnn::pooling_forward::primitive_desc> pool_pd = std::shared_ptr<mkldnn::pooling_forward::primitive_desc> pool_pd =
CreatePrimitiveDesc(src_md, dst_md, strides, padding_left_top, CreatePrimitiveDesc(src_md, dst_md, propagation, strides,
padding_right_bottom, ksize, pooling_type, padding_left_top, padding_right_bottom, ksize,
mkldnn_engine, ceil_mode, is_test); pooling_type, mkldnn_engine, ceil_mode, is_test);
// save pool_pd into global device context to be referred in backward path // save pool_pd into global device context to be referred in backward path
if (!is_test) dev_ctx.SetBlob(key_pool_pd, pool_pd); if (!is_test) dev_ctx.SetBlob(key_pool_pd, pool_pd);
...@@ -203,7 +208,8 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -203,7 +208,8 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
private: private:
std::unique_ptr<mkldnn::pooling_forward::primitive_desc> CreatePrimitiveDesc( std::unique_ptr<mkldnn::pooling_forward::primitive_desc> CreatePrimitiveDesc(
const mkldnn::memory::desc& src, const mkldnn::memory::desc& dst, const mkldnn::memory::desc& src, const mkldnn::memory::desc& dst,
const std::vector<int>& stride, const std::vector<int>& padding_left_top, const mkldnn::prop_kind& propagation, const std::vector<int>& stride,
const std::vector<int>& padding_left_top,
const std::vector<int>& padding_right_bot, const std::vector<int>& kernel, const std::vector<int>& padding_right_bot, const std::vector<int>& kernel,
const std::string& pooling_type, const mkldnn::engine& engine, const std::string& pooling_type, const mkldnn::engine& engine,
bool ceil_mode, bool is_test) const { bool ceil_mode, bool is_test) const {
...@@ -411,6 +417,9 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -411,6 +417,9 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_KERNEL(pool2d, MKLDNN, ::paddle::platform::CPUPlace, REGISTER_OP_KERNEL(pool2d, MKLDNN, ::paddle::platform::CPUPlace,
ops::PoolMKLDNNOpKernel<float>); ops::PoolMKLDNNOpKernel<float>,
ops::PoolMKLDNNOpKernel<int8_t>,
ops::PoolMKLDNNOpKernel<uint8_t>);
REGISTER_OP_KERNEL(pool2d_grad, MKLDNN, ::paddle::platform::CPUPlace, REGISTER_OP_KERNEL(pool2d_grad, MKLDNN, ::paddle::platform::CPUPlace,
ops::PoolMKLDNNGradOpKernel<float>); ops::PoolMKLDNNGradOpKernel<float>);
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from __future__ import division
import unittest
import numpy as np
import paddle.fluid.core as core
from op_test import OpTest
def adaptive_start_index(index, input_size, output_size):
return int(np.floor(index * input_size / output_size))
def adaptive_end_index(index, input_size, output_size):
return int(np.ceil((index + 1) * input_size / output_size))
def max_pool2D_forward_naive(x,
ksize,
strides,
paddings,
global_pool=0,
ceil_mode=False,
exclusive=True,
adaptive=False):
N, C, H, W = x.shape
if global_pool == 1:
ksize = [H, W]
if adaptive:
H_out, W_out = ksize
else:
H_out = (H - ksize[0] + 2 * paddings[0] + strides[0] - 1
) // strides[0] + 1 if ceil_mode else (
H - ksize[0] + 2 * paddings[0]) // strides[0] + 1
W_out = (W - ksize[1] + 2 * paddings[1] + strides[1] - 1
) // strides[1] + 1 if ceil_mode else (
W - ksize[1] + 2 * paddings[1]) // strides[1] + 1
out = np.zeros((N, C, H_out, W_out))
for i in range(H_out):
for j in range(W_out):
if adaptive:
r_start = adaptive_start_index(i, H, ksize[0])
r_end = adaptive_end_index(i, H, ksize[0])
c_start = adaptive_start_index(j, W, ksize[1])
c_end = adaptive_end_index(j, W, ksize[1])
else:
r_start = np.max((i * strides[0] - paddings[0], 0))
r_end = np.min((i * strides[0] + ksize[0] - paddings[0], H))
c_start = np.max((j * strides[1] - paddings[1], 0))
c_end = np.min((j * strides[1] + ksize[1] - paddings[1], W))
x_masked = x[:, :, r_start:r_end, c_start:c_end]
out[:, :, i, j] = np.max(x_masked, axis=(2, 3))
return out
def avg_pool2D_forward_naive(x,
ksize,
strides,
paddings,
global_pool=0,
ceil_mode=False,
exclusive=True,
adaptive=False):
N, C, H, W = x.shape
if global_pool == 1:
ksize = [H, W]
if adaptive:
H_out, W_out = ksize
else:
H_out = (H - ksize[0] + 2 * paddings[0] + strides[0] - 1
) // strides[0] + 1 if ceil_mode else (
H - ksize[0] + 2 * paddings[0]) // strides[0] + 1
W_out = (W - ksize[1] + 2 * paddings[1] + strides[1] - 1
) // strides[1] + 1 if ceil_mode else (
W - ksize[1] + 2 * paddings[1]) // strides[1] + 1
out = np.zeros((N, C, H_out, W_out))
for i in range(H_out):
for j in range(W_out):
if adaptive:
r_start = adaptive_start_index(i, H, ksize[0])
r_end = adaptive_end_index(i, H, ksize[0])
c_start = adaptive_start_index(j, W, ksize[1])
c_end = adaptive_end_index(j, W, ksize[1])
else:
r_start = np.max((i * strides[0] - paddings[0], 0))
r_end = np.min((i * strides[0] + ksize[0] - paddings[0], H))
c_start = np.max((j * strides[1] - paddings[1], 0))
c_end = np.min((j * strides[1] + ksize[1] - paddings[1], W))
x_masked = x[:, :, r_start:r_end, c_start:c_end]
field_size = ((r_end - r_start) * (c_end - c_start)) \
if (exclusive or adaptive) else (ksize[0] * ksize[1])
out[:, :, i, j] = np.sum(x_masked, axis=(2, 3)) / field_size
return out
class TestPool2D_Op(OpTest):
def setUp(self):
self.op_type = "pool2d"
self.use_cudnn = False
self.use_mkldnn = True
self.dtype = np.int8
self.init_test_case()
self.init_global_pool()
self.init_pool_type()
self.init_ceil_mode()
self.init_exclusive()
self.init_adaptive()
if self.global_pool:
self.paddings = [0 for _ in range(len(self.paddings))]
input = np.random.random(self.shape).astype(self.dtype)
output = self.pool2D_forward_naive(
input, self.ksize, self.strides, self.paddings, self.global_pool,
self.ceil_mode, self.exclusive, self.adaptive).astype(self.dtype)
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(input)}
self.attrs = {
'strides': self.strides,
'paddings': self.paddings,
'ksize': self.ksize,
'pooling_type': self.pool_type,
'global_pooling': self.global_pool,
'use_cudnn': self.use_cudnn,
'use_mkldnn': self.use_mkldnn,
'ceil_mode': self.ceil_mode,
'data_format':
'AnyLayout', # TODO(dzhwinter) : should be fix latter
'exclusive': self.exclusive,
'adaptive': self.adaptive
}
self.outputs = {'Out': output}
def test_check_output(self):
self.check_output()
def init_test_case(self):
self.shape = [2, 3, 5, 5]
self.ksize = [3, 3]
self.strides = [1, 1]
self.paddings = [0, 0]
self.dtype = np.int8
def init_pool_type(self):
self.pool_type = "avg"
self.pool2D_forward_naive = avg_pool2D_forward_naive
def init_global_pool(self):
self.global_pool = True
def init_ceil_mode(self):
self.ceil_mode = False
def init_exclusive(self):
self.exclusive = True
def init_adaptive(self):
self.adaptive = False
class TestCase1(TestPool2D_Op):
def init_test_case(self):
self.shape = [2, 3, 7, 7]
self.ksize = [3, 3]
self.strides = [1, 1]
self.paddings = [0, 0]
self.dtype = np.int8
def init_pool_type(self):
self.pool_type = "avg"
self.pool2D_forward_naive = avg_pool2D_forward_naive
def init_global_pool(self):
self.global_pool = False
class TestCase2(TestPool2D_Op):
def init_test_case(self):
self.shape = [2, 3, 7, 7]
self.ksize = [3, 3]
self.strides = [1, 1]
self.paddings = [1, 1]
self.dtype = np.uint8
def init_pool_type(self):
self.pool_type = "avg"
self.pool2D_forward_naive = avg_pool2D_forward_naive
def init_global_pool(self):
self.global_pool = False
class TestCase3(TestPool2D_Op):
def init_test_case(self):
self.shape = [2, 3, 7, 7]
self.ksize = [3, 3]
self.strides = [1, 1]
self.paddings = [0, 0]
self.dtype = np.int8
def init_pool_type(self):
self.pool_type = "max"
self.pool2D_forward_naive = max_pool2D_forward_naive
class TestCase4(TestCase1):
def init_test_case(self):
self.shape = [2, 3, 7, 7]
self.ksize = [3, 3]
self.strides = [1, 1]
self.paddings = [1, 1]
self.dtype = np.uint8
def init_pool_type(self):
self.pool_type = "max"
self.pool2D_forward_naive = max_pool2D_forward_naive
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册