未验证 提交 485b387d 编写于 作者: J Jacek Czaja 提交者: GitHub

[oneDNN] candidate fix to #34554 (#35884)

* - candidate fix

* - More fixes to #34554

* - another incosnstent fix to key

* - Remvoed unneeded line

* - matching the cache behaviour to other ops
上级 4f42e5d7
......@@ -706,7 +706,6 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
platform::CreateKey(dev_ctx, src_tz, src_dt,
ctx.InputName("Input") + ctx.InputName("Filter"));
const std::string key_conv_pd = key + "@conv_pd";
bool need_s8_to_u8 = false;
std::shared_ptr<mkldnn::convolution_forward> conv_p;
std::shared_ptr<mkldnn::memory> src_memory_p;
......@@ -721,6 +720,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
// are merged/unified, this will disappear
auto key_tid = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key);
const std::string key_conv_pd = key_tid + "@conv_pd";
auto prim_key = key_tid + "@conv_p";
auto dst_key = key_tid + "@dst_mem_p";
auto src_key = key_tid + "@src_mem_p";
......@@ -731,12 +731,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto src_reorder_key = key_tid + "@src_mem_preorder_p";
auto residual_reorder_key = key_tid + "@residual_data_mem_preorder_p";
conv_p = std::static_pointer_cast<mkldnn::convolution_forward>(
dev_ctx.GetBlob(prim_key));
conv_pd =
std::static_pointer_cast<mkldnn::convolution_forward::primitive_desc>(
dev_ctx.GetBlob(key_conv_pd));
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
if (conv_p == nullptr || !is_test) {
if (conv_pd == nullptr || !is_test) {
float fuse_alpha = ctx.Attr<float>("fuse_alpha");
float fuse_beta = ctx.Attr<float>("fuse_beta");
bool force_fp32_output = ctx.Attr<bool>("force_fp32_output");
......@@ -946,7 +947,6 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
}
// create convolution op primitive
auto scale_bias_key = key + "@scale_bias";
conv_p = handler->AcquireConvolution();
if (bias) {
const K* bias_data = bias->data<K>();
......@@ -1000,13 +1000,10 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
dev_ctx.GetBlob(weights_key));
dst_memory_p =
std::static_pointer_cast<mkldnn::memory>(dev_ctx.GetBlob(dst_key));
conv_pd =
std::static_pointer_cast<mkldnn::convolution_forward::primitive_desc>(
dev_ctx.GetBlob(key_conv_pd));
if (conv_pd) {
conv_p = std::static_pointer_cast<mkldnn::convolution_forward>(
dev_ctx.GetBlob(prim_key));
handler.reset(new platform::ConvMKLDNNHandler(conv_pd, dev_ctx,
mkldnn_engine, key));
}
if (fuse_residual_conn) {
auto residual_param = ctx.Input<Tensor>("ResidualData");
......
......@@ -603,7 +603,6 @@ class MKLDNNHandler {
const std::string& base_key)
: dev_ctx_(dev_ctx),
engine_(engine),
key_common_(base_key),
key_(platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, base_key)) {
platform::MKLDNNDeviceContext::tls().log_lib_version();
}
......@@ -789,7 +788,6 @@ class MKLDNNHandler {
protected:
const MKLDNNDeviceContext& dev_ctx_;
mkldnn::engine engine_;
std::string key_common_;
std::string key_;
};
......@@ -1371,18 +1369,11 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler {
// Conv PD has to be passed to Grad op that
// may be exxecuted by diffrent thread, hence
// for that one we use key that does not contain TID
const std::string key_conv_pd = key_common_ + "@conv_pd";
const std::string key_conv_pd = key_ + "@conv_pd";
conv_pd_ = std::static_pointer_cast<typename forward_t::primitive_desc>(
dev_ctx_.GetBlob(key_conv_pd));
if (conv_pd_ == nullptr) {
static std::mutex acquire_barrier;
std::lock_guard<std::mutex> block_threads_until_finish_this_job(
acquire_barrier);
conv_pd_ = std::static_pointer_cast<typename forward_t::primitive_desc>(
dev_ctx_.GetBlob(key_conv_pd));
if (conv_pd_ == nullptr) {
mkldnn::memory::dims stride_dims = strides;
mkldnn::memory::dims dilations_dims = dilations;
......@@ -1390,24 +1381,23 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler {
auto conv_desc =
bias ? typename forward_t::desc(
fwd_prop_kind, convolutional_algorithm<forward_t>::T,
src, weights, *bias, dst, stride_dims, dilations_dims,
fwd_prop_kind, convolutional_algorithm<forward_t>::T, src,
weights, *bias, dst, stride_dims, dilations_dims,
mkldnn_paddings[0], mkldnn_paddings[1])
: typename forward_t::desc(
fwd_prop_kind, convolutional_algorithm<forward_t>::T,
src, weights, dst, stride_dims, dilations_dims,
fwd_prop_kind, convolutional_algorithm<forward_t>::T, src,
weights, dst, stride_dims, dilations_dims,
mkldnn_paddings[0], mkldnn_paddings[1]);
mkldnn::primitive_attr conv_attr =
CreatePostOps(fuse_activation, fuse_alpha, fuse_beta,
fuse_residual_conn, output_shift_scale, sum_scale);
conv_pd_.reset(new typename forward_t::primitive_desc(
conv_desc, conv_attr, engine));
conv_pd_.reset(
new typename forward_t::primitive_desc(conv_desc, conv_attr, engine));
// Save conv_pd/src_memory/weights_memory for backward pass
dev_ctx_.SetBlob(key_conv_pd, conv_pd_);
}
}
return conv_pd_;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册