未验证 提交 1168a178 编写于 作者: G GaoYuYang 提交者: GitHub

Add to_hash func and paddle2arg map for cinn (#49402)

上级 1228bad0
......@@ -84,7 +84,7 @@ size_t CinnCacheKey::Hash::operator()(const CinnCacheKey& key) const {
for (const auto& name_shape : key.input_shapes_) {
has_str << name_shape.first;
has_str << name_shape.second.to_str();
has_str << std::hash<phi::DDim>()(name_shape.second);
}
has_str << key.graph_hash_val_;
......
......@@ -255,6 +255,8 @@ void CinnLaunchContext::InitializeArguments() {
framework::DDim(cinn_buffer->dims, cinn_buffer->dimensions).to_str(),
cinn_tensor->type());
name2argument_.emplace(arg, cinn_buffer.get());
auto pdvar2cinnbuf_ = cinn2paddle_varmap_.at(arg);
paddle2argument_.emplace(pdvar2cinnbuf_, cinn_buffer.get());
hold_buffers_.emplace_back(std::move(cinn_buffer));
}
VLOG(4) << "Total argument size:" << name2argument_.size();
......@@ -491,17 +493,12 @@ framework::InterpreterCore* CinnLaunchContext::InitializeInterpreterCore(
cinn_buffer_t* CinnLaunchContext::GetCinnBufferOfVar(
const std::string& var_name) {
auto it = paddle2cinn_varmap_.find(var_name);
auto res = paddle2argument_.find(var_name);
PADDLE_ENFORCE_NE(
it,
paddle2cinn_varmap_.end(),
platform::errors::InvalidArgument(
"Variable(%s) not found in compilation result", var_name));
auto res = name2argument_.find(it->second);
PADDLE_ENFORCE_NE(res,
name2argument_.end(),
platform::errors::NotFound(
"Argument(%s) not be initialized", it->second));
res,
paddle2argument_.end(),
platform::errors::NotFound("Variable(%s) not found in compilation result",
var_name));
return static_cast<cinn_buffer_t*>(res->second);
}
......
......@@ -177,6 +177,9 @@ class CinnLaunchContext {
// this map saves all execution arguments with their cinn names as key,
// and it is passed to the Execute interface of a cinn runtime program.
std::map<std::string, cinn_pod_value_t> name2argument_;
// this map saves all execution arguments with paddle variables as key,
// this map conbine name2argument_ and paddle2cinn_varmap_
std::map<std::string, cinn_pod_value_t> paddle2argument_;
};
} // namespace operators::details
......
......@@ -203,3 +203,16 @@ DDim DDim::transpose(const std::vector<int>& axis) const {
}
} // namespace phi
namespace std {
std::size_t hash<phi::DDim>::operator()(phi::DDim const& ddim) const {
int ndim = ddim.size();
std::size_t seed = ndim;
for (int i = 0; i < ndim; ++i) {
seed ^= ddim.Get()[i] + 0x9e3779b9 + (seed << 6) + (seed >> 2);
}
return seed;
}
} // namespace std
......@@ -155,7 +155,7 @@ class DDim {
std::string to_str() const;
DDim reshape(std::vector<int>& shape) const;
DDim reshape(std::vector<int>& shape) const; // NOLINT
DDim transpose(const std::vector<int>& axis) const;
......@@ -262,3 +262,10 @@ using DDim = phi::DDim;
} // namespace framework
} // namespace paddle
namespace std {
template <>
struct hash<phi::DDim> {
std::size_t operator()(phi::DDim const& ddim) const;
};
} // namespace std
......@@ -124,5 +124,13 @@ TEST(DDim, Print) {
EXPECT_EQ("", ss2.str());
}
TEST(DDim, Hash) {
// hash a DDim
std::size_t h;
phi::DDim ddim = phi::make_ddim({2, 3, 4});
h = std::hash<phi::DDim>()(ddim);
EXPECT_EQ(h, 0xa16fb2b2967ul);
}
} // namespace tests
} // namespace phi
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册