提交 1f9da193 编写于 作者: C chuanqiw

Merge branch 'prv-calibration' of...

Merge branch 'prv-calibration' of ssh://git-ccr-1.devtools.intel.com:29418/intelpaddle-shanghai into prv-calibration
...@@ -300,7 +300,6 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -300,7 +300,6 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()), PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
"It must use CPUPlace."); "It must use CPUPlace.");
const bool is_test = ctx.Attr<bool>("is_test"); const bool is_test = ctx.Attr<bool>("is_test");
auto& dev_ctx = auto& dev_ctx =
...@@ -335,8 +334,14 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -335,8 +334,14 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations"); std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
bool fuse_relu = ctx.Attr<bool>("fuse_relu"); bool fuse_relu = ctx.Attr<bool>("fuse_relu");
bool fuse_residual_conn = ctx.Attr<bool>("fuse_residual_connection"); bool fuse_residual_conn = ctx.Attr<bool>("fuse_residual_connection");
bool force_fp32_output = ctx.Attr<bool>("force_fp32_output");
int groups = ctx.Attr<int>("groups"); int groups = ctx.Attr<int>("groups");
if (fuse_residual_conn) {
PADDLE_ENFORCE(force_fp32_output != true,
"residual fusion does not support force output with fp32");
}
// TODO(tpatejko): add support for dilation // TODO(tpatejko): add support for dilation
PADDLE_ENFORCE( PADDLE_ENFORCE(
dilations.size() == 2 && dilations[0] == 1 && dilations[1] == 1, dilations.size() == 2 && dilations[0] == 1 && dilations[1] == 1,
...@@ -378,20 +383,27 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -378,20 +383,27 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std::shared_ptr<mkldnn::convolution_forward> conv_p; std::shared_ptr<mkldnn::convolution_forward> conv_p;
std::shared_ptr<mkldnn::memory> src_memory_p; std::shared_ptr<mkldnn::memory> src_memory_p;
std::shared_ptr<mkldnn::memory> user_src_memory_p;
std::shared_ptr<mkldnn::memory> dst_memory_p; std::shared_ptr<mkldnn::memory> dst_memory_p;
std::vector<primitive> pipeline; std::vector<primitive> pipeline;
auto prim_key = key + "@conv_p"; auto prim_key = key + "@conv_p";
auto dst_key = key + "@dst_mem_p"; auto dst_key = key + "@dst_mem_p";
auto src_key = key + "@src_mem_p"; auto src_key = key + "@src_mem_p";
auto user_src_key = key + "@user_src_mem_p";
auto src_reorder_key = key + "@src_mem_p" + "reorder_p";
conv_p = std::static_pointer_cast<mkldnn::convolution_forward>(dev_ctx.GetBlob(prim_key)); conv_p = std::static_pointer_cast<mkldnn::convolution_forward>(dev_ctx.GetBlob(prim_key));
auto src_memory_reorder_p = std::static_pointer_cast<mkldnn::memory>(dev_ctx.GetBlob(src_reorder_key));
src_memory_p = std::static_pointer_cast<mkldnn::memory>(dev_ctx.GetBlob(src_key)); src_memory_p = std::static_pointer_cast<mkldnn::memory>(dev_ctx.GetBlob(src_key));
dst_memory_p = std::static_pointer_cast<mkldnn::memory>(dev_ctx.GetBlob(dst_key)); if(src_memory_reorder_p){
user_src_memory_p = std::static_pointer_cast<mkldnn::memory>(dev_ctx.GetBlob(user_src_key));
if (src_memory_p) { user_src_memory_p->set_data_handle(to_void_cast<T>(input_data));
} else if(src_memory_p){
src_memory_p->set_data_handle(to_void_cast<T>(input_data)); src_memory_p->set_data_handle(to_void_cast<T>(input_data));
} }
dst_memory_p = std::static_pointer_cast<mkldnn::memory>(dev_ctx.GetBlob(dst_key));
std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_pd; std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_pd;
conv_pd = std::static_pointer_cast<mkldnn::convolution_forward::primitive_desc>(dev_ctx.GetBlob(key_conv_pd)); conv_pd = std::static_pointer_cast<mkldnn::convolution_forward::primitive_desc>(dev_ctx.GetBlob(key_conv_pd));
std::shared_ptr<ConvMKLDNNHandler> handler; std::shared_ptr<ConvMKLDNNHandler> handler;
...@@ -414,7 +426,6 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -414,7 +426,6 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
residual_data_tz, residual_data_type, residual_param->format()); residual_data_tz, residual_data_type, residual_param->format());
auto user_residual_memory_p = handler->AcquireResidualDataMemory( auto user_residual_memory_p = handler->AcquireResidualDataMemory(
user_residual_md, to_void_cast<T>(residual_param_data)); user_residual_md, to_void_cast<T>(residual_param_data));
dst_memory_p = handler->AcquireDstMemoryFromResidualDataMemory( dst_memory_p = handler->AcquireDstMemoryFromResidualDataMemory(
user_residual_memory_p, to_void_cast<T>(output_data), pipeline); user_residual_memory_p, to_void_cast<T>(output_data), pipeline);
} else { } else {
...@@ -427,6 +438,30 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -427,6 +438,30 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
output->mutable_data<T>(ctx.GetPlace(), ::paddle::memory::Allocator::kDefault, handler->GetDstMemorySize()); output->mutable_data<T>(ctx.GetPlace(), ::paddle::memory::Allocator::kDefault, handler->GetDstMemorySize());
dst_memory_p->set_data_handle(to_void_cast<T>(output_data)); dst_memory_p->set_data_handle(to_void_cast<T>(output_data));
} }
} else if(is_INT8 && dst_memory_p){
if(fuse_residual_conn) {
auto residual_param = ctx.Input<Tensor>("ResidualData");
auto residual_dt = paddle::framework::ToMKLDNNDataType(residual_param->type());
output->ShareDataWith(*residual_param);
if(residual_dt == mkldnn::memory::data_type::u8){
uint8_t* output_data = output->mutable_data<uint8_t>(ctx.GetPlace());
dst_memory_p->set_data_handle(to_void_cast<uint8_t>(output_data));
} else{
int8_t* output_data = output->mutable_data<int8_t>(ctx.GetPlace());
dst_memory_p->set_data_handle(to_void_cast<int8_t>(output_data));
}
} else if(!force_fp32_output){
if(fuse_relu){
uint8_t* output_data = output->mutable_data<uint8_t>(ctx.GetPlace(), ::paddle::memory::Allocator::kDefault, handler->GetDstMemorySize());
dst_memory_p->set_data_handle(to_void_cast<uint8_t>(output_data));
} else{
int8_t* output_data = output->mutable_data<int8_t>(ctx.GetPlace(), ::paddle::memory::Allocator::kDefault, handler->GetDstMemorySize());
dst_memory_p->set_data_handle(to_void_cast<int8_t>(output_data));
}
} else {
float* output_data = output->mutable_data<float>(ctx.GetPlace(), ::paddle::memory::Allocator::kDefault, handler->GetDstMemorySize());
dst_memory_p->set_data_handle(to_void_cast<float>(output_data));
}
} }
if(!is_INT8){ if(!is_INT8){
...@@ -462,11 +497,11 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -462,11 +497,11 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
bias_tz, platform::MKLDNNGetDataType<T>(), memory::format::x); bias_tz, platform::MKLDNNGetDataType<T>(), memory::format::x);
conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, bias_md, dst_md, conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, bias_md, dst_md,
strides, paddings, mkldnn_engine, strides, paddings, mkldnn_engine,
fuse_relu, fuse_residual_conn); fuse_relu, fuse_residual_conn, is_test);
} else { } else {
conv_pd = conv_pd =
ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides, paddings, ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides, paddings,
mkldnn_engine, fuse_relu, fuse_residual_conn); mkldnn_engine, fuse_relu, fuse_residual_conn, is_test);
} }
// Save conv_pd/src_memory/weights_memory for backward pass // Save conv_pd/src_memory/weights_memory for backward pass
dev_ctx.SetBlob(key_conv_pd, conv_pd); dev_ctx.SetBlob(key_conv_pd, conv_pd);
...@@ -474,7 +509,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -474,7 +509,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
handler.reset(new ConvMKLDNNHandler(conv_pd, dev_ctx, mkldnn_engine, key)); handler.reset(new ConvMKLDNNHandler(conv_pd, dev_ctx, mkldnn_engine, key));
// create mkldnn memory from input tensors (data/weights) // create mkldnn memory from input tensors (data/weights)
auto user_src_memory_p = user_src_memory_p =
handler->AcquireSrcMemory(user_src_md, to_void_cast<T>(input_data)); handler->AcquireSrcMemory(user_src_md, to_void_cast<T>(input_data));
auto user_weights_memory_p = handler->AcquireWeightsMemory( auto user_weights_memory_p = handler->AcquireWeightsMemory(
user_weights_md, to_void_cast<float>(filter_data)); user_weights_md, to_void_cast<float>(filter_data));
...@@ -508,7 +543,6 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -508,7 +543,6 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
residual_data_tz, residual_data_type, residual_param->format()); residual_data_tz, residual_data_type, residual_param->format());
auto user_residual_memory_p = handler->AcquireResidualDataMemory( auto user_residual_memory_p = handler->AcquireResidualDataMemory(
user_residual_md, to_void_cast<T>(residual_param_data)); user_residual_md, to_void_cast<T>(residual_param_data));
dst_memory_p = handler->AcquireDstMemoryFromResidualDataMemory( dst_memory_p = handler->AcquireDstMemoryFromResidualDataMemory(
user_residual_memory_p, to_void_cast<T>(output_data), pipeline); user_residual_memory_p, to_void_cast<T>(output_data), pipeline);
} else { } else {
...@@ -546,10 +580,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -546,10 +580,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
output->set_layout(DataLayout::kMKLDNN); output->set_layout(DataLayout::kMKLDNN);
output->set_format(GetMKLDNNFormat(*dst_memory_p)); output->set_format(GetMKLDNNFormat(*dst_memory_p));
} else { } else {
if(src_memory_reorder_p){
pipeline.push_back(*src_memory_reorder_p);
}
pipeline.push_back(*conv_p); pipeline.push_back(*conv_p);
stream(stream::kind::eager).submit(pipeline).wait(); stream(stream::kind::eager).submit(pipeline).wait();
output->set_layout(DataLayout::kMKLDNN); output->set_layout(DataLayout::kMKLDNN);
output->set_format(GetMKLDNNFormat(*dst_memory_p)); output->set_format(GetMKLDNNFormat(*dst_memory_p));
} }
...@@ -572,7 +609,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -572,7 +609,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto sum_scale_key = key + "@sum_scale"; auto sum_scale_key = key + "@sum_scale";
auto scale_in_eltwise_key = key + "@scale_in_eltwise"; auto scale_in_eltwise_key = key + "@scale_in_eltwise";
std::vector<float> scale_in_data; std::vector<float> scale_in_data;
std::vector<float> scale_out_data; std::vector<float> scale_out_data = {1.0f};
std::vector<float> scale_weights_data; std::vector<float> scale_weights_data;
std::vector<float> scale_in_eltwise_data; std::vector<float> scale_in_eltwise_data;
std::vector<float> output_shift_scale; std::vector<float> output_shift_scale;
...@@ -591,7 +628,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -591,7 +628,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
for(int i=0; i<count; i++){ for(int i=0; i<count; i++){
scale_weights_data[i] =*(scale_weights->data<float>() + i); scale_weights_data[i] =*(scale_weights->data<float>() + i);
} }
scale_out_data = {*(scale_out->data<float>())}; if(!force_fp32_output)
scale_out_data = {*(scale_out->data<float>())};
output_shift_scale.resize(count); output_shift_scale.resize(count);
#pragma omp parallel for if (count > 1) #pragma omp parallel for if (count > 1)
for(int i=0; i<count; i++){ for(int i=0; i<count; i++){
...@@ -650,6 +688,10 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -650,6 +688,10 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
paddle::framework::ToMKLDNNDataType(std::type_index(typeid(unsigned char))) paddle::framework::ToMKLDNNDataType(std::type_index(typeid(unsigned char)))
: paddle::framework::ToMKLDNNDataType(std::type_index(typeid(signed char))); : paddle::framework::ToMKLDNNDataType(std::type_index(typeid(signed char)));
if(force_fp32_output){
dst_dt = paddle::framework::ToMKLDNNDataType(std::type_index(typeid(float)));
}
if(fuse_residual_conn){ if(fuse_residual_conn){
auto residual = ctx.Input<Tensor>("ResidualData"); auto residual = ctx.Input<Tensor>("ResidualData");
auto residual_dt = paddle::framework::ToMKLDNNDataType(residual->type()); auto residual_dt = paddle::framework::ToMKLDNNDataType(residual->type());
...@@ -678,7 +720,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -678,7 +720,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
handler.reset(new ConvMKLDNNHandler(conv_pd, dev_ctx, mkldnn_engine, key)); handler.reset(new ConvMKLDNNHandler(conv_pd, dev_ctx, mkldnn_engine, key));
// create mkldnn memory from input tensors (data/weights) // create mkldnn memory from input tensors (data/weights)
auto user_src_memory_p = user_src_memory_p =
handler->AcquireSrcMemory(user_src_md, to_void_cast<T>(input_data)); handler->AcquireSrcMemory(user_src_md, to_void_cast<T>(input_data));
auto user_weights_memory_p = handler->AcquireWeightsMemory( auto user_weights_memory_p = handler->AcquireWeightsMemory(
user_weights_md, to_void_cast<float>(filter_data)); user_weights_md, to_void_cast<float>(filter_data));
...@@ -710,7 +752,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -710,7 +752,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
dst_memory_p = dst_memory_p =
handler->AcquireDstMemoryFromPrimitive(to_void_cast<int8_t>(output_data)); handler->AcquireDstMemoryFromPrimitive(to_void_cast<int8_t>(output_data));
} }
} else { } else if(!force_fp32_output){
if(fuse_relu){ if(fuse_relu){
uint8_t* output_data = output->mutable_data<uint8_t>(ctx.GetPlace(), ::paddle::memory::Allocator::kDefault, handler->GetDstMemorySize()); uint8_t* output_data = output->mutable_data<uint8_t>(ctx.GetPlace(), ::paddle::memory::Allocator::kDefault, handler->GetDstMemorySize());
dst_memory_p = dst_memory_p =
...@@ -720,6 +762,10 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -720,6 +762,10 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
dst_memory_p = dst_memory_p =
handler->AcquireDstMemoryFromPrimitive(to_void_cast<int8_t>(output_data)); handler->AcquireDstMemoryFromPrimitive(to_void_cast<int8_t>(output_data));
} }
} else {
float* output_data = output->mutable_data<float>(ctx.GetPlace(), ::paddle::memory::Allocator::kDefault, handler->GetDstMemorySize());
dst_memory_p =
handler->AcquireDstMemoryFromPrimitive(to_void_cast<float>(output_data));
} }
// create convolution op primitive // create convolution op primitive
...@@ -765,6 +811,9 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -765,6 +811,9 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
output->set_layout(DataLayout::kMKLDNN); output->set_layout(DataLayout::kMKLDNN);
output->set_format(GetMKLDNNFormat(*dst_memory_p)); output->set_format(GetMKLDNNFormat(*dst_memory_p));
} else { } else {
if(src_memory_reorder_p){
pipeline.push_back(*src_memory_reorder_p);
}
pipeline.push_back(*conv_p); pipeline.push_back(*conv_p);
stream(stream::kind::eager).submit(pipeline).wait(); stream(stream::kind::eager).submit(pipeline).wait();
...@@ -1141,7 +1190,8 @@ namespace ops = paddle::operators; ...@@ -1141,7 +1190,8 @@ namespace ops = paddle::operators;
REGISTER_OP_KERNEL(conv2d, MKLDNN, ::paddle::platform::CPUPlace, REGISTER_OP_KERNEL(conv2d, MKLDNN, ::paddle::platform::CPUPlace,
ops::ConvMKLDNNOpKernel<float>, ops::ConvMKLDNNOpKernel<float>,
ops::ConvMKLDNNOpKernel<uint8_t>); ops::ConvMKLDNNOpKernel<uint8_t>,
ops::ConvMKLDNNOpKernel<int8_t>);
REGISTER_OP_KERNEL(conv2d_grad, MKLDNN, ::paddle::platform::CPUPlace, REGISTER_OP_KERNEL(conv2d_grad, MKLDNN, ::paddle::platform::CPUPlace,
ops::ConvMKLDNNGradOpKernel<float>); ops::ConvMKLDNNGradOpKernel<float>);
...@@ -30,7 +30,6 @@ using Tensor = framework::Tensor; ...@@ -30,7 +30,6 @@ using Tensor = framework::Tensor;
using framework::DataLayout; using framework::DataLayout;
using mkldnn::stream; using mkldnn::stream;
using platform::GetMKLDNNFormat; using platform::GetMKLDNNFormat;
//using MKLDNNDataType = mkldnn::memory::data_type;
template <typename T> template <typename T>
class DeQuantOpKernel : public framework::OpKernel<T> { class DeQuantOpKernel : public framework::OpKernel<T> {
...@@ -46,7 +45,6 @@ class DeQuantOpKernel : public framework::OpKernel<T> { ...@@ -46,7 +45,6 @@ class DeQuantOpKernel : public framework::OpKernel<T> {
const T* input_data = input->data<T>(); const T* input_data = input->data<T>();
float* output_data = output->mutable_data<float>(ctx.GetPlace()); float* output_data = output->mutable_data<float>(ctx.GetPlace());
//T scale_data = *(scale->data<T>());
std::vector<float> scale_data = {*(scale->data<float>())}; std::vector<float> scale_data = {*(scale->data<float>())};
std::vector<float> reorder_scale = {1.0f / scale_data[0]}; std::vector<float> reorder_scale = {1.0f / scale_data[0]};
...@@ -77,7 +75,6 @@ class DeQuantOpKernel : public framework::OpKernel<T> { ...@@ -77,7 +75,6 @@ class DeQuantOpKernel : public framework::OpKernel<T> {
pipeline.push_back(*reorder_p); pipeline.push_back(*reorder_p);
stream(stream::kind::eager).submit(pipeline).wait(); stream(stream::kind::eager).submit(pipeline).wait();
//output->set_layout(DataLayout::kMKLDNN);
output->set_format(GetMKLDNNFormat(dst_memory)); output->set_format(GetMKLDNNFormat(dst_memory));
} }
...@@ -114,5 +111,5 @@ namespace ops = paddle::operators; ...@@ -114,5 +111,5 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(dequantize, ops::DeQuantOp, ops::DeQuantOpMaker, paddle::framework::DefaultGradOpDescMaker<true>); REGISTER_OPERATOR(dequantize, ops::DeQuantOp, ops::DeQuantOpMaker, paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OP_KERNEL(dequantize, MKLDNN, ::paddle::platform::CPUPlace, ops::DeQuantOpKernel<uint8_t>); REGISTER_OP_KERNEL(dequantize, MKLDNN, ::paddle::platform::CPUPlace, ops::DeQuantOpKernel<uint8_t>, ops::DeQuantOpKernel<int8_t>);
...@@ -46,7 +46,7 @@ class QuantOpKernel : public framework::OpKernel<T> { ...@@ -46,7 +46,7 @@ class QuantOpKernel : public framework::OpKernel<T> {
std::vector<int> dst_tz = paddle::framework::vectorize2int(output->dims()); std::vector<int> dst_tz = paddle::framework::vectorize2int(output->dims());
const T* input_data = input->data<T>(); const T* input_data = input->data<T>();
uint8_t* output_data = output->mutable_data<uint8_t>(ctx.GetPlace());
std::vector<T> scale_data = {*(scale->data<T>())}; std::vector<T> scale_data = {*(scale->data<T>())};
mkldnn::primitive_attr attri; mkldnn::primitive_attr attri;
...@@ -59,20 +59,32 @@ class QuantOpKernel : public framework::OpKernel<T> { ...@@ -59,20 +59,32 @@ class QuantOpKernel : public framework::OpKernel<T> {
auto src_memory = std::make_shared<mkldnn::memory>(src_pd, to_void_cast<T>(input_data)); auto src_memory = std::make_shared<mkldnn::memory>(src_pd, to_void_cast<T>(input_data));
std::shared_ptr<primitive::at> src_memory_p = std::shared_ptr<primitive::at>(new primitive::at(*src_memory)); std::shared_ptr<primitive::at> src_memory_p = std::shared_ptr<primitive::at>(new primitive::at(*src_memory));
auto dst_md = platform::MKLDNNMemDesc( bool is_negative = ctx.Attr<bool>("is_negative_input");
{dst_tz}, memory::data_type::u8, memory::format::nhwc); mkldnn::memory::primitive_desc dst_pd;
auto dst_pd = mkldnn::memory::primitive_desc(dst_md, engine); std::shared_ptr<mkldnn::memory> dst_memory;
auto dst_memory = mkldnn::memory(dst_pd, to_void_cast<uint8_t>(output_data)); if (is_negative) {
int8_t* output_data = output->mutable_data<int8_t>(ctx.GetPlace());
auto dst_md = platform::MKLDNNMemDesc(
{dst_tz}, memory::data_type::s8, memory::format::nhwc);
dst_pd = mkldnn::memory::primitive_desc(dst_md, engine);
dst_memory.reset(new mkldnn::memory(dst_pd, to_void_cast<int8_t>(output_data)));
} else {
uint8_t* output_data = output->mutable_data<uint8_t>(ctx.GetPlace());
auto dst_md = platform::MKLDNNMemDesc(
{dst_tz}, memory::data_type::u8, memory::format::nhwc);
dst_pd = mkldnn::memory::primitive_desc(dst_md, engine);
dst_memory.reset(new mkldnn::memory(dst_pd, to_void_cast<uint8_t>(output_data)));
}
auto reorder_pd = std::shared_ptr<reorder::primitive_desc>( auto reorder_pd = std::shared_ptr<reorder::primitive_desc>(
new reorder::primitive_desc(src_pd, dst_pd, attri)); new reorder::primitive_desc(src_pd, dst_pd, attri));
auto reorder_p= std::shared_ptr<reorder>(new reorder(*reorder_pd, *src_memory_p, dst_memory)); auto reorder_p= std::shared_ptr<reorder>(new reorder(*reorder_pd, *src_memory_p, *dst_memory));
pipeline.push_back(*reorder_p); pipeline.push_back(*reorder_p);
stream(stream::kind::eager).submit(pipeline).wait(); stream(stream::kind::eager).submit(pipeline).wait();
output->set_layout(DataLayout::kMKLDNN); output->set_layout(DataLayout::kMKLDNN);
output->set_format(GetMKLDNNFormat(dst_memory)); output->set_format(GetMKLDNNFormat(*dst_memory));
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册