未验证 提交 a9bd6f0c 编写于 作者: C CtfGo 提交者: GitHub

change serval variable name and usage related cinn_launch (#38022)

上级 43f19cc3
...@@ -35,8 +35,8 @@ namespace paddle { ...@@ -35,8 +35,8 @@ namespace paddle {
namespace operators { namespace operators {
namespace details { namespace details {
class CinnLaunchContext; class CinnLaunchContext;
} } // namespace details
} } // namespace operators
namespace framework { namespace framework {
namespace paddle2cinn { namespace paddle2cinn {
......
...@@ -55,9 +55,11 @@ bool CinnLaunchContext::IsArgumentsInitialized() const { ...@@ -55,9 +55,11 @@ bool CinnLaunchContext::IsArgumentsInitialized() const {
return true; return true;
} }
bool CinnLaunchContext::IsVariableUsed(const std::string& paddle_name) const { bool CinnLaunchContext::IsVariableUsed(
return paddle2cinn_varmap_.count(paddle_name) > 0 && const std::string& paddle_var_name) const {
cinn_variable_names_.count(paddle2cinn_varmap_.at(paddle_name)) > 0; return paddle2cinn_varmap_.count(paddle_var_name) > 0 &&
cinn_variable_names_.count(paddle2cinn_varmap_.at(paddle_var_name)) >
0;
} }
CinnTensor CinnLaunchContext::GetCinnTensor(const std::string& var_name) { CinnTensor CinnLaunchContext::GetCinnTensor(const std::string& var_name) {
...@@ -76,31 +78,33 @@ std::unordered_set<std::string> CinnLaunchContext::GetInternalVariableNames() { ...@@ -76,31 +78,33 @@ std::unordered_set<std::string> CinnLaunchContext::GetInternalVariableNames() {
return all_parameters; return all_parameters;
} }
void CinnLaunchContext::CheckTensorEquivalent(const std::string& paddle_name, void CinnLaunchContext::CheckTensorEquivalent(
const LoDTensor& paddle_tensor, const std::string& paddle_var_name, const LoDTensor& paddle_tensor,
const CinnTensor& cinn_tensor) { const CinnTensor& cinn_tensor) {
// check dimension // check dimension
auto cinn_dims = framework::make_ddim(cinn_tensor->shape().data()); auto cinn_dims = framework::make_ddim(cinn_tensor->shape().data());
PADDLE_ENFORCE_EQ(paddle_tensor.dims(), cinn_dims, PADDLE_ENFORCE_EQ(paddle_tensor.dims(), cinn_dims,
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"Tensors' shape in variable(%s) are not equivalent, " "Tensors' shape in variable(%s) are not equivalent, "
"paddle's shape = [%s], but cinn's shape = [%s].", "paddle's shape = [%s], but cinn's shape = [%s].",
paddle_name, paddle_tensor.dims(), cinn_dims)); paddle_var_name, paddle_tensor.dims(), cinn_dims));
// TODO(CtfGo): check the underlying data type after CINN ready // TODO(CtfGo): check the underlying data type after CINN ready
} }
void CinnLaunchContext::AssignExternalVariable(const std::string& paddle_name) { void CinnLaunchContext::AssignExternalVariable(
PADDLE_ENFORCE_EQ(IsVariableUsed(paddle_name), true, const std::string& paddle_var_name) {
platform::errors::InvalidArgument( PADDLE_ENFORCE_EQ(
"Paddle variable(%s) not used by cinn", paddle_name)); IsVariableUsed(paddle_var_name), true,
platform::errors::InvalidArgument("Paddle variable(%s) not used by cinn",
paddle_var_name));
const auto& cinn_name = paddle2cinn_varmap_.at(paddle_name); const auto& cinn_var_name = paddle2cinn_varmap_.at(paddle_var_name);
const auto& paddle_tensor = const auto& paddle_tensor =
cached_scope_->GetVar(paddle_name)->Get<LoDTensor>(); cached_scope_->GetVar(paddle_var_name)->Get<LoDTensor>();
CinnTensor cinn_tensor = GetCinnTensor(cinn_name); CinnTensor cinn_tensor = GetCinnTensor(cinn_var_name);
if (paddle_tensor.IsInitialized()) { if (paddle_tensor.IsInitialized()) {
CheckTensorEquivalent(paddle_name, paddle_tensor, cinn_tensor); CheckTensorEquivalent(paddle_var_name, paddle_tensor, cinn_tensor);
} }
auto cinn_buffer = std::make_unique<cinn_buffer_t>(); auto cinn_buffer = std::make_unique<cinn_buffer_t>();
...@@ -108,9 +112,9 @@ void CinnLaunchContext::AssignExternalVariable(const std::string& paddle_name) { ...@@ -108,9 +112,9 @@ void CinnLaunchContext::AssignExternalVariable(const std::string& paddle_name) {
cinn_buffer->resize(cinn_tensor->shape().data().data(), cinn_buffer->resize(cinn_tensor->shape().data().data(),
cinn_tensor->shape().data().size()); cinn_tensor->shape().data().size());
cinn_buffer->external_malloc = new std::function<int(void*, cinn_buffer_t*)>( cinn_buffer->external_malloc = new std::function<int(void*, cinn_buffer_t*)>(
[this, paddle_name](void* ctx, cinn_buffer_t* buffer) { [this, paddle_var_name](void* ctx, cinn_buffer_t* buffer) {
auto* tensor = auto* tensor =
cached_scope_->GetVar(paddle_name)->GetMutable<LoDTensor>(); cached_scope_->GetVar(paddle_var_name)->GetMutable<LoDTensor>();
tensor->Resize(framework::DDim(buffer->dims, buffer->dimensions)); tensor->Resize(framework::DDim(buffer->dims, buffer->dimensions));
buffer->memory = reinterpret_cast<uint8_t*>( buffer->memory = reinterpret_cast<uint8_t*>(
tensor->mutable_data<float>(*cached_place_)); tensor->mutable_data<float>(*cached_place_));
...@@ -124,23 +128,25 @@ void CinnLaunchContext::AssignExternalVariable(const std::string& paddle_name) { ...@@ -124,23 +128,25 @@ void CinnLaunchContext::AssignExternalVariable(const std::string& paddle_name) {
return 0; return 0;
}); });
return SetArgument(cinn_name, std::move(cinn_buffer)); return SetArgument(cinn_var_name, std::move(cinn_buffer));
} }
void CinnLaunchContext::AssignInternalVariable(const std::string& cinn_name) { void CinnLaunchContext::AssignInternalVariable(
PADDLE_ENFORCE_GT(cinn_variable_names_.count(cinn_name), 0, const std::string& cinn_var_name) {
platform::errors::InvalidArgument( PADDLE_ENFORCE_GT(
"Variable(%s) not found in cinn socpe.", cinn_name)); cinn_variable_names_.count(cinn_var_name), 0,
CinnTensor cinn_tensor = GetCinnTensor(cinn_name); platform::errors::InvalidArgument("Variable(%s) not found in cinn socpe.",
cinn_var_name));
CinnTensor cinn_tensor = GetCinnTensor(cinn_var_name);
auto cinn_buffer = std::make_unique<cinn_buffer_t>(); auto cinn_buffer = std::make_unique<cinn_buffer_t>();
// assign dimensions and alloc/free callback of cinn_buffer_t // assign dimensions and alloc/free callback of cinn_buffer_t
cinn_buffer->resize(cinn_tensor->shape().data().data(), cinn_buffer->resize(cinn_tensor->shape().data().data(),
cinn_tensor->shape().data().size()); cinn_tensor->shape().data().size());
cinn_buffer->external_malloc = new std::function<int(void*, cinn_buffer_t*)>( cinn_buffer->external_malloc = new std::function<int(void*, cinn_buffer_t*)>(
[this, cinn_name](void* ctx, cinn_buffer_t* buffer) { [this, cinn_var_name](void* ctx, cinn_buffer_t* buffer) {
auto* tensor = auto* tensor =
cached_temp_scope_->Var(cinn_name)->GetMutable<LoDTensor>(); cached_temp_scope_->Var(cinn_var_name)->GetMutable<LoDTensor>();
tensor->Resize(framework::DDim(buffer->dims, buffer->dimensions)); tensor->Resize(framework::DDim(buffer->dims, buffer->dimensions));
buffer->memory = reinterpret_cast<uint8_t*>( buffer->memory = reinterpret_cast<uint8_t*>(
tensor->mutable_data<float>(*cached_place_)); tensor->mutable_data<float>(*cached_place_));
...@@ -150,22 +156,22 @@ void CinnLaunchContext::AssignInternalVariable(const std::string& cinn_name) { ...@@ -150,22 +156,22 @@ void CinnLaunchContext::AssignInternalVariable(const std::string& cinn_name) {
// internal variables should release its buffer immediately // internal variables should release its buffer immediately
// if no instruction use it // if no instruction use it
cinn_buffer->external_free = new std::function<int(void*, cinn_buffer_t*)>( cinn_buffer->external_free = new std::function<int(void*, cinn_buffer_t*)>(
[this, cinn_name](void* ctx, cinn_buffer_t* buffer) { [this, cinn_var_name](void* ctx, cinn_buffer_t* buffer) {
auto* tensor = auto* tensor =
cached_temp_scope_->GetVar(cinn_name)->GetMutable<LoDTensor>(); cached_temp_scope_->GetVar(cinn_var_name)->GetMutable<LoDTensor>();
tensor->clear(); tensor->clear();
return 0; return 0;
}); });
return SetArgument(cinn_name, std::move(cinn_buffer)); return SetArgument(cinn_var_name, std::move(cinn_buffer));
} }
void CinnLaunchContext::SetArgument(const std::string& cinn_name, void CinnLaunchContext::SetArgument(const std::string& cinn_var_name,
std::unique_ptr<cinn_buffer_t>&& buffer) { std::unique_ptr<cinn_buffer_t>&& buffer) {
VLOG(4) << "SetArgument-" << name2argument_.size() << ": name(" << cinn_name VLOG(4) << "SetArgument-" << name2argument_.size() << ": name("
<< "), dims(" << framework::DDim(buffer->dims, buffer->dimensions) << cinn_var_name << "), dims("
<< ")."; << framework::DDim(buffer->dims, buffer->dimensions) << ").";
name2argument_.emplace(cinn_name, buffer.get()); name2argument_.emplace(cinn_var_name, buffer.get());
hold_buffers_.emplace_back(std::move(buffer)); hold_buffers_.emplace_back(std::move(buffer));
} }
......
...@@ -49,13 +49,13 @@ class CinnLaunchContext { ...@@ -49,13 +49,13 @@ class CinnLaunchContext {
bool IsArgumentsInitialized() const; bool IsArgumentsInitialized() const;
// Return whether a Paddle variable used on compiled kernels // Return whether a Paddle variable used on compiled kernels
bool IsVariableUsed(const std::string& paddle_name) const; bool IsVariableUsed(const std::string& paddle_var_name) const;
// Assign tensor buffer to input or output variables // Assign tensor buffer to input or output variables
void AssignExternalVariable(const std::string& paddle_name); void AssignExternalVariable(const std::string& paddle_var_name);
// Assign tensor buffer to internal variables // Assign tensor buffer to internal variables
void AssignInternalVariable(const std::string& cinn_name); void AssignInternalVariable(const std::string& cinn_var_name);
// Extract internal variable names from CinnScope // Extract internal variable names from CinnScope
// by excluding used input and output variables // by excluding used input and output variables
...@@ -75,7 +75,7 @@ class CinnLaunchContext { ...@@ -75,7 +75,7 @@ class CinnLaunchContext {
const CinnTensor& cinn_tensor); const CinnTensor& cinn_tensor);
// Set an argument with (cinn name)->(cinn_buffer_t) pair // Set an argument with (cinn name)->(cinn_buffer_t) pair
void SetArgument(const std::string& cinn_name, void SetArgument(const std::string& cinn_var_name,
std::unique_ptr<cinn_buffer_t>&& buffer); std::unique_ptr<cinn_buffer_t>&& buffer);
private: private:
......
...@@ -103,7 +103,7 @@ class CinnLaunchOpKernel : public framework::OpKernel<T> { ...@@ -103,7 +103,7 @@ class CinnLaunchOpKernel : public framework::OpKernel<T> {
compilation_key, inputs_name2tensor, target, stream); compilation_key, inputs_name2tensor, target, stream);
details::DebugCinnCompiledResult(cinn_compiled_object); details::DebugCinnCompiledResult(cinn_compiled_object);
const auto& launch_context = cinn_compiled_object.launch_context; auto* launch_context = cinn_compiled_object.launch_context.get();
// Step 3. Prepare arguments needed for the compiled executable program. // Step 3. Prepare arguments needed for the compiled executable program.
launch_context->UpdateCapturedEnv(scope, place); launch_context->UpdateCapturedEnv(scope, place);
if (!launch_context->IsArgumentsInitialized()) { if (!launch_context->IsArgumentsInitialized()) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册