提交 8f17c714 编写于 作者: X xiaolil1 提交者: Tao Luo

Conv int8 residual (#15145)

* Enable basic MKL-DNN INT8 Conv OP
test=develop

* Modify test case
test=develop

* Clean unittest code
test=develop

* Fix test
test=develop

* Modify test
test=develop

* Enable MKL-DNN INT8 Conv with Relu Fusion OP
test=develop

* Enable INT8 Conv with residual fusion OP
test=develop

* Modify code.
test=develop

* Modify basic INT8 Conv
test=develop

* Modify Conv.
test=develop

* fix style
test=develop

* Fix style
test=develop

* Fix test
test=develop

* Modify code.
test=develop

* Fix test
test=develop
上级 93d5c1ed
...@@ -318,10 +318,14 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -318,10 +318,14 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings"); std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations"); std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
int groups = ctx.Attr<int>("groups"); int groups = ctx.Attr<int>("groups");
bool fuse_relu = ctx.Attr<bool>("fuse_relu"); bool fuse_relu = ctx.Attr<bool>("fuse_relu");
bool fuse_residual_conn = ctx.Attr<bool>("fuse_residual_connection");
bool force_fp32_output = ctx.Attr<bool>("force_fp32_output"); bool force_fp32_output = ctx.Attr<bool>("force_fp32_output");
if (fuse_residual_conn) {
PADDLE_ENFORCE(force_fp32_output != true,
"residual fusion does not support force output with fp32");
}
bool is_conv3d = strides.size() == 3U; bool is_conv3d = strides.size() == 3U;
// TODO(tpatejko): add support for dilation // TODO(tpatejko): add support for dilation
...@@ -355,14 +359,23 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -355,14 +359,23 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
framework::DataTypeTrait<float>::DataType); framework::DataTypeTrait<float>::DataType);
} }
if (fuse_residual_conn) {
auto residual = ctx.Input<Tensor>("ResidualData");
auto residual_dt = paddle::framework::ToMKLDNNDataType(residual->type());
if (dst_dt != residual_dt) dst_dt = residual_dt;
}
// Get unique name for storing MKLDNN primitives // Get unique name for storing MKLDNN primitives
std::string key; std::string key;
key.reserve(MaxKeyLength); key.reserve(MaxKeyLength);
platform::ConvMKLDNNHandler::AppendKey( platform::ConvMKLDNNHandler::AppendKey(
&key, src_tz, weights_tz, strides, paddings, dilations, groups, src_dt, &key, src_tz, weights_tz, strides, paddings, dilations, groups, src_dt,
input->format(), dst_dt, ctx.op().Output("Output")); input->format(), fuse_relu, fuse_residual_conn,
ctx.op().Output("Output"));
const std::string key_conv_pd = key + "@conv_pd"; const std::string key_conv_pd = key + "@conv_pd";
bool need_s8_to_u8 = false;
std::shared_ptr<mkldnn::convolution_forward> conv_p = nullptr; std::shared_ptr<mkldnn::convolution_forward> conv_p = nullptr;
std::shared_ptr<mkldnn::memory> src_memory_p = nullptr; std::shared_ptr<mkldnn::memory> src_memory_p = nullptr;
std::shared_ptr<mkldnn::memory> user_src_memory_p = nullptr; std::shared_ptr<mkldnn::memory> user_src_memory_p = nullptr;
...@@ -377,14 +390,20 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -377,14 +390,20 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto src_key = key + "@src_mem_p"; auto src_key = key + "@src_mem_p";
auto user_src_key = key + "@user_src_mem_p"; auto user_src_key = key + "@user_src_mem_p";
auto src_reorder_key = key + "@src_mem_preorder_p"; auto src_reorder_key = key + "@src_mem_preorder_p";
auto residual_reorder_key = key + "@residual_data_mem_preorder_p";
conv_p = std::static_pointer_cast<mkldnn::convolution_forward>( conv_p = std::static_pointer_cast<mkldnn::convolution_forward>(
dev_ctx.GetBlob(prim_key)); dev_ctx.GetBlob(prim_key));
if (conv_p == nullptr || !is_test) { if (conv_p == nullptr || !is_test) {
const K* filter_data = filter->data<K>(); const K* filter_data = filter->data<K>();
auto scale_in_data = ctx.Attr<float>("Scale_in"); auto scale_in_data = ctx.Attr<float>("Scale_in");
auto scale_in_eltwise_data = ctx.Attr<float>("Scale_in_eltwise");
auto scale_weights_data = ctx.Attr<std::vector<float>>("Scale_weights"); auto scale_weights_data = ctx.Attr<std::vector<float>>("Scale_weights");
auto scale_out_data = auto scale_out_data =
force_fp32_output ? 1.0f : ctx.Attr<float>("Scale_out"); force_fp32_output ? 1.0f : ctx.Attr<float>("Scale_out");
float sum_scale =
fuse_residual_conn ? scale_out_data / scale_in_eltwise_data : 1.0f;
bool is_multi_channel = scale_weights_data.size() > 1; bool is_multi_channel = scale_weights_data.size() > 1;
...@@ -427,6 +446,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -427,6 +446,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
weights_tz, memory::data_type::s8, chosen_memory_format); weights_tz, memory::data_type::s8, chosen_memory_format);
auto dst_md = auto dst_md =
platform::MKLDNNMemDesc(dst_tz, dst_dt, chosen_memory_format); platform::MKLDNNMemDesc(dst_tz, dst_dt, chosen_memory_format);
// create a conv primitive descriptor and save it for usage in backward // create a conv primitive descriptor and save it for usage in backward
if (bias) { if (bias) {
bias_tz = paddle::framework::vectorize2int(bias->dims()); bias_tz = paddle::framework::vectorize2int(bias->dims());
...@@ -434,11 +454,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -434,11 +454,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
memory::format::x); memory::format::x);
conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, bias_md, dst_md, conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, bias_md, dst_md,
strides, paddings, mkldnn_engine, strides, paddings, mkldnn_engine,
fuse_relu, output_shift_scale, is_test); fuse_relu, fuse_residual_conn,
output_shift_scale, sum_scale, is_test);
} else { } else {
conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides, conv_pd =
paddings, mkldnn_engine, fuse_relu, ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides, paddings,
output_shift_scale, is_test); mkldnn_engine, fuse_relu, fuse_residual_conn,
output_shift_scale, sum_scale, is_test);
} }
// Save conv_pd/src_memory/weights_memory for backward pass // Save conv_pd/src_memory/weights_memory for backward pass
dev_ctx.SetBlob(key_conv_pd, conv_pd); dev_ctx.SetBlob(key_conv_pd, conv_pd);
...@@ -463,7 +485,41 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -463,7 +485,41 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
user_weights_memory_p, pipeline, is_test, true, scale_weights_data, user_weights_memory_p, pipeline, is_test, true, scale_weights_data,
mask_reorder); mask_reorder);
if (!force_fp32_output) { if (fuse_residual_conn) {
auto residual_param = ctx.Input<Tensor>("ResidualData");
PADDLE_ENFORCE_EQ(output->dims(), residual_param->dims(),
"Output and elementwise parameter need to have the "
"same dimension sizes");
auto residual_dt =
paddle::framework::ToMKLDNNDataType(residual_param->type());
if (residual_param->format() != handler->GetDstFormat()) {
auto residual_data_tz =
paddle::framework::vectorize2int(residual_param->dims());
auto user_residual_md = platform::MKLDNNMemDesc(
residual_data_tz, residual_dt, residual_param->format());
if (residual_dt == mkldnn::memory::data_type::u8) {
dst_memory_p = platform::SetDstMemory<uint8_t>(
ctx, output, residual_param, user_residual_md, handler,
&pipeline);
} else {
need_s8_to_u8 = fuse_relu;
dst_memory_p = platform::SetDstMemory<int8_t>(
ctx, output, residual_param, user_residual_md, handler,
&pipeline);
}
} else {
output->ShareDataWith(*residual_param);
if (residual_dt == mkldnn::memory::data_type::u8) {
dst_memory_p =
platform::SetDstMemory<uint8_t>(ctx, output, handler);
} else {
need_s8_to_u8 = fuse_relu;
dst_memory_p = platform::SetDstMemory<int8_t>(ctx, output, handler);
}
}
} else if (!force_fp32_output) {
if (fuse_relu) { if (fuse_relu) {
dst_memory_p = platform::SetDstMemory<uint8_t>(ctx, output, handler); dst_memory_p = platform::SetDstMemory<uint8_t>(ctx, output, handler);
} else { } else {
...@@ -476,11 +532,11 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -476,11 +532,11 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
// create convolution op primitive // create convolution op primitive
auto scale_bias_key = key + "@scale_bias"; auto scale_bias_key = key + "@scale_bias";
if (bias) { if (bias) {
const float* bias_data = bias->data<float>(); const K* bias_data = bias->data<K>();
auto user_bias_md = platform::MKLDNNMemDesc( auto user_bias_md = platform::MKLDNNMemDesc(
{bias_tz}, platform::MKLDNNGetDataType<float>(), memory::format::x); {bias_tz}, platform::MKLDNNGetDataType<K>(), memory::format::x);
auto user_bias_memory_p = handler->AcquireBiasMemory( auto user_bias_memory_p = handler->AcquireBiasMemory(
user_bias_md, to_void_cast<float>(bias_data)); user_bias_md, to_void_cast<K>(bias_data));
std::shared_ptr<mkldnn::memory> bias_memory_p; std::shared_ptr<mkldnn::memory> bias_memory_p;
int mask_reorder = is_multi_channel ? 1 << 0 : 1; int mask_reorder = is_multi_channel ? 1 << 0 : 1;
int count = int count =
...@@ -526,26 +582,51 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -526,26 +582,51 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
handler.reset(new platform::ConvMKLDNNHandler(conv_pd, dev_ctx, handler.reset(new platform::ConvMKLDNNHandler(conv_pd, dev_ctx,
mkldnn_engine, key)); mkldnn_engine, key));
} }
if (!force_fp32_output) {
if (fuse_residual_conn) {
auto residual_param = ctx.Input<Tensor>("ResidualData");
auto residual_dt =
paddle::framework::ToMKLDNNDataType(residual_param->type());
output->ShareDataWith(*residual_param);
if (residual_dt == mkldnn::memory::data_type::u8) {
platform::SetDstMemoryHandler<uint8_t>(ctx, output, handler,
&dst_memory_p);
} else {
platform::SetDstMemoryHandler<int8_t>(ctx, output, handler,
&dst_memory_p);
}
} else if (!force_fp32_output) {
if (fuse_relu) { if (fuse_relu) {
dst_memory_p = platform::SetDstMemoryHandler<uint8_t>(ctx, output, handler,
platform::SetDstMemoryHandler<uint8_t>(ctx, output, handler); &dst_memory_p);
} else { } else {
dst_memory_p = platform::SetDstMemoryHandler<int8_t>(ctx, output, handler,
platform::SetDstMemoryHandler<int8_t>(ctx, output, handler); &dst_memory_p);
} }
} else { } else {
dst_memory_p = platform::SetDstMemoryHandler<float>(ctx, output, handler,
platform::SetDstMemoryHandler<float>(ctx, output, handler); &dst_memory_p);
} }
if (src_memory_reorder_p) { if (src_memory_reorder_p) {
pipeline.push_back(*src_memory_reorder_p); pipeline.push_back(*src_memory_reorder_p);
} }
auto residual_reorder_p = std::static_pointer_cast<mkldnn::memory>(
dev_ctx.GetBlob(residual_reorder_key));
if (residual_reorder_p) {
pipeline.push_back(*residual_reorder_p);
}
pipeline.push_back(*conv_p); pipeline.push_back(*conv_p);
} }
// push primitive to stream and wait until it's executed // push primitive to stream and wait until it's executed
stream(stream::kind::eager).submit(pipeline).wait(); stream(stream::kind::eager).submit(pipeline).wait();
if (need_s8_to_u8) {
output->mutable_data<uint8_t>(ctx.GetPlace());
}
output->set_layout(DataLayout::kMKLDNN); output->set_layout(DataLayout::kMKLDNN);
output->set_format(GetMKLDNNFormat(*dst_memory_p)); output->set_format(GetMKLDNNFormat(*dst_memory_p));
} }
...@@ -577,11 +658,15 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -577,11 +658,15 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
} }
mkldnn::primitive_attr CreatePostOps( mkldnn::primitive_attr CreatePostOps(
bool fuse_relu, const std::vector<float> output_shift_scale) const { bool fuse_relu, bool fuse_residual_conn,
const std::vector<float> output_shift_scale, float sum_scale) const {
mkldnn::primitive_attr conv_attr; mkldnn::primitive_attr conv_attr;
mkldnn::post_ops post_operations; mkldnn::post_ops post_operations;
int mask = output_shift_scale.size() > 1 ? 1 << 1 : 0; int mask = output_shift_scale.size() > 1 ? 1 << 1 : 0;
conv_attr.set_output_scales(mask, output_shift_scale); conv_attr.set_output_scales(mask, output_shift_scale);
if (fuse_residual_conn) {
post_operations.append_sum(sum_scale);
}
if (fuse_relu) { if (fuse_relu) {
constexpr float scale = 1.0f; constexpr float scale = 1.0f;
constexpr float negative_slope = 0.0f; constexpr float negative_slope = 0.0f;
...@@ -622,8 +707,9 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -622,8 +707,9 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const memory::desc& dst, const std::vector<int>& strides, const memory::desc& dst, const std::vector<int>& strides,
const std::vector<int>& paddings, const std::vector<int>& paddings,
const mkldnn::engine& engine, const bool fuse_relu, const mkldnn::engine& engine, const bool fuse_relu,
const bool fuse_residual_conn,
const std::vector<float> output_shift_scale, const std::vector<float> output_shift_scale,
bool is_test) const { const float sum_scale, bool is_test) const {
memory::dims stride_dims = {strides[0], strides[1]}; memory::dims stride_dims = {strides[0], strides[1]};
memory::dims padding_dims = {paddings[0], paddings[1]}; memory::dims padding_dims = {paddings[0], paddings[1]};
...@@ -634,8 +720,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -634,8 +720,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
propagation, mkldnn::convolution_direct, src, weights, dst, stride_dims, propagation, mkldnn::convolution_direct, src, weights, dst, stride_dims,
padding_dims, padding_dims, mkldnn::padding_kind::zero); padding_dims, padding_dims, mkldnn::padding_kind::zero);
mkldnn::primitive_attr conv_attr = mkldnn::primitive_attr conv_attr = CreatePostOps(
CreatePostOps(fuse_relu, output_shift_scale); fuse_relu, fuse_residual_conn, output_shift_scale, sum_scale);
auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc( auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc(
conv_desc, conv_attr, engine); conv_desc, conv_attr, engine);
...@@ -675,8 +761,9 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -675,8 +761,9 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const std::vector<int>& strides, const std::vector<int>& strides,
const std::vector<int>& paddings, const std::vector<int>& paddings,
const mkldnn::engine& engine, const bool fuse_relu, const mkldnn::engine& engine, const bool fuse_relu,
const bool fuse_residual_conn,
const std::vector<float> output_shift_scale, const std::vector<float> output_shift_scale,
bool is_test) const { const float sum_scale, bool is_test) const {
memory::dims stride_dims = {strides[0], strides[1]}; memory::dims stride_dims = {strides[0], strides[1]};
memory::dims padding_dims = {paddings[0], paddings[1]}; memory::dims padding_dims = {paddings[0], paddings[1]};
...@@ -687,8 +774,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -687,8 +774,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
propagation, mkldnn::convolution_direct, src, weights, bias, dst, propagation, mkldnn::convolution_direct, src, weights, bias, dst,
stride_dims, padding_dims, padding_dims, mkldnn::padding_kind::zero); stride_dims, padding_dims, padding_dims, mkldnn::padding_kind::zero);
mkldnn::primitive_attr conv_attr = mkldnn::primitive_attr conv_attr = CreatePostOps(
CreatePostOps(fuse_relu, output_shift_scale); fuse_relu, fuse_residual_conn, output_shift_scale, sum_scale);
auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc( auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc(
conv_desc, conv_attr, engine); conv_desc, conv_attr, engine);
...@@ -891,7 +978,7 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -891,7 +978,7 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
input_grad->set_format(GetMKLDNNFormat(*diff_src_memory_p)); input_grad->set_format(GetMKLDNNFormat(*diff_src_memory_p));
} }
stream(stream::kind::eager).submit(pipeline).wait(); stream(stream::kind::eager).submit(pipeline).wait();
} // Compute() }
}; };
} // namespace operators } // namespace operators
......
...@@ -210,13 +210,15 @@ class MKLDNNHandler { ...@@ -210,13 +210,15 @@ class MKLDNNHandler {
dst_memory.reset(new mkldnn::memory(*dst_pd, to_void_cast<T>(output_data))); dst_memory.reset(new mkldnn::memory(*dst_pd, to_void_cast<T>(output_data)));
} }
static void AppendKey( static void AppendKey(std::string* key,
std::string* key, const mkldnn::memory::dims& input_dims, const mkldnn::memory::dims& input_dims,
const mkldnn::memory::dims& weights_dims, const std::vector<int>& strides, const mkldnn::memory::dims& weights_dims,
const std::vector<int>& paddings, const std::vector<int>& dilations, const std::vector<int>& strides,
const int& groups, const mkldnn::memory::data_type& srcdt, const std::vector<int>& paddings,
const mkldnn::memory::format& format, const std::vector<int>& dilations, const int& groups,
const mkldnn::memory::data_type& dstdt, const std::string& suffix) { const mkldnn::memory::data_type& srcdt,
const mkldnn::memory::format& format, const bool& relu,
const bool& residual, const std::string& suffix) {
AppendKeyDims(key, input_dims); AppendKeyDims(key, input_dims);
AppendKeyDims(key, weights_dims); AppendKeyDims(key, weights_dims);
AppendKeyVec(key, strides); AppendKeyVec(key, strides);
...@@ -225,7 +227,8 @@ class MKLDNNHandler { ...@@ -225,7 +227,8 @@ class MKLDNNHandler {
AppendKey(key, std::to_string(groups)); AppendKey(key, std::to_string(groups));
AppendKey(key, std::to_string(srcdt)); AppendKey(key, std::to_string(srcdt));
AppendKey(key, std::to_string(format)); AppendKey(key, std::to_string(format));
AppendKey(key, std::to_string(dstdt)); AppendKey(key, std::to_string(relu));
AppendKey(key, std::to_string(residual));
AppendKey(key, suffix); AppendKey(key, suffix);
} }
...@@ -664,15 +667,35 @@ static std::shared_ptr<mkldnn::memory> SetDstMemory( ...@@ -664,15 +667,35 @@ static std::shared_ptr<mkldnn::memory> SetDstMemory(
} }
template <typename T> template <typename T>
static std::shared_ptr<mkldnn::memory> SetDstMemoryHandler( static std::shared_ptr<mkldnn::memory> SetDstMemory(
const framework::ExecutionContext& ctx, framework::Tensor* output, const framework::ExecutionContext& ctx, framework::Tensor* output,
const std::shared_ptr<ConvMKLDNNHandler>& handler) { const framework::Tensor* residual_param,
const mkldnn::memory::desc& user_residual_md,
const std::shared_ptr<ConvMKLDNNHandler>& handler,
std::vector<mkldnn::primitive>* pipeline) {
const T* residual_param_data = residual_param->data<T>();
PADDLE_ENFORCE(residual_param_data != nullptr,
"Provide data if you want MKLDNN conv+elementwise_add fusion");
std::shared_ptr<mkldnn::memory> user_residual_memory_p =
handler->AcquireResidualDataMemory(user_residual_md,
to_void_cast<T>(residual_param_data));
T* output_data = output->mutable_data<T>(ctx.GetPlace());
std::shared_ptr<mkldnn::memory> dst_memory_p =
handler->AcquireDstMemoryFromResidualDataMemory(
user_residual_memory_p, to_void_cast<T>(output_data), *pipeline);
return dst_memory_p;
}
template <typename T>
static void SetDstMemoryHandler(
const framework::ExecutionContext& ctx, framework::Tensor* output,
const std::shared_ptr<ConvMKLDNNHandler>& handler,
std::shared_ptr<mkldnn::memory>* dst_memory_p) {
T* output_data = output->mutable_data<T>( T* output_data = output->mutable_data<T>(
ctx.GetPlace(), ::paddle::memory::Allocator::kDefault, ctx.GetPlace(), ::paddle::memory::Allocator::kDefault,
handler->GetDstMemorySize()); handler->GetDstMemorySize());
std::shared_ptr<mkldnn::memory> dst_memory_p; (*dst_memory_p)->set_data_handle(to_void_cast<T>(output_data));
dst_memory_p->set_data_handle(to_void_cast<T>(output_data));
return dst_memory_p;
} }
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
...@@ -25,6 +25,15 @@ from test_conv2d_op import conv2d_forward_naive, TestConv2dOp ...@@ -25,6 +25,15 @@ from test_conv2d_op import conv2d_forward_naive, TestConv2dOp
def conv2d_forward_refer(input, filter, group, conv_param): def conv2d_forward_refer(input, filter, group, conv_param):
out, in_n, out_h, out_w, out_c = conv2d_forward_naive(input, filter, group, out, in_n, out_h, out_w, out_c = conv2d_forward_naive(input, filter, group,
conv_param) conv_param)
size = [in_n, out_c, out_h, out_w]
return format_reorder(out, size)
def format_reorder(out, size):
in_n = size[0]
out_h = size[2]
out_w = size[3]
out_c = size[1]
out_tmp = np.zeros((in_n, out_h, out_w, out_c)) out_tmp = np.zeros((in_n, out_h, out_w, out_c))
for n in range(in_n): for n in range(in_n):
for i in range(out_h): for i in range(out_h):
...@@ -48,6 +57,7 @@ class TestConv2dInt8Op(TestConv2dOp): ...@@ -48,6 +57,7 @@ class TestConv2dInt8Op(TestConv2dOp):
self.init_dilation() self.init_dilation()
self.init_test_case() self.init_test_case()
self.init_fuse_relu() self.init_fuse_relu()
self.init_fuse_residual()
self.init_data_type() self.init_data_type()
conv2d_param = { conv2d_param = {
...@@ -79,11 +89,24 @@ class TestConv2dInt8Op(TestConv2dOp): ...@@ -79,11 +89,24 @@ class TestConv2dInt8Op(TestConv2dOp):
np.round((input_shift) * self.scale_in).astype(np.int32), np.round((input_shift) * self.scale_in).astype(np.int32),
filter_int, self.groups, filter_int, self.groups,
conv2d_param).astype(np.float32) * scale_output_shift conv2d_param).astype(np.float32) * scale_output_shift
if self.fuse_residual:
input_residual = np.random.randint(
-5, 5, self.input_residual_size).astype(self.srctype)
output_tmp = np.round(output1 - output2 + format_reorder(
input_residual, self.input_residual_size).astype(
self.srctype) * (self.scale_out / self.scale_in_eltwise
))
if self.fuse_relu:
output = np.maximum(output_tmp, 0).astype(self.dsttype)
else:
output = output_tmp.astype(self.dsttype)
else:
if self.fuse_relu: if self.fuse_relu:
output = np.maximum(np.round(output1 - output2), output = np.maximum(np.round(output1 - output2),
0).astype(self.dsttype) 0).astype(self.dsttype)
else: else:
output = np.round(output1 - output2).astype(self.dsttype) output = np.round(output1 - output2).astype(self.dsttype)
else: else:
filter_int = np.round(filter * filter_int = np.round(filter *
self.scale_weights[0]).astype(np.int32) self.scale_weights[0]).astype(np.int32)
...@@ -92,21 +115,35 @@ class TestConv2dInt8Op(TestConv2dOp): ...@@ -92,21 +115,35 @@ class TestConv2dInt8Op(TestConv2dOp):
output1 = conv2d_forward_refer( output1 = conv2d_forward_refer(
input.astype(np.int32), filter_int, self.groups, input.astype(np.int32), filter_int, self.groups,
conv2d_param).astype(np.float32) conv2d_param).astype(np.float32)
if self.fuse_residual:
input_residual = np.random.randint(
0, 10, self.input_residual_size).astype(self.srctype)
output_tmp = np.round(output1 * (self.scale_out / (
self.scale_in * self.scale_weights[0])) + format_reorder(
input_residual, self.input_residual_size).astype(
np.int32) * (self.scale_out / self.scale_in_eltwise
))
output_tmp2 = np.round(output1 * (
self.scale_out / (self.scale_in * self.scale_weights[0])))
if self.fuse_relu: if self.fuse_relu:
output = np.maximum( output = np.maximum(output_tmp, 0).astype(self.dsttype)
np.round(output1 * (self.scale_out / ( else:
self.scale_in * self.scale_weights[0]))), output = output_tmp.astype(self.dsttype)
0).astype(self.dsttype) else:
if self.fuse_relu:
output = np.maximum(output_tmp2, 0).astype(self.dsttype)
else: else:
output = np.round(output1 * (self.scale_out / ( output = output_tmp2.astype(self.dsttype)
self.scale_in *
self.scale_weights[0]))).astype(self.dsttype)
self.inputs = { self.inputs = {
'Input': 'Input':
OpTest.np_dtype_to_fluid_dtype(input.astype(self.srctype)), OpTest.np_dtype_to_fluid_dtype(input.astype(self.srctype)),
'Filter': OpTest.np_dtype_to_fluid_dtype(filter) 'Filter': OpTest.np_dtype_to_fluid_dtype(filter)
} }
if self.fuse_residual:
self.inputs['ResidualData'] = OpTest.np_dtype_to_fluid_dtype(
input_residual)
self.attrs = { self.attrs = {
'strides': self.stride, 'strides': self.stride,
'paddings': self.pad, 'paddings': self.pad,
...@@ -119,7 +156,9 @@ class TestConv2dInt8Op(TestConv2dOp): ...@@ -119,7 +156,9 @@ class TestConv2dInt8Op(TestConv2dOp):
'Scale_in': self.scale_in, 'Scale_in': self.scale_in,
'Scale_out': self.scale_out, 'Scale_out': self.scale_out,
'Scale_weights': self.scale_weights, 'Scale_weights': self.scale_weights,
'fuse_relu': self.fuse_relu 'Scale_in_eltwise': self.scale_in_eltwise,
'fuse_relu': self.fuse_relu,
'fuse_residual_connection': self.fuse_residual
} }
self.outputs = {'Output': output} self.outputs = {'Output': output}
...@@ -137,11 +176,14 @@ class TestConv2dInt8Op(TestConv2dOp): ...@@ -137,11 +176,14 @@ class TestConv2dInt8Op(TestConv2dOp):
def init_test_case(self): def init_test_case(self):
TestConv2dOp.init_test_case(self) TestConv2dOp.init_test_case(self)
self.input_size = [1, 1, 5, 5] # NCHW
f_c = self.input_size[1] // self.groups f_c = self.input_size[1] // self.groups
self.filter_size = [1, f_c, 3, 3] self.input_residual_size = [1, 2, 3, 3]
self.filter_size = [2, f_c, 3, 3]
self.scale_in = 1.0 self.scale_in = 1.0
self.scale_out = 0.5 self.scale_out = 0.5
self.scale_weights = [10.0] self.scale_weights = [10.0]
self.scale_in_eltwise = 0.6
def init_data_type(self): def init_data_type(self):
self.srctype = np.uint8 self.srctype = np.uint8
...@@ -150,8 +192,11 @@ class TestConv2dInt8Op(TestConv2dOp): ...@@ -150,8 +192,11 @@ class TestConv2dInt8Op(TestConv2dOp):
def init_fuse_relu(self): def init_fuse_relu(self):
self.fuse_relu = True self.fuse_relu = True
def init_fuse_residual(self):
self.fuse_residual = True
#--------------------test conv2d u8 in and u8 out-------------------- #--------------------test conv2d u8 in and u8 out with residual fuse--------------------
class TestConv2d(TestConv2dInt8Op): class TestConv2d(TestConv2dInt8Op):
...@@ -159,18 +204,21 @@ class TestConv2d(TestConv2dInt8Op): ...@@ -159,18 +204,21 @@ class TestConv2d(TestConv2dInt8Op):
self.pad = [0, 0] self.pad = [0, 0]
self.stride = [1, 1] self.stride = [1, 1]
self.input_size = [2, 3, 5, 5] # NCHW self.input_size = [2, 3, 5, 5] # NCHW
self.input_residual_size = [2, 6, 3, 3]
assert np.mod(self.input_size[1], self.groups) == 0 assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups f_c = self.input_size[1] // self.groups
self.filter_size = [6, f_c, 3, 3] self.filter_size = [6, f_c, 3, 3]
self.scale_in = 1.0 self.scale_in = 1.0
self.scale_out = 0.5 self.scale_out = 0.5
self.scale_weights = [10.0] self.scale_weights = [10.0]
self.scale_in_eltwise = 0.6
class TestWithPad(TestConv2d): class TestWithPad(TestConv2d):
def init_test_case(self): def init_test_case(self):
TestConv2d.init_test_case(self) TestConv2d.init_test_case(self)
self.pad = [1, 1] self.pad = [1, 1]
self.input_residual_size = [2, 6, 5, 5]
class TestWithGroup(TestConv2d): class TestWithGroup(TestConv2d):
...@@ -183,12 +231,14 @@ class TestWithStride(TestConv2dInt8Op): ...@@ -183,12 +231,14 @@ class TestWithStride(TestConv2dInt8Op):
self.pad = [1, 1] self.pad = [1, 1]
self.stride = [2, 2] self.stride = [2, 2]
self.input_size = [2, 3, 6, 6] self.input_size = [2, 3, 6, 6]
self.input_residual_size = [2, 6, 3, 3]
assert np.mod(self.input_size[1], self.groups) == 0 assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups f_c = self.input_size[1] // self.groups
self.filter_size = [6, f_c, 3, 3] self.filter_size = [6, f_c, 3, 3]
self.scale_in = 1.0 self.scale_in = 1.0
self.scale_out = 0.8 self.scale_out = 0.8
self.scale_weights = [10.0] self.scale_weights = [10.0]
self.scale_in_eltwise = 0.5
class TestWith1x1(TestConv2dInt8Op): class TestWith1x1(TestConv2dInt8Op):
...@@ -196,12 +246,14 @@ class TestWith1x1(TestConv2dInt8Op): ...@@ -196,12 +246,14 @@ class TestWith1x1(TestConv2dInt8Op):
self.pad = [0, 0] self.pad = [0, 0]
self.stride = [1, 1] self.stride = [1, 1]
self.input_size = [1, 3, 5, 5] self.input_size = [1, 3, 5, 5]
self.input_residual_size = [1, 6, 5, 5]
assert np.mod(self.input_size[1], self.groups) == 0 assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups f_c = self.input_size[1] // self.groups
self.filter_size = [6, f_c, 1, 1] self.filter_size = [6, f_c, 1, 1]
self.scale_in = 1.0 self.scale_in = 1.0
self.scale_out = 0.5 self.scale_out = 0.5
self.scale_weights = [12.0] self.scale_weights = [12.0]
self.scale_in_eltwise = 0.5
class TestWithInput1x1Filter1x1(TestConv2dInt8Op): class TestWithInput1x1Filter1x1(TestConv2dInt8Op):
...@@ -209,24 +261,29 @@ class TestWithInput1x1Filter1x1(TestConv2dInt8Op): ...@@ -209,24 +261,29 @@ class TestWithInput1x1Filter1x1(TestConv2dInt8Op):
self.pad = [0, 0] self.pad = [0, 0]
self.stride = [1, 1] self.stride = [1, 1]
self.input_size = [2, 3, 1, 1] self.input_size = [2, 3, 1, 1]
self.input_residual_size = [2, 6, 1, 1]
assert np.mod(self.input_size[1], self.groups) == 0 assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups f_c = self.input_size[1] // self.groups
self.filter_size = [6, f_c, 1, 1] self.filter_size = [6, f_c, 1, 1]
self.scale_in = 1.0 self.scale_in = 1.0
self.scale_out = 0.5 self.scale_out = 0.5
self.scale_weights = [10.0] self.scale_weights = [10.0]
self.scale_in_eltwise = 0.8
def init_group(self): def init_group(self):
self.groups = 3 self.groups = 3
def init_data_type_with_fusion(self, input_dt, fuse_relu): def init_data_type_with_fusion(self, input_dt, fuse_relu, fuse_residual):
self.srctype = input_dt self.srctype = input_dt
self.dsttype = np.uint8 if fuse_relu else np.int8 self.dsttype = np.uint8 if fuse_relu else np.int8
def init_fuse_relu(self): def init_fuse_relu(self):
self.fuse_relu = fuse_relu self.fuse_relu = fuse_relu
def init_fuse_residual(self):
self.fuse_residual = fuse_residual
def create_test_int8_class(parent): def create_test_int8_class(parent):
...@@ -234,29 +291,68 @@ def create_test_int8_class(parent): ...@@ -234,29 +291,68 @@ def create_test_int8_class(parent):
class TestS8U8Case(parent): class TestS8U8Case(parent):
def init_data_type(self): def init_data_type(self):
init_data_type_with_fusion(self, np.int8, True) init_data_type_with_fusion(self, np.int8, True, False)
#--------------------test conv2d s8 in and s8 out-------------------- #--------------------test conv2d s8 in and s8 out--------------------
class TestS8S8Case(parent): class TestS8S8Case(parent):
def init_data_type(self): def init_data_type(self):
init_data_type_with_fusion(self, np.int8, False) init_data_type_with_fusion(self, np.int8, False, False)
#--------------------test conv2d u8 in and s8 out-------------------- #--------------------test conv2d u8 in and s8 out--------------------
class TestU8S8Case(parent): class TestU8S8Case(parent):
def init_data_type(self): def init_data_type(self):
init_data_type_with_fusion(self, np.uint8, False) init_data_type_with_fusion(self, np.uint8, False, False)
#--------------------test conv2d u8 in and u8 out without residual fuse--------------------
class TestU8U8Case(parent):
def init_data_type(self):
init_data_type_with_fusion(self, np.uint8, True, False)
#--------------------test conv2d s8 in and u8 out with residual fuse--------------------
class TestS8U8ResCase(parent):
def init_data_type(self):
init_data_type_with_fusion(self, np.int8, True, True)
#--------------------test conv2d s8 in and s8 out with residual fuse--------------------
class TestS8S8ResCase(parent):
def init_data_type(self):
init_data_type_with_fusion(self, np.int8, False, True)
#--------------------test conv2d u8 in and s8 out with residual fuse--------------------
cls_name_s8u8 = "{0}_relu_{1}".format(parent.__name__, "1") class TestU8S8ResCase(parent):
cls_name_s8s8 = "{0}_relu_{1}".format(parent.__name__, "0") def init_data_type(self):
cls_name_u8s8 = "{0}_relu_{1}".format(parent.__name__, "0") init_data_type_with_fusion(self, np.uint8, False, True)
cls_name_s8u8 = "{0}_relu_{1}_residual_0".format(parent.__name__, "1")
cls_name_s8s8 = "{0}_relu_{1}_residual_0".format(parent.__name__, "0")
cls_name_u8s8 = "{0}_relu_{1}_residual_0".format(parent.__name__, "0")
cls_name_u8u8 = "{0}_relu_{1}_residual_0".format(parent.__name__, "1")
cls_name_s8u8_re_1 = "{0}_relu_{1}_residual_{2}".format(parent.__name__,
"1", "1")
cls_name_s8s8_re_1 = "{0}_relu_{1}_residual_{2}".format(parent.__name__,
"0", "1")
cls_name_u8s8_re_1 = "{0}_relu_{1}_residual_{2}".format(parent.__name__,
"0", "1")
TestS8U8Case.__name__ = cls_name_s8u8 TestS8U8Case.__name__ = cls_name_s8u8
TestS8S8Case.__name__ = cls_name_s8s8 TestS8S8Case.__name__ = cls_name_s8s8
TestU8S8Case.__name__ = cls_name_u8s8 TestU8S8Case.__name__ = cls_name_u8s8
TestU8U8Case.__name__ = cls_name_u8u8
TestS8U8ResCase.__name__ = cls_name_s8u8_re_1
TestS8S8ResCase.__name__ = cls_name_s8s8_re_1
TestU8S8ResCase.__name__ = cls_name_u8s8_re_1
globals()[cls_name_s8u8] = TestS8U8Case globals()[cls_name_s8u8] = TestS8U8Case
globals()[cls_name_s8s8] = TestS8S8Case globals()[cls_name_s8s8] = TestS8S8Case
globals()[cls_name_u8s8] = TestU8S8Case globals()[cls_name_u8s8] = TestU8S8Case
globals()[cls_name_u8u8] = TestU8U8Case
globals()[cls_name_s8u8_re_1] = TestS8U8ResCase
globals()[cls_name_s8s8_re_1] = TestS8S8ResCase
globals()[cls_name_u8s8_re_1] = TestU8S8ResCase
create_test_int8_class(TestConv2dInt8Op) create_test_int8_class(TestConv2dInt8Op)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册