提交 d5167b48 编写于 作者: X xiaolil1

modification for int8 kernel PR, remove requantization op temporarily

上级 a042d86b
......@@ -16,7 +16,6 @@
#include "paddle/fluid/platform/mkldnn_helper.h"
#include "paddle/fluid/framework/data_layout_transform.h"
#include <unordered_map>
#include <map>
namespace paddle {
namespace operators {
......@@ -297,9 +296,6 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext& ctx) const override {
bool is_INT8 = ctx.HasInput("Scale_in")? true : false;
if(!is_INT8){
PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
"It must use CPUPlace.");
......@@ -345,7 +341,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
"dilation in convolution is not implemented yet");
const T* input_data = input->data<T>();
const T* filter_data = filter->data<T>();
const float* filter_data = filter->data<float>();
std::vector<int> src_tz = paddle::framework::vectorize2int(input->dims());
std::vector<int> weights_tz =
......@@ -373,11 +369,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std::vector<primitive> pipeline;
bool is_INT8 = ctx.HasInput("Scale_in")? true : false;
if(!is_INT8){
auto user_src_md = platform::MKLDNNMemDesc(
{src_tz}, platform::MKLDNNGetDataType<T>(), input->format());
auto user_weights_md = platform::MKLDNNMemDesc(
{weights_tz}, platform::MKLDNNGetDataType<T>(),
(g == 1) ? filter->format() : mkldnn::memory::format::goihw);
(g == 1) ? mkldnn::memory::format::oihw : mkldnn::memory::format::goihw);
/* create memory descriptor for convolution without specified format
* ('any') which lets a primitive (convolution in this case) choose
......@@ -393,11 +391,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
weights_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
std::vector<int> bias_tz; // TODO(mgallus): avoid empty vector creation.
// Currently used whenever bias is != nullptr.
auto dst_md = platform::MKLDNNMemDesc(
dst_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
// create a conv primitive descriptor and save it for usage in backward
std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_pd;
if (bias) {
bias_tz = paddle::framework::vectorize2int(bias->dims());
auto bias_md = platform::MKLDNNMemDesc(
......@@ -419,7 +419,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto user_src_memory_p =
handler.AcquireSrcMemory(user_src_md, to_void_cast<T>(input_data));
auto user_weights_memory_p = handler.AcquireWeightsMemory(
user_weights_md, to_void_cast<T>(filter_data));
user_weights_md, to_void_cast<float>(filter_data));
// create reorder primitive if the input format is not the preferred one
auto src_memory_p =
......@@ -492,20 +492,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
output->set_layout(DataLayout::kMKLDNN);
output->set_format(GetMKLDNNFormat(*dst_memory_p));
} else{
PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
"It must use CPUPlace.");
const bool is_test = ctx.Attr<bool>("is_test");
auto& dev_ctx =
ctx.template device_context<paddle::platform::MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine();
auto* input = ctx.Input<Tensor>("Input");
auto* filter = ctx.Input<Tensor>("Filter");
auto* bias = ctx.HasInput("Bias") ? ctx.Input<Tensor>("Bias") : nullptr;
auto* output = ctx.Output<Tensor>("Output");
} else{
auto* scale_in = ctx.HasInput("Scale_in") ? ctx.Input<Tensor>("Scale_in") : nullptr;
auto* scale_in_eltwise = ctx.HasInput("Scale_in_eltwise")? ctx.Input<Tensor>("Scale_in_eltwise") : nullptr;
auto* scale_weights = ctx.HasInput("Scale_weights")? ctx.Input<Tensor>("Scale_weights") : nullptr;
......@@ -513,65 +501,9 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
bool is_multi_channel = (scale_weights->memory_size() > 1) ? true : false;
PADDLE_ENFORCE(input->layout() == DataLayout::kMKLDNN &&
input->format() != memory::format::format_undef,
"Wrong layout/format set for Input tensor");
PADDLE_ENFORCE(filter->layout() == DataLayout::kMKLDNN &&
filter->format() != memory::format::format_undef,
"Wrong layout/format set for Filter tensor");
PADDLE_ENFORCE(input->dims().size() == 4,
"Input must be with 4 dimensions, i.e. NCHW");
PADDLE_ENFORCE(filter->dims().size() == 4,
"Filter must be with 4 dimensions, i.e. OIHW");
if (bias) {
PADDLE_ENFORCE(bias->layout() == DataLayout::kMKLDNN &&
bias->format() != memory::format::format_undef,
"Wrong layout/format set for Bias tensor");
PADDLE_ENFORCE(bias->dims().size() == 1,
"Bias must only have 1 dimension, i.e. X");
}
std::vector<int> strides = ctx.Attr<std::vector<int>>("strides");
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
bool fuse_relu = ctx.Attr<bool>("fuse_relu");
bool fuse_residual_conn = ctx.Attr<bool>("fuse_residual_connection");
int groups = ctx.Attr<int>("groups");
// TODO(tpatejko): add support for dilation
PADDLE_ENFORCE(
dilations.size() == 2 && dilations[0] == 1 && dilations[1] == 1,
"dilation in convolution is not implemented yet");
const T* input_data = input->data<T>();
const float* filter_data = filter->data<float>();
std::vector<int> src_tz = paddle::framework::vectorize2int(input->dims());
std::vector<int> weights_tz =
paddle::framework::vectorize2int(filter->dims());
int g = std::max(groups, 1);
if (g > 1) {
int o = weights_tz[0];
int i = weights_tz[1];
int h = weights_tz[2];
int w = weights_tz[3];
weights_tz.resize(5);
weights_tz[0] = g;
weights_tz[1] = o / g;
weights_tz[2] = i;
weights_tz[3] = h;
weights_tz[4] = w;
}
std::vector<int> dst_tz = paddle::framework::vectorize2int(output->dims());
// Get unique name for storing MKLDNN primitives
const std::string key = ConvMKLDNNHandler::GetHash(
src_tz, weights_tz, strides, paddings, dilations, groups,
ctx.op().Output("Output"));
const std::string key_conv_pd = key + "@conv_pd";
static std::unordered_map<std::string, std::vector<float>> scale_map;
bool scale_reuse = false;
bool scale_reuse = true;
auto scale_in_key = key + "@scale_in";
auto scale_weights_key = key + "@scale_weights";
auto scale_out_key = key + "@scale_out";
......@@ -587,11 +519,10 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std::vector<float> none_scale = {0};
if (GetScaleMap(scale_map, scale_in_key) == none_scale){
scale_reuse = true;
scale_reuse = false;
}
//std::cout<<"scale_reuse = "<<scale_reuse<<std::endl;
if(scale_reuse){
//std::cout<<"load scale!!!!!!!!"<<std::endl;
if(!scale_reuse){
int count = is_multi_channel? (g>1? weights_tz[1]*weights_tz[0] : weights_tz[0]) : 1;
scale_in_data = {*(scale_in->data<float>())};
scale_weights_data.resize(count);
......@@ -629,10 +560,10 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
}
output_shift_scale = GetScaleMap(scale_map, output_shift_scale_key);
sum_scale = GetScaleMap(scale_map, sum_scale_key);
//printf("pause!!!");
}
std::vector<primitive> pipeline;
auto user_src_md = platform::MKLDNNMemDesc(
{src_tz}, paddle::framework::ToMKLDNNDataType(input->type()), input->format());
auto user_weights_md = platform::MKLDNNMemDesc(
......@@ -647,15 +578,17 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto chosen_memory_format =
platform::data_format_to_memory_format(data_format);
std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_pd;
auto bias_tz = paddle::framework::vectorize2int(bias->dims());
auto src_md = platform::MKLDNNMemDesc(
src_tz, memory::data_type::u8, chosen_memory_format);
auto weights_md = platform::MKLDNNMemDesc(
weights_tz, memory::data_type::s8,
(g == 1) ? chosen_memory_format : mkldnn::memory::format::goihw);
auto dst_dt = fuse_relu? paddle::framework::ToMKLDNNDataType(std::type_index(typeid(unsigned char))) : paddle::framework::ToMKLDNNDataType(std::type_index(typeid(signed char)));
weights_tz, memory::data_type::s8, chosen_memory_format);
auto dst_dt = fuse_relu?
paddle::framework::ToMKLDNNDataType(std::type_index(typeid(unsigned char)))
: paddle::framework::ToMKLDNNDataType(std::type_index(typeid(signed char)));
if(fuse_residual_conn){
auto residual = ctx.Input<Tensor>("ResidualData");
auto residual_dt = paddle::framework::ToMKLDNNDataType(residual->type());
......@@ -665,6 +598,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto dst_md = platform::MKLDNNMemDesc(dst_tz, dst_dt, chosen_memory_format);
// create a conv primitive descriptor and save it for usage in backward
std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_pd;
if (bias) {
auto bias_md = platform::MKLDNNMemDesc(
bias_tz, memory::data_type::s32, memory::format::x);
......@@ -706,39 +641,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
"Output and elementwise parameter need to have the "
"same dimension sizes");
auto residual_dt = paddle::framework::ToMKLDNNDataType(residual_param->type());
if(residual_param->format() != handler.GetDstFormat()) {
auto residual_data_tz =
paddle::framework::vectorize2int(residual_param->dims());
auto residual_data_type =
paddle::framework::ToMKLDNNDataType(residual_param->type());
auto user_residual_md = platform::MKLDNNMemDesc(
residual_data_tz, residual_data_type, residual_param->format());
if(residual_dt == mkldnn::memory::data_type::u8){
auto residual_param_data = residual_param->data<uint8_t>();
auto user_residual_memory_p = handler.AcquireResidualDataMemory(
user_residual_md, to_void_cast<uint8_t>(residual_param_data));
PADDLE_ENFORCE(
residual_param_data != nullptr,
"Provide data if you want MKLDNN conv+elementwise_add fusion");
uint8_t* output_data = output->mutable_data<uint8_t>(ctx.GetPlace());
dst_memory_p =
handler.AcquireDstMemoryFromResidualDataMemory(
user_residual_memory_p, to_void_cast<uint8_t>(output_data), pipeline);
} else{
auto residual_param_data = residual_param->data<int8_t>();
auto user_residual_memory_p = handler.AcquireResidualDataMemory(
user_residual_md, to_void_cast<int8_t>(residual_param_data));
PADDLE_ENFORCE(
residual_param_data != nullptr,
"Provide data if you want MKLDNN conv+elementwise_add fusion");
int8_t* output_data = output->mutable_data<int8_t>(ctx.GetPlace());
dst_memory_p =
handler.AcquireDstMemoryFromResidualDataMemory(
user_residual_memory_p, to_void_cast<int8_t>(output_data), pipeline);
if(fuse_relu)
need_s8_to_u8 = true;
}
} else {
PADDLE_ENFORCE_EQ(residual_param->format(), handler.GetDstFormat(),
"Conv input dimension and filter dimension should be the same.");
output->ShareDataWith(*residual_param);
if(residual_dt == mkldnn::memory::data_type::u8){
uint8_t* output_data = output->mutable_data<uint8_t>(ctx.GetPlace());
......@@ -751,7 +655,6 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
if(fuse_relu)
need_s8_to_u8 = true;
}
}
} else {
if(fuse_relu){
uint8_t* output_data = output->mutable_data<uint8_t>(ctx.GetPlace(), handler.GetDstMemorySize());
......@@ -776,7 +679,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
handler.AcquireBiasMemory(user_bias_md, to_void_cast<float>(bias_data));
std::shared_ptr<mkldnn::memory> bias_memory_p;
int mask_reorder = is_multi_channel? 1<<0 : 1;
if(scale_reuse){
if(!scale_reuse){
int count = is_multi_channel? (g>1? weights_tz[1]*weights_tz[0] : weights_tz[0]) : 1;
scale_bias_data.resize(count);
#pragma omp parallel for if (count > 1)
......@@ -918,7 +821,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
memory::dims stride_dims = {strides[0], strides[1]};
memory::dims padding_dims = {paddings[0], paddings[1]};
auto propagation = is_test ? mkldnn::prop_kind::forward_scoring : mkldnn::prop_kind::forward_training;
auto propagation = is_test ? mkldnn::prop_kind::forward_scoring : mkldnn::prop_kind::forward_training; //Fix propagation bug for FP32 inference.
auto conv_desc = mkldnn::convolution_forward::desc(
propagation, mkldnn::convolution_direct, src, weights,
......@@ -972,7 +875,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
memory::dims stride_dims = {strides[0], strides[1]};
memory::dims padding_dims = {paddings[0], paddings[1]};
auto propagation = is_test ? mkldnn::prop_kind::forward_scoring : mkldnn::prop_kind::forward_training;
auto propagation = is_test ? mkldnn::prop_kind::forward_scoring : mkldnn::prop_kind::forward_training; //Fix propagation bug for FP32 inference.
auto conv_desc = mkldnn::convolution_forward::desc(
propagation, mkldnn::convolution_direct, src, weights,
......
......@@ -54,7 +54,7 @@ class DeQuantOpKernel : public framework::OpKernel<T> {
std::vector<int> src_tz = paddle::framework::vectorize2int(input->dims());
std::vector<int> dst_tz = paddle::framework::vectorize2int(output->dims());
mkldnn::memory::data_type src_dt = paddle::framework::ToMKLDNNDataType(input->type());
mkldnn::memory::format src_fmt = memory::format::nhwc;//input->format();
mkldnn::memory::format src_fmt = input->format();
mkldnn::primitive_attr attri;
int mask = 0;
......@@ -101,12 +101,10 @@ framework::OpKernelType DeQuantOp::GetExpectedKernelType(const framework::Execut
}
void DeQuantOpMaker::Make() {
AddInput("Input","input");
AddInput("Scale","scale...");
AddOutput("Output","output");
AddComment(R"DOC(
This op will quantize data from INT8 to FP32
)DOC");
AddInput("Input","input data");
AddInput("Scale","scale data");
AddOutput("Output","output data");
AddComment(R"DOC(This op will quantize data from INT8 to FP32)DOC");
}
} // namespace operators
......
......@@ -95,18 +95,17 @@ framework::OpKernelType QuantOp::GetExpectedKernelType(const framework::Executio
void QuantOpMaker::Make() {
AddInput("Input","input");
AddInput("Scale","scale...");
AddOutput("Output","output");
AddComment(R"DOC(
This op will quantize data from FP32 to INT8
)DOC");
AddInput("Input","input data");
AddInput("Scale","scale data");
AddOutput("Output","output data");
AddComment(R"DOC(This op will quantize data from FP32 to INT8)DOC");
}
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
//TODO Support FP32->S8 quantization.
REGISTER_OPERATOR(quantize, ops::QuantOp, ops::QuantOpMaker, paddle::framework::DefaultGradOpDescMaker<true>);
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "mkldnn.hpp"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/mkldnn_helper.h"
#include "paddle/fluid/operators/requantize_op.h"
#include "paddle/fluid/framework/data_layout_transform.h"
namespace paddle {
namespace operators {
using mkldnn::memory;
using mkldnn::primitive;
using mkldnn::reorder;
using platform::to_void_cast;
using Tensor = framework::Tensor;
using framework::DataLayout;
using mkldnn::stream;
using platform::GetMKLDNNFormat;
template <typename T>
class ReQuantOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
//std::cout<<"this is requant op!!!!!"<<std::endl;
auto* input = ctx.Input<Tensor>("Input");
//auto* scale = ctx.Input<Tensor>("Scale");
auto* output = ctx.Output<Tensor>("Output");
auto& dev_ctx =
ctx.template device_context<platform::MKLDNNDeviceContext>();
const auto& engine = dev_ctx.GetEngine();
std::vector<primitive> pipeline;
std::vector<int> src_tz = paddle::framework::vectorize2int(input->dims());
std::vector<int> dst_tz = paddle::framework::vectorize2int(output->dims());
mkldnn::memory::data_type src_dt = paddle::framework::ToMKLDNNDataType(input->type());
mkldnn::memory::data_type dst_dt = mkldnn::memory::data_type::u8;//paddle::framework::ToMKLDNNDataType(output->type());
mkldnn::memory::format src_fmt = memory::format::nhwc;//input->format();
mkldnn::memory::format dst_fmt = memory::format::nhwc;//output->format();
const T* input_data = input->data<T>();
uint8_t* output_data = output->mutable_data<uint8_t>(ctx.GetPlace());
//T scale_data = *(scale->data<T>());
std::vector<float> scale_data = {0.9999999}; //{*(scale->data<float>())};
mkldnn::primitive_attr attri;
int mask = 0;
attri.set_output_scales(mask,scale_data);// scale_data);
//attri.set_int_output_round_mode(round_nearest); //FIX ME
auto src_md = platform::MKLDNNMemDesc(
{src_tz}, src_dt, src_fmt); //FIX ME WITH S8
auto src_pd = mkldnn::memory::primitive_desc(src_md, engine);
auto src_memory = std::make_shared<mkldnn::memory>(src_pd, to_void_cast<T>(input_data));
std::shared_ptr<primitive::at> src_memory_p = std::shared_ptr<primitive::at>(new primitive::at(*src_memory));
auto dst_md = platform::MKLDNNMemDesc(
{dst_tz}, dst_dt, dst_fmt);
auto dst_pd = mkldnn::memory::primitive_desc(dst_md, engine);
auto dst_memory = mkldnn::memory(dst_pd, to_void_cast<uint8_t>(output_data));
auto reorder_pd = std::shared_ptr<reorder::primitive_desc>(
new reorder::primitive_desc(src_pd, dst_pd, attri));
int is_sum = ctx.Attr<int>("is_sum");
if(is_sum){
//std::cout<<"is_sum == true"<<std::endl;
memcpy(output_data, input_data, sizeof(uint8_t) * input->numel());
} else{
auto reorder_p= std::shared_ptr<reorder>(new reorder(*reorder_pd, *src_memory_p, dst_memory));
pipeline.push_back(*reorder_p);
stream(stream::kind::eager).submit(pipeline).wait();
}
output->set_layout(DataLayout::kMKLDNN);
output->set_format(GetMKLDNNFormat(dst_memory));
//std::cout<<"requant op end!!!!!"<<std::endl;
}
};
framework::OpKernelType ReQuantOp::GetExpectedKernelType(const framework::ExecutionContext& ctx) const {
framework::LibraryType library_{framework::LibraryType::kPlain};
std::string data_format = ctx.Attr<std::string>("data_format");
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN;
}
#endif
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("Input")->type()),ctx.GetPlace(),layout_, library_);
}
void ReQuantOpMaker::Make() {
AddInput("Input","input");
AddInput("Scale","scale...");
AddOutput("Output","output");
AddComment(R"DOC(
This op will requantize data from INT8 to INT8
)DOC");
}
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(requantize, ops::ReQuantOp, ops::ReQuantOpMaker, paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OP_KERNEL(requantize, MKLDNN, ::paddle::platform::CPUPlace, ops::ReQuantOpKernel<int8_t>);
......@@ -153,7 +153,6 @@ class MKLDNNHandler {
std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key));
PADDLE_ENFORCE((mem_p != nullptr) || (is_reusing_ == false),
"Fail to find mem primitive in device context");
//mem_p = nullptr;
if (mem_p == nullptr) {
mem_p = std::make_shared<mkldnn::memory>(mdp, ptr);
dev_ctx_.SetBlob(local_key, mem_p);
......@@ -234,10 +233,9 @@ class MKLDNNHandler {
std::shared_ptr<mkldnn::primitive> reorder_p;
if (mpd != user_mpd) {
target_memory_p = std::make_shared<mkldnn::memory>(mpd);
std::shared_ptr<mkldnn::reorder> reorder_p;// =
//std::make_shared<mkldnn::reorder>(*user_memory_p, *target_memory_p);
std::shared_ptr<mkldnn::reorder> reorder_p;
if(is_INT8){
mkldnn::primitive_attr attri;
mkldnn::primitive_attr attri; //attribute for int8 weights and bias data reorder.
attri.set_output_scales(mask, scale_data);
auto reorder_pd = std::shared_ptr<mkldnn::reorder::primitive_desc>(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册