提交 aac5ee72 编写于 作者: X xiaolil1

remove md recreate for int8

上级 ea060d51
...@@ -401,11 +401,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -401,11 +401,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
int count = is_multi_channel? (g>1? weights_tz[1]*weights_tz[0] : weights_tz[0]) : 1; int count = is_multi_channel? (g>1? weights_tz[1]*weights_tz[0] : weights_tz[0]) : 1;
scale_in_data = {*(scale_in->data<float>())}; scale_in_data = {*(scale_in->data<float>())};
scale_weights_data.resize(count); scale_weights_data.resize(count);
#pragma omp parallel for if (count > 1)
for(int i=0; i<count; i++){ for(int i=0; i<count; i++){
scale_weights_data[i] =*(scale_weights->data<float>() + i); scale_weights_data[i] =*(scale_weights->data<float>() + i);
} }
scale_out_data = {*(scale_out->data<float>())}; scale_out_data = {*(scale_out->data<float>())};
output_shift_scale.resize(count); output_shift_scale.resize(count);
#pragma omp parallel for if (count > 1)
for(int i=0; i<count; i++){ for(int i=0; i<count; i++){
if(scale_weights_data[i] == 0.0) if(scale_weights_data[i] == 0.0)
output_shift_scale[i] = scale_out_data[0]; output_shift_scale[i] = scale_out_data[0];
...@@ -454,61 +456,56 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -454,61 +456,56 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto chosen_memory_format = auto chosen_memory_format =
platform::data_format_to_memory_format(data_format); platform::data_format_to_memory_format(data_format);
auto src_md = platform::MKLDNNMemDesc( std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_pd;
src_tz, platform::MKLDNNGetDataType<float>(), chosen_memory_format); auto bias_tz = paddle::framework::vectorize2int(bias->dims());
auto weights_md = platform::MKLDNNMemDesc(
weights_tz, platform::MKLDNNGetDataType<float>(),
(g == 1) ? chosen_memory_format : mkldnn::memory::format::goihw);
auto dst_md = platform::MKLDNNMemDesc(
dst_tz, platform::MKLDNNGetDataType<float>(), chosen_memory_format);
std::vector<int> bias_tz;
if(is_INT8){ if(is_INT8){
src_md = platform::MKLDNNMemDesc( auto src_md = platform::MKLDNNMemDesc(
src_tz, memory::data_type::u8, chosen_memory_format); src_tz, memory::data_type::u8, chosen_memory_format);
weights_md = platform::MKLDNNMemDesc( auto weights_md = platform::MKLDNNMemDesc(
weights_tz, memory::data_type::s8, weights_tz, memory::data_type::s8,
(g == 1) ? chosen_memory_format : mkldnn::memory::format::goihw); (g == 1) ? chosen_memory_format : mkldnn::memory::format::goihw);
auto dst_dt = fuse_relu? paddle::framework::ToMKLDNNDataType(std::type_index(typeid(unsigned char))) : paddle::framework::ToMKLDNNDataType(std::type_index(typeid(signed char))); auto dst_dt = fuse_relu? paddle::framework::ToMKLDNNDataType(std::type_index(typeid(unsigned char))) : paddle::framework::ToMKLDNNDataType(std::type_index(typeid(signed char)));
if(fuse_residual_conn){ if(fuse_residual_conn){
auto residual = ctx.Input<Tensor>("ResidualData"); auto residual = ctx.Input<Tensor>("ResidualData");
auto residual_dt = paddle::framework::ToMKLDNNDataType(residual->type()); auto residual_dt = paddle::framework::ToMKLDNNDataType(residual->type());
if(dst_dt != residual_dt) if(dst_dt != residual_dt)
dst_dt = residual_dt; dst_dt = residual_dt;
} }
dst_md = platform::MKLDNNMemDesc(dst_tz, dst_dt, chosen_memory_format); auto dst_md = platform::MKLDNNMemDesc(dst_tz, dst_dt, chosen_memory_format);
}
// create a conv primitive descriptor and save it for usage in backward // create a conv primitive descriptor and save it for usage in backward
std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_pd; if (bias) {
if (bias) { auto bias_md = platform::MKLDNNMemDesc(
bias_tz = paddle::framework::vectorize2int(bias->dims());
auto bias_md = platform::MKLDNNMemDesc(
bias_tz, platform::MKLDNNGetDataType<float>(), memory::format::x);
if(is_INT8){
bias_md = platform::MKLDNNMemDesc(
bias_tz, memory::data_type::s32, memory::format::x); bias_tz, memory::data_type::s32, memory::format::x);
}
if(is_INT8){
conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, bias_md, dst_md, conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, bias_md, dst_md,
strides, paddings, mkldnn_engine, strides, paddings, mkldnn_engine,
fuse_relu, fuse_residual_conn, fuse_relu, fuse_residual_conn,
output_shift_scale, sum_scale[0]); output_shift_scale, sum_scale[0]);
} else{ } else {
conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, bias_md, dst_md,
strides, paddings, mkldnn_engine,
fuse_relu, fuse_residual_conn);
}
} else {
if(is_INT8){
conv_pd = conv_pd =
ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides, paddings, ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides, paddings,
mkldnn_engine, fuse_relu, fuse_residual_conn, mkldnn_engine, fuse_relu, fuse_residual_conn,
output_shift_scale, sum_scale[0]); output_shift_scale, sum_scale[0]);
} else{ }
conv_pd = } else{
ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides, paddings, auto src_md = platform::MKLDNNMemDesc(
mkldnn_engine, fuse_relu, fuse_residual_conn); src_tz, platform::MKLDNNGetDataType<float>(), chosen_memory_format);
auto weights_md = platform::MKLDNNMemDesc(
weights_tz, platform::MKLDNNGetDataType<float>(),
(g == 1) ? chosen_memory_format : mkldnn::memory::format::goihw);
auto dst_md = platform::MKLDNNMemDesc(
dst_tz, platform::MKLDNNGetDataType<float>(), chosen_memory_format);
// create a conv primitive descriptor and save it for usage in backward
if (bias) {
auto bias_md = platform::MKLDNNMemDesc(
bias_tz, platform::MKLDNNGetDataType<float>(), memory::format::x);
conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, bias_md, dst_md,
strides, paddings, mkldnn_engine,
fuse_relu, fuse_residual_conn);
} else {
conv_pd =
ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides, paddings,
mkldnn_engine, fuse_relu, fuse_residual_conn);
} }
} }
// Save conv_pd/src_memory/weights_memory for backward pass // Save conv_pd/src_memory/weights_memory for backward pass
...@@ -634,6 +631,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -634,6 +631,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
if(scale_reuse){ if(scale_reuse){
int count = is_multi_channel? (g>1? weights_tz[1]*weights_tz[0] : weights_tz[0]) : 1; int count = is_multi_channel? (g>1? weights_tz[1]*weights_tz[0] : weights_tz[0]) : 1;
scale_bias_data.resize(count); scale_bias_data.resize(count);
#pragma omp parallel for if (count > 1)
for(int i=0; i<count; i++){ for(int i=0; i<count; i++){
scale_bias_data[i] = scale_in_data[0] * scale_weights_data[i]; scale_bias_data[i] = scale_in_data[0] * scale_weights_data[i];
} }
...@@ -722,7 +720,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -722,7 +720,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
// the scale parameter. It is assumed that when fuse_residual_conn is true, the // the scale parameter. It is assumed that when fuse_residual_conn is true, the
// Output tensor contains the data coming from residual connection. The // Output tensor contains the data coming from residual connection. The
// result of this post_op is: Output = scale * Output + Conv_Out. // result of this post_op is: Output = scale * Output + Conv_Out.
conv_attr.set_output_scales(0, {1.0f});
if (fuse_residual_conn) { if (fuse_residual_conn) {
post_operations.append_sum(1.0f); post_operations.append_sum(1.0f);
} }
...@@ -774,7 +772,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -774,7 +772,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
memory::dims padding_dims = {paddings[0], paddings[1]}; memory::dims padding_dims = {paddings[0], paddings[1]};
auto conv_desc = mkldnn::convolution_forward::desc( auto conv_desc = mkldnn::convolution_forward::desc(
mkldnn::prop_kind::forward, mkldnn::convolution_direct, src, weights, mkldnn::prop_kind::forward_scoring, mkldnn::convolution_direct, src, weights,
dst, stride_dims, padding_dims, padding_dims, dst, stride_dims, padding_dims, padding_dims,
mkldnn::padding_kind::zero); mkldnn::padding_kind::zero);
...@@ -824,7 +822,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -824,7 +822,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
memory::dims padding_dims = {paddings[0], paddings[1]}; memory::dims padding_dims = {paddings[0], paddings[1]};
auto conv_desc = mkldnn::convolution_forward::desc( auto conv_desc = mkldnn::convolution_forward::desc(
mkldnn::prop_kind::forward, mkldnn::convolution_direct, src, weights, mkldnn::prop_kind::forward_scoring, mkldnn::convolution_direct, src, weights,
bias, dst, stride_dims, padding_dims, padding_dims, bias, dst, stride_dims, padding_dims, padding_dims,
mkldnn::padding_kind::zero); mkldnn::padding_kind::zero);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册