提交 7afea1b6 编写于 作者: H Haihao Shen

Fix the potential issue on conversion from s8 to u8

上级 1e8baec5
...@@ -479,6 +479,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -479,6 +479,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
} }
std::shared_ptr<mkldnn::memory> dst_memory_p; std::shared_ptr<mkldnn::memory> dst_memory_p;
bool need_s8_to_u8 = false;
if(is_INT8){ if(is_INT8){
if (fuse_residual_conn) { if (fuse_residual_conn) {
auto residual_param = ctx.Input<Tensor>("ResidualData"); auto residual_param = ctx.Input<Tensor>("ResidualData");
...@@ -496,6 +497,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -496,6 +497,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
int8_t* output_data = output->mutable_data<int8_t>(ctx.GetPlace()); int8_t* output_data = output->mutable_data<int8_t>(ctx.GetPlace());
dst_memory_p = dst_memory_p =
handler.AcquireDstMemoryFromPrimitive(to_void_cast<int8_t>(output_data)); handler.AcquireDstMemoryFromPrimitive(to_void_cast<int8_t>(output_data));
if(fuse_relu)
need_s8_to_u8 = true;
} }
} else { } else {
if(fuse_relu){ if(fuse_relu){
...@@ -565,7 +568,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -565,7 +568,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
pipeline.push_back(*conv_p); pipeline.push_back(*conv_p);
stream(stream::kind::eager).submit(pipeline).wait(); stream(stream::kind::eager).submit(pipeline).wait();
if(is_INT8 && fuse_residual_conn){ if(need_s8_to_u8){
output->mutable_data<uint8_t>(ctx.GetPlace()); output->mutable_data<uint8_t>(ctx.GetPlace());
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册