提交 fcbe4898 编写于 作者: X xiaolil1

modify for eltwise with some useless log

上级 4a1346e5
...@@ -821,6 +821,7 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType( ...@@ -821,6 +821,7 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType(
"DataType of Paddle Op %s must be the same. Get %s(%d) != %s(%d)", "DataType of Paddle Op %s must be the same. Get %s(%d) != %s(%d)",
Type(), last_input_name, data_type, ipt_name, tmp); Type(), last_input_name, data_type, ipt_name, tmp);
data_type = tmp; data_type = tmp;
std::cout<<"data_type = "<<data_type;
last_input_name = ipt_name; last_input_name = ipt_name;
} }
} }
......
...@@ -54,6 +54,7 @@ class ConvMKLDNNHandler : public platform::MKLDNNHandler { ...@@ -54,6 +54,7 @@ class ConvMKLDNNHandler : public platform::MKLDNNHandler {
} }
size_t GetDstMemorySize() const { size_t GetDstMemorySize() const {
std::cout<<"dst size = "<<conv_pd_->dst_primitive_desc().get_size()<<std::endl;
return conv_pd_->dst_primitive_desc().get_size(); return conv_pd_->dst_primitive_desc().get_size();
} }
...@@ -121,9 +122,9 @@ class ConvMKLDNNHandler : public platform::MKLDNNHandler { ...@@ -121,9 +122,9 @@ class ConvMKLDNNHandler : public platform::MKLDNNHandler {
std::shared_ptr<mkldnn::memory> AcquireSrcMemoryFromPrimitive( std::shared_ptr<mkldnn::memory> AcquireSrcMemoryFromPrimitive(
const std::shared_ptr<mkldnn::memory> user_memory_p, const std::shared_ptr<mkldnn::memory> user_memory_p,
std::vector<mkldnn::primitive>& pipeline) { // NOLINT std::vector<mkldnn::primitive>& pipeline, bool is_INT8=false) { // NOLINT
auto src_pd = conv_pd_->src_primitive_desc(); auto src_pd = conv_pd_->src_primitive_desc();
auto user_pd = user_memory_p->get_primitive_desc(); auto user_pd = is_INT8? src_pd : user_memory_p->get_primitive_desc();
return this->AcquireMemory(src_pd, user_pd, user_memory_p, "@src_mem_p", return this->AcquireMemory(src_pd, user_pd, user_memory_p, "@src_mem_p",
pipeline); pipeline);
} }
...@@ -274,7 +275,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -274,7 +275,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
void Compute(const paddle::framework::ExecutionContext& ctx) const override { void Compute(const paddle::framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()), PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
"It must use CPUPlace."); "It must use CPUPlace.");
std::cout<<"this is conv kernel op....................."<<std::endl;
const bool is_test = ctx.Attr<bool>("is_test"); const bool is_test = ctx.Attr<bool>("is_test");
auto& dev_ctx = auto& dev_ctx =
...@@ -324,7 +325,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -324,7 +325,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
"dilation in convolution is not implemented yet"); "dilation in convolution is not implemented yet");
const T* input_data = input->data<T>(); const T* input_data = input->data<T>();
const T* filter_data = filter->data<T>(); const float* filter_data = filter->data<float>();
std::vector<int> src_tz = paddle::framework::vectorize2int(input->dims()); std::vector<int> src_tz = paddle::framework::vectorize2int(input->dims());
std::vector<int> weights_tz = std::vector<int> weights_tz =
...@@ -344,17 +345,17 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -344,17 +345,17 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
} }
std::vector<int> dst_tz = paddle::framework::vectorize2int(output->dims()); std::vector<int> dst_tz = paddle::framework::vectorize2int(output->dims());
std::vector<T> output_shift_scale; std::vector<float> output_shift_scale;
T sum_scale = 1.0f; float sum_scale = 1.0f;
if(is_INT8){ if(is_INT8){
std::cout<<"this is conv int8 op .............."<<std::endl;
int count = is_multi_channel? (g>1? weights_tz[1]*weights_tz[0] : weights_tz[0]) : 1; int count = is_multi_channel? (g>1? weights_tz[1]*weights_tz[0] : weights_tz[0]) : 1;
T scale_in_data = *(scale_in->data<T>()); float scale_in_data = *(scale_in->data<float>());
T scale_in_eltwise_data = *(scale_in_eltwise->data<T>()); std::vector<float> scale_weights_data(count);
std::vector<T> scale_weights_data(count);
for(int i=0; i<count; i++){ for(int i=0; i<count; i++){
scale_weights_data[i] =*(scale_weights->data<T>() + i); scale_weights_data[i] =*(scale_weights->data<float>() + i);
} }
T scale_out_data = *(scale_out->data<T>()); float scale_out_data = *(scale_out->data<float>());
output_shift_scale.resize(count); output_shift_scale.resize(count);
for(int i=0; i<count; i++){ for(int i=0; i<count; i++){
...@@ -363,8 +364,10 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -363,8 +364,10 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
else else
output_shift_scale[i] = scale_out_data / (scale_in_data * scale_weights_data[i]); output_shift_scale[i] = scale_out_data / (scale_in_data * scale_weights_data[i]);
} }
if(fuse_residual_conn){
sum_scale = scale_out_data / scale_in_eltwise_data; float scale_in_eltwise_data = *(scale_in_eltwise->data<float>());
sum_scale = scale_out_data / scale_in_eltwise_data;
}
} }
// Get unique name for storing MKLDNN primitives // Get unique name for storing MKLDNN primitives
...@@ -378,7 +381,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -378,7 +381,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto user_src_md = platform::MKLDNNMemDesc( auto user_src_md = platform::MKLDNNMemDesc(
{src_tz}, platform::MKLDNNGetDataType<T>(), input->format()); {src_tz}, platform::MKLDNNGetDataType<T>(), input->format());
auto user_weights_md = platform::MKLDNNMemDesc( auto user_weights_md = platform::MKLDNNMemDesc(
{weights_tz}, platform::MKLDNNGetDataType<T>(), {weights_tz}, platform::MKLDNNGetDataType<float>(),
(g == 1) ? filter->format() : mkldnn::memory::format::goihw); (g == 1) ? filter->format() : mkldnn::memory::format::goihw);
/* create memory descriptor for convolution without specified format /* create memory descriptor for convolution without specified format
...@@ -399,12 +402,28 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -399,12 +402,28 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto dst_md = platform::MKLDNNMemDesc( auto dst_md = platform::MKLDNNMemDesc(
dst_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format); dst_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
if(is_INT8){
src_md = platform::MKLDNNMemDesc(
src_tz, memory::data_type::u8, chosen_memory_format);
weights_md = platform::MKLDNNMemDesc(
weights_tz, memory::data_type::s8,
(g == 1) ? chosen_memory_format : mkldnn::memory::format::goihw);
dst_md = platform::MKLDNNMemDesc(
dst_tz,
fuse_relu?memory::data_type::u8:memory::data_type::s8,
chosen_memory_format);
}
// create a conv primitive descriptor and save it for usage in backward // create a conv primitive descriptor and save it for usage in backward
std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_pd; std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_pd;
if (bias) { if (bias) {
bias_tz = paddle::framework::vectorize2int(bias->dims()); bias_tz = paddle::framework::vectorize2int(bias->dims());
auto bias_md = platform::MKLDNNMemDesc( auto bias_md = platform::MKLDNNMemDesc(
bias_tz, platform::MKLDNNGetDataType<T>(), memory::format::x); bias_tz, platform::MKLDNNGetDataType<float>(), memory::format::x);
if(is_INT8){
bias_md = platform::MKLDNNMemDesc(
bias_tz, memory::data_type::s32, memory::format::x);
}
if(is_INT8){ if(is_INT8){
conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, bias_md, dst_md, conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, bias_md, dst_md,
strides, paddings, mkldnn_engine, strides, paddings, mkldnn_engine,
...@@ -436,62 +455,85 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -436,62 +455,85 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto user_src_memory_p = auto user_src_memory_p =
handler.AcquireSrcMemory(user_src_md, to_void_cast<T>(input_data)); handler.AcquireSrcMemory(user_src_md, to_void_cast<T>(input_data));
auto user_weights_memory_p = handler.AcquireWeightsMemory( auto user_weights_memory_p = handler.AcquireWeightsMemory(
user_weights_md, to_void_cast<T>(filter_data)); user_weights_md, to_void_cast<float>(filter_data));
T* output_data = nullptr;
if (fuse_residual_conn) {
auto residual_param = ctx.Input<Tensor>("ResidualData");
auto residual_param_data = residual_param->data<T>();
PADDLE_ENFORCE(
residual_param_data != nullptr,
"Provide data if you want MKLDNN conv+elementwise_add fusion");
PADDLE_ENFORCE_EQ(output->dims(), residual_param->dims(),
"Output and elementwise parameter need to have the "
"same dimension sizes");
output->ShareDataWith(*residual_param);
output_data = output->mutable_data<T>(ctx.GetPlace());
} else {
output_data =
output->mutable_data<T>(ctx.GetPlace(), handler.GetDstMemorySize());
}
// create reorder primitive if the input format is not the preferred one // create reorder primitive if the input format is not the preferred one
auto src_memory_p = auto src_memory_p =
handler.AcquireSrcMemoryFromPrimitive(user_src_memory_p, pipeline); handler.AcquireSrcMemoryFromPrimitive(user_src_memory_p, pipeline, is_INT8);
auto weights_memory_p = handler.AcquireWeightsMemoryFromPrimitive( auto weights_memory_p = handler.AcquireWeightsMemoryFromPrimitive(
user_weights_memory_p, pipeline, is_test); user_weights_memory_p, pipeline, is_test);
if(is_INT8){ if(is_INT8){
int mask_reorder = is_multi_channel? 0 : ((g!= 1) ? (1<<1)+(1<<0) : 1<<0); int mask_reorder = is_multi_channel? 0 : ((g!= 1) ? (1<<1)+(1<<0) : 1<<0);
int count = is_multi_channel? (g>1? weights_tz[1]*weights_tz[0] : weights_tz[0]) : 1; int count = is_multi_channel? (g>1? weights_tz[1]*weights_tz[0] : weights_tz[0]) : 1;
std::vector<T> scale_weights_data(count); std::vector<float> scale_weights_data(count);
for(int i=0; i<count; i++){ for(int i=0; i<count; i++){
scale_weights_data[i] = *(scale_weights->data<T>() + i); scale_weights_data[i] = *(scale_weights->data<T>() + i);
} }
auto weights_memory_p = handler.AcquireWeightsMemoryFromPrimitive( auto weights_memory_p = handler.AcquireWeightsMemoryFromPrimitive(
user_weights_memory_p, pipeline, is_test, is_INT8, scale_weights_data, mask_reorder); user_weights_memory_p, pipeline, is_test, is_INT8, scale_weights_data, mask_reorder);
} }
auto dst_memory_p =
handler.AcquireDstMemoryFromPrimitive(to_void_cast<T>(output_data)); std::shared_ptr<mkldnn::memory> dst_memory_p;
if(is_INT8){
int8_t* output_data = nullptr;
if (fuse_residual_conn) {
auto residual_param = ctx.Input<Tensor>("ResidualData");
PADDLE_ENFORCE_EQ(output->dims(), residual_param->dims(),
"Output and elementwise parameter need to have the "
"same dimension sizes");
output->ShareDataWith(*residual_param);
output_data = output->mutable_data<int8_t>(ctx.GetPlace());
} else {
std::cout<<"conv log 1 ....................."<<std::endl;
output_data =
output->mutable_data<int8_t>(ctx.GetPlace(), handler.GetDstMemorySize());
std::cout<<"conv log 2 //////////////////////"<<std::endl;
}
dst_memory_p =
handler.AcquireDstMemoryFromPrimitive(to_void_cast<int8_t>(output_data));
std::cout<<"input fmt = "<<input->format()<<" output fmt = "<<output->format()<<" dst fmt = "<<dst_memory_p->get_primitive_desc().desc().data.format<<std::endl;
} else{
T* output_data = nullptr;
if (fuse_residual_conn) {
auto residual_param = ctx.Input<Tensor>("ResidualData");
auto residual_param_data = residual_param->data<T>();
PADDLE_ENFORCE(
residual_param_data != nullptr,
"Provide data if you want MKLDNN conv+elementwise_add fusion");
PADDLE_ENFORCE_EQ(output->dims(), residual_param->dims(),
"Output and elementwise parameter need to have the "
"same dimension sizes");
output->ShareDataWith(*residual_param);
output_data = output->mutable_data<T>(ctx.GetPlace());
} else {
std::cout<<"conv log 1 ....................."<<std::endl;
output_data =
output->mutable_data<T>(ctx.GetPlace(), handler.GetDstMemorySize());
std::cout<<"conv log 2 //////////////////////"<<std::endl;
}
dst_memory_p =
handler.AcquireDstMemoryFromPrimitive(to_void_cast<T>(output_data));
}
// create convolution op primitive // create convolution op primitive
std::shared_ptr<mkldnn::convolution_forward> conv_p; std::shared_ptr<mkldnn::convolution_forward> conv_p;
if (bias) { if (bias) {
const T* bias_data = bias->data<T>(); const float* bias_data = bias->data<float>();
auto user_bias_md = platform::MKLDNNMemDesc( auto user_bias_md = platform::MKLDNNMemDesc(
{bias_tz}, platform::MKLDNNGetDataType<T>(), memory::format::x); {bias_tz}, platform::MKLDNNGetDataType<float>(), memory::format::x);
auto user_bias_memory_p = auto user_bias_memory_p =
handler.AcquireBiasMemory(user_bias_md, to_void_cast<T>(bias_data)); handler.AcquireBiasMemory(user_bias_md, to_void_cast<float>(bias_data));
auto bias_memory_p = auto bias_memory_p =
handler.AcquireBiasMemoryFromPrimitive(user_bias_memory_p, pipeline); handler.AcquireBiasMemoryFromPrimitive(user_bias_memory_p, pipeline);
if(is_INT8){ if(is_INT8){
int mask_reorder = is_multi_channel? 0 : 1<<0; int mask_reorder = is_multi_channel? 0 : 1<<0;
int count = is_multi_channel? (g>1? weights_tz[1]*weights_tz[0] : weights_tz[0]) : 1; int count = is_multi_channel? (g>1? weights_tz[1]*weights_tz[0] : weights_tz[0]) : 1;
std::vector<T> scale_bias_data(count); std::vector<float> scale_bias_data(count);
for(int i=0; i<count; i++){ for(int i=0; i<count; i++){
scale_bias_data[i] = (*scale_in->data<T>()) * (*(scale_weights->data<T>() + i)); scale_bias_data[i] = (*scale_in->data<float>()) * (*(scale_weights->data<float>() + i));
} }
auto bias_memory_p = auto bias_memory_p =
handler.AcquireBiasMemoryFromPrimitive(user_bias_memory_p, pipeline, is_INT8, scale_bias_data, mask_reorder); handler.AcquireBiasMemoryFromPrimitive(user_bias_memory_p, pipeline, is_INT8, scale_bias_data, mask_reorder);
...@@ -503,17 +545,19 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -503,17 +545,19 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
dst_memory_p); dst_memory_p);
} }
// push primitive to stream and wait until it's executed // push primitive to stream and wait until it's executed
pipeline.push_back(*conv_p); pipeline.push_back(*conv_p);
stream(stream::kind::eager).submit(pipeline).wait(); stream(stream::kind::eager).submit(pipeline).wait();
output->set_layout(DataLayout::kMKLDNN); output->set_layout(DataLayout::kMKLDNN);
output->set_format(GetMKLDNNFormat(*dst_memory_p)); output->set_format(GetMKLDNNFormat(*dst_memory_p));
std::cout<<"input fmt = "<<input->format()<<" output fmt = "<<output->format()<<" dst fmt = "<<dst_memory_p->get_primitive_desc().desc().data.format<<std::endl;
} }
private: private:
mkldnn::primitive_attr CreatePostOps(bool fuse_relu, bool fuse_residual_conn, mkldnn::primitive_attr CreatePostOps(bool fuse_relu, bool fuse_residual_conn,
const std::vector<T> output_shift_scale, T sum_scale) const { const std::vector<float> output_shift_scale, float sum_scale) const {
mkldnn::primitive_attr conv_attr; mkldnn::primitive_attr conv_attr;
mkldnn::post_ops post_operations; mkldnn::post_ops post_operations;
// Fusion with Elementwise layer relies on adding a sum post-operation with // Fusion with Elementwise layer relies on adding a sum post-operation with
...@@ -568,7 +612,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -568,7 +612,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const std::vector<int>& paddings, const std::vector<int>& paddings,
const mkldnn::engine& engine, const bool fuse_relu, const mkldnn::engine& engine, const bool fuse_relu,
const bool fuse_residual_conn, const bool fuse_residual_conn,
const std::vector<T> output_shift_scale, const T sum_scale) const { const std::vector<float> output_shift_scale, const float sum_scale) const {
memory::dims stride_dims = {strides[0], strides[1]}; memory::dims stride_dims = {strides[0], strides[1]};
memory::dims padding_dims = {paddings[0], paddings[1]}; memory::dims padding_dims = {paddings[0], paddings[1]};
...@@ -617,7 +661,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -617,7 +661,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const std::vector<int>& paddings, const std::vector<int>& paddings,
const mkldnn::engine& engine, const bool fuse_relu, const mkldnn::engine& engine, const bool fuse_relu,
const bool fuse_residual_conn, const bool fuse_residual_conn,
const std::vector<T> output_shift_scale, const T sum_scale) const { const std::vector<float> output_shift_scale, const float sum_scale) const {
memory::dims stride_dims = {strides[0], strides[1]}; memory::dims stride_dims = {strides[0], strides[1]};
memory::dims padding_dims = {paddings[0], paddings[1]}; memory::dims padding_dims = {paddings[0], paddings[1]};
...@@ -841,7 +885,8 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -841,7 +885,8 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_KERNEL(conv2d, MKLDNN, ::paddle::platform::CPUPlace, REGISTER_OP_KERNEL(conv2d, MKLDNN, ::paddle::platform::CPUPlace,
ops::ConvMKLDNNOpKernel<float>); ops::ConvMKLDNNOpKernel<float>,
ops::ConvMKLDNNOpKernel<int8_t>);
REGISTER_OP_KERNEL(conv2d_grad, MKLDNN, ::paddle::platform::CPUPlace, REGISTER_OP_KERNEL(conv2d_grad, MKLDNN, ::paddle::platform::CPUPlace,
ops::ConvMKLDNNGradOpKernel<float>); ops::ConvMKLDNNGradOpKernel<float>);
...@@ -94,10 +94,10 @@ framework::OpKernelType ConvOp::GetExpectedKernelType( ...@@ -94,10 +94,10 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
auto input_data_type = auto input_data_type =
framework::ToDataType(ctx.Input<Tensor>("Input")->type()); framework::ToDataType(ctx.Input<Tensor>("Input")->type());
auto filter_data_type = //auto filter_data_type =
framework::ToDataType(ctx.Input<Tensor>("Filter")->type()); // framework::ToDataType(ctx.Input<Tensor>("Filter")->type());
PADDLE_ENFORCE_EQ(input_data_type, filter_data_type, //PADDLE_ENFORCE_EQ(input_data_type, filter_data_type,
"input and filter data type should be consistent"); // "input and filter data type should be consistent");
if (input_data_type == framework::proto::VarType::FP16) { if (input_data_type == framework::proto::VarType::FP16) {
PADDLE_ENFORCE_EQ(library, framework::LibraryType::kCUDNN, PADDLE_ENFORCE_EQ(library, framework::LibraryType::kCUDNN,
......
...@@ -40,7 +40,7 @@ class DeQuantOpKernel : public framework::OpKernel<T> { ...@@ -40,7 +40,7 @@ class DeQuantOpKernel : public framework::OpKernel<T> {
auto* input = ctx.Input<Tensor>("Input"); auto* input = ctx.Input<Tensor>("Input");
auto* scale = ctx.Input<Tensor>("Scale"); auto* scale = ctx.Input<Tensor>("Scale");
auto* output = ctx.Output<Tensor>("Output"); auto* output = ctx.Output<Tensor>("Output");
std::cout<<"this is dequant op ***********"<<std::endl;
auto& dev_ctx = auto& dev_ctx =
ctx.template device_context<platform::MKLDNNDeviceContext>(); ctx.template device_context<platform::MKLDNNDeviceContext>();
const auto& engine = dev_ctx.GetEngine(); const auto& engine = dev_ctx.GetEngine();
......
...@@ -37,7 +37,7 @@ class QuantOpKernel : public framework::OpKernel<T> { ...@@ -37,7 +37,7 @@ class QuantOpKernel : public framework::OpKernel<T> {
auto* input = ctx.Input<Tensor>("Input"); auto* input = ctx.Input<Tensor>("Input");
auto* scale = ctx.Input<Tensor>("Scale"); auto* scale = ctx.Input<Tensor>("Scale");
auto* output = ctx.Output<Tensor>("Output"); auto* output = ctx.Output<Tensor>("Output");
std::cout<<"this is quantize op!!!!!!!!!!!!!!"<<std::endl;
auto& dev_ctx = auto& dev_ctx =
ctx.template device_context<platform::MKLDNNDeviceContext>(); ctx.template device_context<platform::MKLDNNDeviceContext>();
const auto& engine = dev_ctx.GetEngine(); const auto& engine = dev_ctx.GetEngine();
...@@ -68,7 +68,12 @@ class QuantOpKernel : public framework::OpKernel<T> { ...@@ -68,7 +68,12 @@ class QuantOpKernel : public framework::OpKernel<T> {
auto reorder_pd = std::shared_ptr<reorder::primitive_desc>( auto reorder_pd = std::shared_ptr<reorder::primitive_desc>(
new reorder::primitive_desc(dst_pd, src_pd, attri)); new reorder::primitive_desc(dst_pd, src_pd, attri));
auto reorder_p= std::shared_ptr<reorder>(new reorder(*reorder_pd, *src_memory_p, dst_memory)); auto reorder_p= std::shared_ptr<reorder>(new reorder(*reorder_pd, *src_memory_p, dst_memory));
pipeline.push_back(*reorder_p); pipeline.push_back(*reorder_p);
stream(stream::kind::eager).submit(pipeline).wait();
output->set_layout(DataLayout::kMKLDNN);
output->set_format(GetMKLDNNFormat(dst_memory));
} }
}; };
......
...@@ -153,8 +153,11 @@ class MKLDNNHandler { ...@@ -153,8 +153,11 @@ class MKLDNNHandler {
std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key)); std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key));
PADDLE_ENFORCE((mem_p != nullptr) || (is_reusing_ == false), PADDLE_ENFORCE((mem_p != nullptr) || (is_reusing_ == false),
"Fail to find mem primitive in device context"); "Fail to find mem primitive in device context");
//mem_p = nullptr;
if (mem_p == nullptr) { if (mem_p == nullptr) {
mem_p = std::make_shared<mkldnn::memory>(mdp, ptr); mem_p = std::make_shared<mkldnn::memory>(mdp, ptr);
std::cout<<"mem_p == null"<<std::endl;
//std::cout<<"mdp fmt = "<<mdp.desc().data.format<<" mem_p fmt = "<<mem_p->get_primitive_desc().desc().data.format<<std::endl;
dev_ctx_.SetBlob(local_key, mem_p); dev_ctx_.SetBlob(local_key, mem_p);
} else { } else {
mem_p->set_data_handle(ptr); mem_p->set_data_handle(ptr);
...@@ -162,6 +165,7 @@ class MKLDNNHandler { ...@@ -162,6 +165,7 @@ class MKLDNNHandler {
// should be reused or none of them. So we check consistency // should be reused or none of them. So we check consistency
is_reusing_ = true; is_reusing_ = true;
} }
std::cout<<"mdp fmt = "<<mdp.desc().data.format<<" mem_p fmt = "<<mem_p->get_primitive_desc().desc().data.format<<std::endl;
return mem_p; return mem_p;
} }
...@@ -174,7 +178,9 @@ class MKLDNNHandler { ...@@ -174,7 +178,9 @@ class MKLDNNHandler {
std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key)); std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key));
PADDLE_ENFORCE((mem_p != nullptr) || (is_reusing_ == false), PADDLE_ENFORCE((mem_p != nullptr) || (is_reusing_ == false),
"Fail to find mem primitive in device context"); "Fail to find mem primitive in device context");
//mem_p = nullptr;
if (mem_p == nullptr) { if (mem_p == nullptr) {
std::cout<<"mem_p == null"<<std::endl;
mem_p = std::make_shared<mkldnn::memory>( mem_p = std::make_shared<mkldnn::memory>(
mkldnn::memory::primitive_desc{md, engine_}, ptr); mkldnn::memory::primitive_desc{md, engine_}, ptr);
dev_ctx_.SetBlob(local_key, mem_p); dev_ctx_.SetBlob(local_key, mem_p);
...@@ -184,6 +190,7 @@ class MKLDNNHandler { ...@@ -184,6 +190,7 @@ class MKLDNNHandler {
// should be reused or none of them. So we check consistency // should be reused or none of them. So we check consistency
is_reusing_ = true; is_reusing_ = true;
} }
std::cout<<"md fmt = "<<md.data.format<<" mem_p fmt = "<<mem_p->get_primitive_desc().desc().data.format<<std::endl;
return mem_p; return mem_p;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册