提交 3e5c1915 编写于 作者: X xiaolil1

add weights reorder

上级 1f3300ba
...@@ -131,17 +131,24 @@ class ConvMKLDNNHandler : public platform::MKLDNNHandler { ...@@ -131,17 +131,24 @@ class ConvMKLDNNHandler : public platform::MKLDNNHandler {
std::shared_ptr<mkldnn::memory> AcquireWeightsMemoryFromPrimitive( std::shared_ptr<mkldnn::memory> AcquireWeightsMemoryFromPrimitive(
const std::shared_ptr<mkldnn::memory> user_weights_memory_p, const std::shared_ptr<mkldnn::memory> user_weights_memory_p,
std::vector<mkldnn::primitive>& pipeline, // NOLINT std::vector<mkldnn::primitive>& pipeline, // NOLINT
bool is_persistent = false) { bool is_persistent = false,
bool is_INT8 = false,
std::vector<float> scale_data = {1.0f},
int mask = 0) {
auto user_weights_pd = user_weights_memory_p->get_primitive_desc(); auto user_weights_pd = user_weights_memory_p->get_primitive_desc();
auto weights_pd = conv_pd_->weights_primitive_desc(); auto weights_pd = conv_pd_->weights_primitive_desc();
return this->AcquireMemory(weights_pd, user_weights_pd, return this->AcquireMemory(weights_pd, user_weights_pd,
user_weights_memory_p, "@weights_mem_p", user_weights_memory_p, "@weights_mem_p",
pipeline, is_persistent); pipeline, is_persistent,
is_INT8, scale_data, mask);
} }
std::shared_ptr<mkldnn::memory> AcquireBiasMemoryFromPrimitive( std::shared_ptr<mkldnn::memory> AcquireBiasMemoryFromPrimitive(
const std::shared_ptr<mkldnn::memory> user_bias_memory_p, const std::shared_ptr<mkldnn::memory> user_bias_memory_p,
std::vector<mkldnn::primitive>& pipeline) { // NOLINT std::vector<mkldnn::primitive>& pipeline,
bool is_INT8 = false,
std::vector<float> scale_data = {1.0f},
int mask = 0) { // NOLINT
auto user_bias_pd = user_bias_memory_p->get_primitive_desc(); auto user_bias_pd = user_bias_memory_p->get_primitive_desc();
auto bias_pd = conv_pd_->bias_primitive_desc(); auto bias_pd = conv_pd_->bias_primitive_desc();
return this->AcquireMemory(bias_pd, user_bias_pd, user_bias_memory_p, return this->AcquireMemory(bias_pd, user_bias_pd, user_bias_memory_p,
...@@ -283,6 +290,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -283,6 +290,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto* scale_in_eltwise = ctx.HasInput("Scale_in_eltwise")? ctx.Input<Tensor>("Scale_in_eltwise") : nullptr; auto* scale_in_eltwise = ctx.HasInput("Scale_in_eltwise")? ctx.Input<Tensor>("Scale_in_eltwise") : nullptr;
auto* scale_weights = ctx.HasInput("Scale_weights")? ctx.Input<Tensor>("Scale_weights") : nullptr; auto* scale_weights = ctx.HasInput("Scale_weights")? ctx.Input<Tensor>("Scale_weights") : nullptr;
auto* scale_out = ctx.HasInput("Scale_out")? ctx.Input<Tensor>("Scale_out") : nullptr; auto* scale_out = ctx.HasInput("Scale_out")? ctx.Input<Tensor>("Scale_out") : nullptr;
bool is_multi_channel = (is_INT8 && scale_weights->memory_size() > 1) ? true : false;
PADDLE_ENFORCE(input->layout() == DataLayout::kMKLDNN && PADDLE_ENFORCE(input->layout() == DataLayout::kMKLDNN &&
input->format() != memory::format::format_undef, input->format() != memory::format::format_undef,
...@@ -338,12 +346,12 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -338,12 +346,12 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std::vector<T> output_shift_scale; std::vector<T> output_shift_scale;
T sum_scale = 1.0f; T sum_scale = 1.0f;
if(is_INT8){ if(is_INT8){
int count = g>1? weights_tz[1]*weights_tz[0] : weights_tz[0]; int count = is_multi_channel? (g>1? weights_tz[1]*weights_tz[0] : weights_tz[0]) : 1;
T scale_in_data = *(scale_in->data<T>()); T scale_in_data = *(scale_in->data<T>());
T scale_in_eltwise_data = *(scale_in_eltwise->data<T>()); T scale_in_eltwise_data = *(scale_in_eltwise->data<T>());
std::vector<T> scale_weights_data(count); std::vector<T> scale_weights_data(count);
for(int i=0; i<count; i++){ for(int i=0; i<count; i++){
scale_weights_data[i] =*(scale_weights->data<T>()); scale_weights_data[i] =*(scale_weights->data<T>() + i);
} }
T scale_out_data = *(scale_out->data<T>()); T scale_out_data = *(scale_out->data<T>());
...@@ -436,6 +444,16 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -436,6 +444,16 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
handler.AcquireSrcMemoryFromPrimitive(user_src_memory_p, pipeline); handler.AcquireSrcMemoryFromPrimitive(user_src_memory_p, pipeline);
auto weights_memory_p = handler.AcquireWeightsMemoryFromPrimitive( auto weights_memory_p = handler.AcquireWeightsMemoryFromPrimitive(
user_weights_memory_p, pipeline, is_test); user_weights_memory_p, pipeline, is_test);
if(is_INT8){
int mask_reorder = is_multi_channel? 0 : ((g!= 1) ? (1<<1)+(1<<0) : 1<<0);
int count = is_multi_channel? (g>1? weights_tz[1]*weights_tz[0] : weights_tz[0]) : 1;
std::vector<T> scale_weights_data(count);
for(int i=0; i<count; i++){
scale_weights_data[i] = *(scale_weights->data<T>() + i);
}
auto weights_memory_p = handler.AcquireWeightsMemoryFromPrimitive(
user_weights_memory_p, pipeline, is_test, is_INT8, scale_weights_data, mask_reorder);
}
auto dst_memory_p = auto dst_memory_p =
handler.AcquireDstMemoryFromPrimitive(to_void_cast<T>(output_data)); handler.AcquireDstMemoryFromPrimitive(to_void_cast<T>(output_data));
...@@ -447,9 +465,18 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -447,9 +465,18 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
{bias_tz}, platform::MKLDNNGetDataType<T>(), memory::format::x); {bias_tz}, platform::MKLDNNGetDataType<T>(), memory::format::x);
auto user_bias_memory_p = auto user_bias_memory_p =
handler.AcquireBiasMemory(user_bias_md, to_void_cast<T>(bias_data)); handler.AcquireBiasMemory(user_bias_md, to_void_cast<T>(bias_data));
auto bias_memory_p = auto bias_memory_p =
handler.AcquireBiasMemoryFromPrimitive(user_bias_memory_p, pipeline); handler.AcquireBiasMemoryFromPrimitive(user_bias_memory_p, pipeline);
if(is_INT8){
int mask_reorder = is_multi_channel? 0 : 1<<0;
int count = is_multi_channel? (g>1? weights_tz[1]*weights_tz[0] : weights_tz[0]) : 1;
std::vector<T> scale_bias_data(count);
for(int i=0; i<count; i++){
scale_bias_data[i] = (*scale_in->data<T>()) * (*(scale_weights->data<T>() + i));
}
auto bias_memory_p =
handler.AcquireBiasMemoryFromPrimitive(user_bias_memory_p, pipeline, is_INT8, scale_bias_data, mask_reorder);
}
conv_p = handler.AcquireConvolution(src_memory_p, weights_memory_p, conv_p = handler.AcquireConvolution(src_memory_p, weights_memory_p,
bias_memory_p, dst_memory_p); bias_memory_p, dst_memory_p);
} else { } else {
...@@ -470,7 +497,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -470,7 +497,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const std::vector<T> output_shift_scale, T sum_scale) const { const std::vector<T> output_shift_scale, T sum_scale) const {
mkldnn::primitive_attr conv_attr; mkldnn::primitive_attr conv_attr;
mkldnn::post_ops post_operations; mkldnn::post_ops post_operations;
int mask = 0; int mask = output_shift_scale.size() > 1 ? 1<<1 : 0;
conv_attr.set_output_scales(mask, output_shift_scale); conv_attr.set_output_scales(mask, output_shift_scale);
if (fuse_eltwise) { if (fuse_eltwise) {
post_operations.append_sum(sum_scale); post_operations.append_sum(sum_scale);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册