提交 1f8afa6f 编写于 作者: X xiaolil1

enable initialization for INT8

上级 9dead9a2
...@@ -496,6 +496,27 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -496,6 +496,27 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
output->set_format(GetMKLDNNFormat(*dst_memory_p)); output->set_format(GetMKLDNNFormat(*dst_memory_p));
} else{ } else{
bool need_s8_to_u8 = false;
if (fuse_residual_conn && fuse_relu) {
need_s8_to_u8 = true;
}
std::shared_ptr<mkldnn::convolution_forward> conv_p;
std::shared_ptr<mkldnn::memory> src_memory_p;
std::shared_ptr<mkldnn::memory> dst_memory_p;
std::vector<primitive> pipeline;
auto prim_key = key + "@conv_p";
auto dst_key = key + "@dst_mem_p";
auto src_key = key + "@src_mem_p";
conv_p = std::static_pointer_cast<mkldnn::convolution_forward>(dev_ctx.GetBlob(prim_key));
src_memory_p = std::static_pointer_cast<mkldnn::memory>(dev_ctx.GetBlob(src_key));
dst_memory_p = std::static_pointer_cast<mkldnn::memory>(dev_ctx.GetBlob(dst_key));
if (src_memory_p) {
src_memory_p->set_data_handle(to_void_cast<T>(input_data));
}
if(conv_p == nullptr){
auto* scale_in = ctx.HasInput("Scale_in") ? ctx.Input<Tensor>("Scale_in") : nullptr; auto* scale_in = ctx.HasInput("Scale_in") ? ctx.Input<Tensor>("Scale_in") : nullptr;
auto* scale_in_eltwise = ctx.HasInput("Scale_in_eltwise")? ctx.Input<Tensor>("Scale_in_eltwise") : nullptr; auto* scale_in_eltwise = ctx.HasInput("Scale_in_eltwise")? ctx.Input<Tensor>("Scale_in_eltwise") : nullptr;
auto* scale_weights = ctx.HasInput("Scale_weights")? ctx.Input<Tensor>("Scale_weights") : nullptr; auto* scale_weights = ctx.HasInput("Scale_weights")? ctx.Input<Tensor>("Scale_weights") : nullptr;
...@@ -627,7 +648,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -627,7 +648,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
user_weights_md, to_void_cast<float>(filter_data)); user_weights_md, to_void_cast<float>(filter_data));
// create reorder primitive if the input format is not the preferred one // create reorder primitive if the input format is not the preferred one
auto src_memory_p = src_memory_p =
handler.AcquireSrcMemoryFromPrimitive(user_src_memory_p, pipeline); handler.AcquireSrcMemoryFromPrimitive(user_src_memory_p, pipeline);
std::shared_ptr<mkldnn::memory> weights_memory_p; std::shared_ptr<mkldnn::memory> weights_memory_p;
...@@ -635,8 +656,6 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -635,8 +656,6 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
weights_memory_p = handler.AcquireWeightsMemoryFromPrimitive( weights_memory_p = handler.AcquireWeightsMemoryFromPrimitive(
user_weights_memory_p, pipeline, is_test, is_INT8, scale_weights_data, mask_reorder); user_weights_memory_p, pipeline, is_test, is_INT8, scale_weights_data, mask_reorder);
std::shared_ptr<mkldnn::memory> dst_memory_p;
bool need_s8_to_u8 = false;
if(fuse_residual_conn) { if(fuse_residual_conn) {
auto residual_param = ctx.Input<Tensor>("ResidualData"); auto residual_param = ctx.Input<Tensor>("ResidualData");
PADDLE_ENFORCE_EQ(output->dims(), residual_param->dims(), PADDLE_ENFORCE_EQ(output->dims(), residual_param->dims(),
...@@ -654,8 +673,6 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -654,8 +673,6 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
int8_t* output_data = output->mutable_data<int8_t>(ctx.GetPlace()); int8_t* output_data = output->mutable_data<int8_t>(ctx.GetPlace());
dst_memory_p = dst_memory_p =
handler.AcquireDstMemoryFromPrimitive(to_void_cast<int8_t>(output_data)); handler.AcquireDstMemoryFromPrimitive(to_void_cast<int8_t>(output_data));
if(fuse_relu)
need_s8_to_u8 = true;
} }
} else { } else {
if(fuse_relu){ if(fuse_relu){
...@@ -670,7 +687,6 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -670,7 +687,6 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
} }
// create convolution op primitive // create convolution op primitive
std::shared_ptr<mkldnn::convolution_forward> conv_p;
std::vector<float> scale_bias_data; std::vector<float> scale_bias_data;
auto scale_bias_key = key + "@scale_bias"; auto scale_bias_key = key + "@scale_bias";
if (bias) { if (bias) {
...@@ -712,6 +728,17 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -712,6 +728,17 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
output->set_layout(DataLayout::kMKLDNN); output->set_layout(DataLayout::kMKLDNN);
output->set_format(GetMKLDNNFormat(*dst_memory_p)); output->set_format(GetMKLDNNFormat(*dst_memory_p));
} else {
pipeline.push_back(*conv_p);
stream(stream::kind::eager).submit(pipeline).wait();
if (need_s8_to_u8) {
output->mutable_data<uint8_t>(ctx.GetPlace());
}
output->set_layout(DataLayout::kMKLDNN);
output->set_format(GetMKLDNNFormat(*dst_memory_p));
}
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册