diff --git a/paddle/fluid/framework/paddle2cinn/cinn_cache_key.cc b/paddle/fluid/framework/paddle2cinn/cinn_cache_key.cc old mode 100644 new mode 100755 index 90bb7337f2a355ec89553c5f9d46f5571a76394f..3a7aa273f27ba08446ba9909eb1930dea11aac24 --- a/paddle/fluid/framework/paddle2cinn/cinn_cache_key.cc +++ b/paddle/fluid/framework/paddle2cinn/cinn_cache_key.cc @@ -84,7 +84,7 @@ size_t CinnCacheKey::Hash::operator()(const CinnCacheKey& key) const { for (const auto& name_shape : key.input_shapes_) { has_str << name_shape.first; - has_str << name_shape.second.to_str(); + has_str << std::hash()(name_shape.second); } has_str << key.graph_hash_val_; diff --git a/paddle/fluid/operators/cinn/cinn_launch_context.cc b/paddle/fluid/operators/cinn/cinn_launch_context.cc index 982fedfe23d8c47d66bb3415e13324eebb6ec674..e7da2d636c3e6b137789ff7995e78607ddbcb58a 100644 --- a/paddle/fluid/operators/cinn/cinn_launch_context.cc +++ b/paddle/fluid/operators/cinn/cinn_launch_context.cc @@ -255,6 +255,8 @@ void CinnLaunchContext::InitializeArguments() { framework::DDim(cinn_buffer->dims, cinn_buffer->dimensions).to_str(), cinn_tensor->type()); name2argument_.emplace(arg, cinn_buffer.get()); + auto pdvar2cinnbuf_ = cinn2paddle_varmap_.at(arg); + paddle2argument_.emplace(pdvar2cinnbuf_, cinn_buffer.get()); hold_buffers_.emplace_back(std::move(cinn_buffer)); } VLOG(4) << "Total argument size:" << name2argument_.size(); @@ -491,17 +493,12 @@ framework::InterpreterCore* CinnLaunchContext::InitializeInterpreterCore( cinn_buffer_t* CinnLaunchContext::GetCinnBufferOfVar( const std::string& var_name) { - auto it = paddle2cinn_varmap_.find(var_name); + auto res = paddle2argument_.find(var_name); PADDLE_ENFORCE_NE( - it, - paddle2cinn_varmap_.end(), - platform::errors::InvalidArgument( - "Variable(%s) not found in compilation result", var_name)); - auto res = name2argument_.find(it->second); - PADDLE_ENFORCE_NE(res, - name2argument_.end(), - platform::errors::NotFound( - "Argument(%s) not be initialized", it->second)); + res, + paddle2argument_.end(), + platform::errors::NotFound("Variable(%s) not found in compilation result", + var_name)); return static_cast(res->second); } diff --git a/paddle/fluid/operators/cinn/cinn_launch_context.h b/paddle/fluid/operators/cinn/cinn_launch_context.h index d6ce95de0859d0e3d2e63e0dbaef8ae7bbac5036..e66658750bb230e0a81cb31095bacbf509a6d635 100644 --- a/paddle/fluid/operators/cinn/cinn_launch_context.h +++ b/paddle/fluid/operators/cinn/cinn_launch_context.h @@ -177,6 +177,9 @@ class CinnLaunchContext { // this map saves all execution arguments with their cinn names as key, // and it is passed to the Execute interface of a cinn runtime program. std::map name2argument_; + // this map saves all execution arguments with paddle variables as key, + // this map conbine name2argument_ and paddle2cinn_varmap_ + std::map paddle2argument_; }; } // namespace operators::details diff --git a/paddle/phi/core/ddim.cc b/paddle/phi/core/ddim.cc index 18778c9abf60f6051599fe4419b8029ac8b812c1..3256458e02be9e9659cf3bbba52503bc0277c459 100644 --- a/paddle/phi/core/ddim.cc +++ b/paddle/phi/core/ddim.cc @@ -203,3 +203,16 @@ DDim DDim::transpose(const std::vector& axis) const { } } // namespace phi + +namespace std { + +std::size_t hash::operator()(phi::DDim const& ddim) const { + int ndim = ddim.size(); + std::size_t seed = ndim; + for (int i = 0; i < ndim; ++i) { + seed ^= ddim.Get()[i] + 0x9e3779b9 + (seed << 6) + (seed >> 2); + } + return seed; +} + +} // namespace std diff --git a/paddle/phi/core/ddim.h b/paddle/phi/core/ddim.h old mode 100644 new mode 100755 index 794d7051aee58e4fc63c40f2d7dd361ed6b834ba..8003df7fd6842d286b037679d4c71dd9257ac353 --- a/paddle/phi/core/ddim.h +++ b/paddle/phi/core/ddim.h @@ -155,7 +155,7 @@ class DDim { std::string to_str() const; - DDim reshape(std::vector& shape) const; + DDim reshape(std::vector& shape) const; // NOLINT DDim transpose(const std::vector& axis) const; @@ -262,3 +262,10 @@ using DDim = phi::DDim; } // namespace framework } // namespace paddle + +namespace std { +template <> +struct hash { + std::size_t operator()(phi::DDim const& ddim) const; +}; +} // namespace std diff --git a/paddle/phi/tests/core/test_ddim.cc b/paddle/phi/tests/core/test_ddim.cc old mode 100644 new mode 100755 index 72c91b452296f968a6b89bb545b3ed1d8ad64b73..0251e3f8bb9116082b808883ec068e6ab4801e1b --- a/paddle/phi/tests/core/test_ddim.cc +++ b/paddle/phi/tests/core/test_ddim.cc @@ -124,5 +124,13 @@ TEST(DDim, Print) { EXPECT_EQ("", ss2.str()); } +TEST(DDim, Hash) { + // hash a DDim + std::size_t h; + phi::DDim ddim = phi::make_ddim({2, 3, 4}); + h = std::hash()(ddim); + EXPECT_EQ(h, 0xa16fb2b2967ul); +} + } // namespace tests } // namespace phi