未验证 提交 a9bd6f0c 编写于 作者: C CtfGo 提交者: GitHub

change serval variable name and usage related cinn_launch (#38022)

上级 43f19cc3
......@@ -35,8 +35,8 @@ namespace paddle {
namespace operators {
namespace details {
class CinnLaunchContext;
}
}
} // namespace details
} // namespace operators
namespace framework {
namespace paddle2cinn {
......
......@@ -55,9 +55,11 @@ bool CinnLaunchContext::IsArgumentsInitialized() const {
return true;
}
bool CinnLaunchContext::IsVariableUsed(const std::string& paddle_name) const {
return paddle2cinn_varmap_.count(paddle_name) > 0 &&
cinn_variable_names_.count(paddle2cinn_varmap_.at(paddle_name)) > 0;
bool CinnLaunchContext::IsVariableUsed(
const std::string& paddle_var_name) const {
return paddle2cinn_varmap_.count(paddle_var_name) > 0 &&
cinn_variable_names_.count(paddle2cinn_varmap_.at(paddle_var_name)) >
0;
}
CinnTensor CinnLaunchContext::GetCinnTensor(const std::string& var_name) {
......@@ -76,8 +78,8 @@ std::unordered_set<std::string> CinnLaunchContext::GetInternalVariableNames() {
return all_parameters;
}
void CinnLaunchContext::CheckTensorEquivalent(const std::string& paddle_name,
const LoDTensor& paddle_tensor,
void CinnLaunchContext::CheckTensorEquivalent(
const std::string& paddle_var_name, const LoDTensor& paddle_tensor,
const CinnTensor& cinn_tensor) {
// check dimension
auto cinn_dims = framework::make_ddim(cinn_tensor->shape().data());
......@@ -85,22 +87,24 @@ void CinnLaunchContext::CheckTensorEquivalent(const std::string& paddle_name,
platform::errors::PreconditionNotMet(
"Tensors' shape in variable(%s) are not equivalent, "
"paddle's shape = [%s], but cinn's shape = [%s].",
paddle_name, paddle_tensor.dims(), cinn_dims));
paddle_var_name, paddle_tensor.dims(), cinn_dims));
// TODO(CtfGo): check the underlying data type after CINN ready
}
void CinnLaunchContext::AssignExternalVariable(const std::string& paddle_name) {
PADDLE_ENFORCE_EQ(IsVariableUsed(paddle_name), true,
platform::errors::InvalidArgument(
"Paddle variable(%s) not used by cinn", paddle_name));
void CinnLaunchContext::AssignExternalVariable(
const std::string& paddle_var_name) {
PADDLE_ENFORCE_EQ(
IsVariableUsed(paddle_var_name), true,
platform::errors::InvalidArgument("Paddle variable(%s) not used by cinn",
paddle_var_name));
const auto& cinn_name = paddle2cinn_varmap_.at(paddle_name);
const auto& cinn_var_name = paddle2cinn_varmap_.at(paddle_var_name);
const auto& paddle_tensor =
cached_scope_->GetVar(paddle_name)->Get<LoDTensor>();
CinnTensor cinn_tensor = GetCinnTensor(cinn_name);
cached_scope_->GetVar(paddle_var_name)->Get<LoDTensor>();
CinnTensor cinn_tensor = GetCinnTensor(cinn_var_name);
if (paddle_tensor.IsInitialized()) {
CheckTensorEquivalent(paddle_name, paddle_tensor, cinn_tensor);
CheckTensorEquivalent(paddle_var_name, paddle_tensor, cinn_tensor);
}
auto cinn_buffer = std::make_unique<cinn_buffer_t>();
......@@ -108,9 +112,9 @@ void CinnLaunchContext::AssignExternalVariable(const std::string& paddle_name) {
cinn_buffer->resize(cinn_tensor->shape().data().data(),
cinn_tensor->shape().data().size());
cinn_buffer->external_malloc = new std::function<int(void*, cinn_buffer_t*)>(
[this, paddle_name](void* ctx, cinn_buffer_t* buffer) {
[this, paddle_var_name](void* ctx, cinn_buffer_t* buffer) {
auto* tensor =
cached_scope_->GetVar(paddle_name)->GetMutable<LoDTensor>();
cached_scope_->GetVar(paddle_var_name)->GetMutable<LoDTensor>();
tensor->Resize(framework::DDim(buffer->dims, buffer->dimensions));
buffer->memory = reinterpret_cast<uint8_t*>(
tensor->mutable_data<float>(*cached_place_));
......@@ -124,23 +128,25 @@ void CinnLaunchContext::AssignExternalVariable(const std::string& paddle_name) {
return 0;
});
return SetArgument(cinn_name, std::move(cinn_buffer));
return SetArgument(cinn_var_name, std::move(cinn_buffer));
}
void CinnLaunchContext::AssignInternalVariable(const std::string& cinn_name) {
PADDLE_ENFORCE_GT(cinn_variable_names_.count(cinn_name), 0,
platform::errors::InvalidArgument(
"Variable(%s) not found in cinn socpe.", cinn_name));
CinnTensor cinn_tensor = GetCinnTensor(cinn_name);
void CinnLaunchContext::AssignInternalVariable(
const std::string& cinn_var_name) {
PADDLE_ENFORCE_GT(
cinn_variable_names_.count(cinn_var_name), 0,
platform::errors::InvalidArgument("Variable(%s) not found in cinn socpe.",
cinn_var_name));
CinnTensor cinn_tensor = GetCinnTensor(cinn_var_name);
auto cinn_buffer = std::make_unique<cinn_buffer_t>();
// assign dimensions and alloc/free callback of cinn_buffer_t
cinn_buffer->resize(cinn_tensor->shape().data().data(),
cinn_tensor->shape().data().size());
cinn_buffer->external_malloc = new std::function<int(void*, cinn_buffer_t*)>(
[this, cinn_name](void* ctx, cinn_buffer_t* buffer) {
[this, cinn_var_name](void* ctx, cinn_buffer_t* buffer) {
auto* tensor =
cached_temp_scope_->Var(cinn_name)->GetMutable<LoDTensor>();
cached_temp_scope_->Var(cinn_var_name)->GetMutable<LoDTensor>();
tensor->Resize(framework::DDim(buffer->dims, buffer->dimensions));
buffer->memory = reinterpret_cast<uint8_t*>(
tensor->mutable_data<float>(*cached_place_));
......@@ -150,22 +156,22 @@ void CinnLaunchContext::AssignInternalVariable(const std::string& cinn_name) {
// internal variables should release its buffer immediately
// if no instruction use it
cinn_buffer->external_free = new std::function<int(void*, cinn_buffer_t*)>(
[this, cinn_name](void* ctx, cinn_buffer_t* buffer) {
[this, cinn_var_name](void* ctx, cinn_buffer_t* buffer) {
auto* tensor =
cached_temp_scope_->GetVar(cinn_name)->GetMutable<LoDTensor>();
cached_temp_scope_->GetVar(cinn_var_name)->GetMutable<LoDTensor>();
tensor->clear();
return 0;
});
return SetArgument(cinn_name, std::move(cinn_buffer));
return SetArgument(cinn_var_name, std::move(cinn_buffer));
}
void CinnLaunchContext::SetArgument(const std::string& cinn_name,
void CinnLaunchContext::SetArgument(const std::string& cinn_var_name,
std::unique_ptr<cinn_buffer_t>&& buffer) {
VLOG(4) << "SetArgument-" << name2argument_.size() << ": name(" << cinn_name
<< "), dims(" << framework::DDim(buffer->dims, buffer->dimensions)
<< ").";
VLOG(4) << "SetArgument-" << name2argument_.size() << ": name("
<< cinn_var_name << "), dims("
<< framework::DDim(buffer->dims, buffer->dimensions) << ").";
name2argument_.emplace(cinn_name, buffer.get());
name2argument_.emplace(cinn_var_name, buffer.get());
hold_buffers_.emplace_back(std::move(buffer));
}
......
......@@ -49,13 +49,13 @@ class CinnLaunchContext {
bool IsArgumentsInitialized() const;
// Return whether a Paddle variable used on compiled kernels
bool IsVariableUsed(const std::string& paddle_name) const;
bool IsVariableUsed(const std::string& paddle_var_name) const;
// Assign tensor buffer to input or output variables
void AssignExternalVariable(const std::string& paddle_name);
void AssignExternalVariable(const std::string& paddle_var_name);
// Assign tensor buffer to internal variables
void AssignInternalVariable(const std::string& cinn_name);
void AssignInternalVariable(const std::string& cinn_var_name);
// Extract internal variable names from CinnScope
// by excluding used input and output variables
......@@ -75,7 +75,7 @@ class CinnLaunchContext {
const CinnTensor& cinn_tensor);
// Set an argument with (cinn name)->(cinn_buffer_t) pair
void SetArgument(const std::string& cinn_name,
void SetArgument(const std::string& cinn_var_name,
std::unique_ptr<cinn_buffer_t>&& buffer);
private:
......
......@@ -103,7 +103,7 @@ class CinnLaunchOpKernel : public framework::OpKernel<T> {
compilation_key, inputs_name2tensor, target, stream);
details::DebugCinnCompiledResult(cinn_compiled_object);
const auto& launch_context = cinn_compiled_object.launch_context;
auto* launch_context = cinn_compiled_object.launch_context.get();
// Step 3. Prepare arguments needed for the compiled executable program.
launch_context->UpdateCapturedEnv(scope, place);
if (!launch_context->IsArgumentsInitialized()) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册