未验证 提交 c62902ee 编写于 作者: Z zhanglirong1999 提交者: GitHub

[ONEDNN] fix accuracy issue of fc when the input shapes are dynamic

上级 3bcc91e4
......@@ -284,7 +284,13 @@ class FCMKLDNNHandler
std::shared_ptr<dnnl::memory> AcquireWeightsMemoryWithReorder(
const phi::DenseTensor* weights, const std::vector<float>& scale_data) {
const std::string weights_key = this->memory_key_ + "@weights";
const std::string weights_base_key = this->memory_key_ + "@weights";
std::string weights_key;
weights_key.reserve(128);
weights_key = phi::funcs::ExtendKeyWithThreadInfoIfNeeded(
dev_ctx_,
phi::funcs::CreateKey(
dev_ctx_, weights_base_key, this->fwd_pd_->weights_desc()));
auto memory_p = std::static_pointer_cast<dnnl::memory>(
this->dev_ctx_.GetBlob(weights_key));
......@@ -410,7 +416,8 @@ class FCMKLDNNKernel : public framework::OpKernel<T_in> {
phi::funcs::CreateKey(dev_ctx,
ctx.InputName("Input"),
ctx.InputName("W"),
phi::vectorize(x->dims())));
phi::vectorize(x->dims()),
phi::vectorize(weights->dims())));
auto inner_product_cache =
std::static_pointer_cast<InnerProductCache>(dev_ctx.GetBlob(cache_key));
......
......@@ -154,6 +154,12 @@ inline void AppendKey(std::string* key, const T& num) {
key->append(std::to_string(num));
}
template <>
inline void AppendKey(std::string* key,
const dnnl::memory::format_kind& format) {
key->append(std::to_string(static_cast<int>(format)));
}
template <>
inline void AppendKey(std::string* key,
const dnnl::memory::format_tag& format) {
......@@ -171,6 +177,25 @@ inline void AppendKey(std::string* key, const dnnl::algorithm& algorithm) {
key->append(std::to_string(static_cast<int>(algorithm)));
}
template <>
inline void AppendKey(std::string* key, const dnnl::memory::dims& dims) {
for (size_t i = 0; i < dims.size(); i++) {
AppendKey(key, static_cast<int64_t>(dims[i]));
}
}
template <>
inline void AppendKey(std::string* key, const dnnl::memory::desc& md) {
AppendKey(key, md.get_dims());
AppendKey(key, md.get_data_type());
AppendKey(key, md.get_format_kind());
AppendKey(key, md.get_inner_blks());
AppendKey(key, md.get_inner_idxs());
AppendKey(key, md.get_inner_nblks());
AppendKey(key, md.get_padded_dims());
AppendKey(key, md.get_strides());
}
template <>
inline void AppendKey(std::string* key,
const dnnl::normalization_flags& flags) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册