提交 c167474c 编写于 作者: X xiaolil1

modify requantize for s8 u8 reorder

上级 edc53a0d
......@@ -35,6 +35,7 @@ template <typename T>
class ReQuantOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
//std::cout<<"this is requant op!!!!!"<<std::endl;
auto* input = ctx.Input<Tensor>("Input");
//auto* scale = ctx.Input<Tensor>("Scale");
auto* output = ctx.Output<Tensor>("Output");
......@@ -74,12 +75,20 @@ class ReQuantOpKernel : public framework::OpKernel<T> {
auto reorder_pd = std::shared_ptr<reorder::primitive_desc>(
new reorder::primitive_desc(src_pd, dst_pd, attri));
auto reorder_p= std::shared_ptr<reorder>(new reorder(*reorder_pd, *src_memory_p, dst_memory));
pipeline.push_back(*reorder_p);
stream(stream::kind::eager).submit(pipeline).wait();
int is_sum = ctx.Attr<int>("is_sum");
if(is_sum){
//std::cout<<"is_sum == true"<<std::endl;
memcpy(output_data, input_data, sizeof(uint8_t) * input->numel());
} else{
auto reorder_p= std::shared_ptr<reorder>(new reorder(*reorder_pd, *src_memory_p, dst_memory));
pipeline.push_back(*reorder_p);
stream(stream::kind::eager).submit(pipeline).wait();
}
output->set_layout(DataLayout::kMKLDNN);
output->set_format(GetMKLDNNFormat(dst_memory));
//std::cout<<"requant op end!!!!!"<<std::endl;
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册