提交 a042d86b 编写于 作者: X xiaolil1

adjust structure for PR, peel int8 from fp32

上级 fc9e1347
...@@ -296,384 +296,518 @@ template <typename T> ...@@ -296,384 +296,518 @@ template <typename T>
class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
public: public:
void Compute(const paddle::framework::ExecutionContext& ctx) const override { void Compute(const paddle::framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
"It must use CPUPlace.");
const bool is_test = ctx.Attr<bool>("is_test");
auto& dev_ctx =
ctx.template device_context<paddle::platform::MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine();
auto* input = ctx.Input<Tensor>("Input");
auto* filter = ctx.Input<Tensor>("Filter");
auto* bias = ctx.HasInput("Bias") ? ctx.Input<Tensor>("Bias") : nullptr;
auto* output = ctx.Output<Tensor>("Output");
auto* scale_in = ctx.HasInput("Scale_in") ? ctx.Input<Tensor>("Scale_in") : nullptr;
auto* scale_in_eltwise = ctx.HasInput("Scale_in_eltwise")? ctx.Input<Tensor>("Scale_in_eltwise") : nullptr;
auto* scale_weights = ctx.HasInput("Scale_weights")? ctx.Input<Tensor>("Scale_weights") : nullptr;
auto* scale_out = ctx.HasInput("Scale_out")? ctx.Input<Tensor>("Scale_out") : nullptr;
bool is_INT8 = ctx.HasInput("Scale_in")? true : false; bool is_INT8 = ctx.HasInput("Scale_in")? true : false;
bool is_multi_channel = (is_INT8 && scale_weights->memory_size() > 1) ? true : false;
PADDLE_ENFORCE(input->layout() == DataLayout::kMKLDNN && if(!is_INT8){
input->format() != memory::format::format_undef, PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
"Wrong layout/format set for Input tensor"); "It must use CPUPlace.");
PADDLE_ENFORCE(filter->layout() == DataLayout::kMKLDNN &&
filter->format() != memory::format::format_undef, const bool is_test = ctx.Attr<bool>("is_test");
"Wrong layout/format set for Filter tensor");
PADDLE_ENFORCE(input->dims().size() == 4, auto& dev_ctx =
"Input must be with 4 dimensions, i.e. NCHW"); ctx.template device_context<paddle::platform::MKLDNNDeviceContext>();
PADDLE_ENFORCE(filter->dims().size() == 4, const auto& mkldnn_engine = dev_ctx.GetEngine();
"Filter must be with 4 dimensions, i.e. OIHW");
if (bias) { auto* input = ctx.Input<Tensor>("Input");
PADDLE_ENFORCE(bias->layout() == DataLayout::kMKLDNN && auto* filter = ctx.Input<Tensor>("Filter");
bias->format() != memory::format::format_undef, auto* bias = ctx.HasInput("Bias") ? ctx.Input<Tensor>("Bias") : nullptr;
"Wrong layout/format set for Bias tensor"); auto* output = ctx.Output<Tensor>("Output");
PADDLE_ENFORCE(bias->dims().size() == 1,
"Bias must only have 1 dimension, i.e. X"); PADDLE_ENFORCE(input->layout() == DataLayout::kMKLDNN &&
} input->format() != memory::format::format_undef,
"Wrong layout/format set for Input tensor");
PADDLE_ENFORCE(filter->layout() == DataLayout::kMKLDNN &&
filter->format() != memory::format::format_undef,
"Wrong layout/format set for Filter tensor");
PADDLE_ENFORCE(input->dims().size() == 4,
"Input must be with 4 dimensions, i.e. NCHW");
PADDLE_ENFORCE(filter->dims().size() == 4,
"Filter must be with 4 dimensions, i.e. OIHW");
if (bias) {
PADDLE_ENFORCE(bias->layout() == DataLayout::kMKLDNN &&
bias->format() != memory::format::format_undef,
"Wrong layout/format set for Bias tensor");
PADDLE_ENFORCE(bias->dims().size() == 1,
"Bias must only have 1 dimension, i.e. X");
}
std::vector<int> strides = ctx.Attr<std::vector<int>>("strides"); std::vector<int> strides = ctx.Attr<std::vector<int>>("strides");
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings"); std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations"); std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
bool fuse_relu = ctx.Attr<bool>("fuse_relu"); bool fuse_relu = ctx.Attr<bool>("fuse_relu");
bool fuse_residual_conn = ctx.Attr<bool>("fuse_residual_connection"); bool fuse_residual_conn = ctx.Attr<bool>("fuse_residual_connection");
int groups = ctx.Attr<int>("groups"); int groups = ctx.Attr<int>("groups");
// TODO(tpatejko): add support for dilation
PADDLE_ENFORCE(
dilations.size() == 2 && dilations[0] == 1 && dilations[1] == 1,
"dilation in convolution is not implemented yet");
const T* input_data = input->data<T>();
const T* filter_data = filter->data<T>();
std::vector<int> src_tz = paddle::framework::vectorize2int(input->dims());
std::vector<int> weights_tz =
paddle::framework::vectorize2int(filter->dims());
int g = std::max(groups, 1);
if (g > 1) {
int o = weights_tz[0];
int i = weights_tz[1];
int h = weights_tz[2];
int w = weights_tz[3];
weights_tz.resize(5);
weights_tz[0] = g;
weights_tz[1] = o / g;
weights_tz[2] = i;
weights_tz[3] = h;
weights_tz[4] = w;
}
std::vector<int> dst_tz = paddle::framework::vectorize2int(output->dims());
// Get unique name for storing MKLDNN primitives
const std::string key = ConvMKLDNNHandler::GetHash(
src_tz, weights_tz, strides, paddings, dilations, groups,
ctx.op().Output("Output"));
const std::string key_conv_pd = key + "@conv_pd";
std::vector<primitive> pipeline;
auto user_src_md = platform::MKLDNNMemDesc(
{src_tz}, platform::MKLDNNGetDataType<T>(), input->format());
auto user_weights_md = platform::MKLDNNMemDesc(
{weights_tz}, platform::MKLDNNGetDataType<T>(),
(g == 1) ? filter->format() : mkldnn::memory::format::goihw);
/* create memory descriptor for convolution without specified format
* ('any') which lets a primitive (convolution in this case) choose
* the memory format preferred for best performance
*/
std::string data_format = ctx.Attr<std::string>("data_format");
auto chosen_memory_format =
platform::data_format_to_memory_format(data_format);
auto src_md = platform::MKLDNNMemDesc(
src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
auto weights_md = platform::MKLDNNMemDesc(
weights_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
std::vector<int> bias_tz; // TODO(mgallus): avoid empty vector creation.
// Currently used whenever bias is != nullptr.
auto dst_md = platform::MKLDNNMemDesc(
dst_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
// create a conv primitive descriptor and save it for usage in backward
std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_pd;
if (bias) {
bias_tz = paddle::framework::vectorize2int(bias->dims());
auto bias_md = platform::MKLDNNMemDesc(
bias_tz, platform::MKLDNNGetDataType<T>(), memory::format::x);
conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, bias_md, dst_md,
strides, paddings, mkldnn_engine,
fuse_relu, fuse_residual_conn);
} else {
conv_pd =
ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides, paddings,
mkldnn_engine, fuse_relu, fuse_residual_conn);
}
// Save conv_pd/src_memory/weights_memory for backward pass
dev_ctx.SetBlob(key_conv_pd, conv_pd);
// TODO(tpatejko): add support for dilation ConvMKLDNNHandler handler(conv_pd, dev_ctx, mkldnn_engine, key);
PADDLE_ENFORCE(
dilations.size() == 2 && dilations[0] == 1 && dilations[1] == 1,
"dilation in convolution is not implemented yet");
const T* input_data = input->data<T>(); // create mkldnn memory from input tensors (data/weights)
const float* filter_data = filter->data<float>(); auto user_src_memory_p =
handler.AcquireSrcMemory(user_src_md, to_void_cast<T>(input_data));
auto user_weights_memory_p = handler.AcquireWeightsMemory(
user_weights_md, to_void_cast<T>(filter_data));
std::vector<int> src_tz = paddle::framework::vectorize2int(input->dims()); // create reorder primitive if the input format is not the preferred one
std::vector<int> weights_tz = auto src_memory_p =
paddle::framework::vectorize2int(filter->dims()); handler.AcquireSrcMemoryFromPrimitive(user_src_memory_p, pipeline);
int g = std::max(groups, 1); auto weights_memory_p = handler.AcquireWeightsMemoryFromPrimitive(
if (g > 1) { user_weights_memory_p, pipeline, is_test);
int o = weights_tz[0];
int i = weights_tz[1];
int h = weights_tz[2];
int w = weights_tz[3];
weights_tz.resize(5);
weights_tz[0] = g;
weights_tz[1] = o / g;
weights_tz[2] = i;
weights_tz[3] = h;
weights_tz[4] = w;
}
std::vector<int> dst_tz = paddle::framework::vectorize2int(output->dims());
// Get unique name for storing MKLDNN primitives std::shared_ptr<mkldnn::memory> dst_memory_p;
const std::string key = ConvMKLDNNHandler::GetHash(
src_tz, weights_tz, strides, paddings, dilations, groups,
ctx.op().Output("Output"));
const std::string key_conv_pd = key + "@conv_pd";
static std::unordered_map<std::string, std::vector<float>> scale_map;
//scale_map.insert({key_conv_pd,{1.0f}});
//scale_map[key_conv_pd]={0.1f};
bool scale_reuse = false;
auto scale_in_key = key + "@scale_in";
auto scale_weights_key = key + "@scale_weights";
auto scale_out_key = key + "@scale_out";
auto output_shift_scale_key = key + "@output_shift_scale";
auto sum_scale_key = key + "@sum_scale";
auto scale_in_eltwise_key = key + "@scale_in_eltwise";
std::vector<float> scale_in_data;
std::vector<float> scale_out_data;
std::vector<float> scale_weights_data;
std::vector<float> scale_in_eltwise_data;
std::vector<float> output_shift_scale;
std::vector<float> sum_scale = {1.0f};
std::vector<float> none_scale = {0};
if (is_INT8 && GetScaleMap(scale_map, scale_in_key) == none_scale){
scale_reuse = true;
}
//std::cout<<"scale_reuse = "<<scale_reuse<<std::endl;
if(is_INT8){
if(scale_reuse){
//std::cout<<"load scale!!!!!!!!"<<std::endl;
int count = is_multi_channel? (g>1? weights_tz[1]*weights_tz[0] : weights_tz[0]) : 1;
scale_in_data = {*(scale_in->data<float>())};
scale_weights_data.resize(count);
#pragma omp parallel for if (count > 1)
for(int i=0; i<count; i++){
scale_weights_data[i] =*(scale_weights->data<float>() + i);
}
scale_out_data = {*(scale_out->data<float>())};
output_shift_scale.resize(count);
#pragma omp parallel for if (count > 1)
for(int i=0; i<count; i++){
if(scale_weights_data[i] == 0.0)
output_shift_scale[i] = scale_out_data[0];
else
output_shift_scale[i] = scale_out_data[0] / (scale_in_data[0] * scale_weights_data[i]);
}
if(fuse_residual_conn){
scale_in_eltwise_data = {*(scale_in_eltwise->data<float>())};
sum_scale[0] = scale_out_data[0] / scale_in_eltwise_data[0];
SetScaleMap(scale_map, scale_in_eltwise_key, scale_in_eltwise_data);
}
//scale reuse if (fuse_residual_conn) {
SetScaleMap(scale_map, scale_in_key, scale_in_data); auto residual_param = ctx.Input<Tensor>("ResidualData");
SetScaleMap(scale_map, scale_weights_key, scale_weights_data); auto residual_param_data = residual_param->data<T>();
SetScaleMap(scale_map, scale_out_key, scale_out_data);
SetScaleMap(scale_map, output_shift_scale_key, output_shift_scale);
SetScaleMap(scale_map, sum_scale_key, sum_scale);
} else{
scale_in_data = GetScaleMap(scale_map, scale_in_key);
scale_out_data = GetScaleMap(scale_map, scale_out_key);
scale_weights_data = GetScaleMap(scale_map, scale_weights_key);
if(fuse_residual_conn){
scale_in_eltwise_data = GetScaleMap(scale_map, scale_in_eltwise_key);
}
output_shift_scale = GetScaleMap(scale_map, output_shift_scale_key);
sum_scale = GetScaleMap(scale_map, sum_scale_key);
//printf("pause!!!");
}
} PADDLE_ENFORCE(
residual_param_data != nullptr,
"Provide data if you want MKLDNN conv+elementwise_add fusion");
PADDLE_ENFORCE_EQ(output->dims(), residual_param->dims(),
"Output and elementwise parameter need to have the "
"same dimension sizes");
if (residual_param->format() != handler.GetDstFormat()) {
auto output_data =
output->mutable_data<T>(ctx.GetPlace(), handler.GetDstMemorySize());
auto residual_data_tz =
paddle::framework::vectorize2int(residual_param->dims());
auto residual_data_type =
paddle::framework::ToMKLDNNDataType(residual_param->type());
auto user_residual_md = platform::MKLDNNMemDesc(
residual_data_tz, residual_data_type, residual_param->format());
auto user_residual_memory_p = handler.AcquireResidualDataMemory(
user_residual_md, to_void_cast<T>(residual_param_data));
dst_memory_p = handler.AcquireDstMemoryFromResidualDataMemory(
user_residual_memory_p, to_void_cast<T>(output_data), pipeline);
} else {
output->ShareDataWith(*residual_param);
auto output_data = output->mutable_data<T>(ctx.GetPlace());
dst_memory_p =
handler.AcquireDstMemoryFromPrimitive(to_void_cast<T>(output_data));
}
} else {
auto output_data =
output->mutable_data<T>(ctx.GetPlace(), handler.GetDstMemorySize());
dst_memory_p =
handler.AcquireDstMemoryFromPrimitive(to_void_cast<T>(output_data));
}
// create convolution op primitive
std::shared_ptr<mkldnn::convolution_forward> conv_p;
if (bias) {
const T* bias_data = bias->data<T>();
auto user_bias_md = platform::MKLDNNMemDesc(
{bias_tz}, platform::MKLDNNGetDataType<T>(), memory::format::x);
auto user_bias_memory_p =
handler.AcquireBiasMemory(user_bias_md, to_void_cast<T>(bias_data));
auto bias_memory_p =
handler.AcquireBiasMemoryFromPrimitive(user_bias_memory_p, pipeline);
conv_p = handler.AcquireConvolution(src_memory_p, weights_memory_p,
bias_memory_p, dst_memory_p);
} else {
conv_p = handler.AcquireConvolution(src_memory_p, weights_memory_p,
dst_memory_p);
}
std::vector<primitive> pipeline; // push primitive to stream and wait until it's executed
auto user_src_md = platform::MKLDNNMemDesc( pipeline.push_back(*conv_p);
{src_tz}, paddle::framework::ToMKLDNNDataType(input->type()), input->format()); stream(stream::kind::eager).submit(pipeline).wait();
auto user_weights_md = platform::MKLDNNMemDesc(
{weights_tz}, platform::MKLDNNGetDataType<float>(),
(g == 1) ? mkldnn::memory::format::oihw : mkldnn::memory::format::goihw);
/* create memory descriptor for convolution without specified format output->set_layout(DataLayout::kMKLDNN);
* ('any') which lets a primitive (convolution in this case) choose output->set_format(GetMKLDNNFormat(*dst_memory_p));
* the memory format preferred for best performance } else{
*/ PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
std::string data_format = ctx.Attr<std::string>("data_format"); "It must use CPUPlace.");
auto chosen_memory_format = const bool is_test = ctx.Attr<bool>("is_test");
platform::data_format_to_memory_format(data_format);
auto& dev_ctx =
ctx.template device_context<paddle::platform::MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine();
auto* input = ctx.Input<Tensor>("Input");
auto* filter = ctx.Input<Tensor>("Filter");
auto* bias = ctx.HasInput("Bias") ? ctx.Input<Tensor>("Bias") : nullptr;
auto* output = ctx.Output<Tensor>("Output");
auto* scale_in = ctx.HasInput("Scale_in") ? ctx.Input<Tensor>("Scale_in") : nullptr;
auto* scale_in_eltwise = ctx.HasInput("Scale_in_eltwise")? ctx.Input<Tensor>("Scale_in_eltwise") : nullptr;
auto* scale_weights = ctx.HasInput("Scale_weights")? ctx.Input<Tensor>("Scale_weights") : nullptr;
auto* scale_out = ctx.HasInput("Scale_out")? ctx.Input<Tensor>("Scale_out") : nullptr;
bool is_multi_channel = (scale_weights->memory_size() > 1) ? true : false;
PADDLE_ENFORCE(input->layout() == DataLayout::kMKLDNN &&
input->format() != memory::format::format_undef,
"Wrong layout/format set for Input tensor");
PADDLE_ENFORCE(filter->layout() == DataLayout::kMKLDNN &&
filter->format() != memory::format::format_undef,
"Wrong layout/format set for Filter tensor");
PADDLE_ENFORCE(input->dims().size() == 4,
"Input must be with 4 dimensions, i.e. NCHW");
PADDLE_ENFORCE(filter->dims().size() == 4,
"Filter must be with 4 dimensions, i.e. OIHW");
if (bias) {
PADDLE_ENFORCE(bias->layout() == DataLayout::kMKLDNN &&
bias->format() != memory::format::format_undef,
"Wrong layout/format set for Bias tensor");
PADDLE_ENFORCE(bias->dims().size() == 1,
"Bias must only have 1 dimension, i.e. X");
}
std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_pd; std::vector<int> strides = ctx.Attr<std::vector<int>>("strides");
auto bias_tz = paddle::framework::vectorize2int(bias->dims()); std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
if(is_INT8){ std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
auto src_md = platform::MKLDNNMemDesc( bool fuse_relu = ctx.Attr<bool>("fuse_relu");
src_tz, memory::data_type::u8, chosen_memory_format); bool fuse_residual_conn = ctx.Attr<bool>("fuse_residual_connection");
auto weights_md = platform::MKLDNNMemDesc( int groups = ctx.Attr<int>("groups");
weights_tz, memory::data_type::s8,
(g == 1) ? chosen_memory_format : mkldnn::memory::format::goihw); // TODO(tpatejko): add support for dilation
auto dst_dt = fuse_relu? paddle::framework::ToMKLDNNDataType(std::type_index(typeid(unsigned char))) : paddle::framework::ToMKLDNNDataType(std::type_index(typeid(signed char))); PADDLE_ENFORCE(
if(fuse_residual_conn){ dilations.size() == 2 && dilations[0] == 1 && dilations[1] == 1,
auto residual = ctx.Input<Tensor>("ResidualData"); "dilation in convolution is not implemented yet");
auto residual_dt = paddle::framework::ToMKLDNNDataType(residual->type());
if(dst_dt != residual_dt) const T* input_data = input->data<T>();
dst_dt = residual_dt; const float* filter_data = filter->data<float>();
std::vector<int> src_tz = paddle::framework::vectorize2int(input->dims());
std::vector<int> weights_tz =
paddle::framework::vectorize2int(filter->dims());
int g = std::max(groups, 1);
if (g > 1) {
int o = weights_tz[0];
int i = weights_tz[1];
int h = weights_tz[2];
int w = weights_tz[3];
weights_tz.resize(5);
weights_tz[0] = g;
weights_tz[1] = o / g;
weights_tz[2] = i;
weights_tz[3] = h;
weights_tz[4] = w;
}
std::vector<int> dst_tz = paddle::framework::vectorize2int(output->dims());
// Get unique name for storing MKLDNN primitives
const std::string key = ConvMKLDNNHandler::GetHash(
src_tz, weights_tz, strides, paddings, dilations, groups,
ctx.op().Output("Output"));
const std::string key_conv_pd = key + "@conv_pd";
static std::unordered_map<std::string, std::vector<float>> scale_map;
bool scale_reuse = false;
auto scale_in_key = key + "@scale_in";
auto scale_weights_key = key + "@scale_weights";
auto scale_out_key = key + "@scale_out";
auto output_shift_scale_key = key + "@output_shift_scale";
auto sum_scale_key = key + "@sum_scale";
auto scale_in_eltwise_key = key + "@scale_in_eltwise";
std::vector<float> scale_in_data;
std::vector<float> scale_out_data;
std::vector<float> scale_weights_data;
std::vector<float> scale_in_eltwise_data;
std::vector<float> output_shift_scale;
std::vector<float> sum_scale = {1.0f};
std::vector<float> none_scale = {0};
if (GetScaleMap(scale_map, scale_in_key) == none_scale){
scale_reuse = true;
}
//std::cout<<"scale_reuse = "<<scale_reuse<<std::endl;
if(scale_reuse){
//std::cout<<"load scale!!!!!!!!"<<std::endl;
int count = is_multi_channel? (g>1? weights_tz[1]*weights_tz[0] : weights_tz[0]) : 1;
scale_in_data = {*(scale_in->data<float>())};
scale_weights_data.resize(count);
#pragma omp parallel for if (count > 1)
for(int i=0; i<count; i++){
scale_weights_data[i] =*(scale_weights->data<float>() + i);
} }
auto dst_md = platform::MKLDNNMemDesc(dst_tz, dst_dt, chosen_memory_format); scale_out_data = {*(scale_out->data<float>())};
output_shift_scale.resize(count);
// create a conv primitive descriptor and save it for usage in backward #pragma omp parallel for if (count > 1)
if (bias) { for(int i=0; i<count; i++){
auto bias_md = platform::MKLDNNMemDesc( if(scale_weights_data[i] == 0.0)
bias_tz, memory::data_type::s32, memory::format::x); output_shift_scale[i] = scale_out_data[0];
conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, bias_md, dst_md, else
strides, paddings, mkldnn_engine, output_shift_scale[i] = scale_out_data[0] / (scale_in_data[0] * scale_weights_data[i]);
fuse_relu, fuse_residual_conn,
output_shift_scale, sum_scale[0], is_test);
} else {
conv_pd =
ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides, paddings,
mkldnn_engine, fuse_relu, fuse_residual_conn,
output_shift_scale, sum_scale[0], is_test);
} }
} else{ if(fuse_residual_conn){
auto src_md = platform::MKLDNNMemDesc( scale_in_eltwise_data = {*(scale_in_eltwise->data<float>())};
src_tz, platform::MKLDNNGetDataType<float>(), chosen_memory_format); sum_scale[0] = scale_out_data[0] / scale_in_eltwise_data[0];
auto weights_md = platform::MKLDNNMemDesc( SetScaleMap(scale_map, scale_in_eltwise_key, scale_in_eltwise_data);
weights_tz, platform::MKLDNNGetDataType<float>(),
(g == 1) ? chosen_memory_format : mkldnn::memory::format::goihw);
auto dst_md = platform::MKLDNNMemDesc(
dst_tz, platform::MKLDNNGetDataType<float>(), chosen_memory_format);
// create a conv primitive descriptor and save it for usage in backward
if (bias) {
auto bias_md = platform::MKLDNNMemDesc(
bias_tz, platform::MKLDNNGetDataType<float>(), memory::format::x);
conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, bias_md, dst_md,
strides, paddings, mkldnn_engine,
fuse_relu, fuse_residual_conn, is_test);
} else {
conv_pd =
ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides, paddings,
mkldnn_engine, fuse_relu, fuse_residual_conn, is_test);
} }
}
// Save conv_pd/src_memory/weights_memory for backward pass
dev_ctx.SetBlob(key_conv_pd, conv_pd);
ConvMKLDNNHandler handler(conv_pd, dev_ctx, mkldnn_engine, key);
// create mkldnn memory from input tensors (data/weights) //scale reuse
auto user_src_memory_p = SetScaleMap(scale_map, scale_in_key, scale_in_data);
handler.AcquireSrcMemory(user_src_md, to_void_cast<T>(input_data)); SetScaleMap(scale_map, scale_weights_key, scale_weights_data);
auto user_weights_memory_p = handler.AcquireWeightsMemory( SetScaleMap(scale_map, scale_out_key, scale_out_data);
user_weights_md, to_void_cast<float>(filter_data)); SetScaleMap(scale_map, output_shift_scale_key, output_shift_scale);
SetScaleMap(scale_map, sum_scale_key, sum_scale);
// create reorder primitive if the input format is not the preferred one } else{
auto src_memory_p = scale_in_data = GetScaleMap(scale_map, scale_in_key);
handler.AcquireSrcMemoryFromPrimitive(user_src_memory_p, pipeline); scale_out_data = GetScaleMap(scale_map, scale_out_key);
scale_weights_data = GetScaleMap(scale_map, scale_weights_key);
std::shared_ptr<mkldnn::memory> weights_memory_p; if(fuse_residual_conn){
if(is_INT8){ scale_in_eltwise_data = GetScaleMap(scale_map, scale_in_eltwise_key);
int mask_reorder = is_multi_channel? ((g!= 1) ? (1<<1)+(1<<0) : 1<<0) : 0; }
weights_memory_p = handler.AcquireWeightsMemoryFromPrimitive( output_shift_scale = GetScaleMap(scale_map, output_shift_scale_key);
user_weights_memory_p, pipeline, is_test, is_INT8, scale_weights_data, mask_reorder); sum_scale = GetScaleMap(scale_map, sum_scale_key);
} else{ //printf("pause!!!");
weights_memory_p = handler.AcquireWeightsMemoryFromPrimitive( }
user_weights_memory_p, pipeline, is_test);
}
std::shared_ptr<mkldnn::memory> dst_memory_p; std::vector<primitive> pipeline;
bool need_s8_to_u8 = false; auto user_src_md = platform::MKLDNNMemDesc(
if(fuse_residual_conn) { {src_tz}, paddle::framework::ToMKLDNNDataType(input->type()), input->format());
auto user_weights_md = platform::MKLDNNMemDesc(
{weights_tz}, platform::MKLDNNGetDataType<float>(),
(g == 1) ? mkldnn::memory::format::oihw : mkldnn::memory::format::goihw);
/* create memory descriptor for convolution without specified format
* ('any') which lets a primitive (convolution in this case) choose
* the memory format preferred for best performance
*/
std::string data_format = ctx.Attr<std::string>("data_format");
auto chosen_memory_format =
platform::data_format_to_memory_format(data_format);
std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_pd;
auto bias_tz = paddle::framework::vectorize2int(bias->dims());
auto src_md = platform::MKLDNNMemDesc(
src_tz, memory::data_type::u8, chosen_memory_format);
auto weights_md = platform::MKLDNNMemDesc(
weights_tz, memory::data_type::s8,
(g == 1) ? chosen_memory_format : mkldnn::memory::format::goihw);
auto dst_dt = fuse_relu? paddle::framework::ToMKLDNNDataType(std::type_index(typeid(unsigned char))) : paddle::framework::ToMKLDNNDataType(std::type_index(typeid(signed char)));
if(fuse_residual_conn){
auto residual = ctx.Input<Tensor>("ResidualData");
auto residual_dt = paddle::framework::ToMKLDNNDataType(residual->type());
if(dst_dt != residual_dt)
dst_dt = residual_dt;
}
auto dst_md = platform::MKLDNNMemDesc(dst_tz, dst_dt, chosen_memory_format);
// create a conv primitive descriptor and save it for usage in backward
if (bias) {
auto bias_md = platform::MKLDNNMemDesc(
bias_tz, memory::data_type::s32, memory::format::x);
conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, bias_md, dst_md,
strides, paddings, mkldnn_engine,
fuse_relu, fuse_residual_conn,
output_shift_scale, sum_scale[0], is_test);
} else {
conv_pd =
ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides, paddings,
mkldnn_engine, fuse_relu, fuse_residual_conn,
output_shift_scale, sum_scale[0], is_test);
}
// Save conv_pd/src_memory/weights_memory for backward pass
dev_ctx.SetBlob(key_conv_pd, conv_pd);
ConvMKLDNNHandler handler(conv_pd, dev_ctx, mkldnn_engine, key);
// create mkldnn memory from input tensors (data/weights)
auto user_src_memory_p =
handler.AcquireSrcMemory(user_src_md, to_void_cast<T>(input_data));
auto user_weights_memory_p = handler.AcquireWeightsMemory(
user_weights_md, to_void_cast<float>(filter_data));
// create reorder primitive if the input format is not the preferred one
auto src_memory_p =
handler.AcquireSrcMemoryFromPrimitive(user_src_memory_p, pipeline);
std::shared_ptr<mkldnn::memory> weights_memory_p;
int mask_reorder = is_multi_channel? ((g!= 1) ? (1<<1)+(1<<0) : 1<<0) : 0;
weights_memory_p = handler.AcquireWeightsMemoryFromPrimitive(
user_weights_memory_p, pipeline, is_test, is_INT8, scale_weights_data, mask_reorder);
std::shared_ptr<mkldnn::memory> dst_memory_p;
bool need_s8_to_u8 = false;
if(fuse_residual_conn) {
auto residual_param = ctx.Input<Tensor>("ResidualData"); auto residual_param = ctx.Input<Tensor>("ResidualData");
PADDLE_ENFORCE_EQ(output->dims(), residual_param->dims(), PADDLE_ENFORCE_EQ(output->dims(), residual_param->dims(),
"Output and elementwise parameter need to have the " "Output and elementwise parameter need to have the "
"same dimension sizes"); "same dimension sizes");
auto residual_dt = paddle::framework::ToMKLDNNDataType(residual_param->type()); auto residual_dt = paddle::framework::ToMKLDNNDataType(residual_param->type());
if(residual_param->format() != handler.GetDstFormat()) { if(residual_param->format() != handler.GetDstFormat()) {
auto residual_data_tz = auto residual_data_tz =
paddle::framework::vectorize2int(residual_param->dims()); paddle::framework::vectorize2int(residual_param->dims());
auto residual_data_type = auto residual_data_type =
paddle::framework::ToMKLDNNDataType(residual_param->type()); paddle::framework::ToMKLDNNDataType(residual_param->type());
auto user_residual_md = platform::MKLDNNMemDesc( auto user_residual_md = platform::MKLDNNMemDesc(
residual_data_tz, residual_data_type, residual_param->format()); residual_data_tz, residual_data_type, residual_param->format());
if(is_INT8){ if(residual_dt == mkldnn::memory::data_type::u8){
if(residual_dt == mkldnn::memory::data_type::u8){ auto residual_param_data = residual_param->data<uint8_t>();
auto residual_param_data = residual_param->data<uint8_t>(); auto user_residual_memory_p = handler.AcquireResidualDataMemory(
auto user_residual_memory_p = handler.AcquireResidualDataMemory( user_residual_md, to_void_cast<uint8_t>(residual_param_data));
user_residual_md, to_void_cast<uint8_t>(residual_param_data)); PADDLE_ENFORCE(
PADDLE_ENFORCE( residual_param_data != nullptr,
residual_param_data != nullptr, "Provide data if you want MKLDNN conv+elementwise_add fusion");
"Provide data if you want MKLDNN conv+elementwise_add fusion"); uint8_t* output_data = output->mutable_data<uint8_t>(ctx.GetPlace());
uint8_t* output_data = output->mutable_data<uint8_t>(ctx.GetPlace()); dst_memory_p =
dst_memory_p = handler.AcquireDstMemoryFromResidualDataMemory(
handler.AcquireDstMemoryFromResidualDataMemory( user_residual_memory_p, to_void_cast<uint8_t>(output_data), pipeline);
user_residual_memory_p, to_void_cast<uint8_t>(output_data), pipeline);
} else{
auto residual_param_data = residual_param->data<int8_t>();
auto user_residual_memory_p = handler.AcquireResidualDataMemory(
user_residual_md, to_void_cast<int8_t>(residual_param_data));
PADDLE_ENFORCE(
residual_param_data != nullptr,
"Provide data if you want MKLDNN conv+elementwise_add fusion");
int8_t* output_data = output->mutable_data<int8_t>(ctx.GetPlace());
dst_memory_p =
handler.AcquireDstMemoryFromResidualDataMemory(
user_residual_memory_p, to_void_cast<int8_t>(output_data), pipeline);
if(fuse_relu)
need_s8_to_u8 = true;
}
} else{ } else{
auto residual_param_data = residual_param->data<T>(); auto residual_param_data = residual_param->data<int8_t>();
auto user_residual_memory_p = handler.AcquireResidualDataMemory( auto user_residual_memory_p = handler.AcquireResidualDataMemory(
user_residual_md, to_void_cast<T>(residual_param_data)); user_residual_md, to_void_cast<int8_t>(residual_param_data));
PADDLE_ENFORCE( PADDLE_ENFORCE(
residual_param_data != nullptr, residual_param_data != nullptr,
"Provide data if you want MKLDNN conv+elementwise_add fusion"); "Provide data if you want MKLDNN conv+elementwise_add fusion");
auto output_data = int8_t* output_data = output->mutable_data<int8_t>(ctx.GetPlace());
output->mutable_data<T>(ctx.GetPlace(), handler.GetDstMemorySize()); dst_memory_p =
dst_memory_p = handler.AcquireDstMemoryFromResidualDataMemory( handler.AcquireDstMemoryFromResidualDataMemory(
user_residual_memory_p, to_void_cast<T>(output_data), pipeline); user_residual_memory_p, to_void_cast<int8_t>(output_data), pipeline);
if(fuse_relu)
need_s8_to_u8 = true;
} }
} else { } else {
output->ShareDataWith(*residual_param); output->ShareDataWith(*residual_param);
if(is_INT8){ if(residual_dt == mkldnn::memory::data_type::u8){
if(residual_dt == mkldnn::memory::data_type::u8){ uint8_t* output_data = output->mutable_data<uint8_t>(ctx.GetPlace());
dst_memory_p =
uint8_t* output_data = output->mutable_data<uint8_t>(ctx.GetPlace()); handler.AcquireDstMemoryFromPrimitive(to_void_cast<uint8_t>(output_data));
dst_memory_p =
handler.AcquireDstMemoryFromPrimitive(to_void_cast<uint8_t>(output_data));
} else{
int8_t* output_data = output->mutable_data<int8_t>(ctx.GetPlace());
dst_memory_p =
handler.AcquireDstMemoryFromPrimitive(to_void_cast<int8_t>(output_data));
if(fuse_relu)
need_s8_to_u8 = true;
}
} else{
auto output_data = output->mutable_data<T>(ctx.GetPlace());
dst_memory_p =
handler.AcquireDstMemoryFromPrimitive(to_void_cast<T>(output_data));
}
}
} else {
if(is_INT8){
if(fuse_relu){
uint8_t* output_data = output->mutable_data<uint8_t>(ctx.GetPlace(), handler.GetDstMemorySize());
dst_memory_p =
handler.AcquireDstMemoryFromPrimitive(to_void_cast<uint8_t>(output_data));
} else{ } else{
int8_t* output_data = output->mutable_data<int8_t>(ctx.GetPlace(), handler.GetDstMemorySize()); int8_t* output_data = output->mutable_data<int8_t>(ctx.GetPlace());
dst_memory_p = dst_memory_p =
handler.AcquireDstMemoryFromPrimitive(to_void_cast<int8_t>(output_data)); handler.AcquireDstMemoryFromPrimitive(to_void_cast<int8_t>(output_data));
if(fuse_relu)
need_s8_to_u8 = true;
} }
}
} else {
if(fuse_relu){
uint8_t* output_data = output->mutable_data<uint8_t>(ctx.GetPlace(), handler.GetDstMemorySize());
dst_memory_p =
handler.AcquireDstMemoryFromPrimitive(to_void_cast<uint8_t>(output_data));
} else{ } else{
auto output_data = int8_t* output_data = output->mutable_data<int8_t>(ctx.GetPlace(), handler.GetDstMemorySize());
output->mutable_data<T>(ctx.GetPlace(), handler.GetDstMemorySize()); dst_memory_p =
dst_memory_p = handler.AcquireDstMemoryFromPrimitive(to_void_cast<int8_t>(output_data));
handler.AcquireDstMemoryFromPrimitive(to_void_cast<T>(output_data));
} }
} }
// create convolution op primitive // create convolution op primitive
std::shared_ptr<mkldnn::convolution_forward> conv_p; std::shared_ptr<mkldnn::convolution_forward> conv_p;
std::vector<float> scale_bias_data; std::vector<float> scale_bias_data;
auto scale_bias_key = key + "@scale_bias"; auto scale_bias_key = key + "@scale_bias";
if (bias) { if (bias) {
const float* bias_data = bias->data<float>(); const float* bias_data = bias->data<float>();
auto user_bias_md = platform::MKLDNNMemDesc( auto user_bias_md = platform::MKLDNNMemDesc(
{bias_tz}, platform::MKLDNNGetDataType<float>(), memory::format::x); {bias_tz}, platform::MKLDNNGetDataType<float>(), memory::format::x);
auto user_bias_memory_p = auto user_bias_memory_p =
handler.AcquireBiasMemory(user_bias_md, to_void_cast<float>(bias_data)); handler.AcquireBiasMemory(user_bias_md, to_void_cast<float>(bias_data));
std::shared_ptr<mkldnn::memory> bias_memory_p; std::shared_ptr<mkldnn::memory> bias_memory_p;
if(is_INT8){ int mask_reorder = is_multi_channel? 1<<0 : 1;
int mask_reorder = is_multi_channel? 1<<0 : 1; if(scale_reuse){
if(scale_reuse){ int count = is_multi_channel? (g>1? weights_tz[1]*weights_tz[0] : weights_tz[0]) : 1;
int count = is_multi_channel? (g>1? weights_tz[1]*weights_tz[0] : weights_tz[0]) : 1; scale_bias_data.resize(count);
scale_bias_data.resize(count); #pragma omp parallel for if (count > 1)
#pragma omp parallel for if (count > 1) for(int i=0; i<count; i++){
for(int i=0; i<count; i++){ scale_bias_data[i] = scale_in_data[0] * scale_weights_data[i];
scale_bias_data[i] = scale_in_data[0] * scale_weights_data[i]; }
} SetScaleMap(scale_map, scale_bias_key, scale_bias_data);
SetScaleMap(scale_map, scale_bias_key, scale_bias_data); } else{
} else{ scale_bias_data = GetScaleMap(scale_map, scale_bias_key);
scale_bias_data = GetScaleMap(scale_map, scale_bias_key); }
} bias_memory_p =
bias_memory_p = handler.AcquireBiasMemoryFromPrimitive(user_bias_memory_p, pipeline, is_test, is_INT8, scale_bias_data, mask_reorder);
handler.AcquireBiasMemoryFromPrimitive(user_bias_memory_p, pipeline, is_test, is_INT8, scale_bias_data, mask_reorder); conv_p = handler.AcquireConvolution(src_memory_p, weights_memory_p,
} else{ bias_memory_p, dst_memory_p);
bias_memory_p = } else {
handler.AcquireBiasMemoryFromPrimitive(user_bias_memory_p, pipeline); conv_p = handler.AcquireConvolution(src_memory_p, weights_memory_p,
} dst_memory_p);
conv_p = handler.AcquireConvolution(src_memory_p, weights_memory_p, }
bias_memory_p, dst_memory_p);
} else {
conv_p = handler.AcquireConvolution(src_memory_p, weights_memory_p,
dst_memory_p);
}
// push primitive to stream and wait until it's executed // push primitive to stream and wait until it's executed
pipeline.push_back(*conv_p); pipeline.push_back(*conv_p);
stream(stream::kind::eager).submit(pipeline).wait(); stream(stream::kind::eager).submit(pipeline).wait();
if(need_s8_to_u8){ if(need_s8_to_u8){
output->mutable_data<uint8_t>(ctx.GetPlace()); output->mutable_data<uint8_t>(ctx.GetPlace());
} }
output->set_layout(DataLayout::kMKLDNN); output->set_layout(DataLayout::kMKLDNN);
output->set_format(GetMKLDNNFormat(*dst_memory_p)); output->set_format(GetMKLDNNFormat(*dst_memory_p));
}
} }
private: private:
...@@ -780,7 +914,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -780,7 +914,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const memory::desc& dst, const std::vector<int>& strides, const memory::desc& dst, const std::vector<int>& strides,
const std::vector<int>& paddings, const std::vector<int>& paddings,
const mkldnn::engine& engine, const bool fuse_relu, const mkldnn::engine& engine, const bool fuse_relu,
const bool fuse_residual_conn, bool is_test) const{ const bool fuse_residual_conn, bool is_test=false) const{
memory::dims stride_dims = {strides[0], strides[1]}; memory::dims stride_dims = {strides[0], strides[1]};
memory::dims padding_dims = {paddings[0], paddings[1]}; memory::dims padding_dims = {paddings[0], paddings[1]};
...@@ -834,7 +968,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -834,7 +968,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const std::vector<int>& strides, const std::vector<int>& strides,
const std::vector<int>& paddings, const std::vector<int>& paddings,
const mkldnn::engine& engine, const bool fuse_relu, const mkldnn::engine& engine, const bool fuse_relu,
const bool fuse_residual_conn, bool is_test) const{ const bool fuse_residual_conn, bool is_test=false) const{
memory::dims stride_dims = {strides[0], strides[1]}; memory::dims stride_dims = {strides[0], strides[1]};
memory::dims padding_dims = {paddings[0], paddings[1]}; memory::dims padding_dims = {paddings[0], paddings[1]};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册