未验证 提交 bce67644 编写于 作者: J jiangcheng 提交者: GitHub

[CINN] add paddle2cinn visible variables api (#54732)

上级 24523c16
......@@ -21,6 +21,7 @@
#include <mutex>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include "cinn/common/target.h"
#include "paddle/fluid/framework/ir/graph.h"
......
......@@ -188,6 +188,14 @@ void CinnLaunchContext::BuildVarNameMap(
cinn2paddle_varmap_.size()));
}
std::unordered_set<std::string> CinnLaunchContext::GetVisibleVarNames() const {
std::unordered_set<std::string> remain_var_names;
for (const auto& pair : paddle2cinn_varmap_) {
remain_var_names.insert(this->RedirectVarName(pair.first));
}
return remain_var_names;
}
void CinnLaunchContext::UpdateCapturedEnv(const framework::Scope& scope,
const platform::Place& place) {
if (std::addressof(scope) == cached_scope_ &&
......@@ -547,7 +555,8 @@ framework::InterpreterCore* CinnLaunchContext::InitializeInterpreterCore(
return interpreter_core_.get();
}
std::string CinnLaunchContext::RedirectVarName(const std::string& var_name) {
std::string CinnLaunchContext::RedirectVarName(
const std::string& var_name) const {
auto pos = var_name.find(InplaceOutSuffix);
if (pos == std::string::npos) {
return var_name;
......
......@@ -97,13 +97,15 @@ class CinnLaunchContext {
}
// Redirect the name of a Paddle variable to the orignal if it was inplaced
std::string RedirectVarName(const std::string& var_name);
std::string RedirectVarName(const std::string& var_name) const;
// Return internal variable names list
const std::unordered_set<std::string>& GetInternalVarNames() const {
return internal_var_names_;
}
std::unordered_set<std::string> GetVisibleVarNames() const;
// Finalize all execution arguments and return the name->argument map
const std::map<std::string, cinn_pod_value_t>& FinalizeArguments() const {
return name2argument_;
......
......@@ -133,6 +133,11 @@ class CinnLaunchOpKernel : public framework::OpKernel<T> {
end_t - start_t);
VLOG(1) << "Ends to compile at thread " << std::this_thread::get_id()
<< " , time cost : " << time_sec.count() << " ms";
const auto& visible_names =
cinn_compiled_object.launch_context->GetVisibleVarNames();
VLOG(1) << "These CINN variable can visible by Paddle: "
<< string::join_strings(visible_names, ", ");
}
details::DebugCinnCompiledResult(cinn_compiled_object);
auto* launch_context = cinn_compiled_object.launch_context.get();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册