提交 7afea1b6 编写于 作者: H Haihao Shen

Fix the potential issue on conversion from s8 to u8

上级 1e8baec5
......@@ -479,6 +479,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
}
std::shared_ptr<mkldnn::memory> dst_memory_p;
bool need_s8_to_u8 = false;
if(is_INT8){
if (fuse_residual_conn) {
auto residual_param = ctx.Input<Tensor>("ResidualData");
......@@ -496,6 +497,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
int8_t* output_data = output->mutable_data<int8_t>(ctx.GetPlace());
dst_memory_p =
handler.AcquireDstMemoryFromPrimitive(to_void_cast<int8_t>(output_data));
if(fuse_relu)
need_s8_to_u8 = true;
}
} else {
if(fuse_relu){
......@@ -565,7 +568,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
pipeline.push_back(*conv_p);
stream(stream::kind::eager).submit(pipeline).wait();
if(is_INT8 && fuse_residual_conn){
if(need_s8_to_u8){
output->mutable_data<uint8_t>(ctx.GetPlace());
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册