提交 2c32c2d6 编写于 作者: L lidanqing 提交者: Tao Luo

Refactor conv computeINT8 (#19574)

* fix conflicts
test=develop

* change mask_bias_reorder
test=develop

* add ComputeMask function to make code clear
test=develop

* change according to reviews
test=develop

* change according to reviews
test=develop
上级 3f1d0234
......@@ -980,15 +980,6 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler {
mkldnn::engine engine, const std::string& base_key)
: platform::MKLDNNHandler(dev_ctx, engine, base_key) {}
// TODO(jczaja): remove after conv int8 is adapted
ConvMKLDNNTemplateHandler(
std::shared_ptr<typename forward_t::primitive_desc> conv_pd,
const platform::MKLDNNDeviceContext& dev_ctx, mkldnn::engine engine,
const std::string& base_key)
: platform::MKLDNNHandler(dev_ctx, engine, base_key) {
conv_pd_ = conv_pd;
}
ConvMKLDNNTemplateHandler(
std::shared_ptr<typename forward_t::primitive_desc> conv_pd,
std::shared_ptr<typename backward_data_t::primitive_desc>
......@@ -1309,47 +1300,6 @@ using ConvTransposeMKLDNNHandler =
mkldnn::deconvolution_backward_data,
mkldnn::deconvolution_backward_weights>;
template <typename T>
static std::shared_ptr<mkldnn::memory> SetDstMemory(
const framework::ExecutionContext& ctx, framework::Tensor* output,
const std::shared_ptr<ConvMKLDNNHandler>& handler) {
T* output_data =
output->mutable_data<T>(ctx.GetPlace(), handler->GetDstMemorySize());
std::shared_ptr<mkldnn::memory> dst_memory_p =
handler->AcquireDstMemoryFromPrimitive(to_void_cast<T>(output_data));
return dst_memory_p;
}
template <typename T>
static std::shared_ptr<mkldnn::memory> SetDstMemory(
const framework::ExecutionContext& ctx, framework::Tensor* output,
const framework::Tensor* residual_param,
const mkldnn::memory::desc& user_residual_md,
const std::shared_ptr<ConvMKLDNNHandler>& handler,
std::vector<mkldnn::primitive>* pipeline) {
const T* residual_param_data = residual_param->data<T>();
PADDLE_ENFORCE(residual_param_data != nullptr,
"Provide data if you want MKLDNN conv+elementwise_add fusion");
std::shared_ptr<mkldnn::memory> user_residual_memory_p =
handler->AcquireResidualDataMemory(user_residual_md,
to_void_cast<T>(residual_param_data));
T* output_data = output->mutable_data<T>(ctx.GetPlace());
std::shared_ptr<mkldnn::memory> dst_memory_p =
handler->AcquireDstMemoryFromResidualDataMemory(
user_residual_memory_p, to_void_cast<T>(output_data), *pipeline);
return dst_memory_p;
}
template <typename T>
static void SetDstMemoryHandler(
const framework::ExecutionContext& ctx, framework::Tensor* output,
const std::shared_ptr<ConvMKLDNNHandler>& handler,
std::shared_ptr<mkldnn::memory> dst_memory_p) {
T* output_data =
output->mutable_data<T>(ctx.GetPlace(), handler->GetDstMemorySize());
dst_memory_p->set_data_handle(to_void_cast<T>(output_data));
}
template <typename T>
static void SetDstMemoryQuantized(
const framework::ExecutionContext& ctx, framework::Tensor* output,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册