From 1168a1785d77b1549531a876e34d42982aa84e88 Mon Sep 17 00:00:00 2001 From: GaoYuYang Date: Thu, 5 Jan 2023 10:38:08 +0800 Subject: [PATCH] Add to_hash func and paddle2arg map for cinn (#49402) --- .../framework/paddle2cinn/cinn_cache_key.cc | 2 +- .../fluid/operators/cinn/cinn_launch_context.cc | 17 +++++++---------- .../fluid/operators/cinn/cinn_launch_context.h | 3 +++ paddle/phi/core/ddim.cc | 13 +++++++++++++ paddle/phi/core/ddim.h | 9 ++++++++- paddle/phi/tests/core/test_ddim.cc | 8 ++++++++ 6 files changed, 40 insertions(+), 12 deletions(-) mode change 100644 => 100755 paddle/fluid/framework/paddle2cinn/cinn_cache_key.cc mode change 100644 => 100755 paddle/phi/core/ddim.h mode change 100644 => 100755 paddle/phi/tests/core/test_ddim.cc diff --git a/paddle/fluid/framework/paddle2cinn/cinn_cache_key.cc b/paddle/fluid/framework/paddle2cinn/cinn_cache_key.cc old mode 100644 new mode 100755 index 90bb7337f2..3a7aa273f2 --- a/paddle/fluid/framework/paddle2cinn/cinn_cache_key.cc +++ b/paddle/fluid/framework/paddle2cinn/cinn_cache_key.cc @@ -84,7 +84,7 @@ size_t CinnCacheKey::Hash::operator()(const CinnCacheKey& key) const { for (const auto& name_shape : key.input_shapes_) { has_str << name_shape.first; - has_str << name_shape.second.to_str(); + has_str << std::hash()(name_shape.second); } has_str << key.graph_hash_val_; diff --git a/paddle/fluid/operators/cinn/cinn_launch_context.cc b/paddle/fluid/operators/cinn/cinn_launch_context.cc index 982fedfe23..e7da2d636c 100644 --- a/paddle/fluid/operators/cinn/cinn_launch_context.cc +++ b/paddle/fluid/operators/cinn/cinn_launch_context.cc @@ -255,6 +255,8 @@ void CinnLaunchContext::InitializeArguments() { framework::DDim(cinn_buffer->dims, cinn_buffer->dimensions).to_str(), cinn_tensor->type()); name2argument_.emplace(arg, cinn_buffer.get()); + auto pdvar2cinnbuf_ = cinn2paddle_varmap_.at(arg); + paddle2argument_.emplace(pdvar2cinnbuf_, cinn_buffer.get()); hold_buffers_.emplace_back(std::move(cinn_buffer)); } VLOG(4) << "Total argument size:" << name2argument_.size(); @@ -491,17 +493,12 @@ framework::InterpreterCore* CinnLaunchContext::InitializeInterpreterCore( cinn_buffer_t* CinnLaunchContext::GetCinnBufferOfVar( const std::string& var_name) { - auto it = paddle2cinn_varmap_.find(var_name); + auto res = paddle2argument_.find(var_name); PADDLE_ENFORCE_NE( - it, - paddle2cinn_varmap_.end(), - platform::errors::InvalidArgument( - "Variable(%s) not found in compilation result", var_name)); - auto res = name2argument_.find(it->second); - PADDLE_ENFORCE_NE(res, - name2argument_.end(), - platform::errors::NotFound( - "Argument(%s) not be initialized", it->second)); + res, + paddle2argument_.end(), + platform::errors::NotFound("Variable(%s) not found in compilation result", + var_name)); return static_cast(res->second); } diff --git a/paddle/fluid/operators/cinn/cinn_launch_context.h b/paddle/fluid/operators/cinn/cinn_launch_context.h index d6ce95de08..e66658750b 100644 --- a/paddle/fluid/operators/cinn/cinn_launch_context.h +++ b/paddle/fluid/operators/cinn/cinn_launch_context.h @@ -177,6 +177,9 @@ class CinnLaunchContext { // this map saves all execution arguments with their cinn names as key, // and it is passed to the Execute interface of a cinn runtime program. std::map name2argument_; + // this map saves all execution arguments with paddle variables as key, + // this map conbine name2argument_ and paddle2cinn_varmap_ + std::map paddle2argument_; }; } // namespace operators::details diff --git a/paddle/phi/core/ddim.cc b/paddle/phi/core/ddim.cc index 18778c9abf..3256458e02 100644 --- a/paddle/phi/core/ddim.cc +++ b/paddle/phi/core/ddim.cc @@ -203,3 +203,16 @@ DDim DDim::transpose(const std::vector& axis) const { } } // namespace phi + +namespace std { + +std::size_t hash::operator()(phi::DDim const& ddim) const { + int ndim = ddim.size(); + std::size_t seed = ndim; + for (int i = 0; i < ndim; ++i) { + seed ^= ddim.Get()[i] + 0x9e3779b9 + (seed << 6) + (seed >> 2); + } + return seed; +} + +} // namespace std diff --git a/paddle/phi/core/ddim.h b/paddle/phi/core/ddim.h old mode 100644 new mode 100755 index 794d7051ae..8003df7fd6 --- a/paddle/phi/core/ddim.h +++ b/paddle/phi/core/ddim.h @@ -155,7 +155,7 @@ class DDim { std::string to_str() const; - DDim reshape(std::vector& shape) const; + DDim reshape(std::vector& shape) const; // NOLINT DDim transpose(const std::vector& axis) const; @@ -262,3 +262,10 @@ using DDim = phi::DDim; } // namespace framework } // namespace paddle + +namespace std { +template <> +struct hash { + std::size_t operator()(phi::DDim const& ddim) const; +}; +} // namespace std diff --git a/paddle/phi/tests/core/test_ddim.cc b/paddle/phi/tests/core/test_ddim.cc old mode 100644 new mode 100755 index 72c91b4522..0251e3f8bb --- a/paddle/phi/tests/core/test_ddim.cc +++ b/paddle/phi/tests/core/test_ddim.cc @@ -124,5 +124,13 @@ TEST(DDim, Print) { EXPECT_EQ("", ss2.str()); } +TEST(DDim, Hash) { + // hash a DDim + std::size_t h; + phi::DDim ddim = phi::make_ddim({2, 3, 4}); + h = std::hash()(ddim); + EXPECT_EQ(h, 0xa16fb2b2967ul); +} + } // namespace tests } // namespace phi -- GitLab