未验证 提交 167d511f 编写于 作者: T TeFeng Chen 提交者: GitHub

cinn_launch_op: switch to execution by PE (#39911)

* swith to PE execution in cinn launch

* fix outer variables erased

* skip the map bug temporarily for test

* temporary solution for batch_norm bug

* update comment

* fix compile error

* cinn_instruction_run_op_test: update code to skip external alloc/free instructions generated
上级 d8b40223
...@@ -241,7 +241,6 @@ std::unique_ptr<CinnCompiledObject> CinnCompiler::CompileGraph( ...@@ -241,7 +241,6 @@ std::unique_ptr<CinnCompiledObject> CinnCompiler::CompileGraph(
std::make_unique<GraphCompiler>(target, scope, cinn_graph); std::make_unique<GraphCompiler>(target, scope, cinn_graph);
GraphCompiler::CompileOptions options; GraphCompiler::CompileOptions options;
options.with_instantiate_variables = false; options.with_instantiate_variables = false;
options.with_buffer_handle_instruction_inserted = true;
auto compiled_res = auto compiled_res =
graph_compiler->Build(options, std::move(fetch_ids), stream); graph_compiler->Build(options, std::move(fetch_ids), stream);
auto compiled_obj = std::make_unique<CinnCompiledObject>(); auto compiled_obj = std::make_unique<CinnCompiledObject>();
......
include(operators) include(operators)
cc_library(cinn_op_helper SRCS cinn_op_helper.cc DEPS operator device_context) cc_library(cinn_op_helper SRCS cinn_op_helper.cc DEPS operator device_context)
cc_library(cinn_launch_context SRCS cinn_launch_context.cc DEPS ddim lod_tensor scope proto_desc graph build_strategy parallel_executor cinn) cc_library(cinn_launch_context SRCS cinn_launch_context.cc DEPS ddim lod_tensor scope proto_desc graph build_strategy device_context parallel_executor cinn)
SET(CINN_OP_DEPS string_helper cinn cinn_compiler cinn_op_helper cinn_launch_context) SET(CINN_OP_DEPS parallel_executor string_helper cinn cinn_compiler cinn_op_helper cinn_launch_context)
register_operators(DEPS ${CINN_OP_DEPS}) register_operators(DEPS ${CINN_OP_DEPS})
if (WITH_TESTING) if (WITH_TESTING)
...@@ -11,7 +11,7 @@ if (WITH_TESTING) ...@@ -11,7 +11,7 @@ if (WITH_TESTING)
set_tests_properties(cinn_launch_context_test PROPERTIES LABELS "RUN_TYPE=CINN") set_tests_properties(cinn_launch_context_test PROPERTIES LABELS "RUN_TYPE=CINN")
SET(CINN_RUN_ENVIRONMENT "OMP_NUM_THREADS=1;runtime_include_dir=${PADDLE_BINARY_DIR}/third_party/CINN/src/external_cinn/cinn/runtime/cuda") SET(CINN_RUN_ENVIRONMENT "OMP_NUM_THREADS=1;runtime_include_dir=${PADDLE_BINARY_DIR}/third_party/CINN/src/external_cinn/cinn/runtime/cuda")
cc_test(cinn_launch_op_test SRCS cinn_launch_op_test.cc DEPS cinn_compiler cinn_launch_op elementwise_add_op) cc_test(cinn_launch_op_test SRCS cinn_launch_op_test.cc DEPS cinn_compiler cinn_launch_op cinn_instruction_run_op elementwise_add_op gflags)
set_tests_properties(cinn_launch_op_test PROPERTIES LABELS "RUN_TYPE=CINN" ENVIRONMENT "${CINN_RUN_ENVIRONMENT}") set_tests_properties(cinn_launch_op_test PROPERTIES LABELS "RUN_TYPE=CINN" ENVIRONMENT "${CINN_RUN_ENVIRONMENT}")
cc_test(cinn_instruction_run_op_test SRCS cinn_instruction_run_op_test.cc DEPS cinn_compiler cinn_launch_op cinn_instruction_run_op elementwise_add_op) cc_test(cinn_instruction_run_op_test SRCS cinn_instruction_run_op_test.cc DEPS cinn_compiler cinn_launch_op cinn_instruction_run_op elementwise_add_op)
......
...@@ -50,7 +50,7 @@ TEST(CinnInstructionOpTest, TestWithElementwiseAdd) { ...@@ -50,7 +50,7 @@ TEST(CinnInstructionOpTest, TestWithElementwiseAdd) {
auto cinn_instruction_run_op = paddle::framework::OpRegistry::CreateOp( auto cinn_instruction_run_op = paddle::framework::OpRegistry::CreateOp(
"cinn_instruction_run", {{"X", {"x", "y"}}}, "cinn_instruction_run", {{"X", {"x", "y"}}},
{{"Out", {test_op_out_name}}}, {{"Out", {test_op_out_name}}},
{{"cached_index", 0}, {"instruction_index", 1}}); {{"cached_index", 0}, {"instruction_index", 0}});
auto elementwise_add_op = paddle::framework::OpRegistry::CreateOp( auto elementwise_add_op = paddle::framework::OpRegistry::CreateOp(
"elementwise_add", {{"X", {"x"}}, {"Y", {"y"}}}, "elementwise_add", {{"X", {"x"}}, {"Y", {"y"}}},
{{"Out", {add_op_out_name}}}, {{}}); {{"Out", {add_op_out_name}}}, {{}});
......
...@@ -31,6 +31,7 @@ ...@@ -31,6 +31,7 @@
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/operators/cinn/cinn_op_helper.h" #include "paddle/fluid/operators/cinn/cinn_op_helper.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/fluid/string/printf.h" #include "paddle/fluid/string/printf.h"
#include "paddle/phi/core/ddim.h" #include "paddle/phi/core/ddim.h"
...@@ -90,9 +91,30 @@ CinnLaunchContext::CinnLaunchContext(const framework::ir::Graph& graph, ...@@ -90,9 +91,30 @@ CinnLaunchContext::CinnLaunchContext(const framework::ir::Graph& graph,
// Convert the CINN runtime program to a Paddle graph // Convert the CINN runtime program to a Paddle graph
runtime_graph_ = std::make_unique<framework::ir::Graph>( runtime_graph_ = std::make_unique<framework::ir::Graph>(
BuildCompiledProgram(graph, compiled_obj)); BuildCompiledProgram(graph, compiled_obj));
runtime_graph_->SetNotOwned<Name2VarInfoMap>( auto& outer_varinfo = graph.Get<Name2VarInfoMap>(kMemOptVarInfoFromMainGraph);
kMemOptVarInfoFromMainGraph, runtime_graph_->SetNotOwned<Name2VarInfoMap>(kMemOptVarInfoFromMainGraph,
&graph.Get<Name2VarInfoMap>(kMemOptVarInfoFromMainGraph)); &outer_varinfo);
// collect skip_eager_vars
skip_eager_vars_.reserve(input_var_names.size() + output_var_names.size());
auto add_skip_var_fn = [&outer_varinfo, this](const std::string& var_name) {
// if a var exists at outer_varinfo map,
// that means it can be erased after graph execution
if (!outer_varinfo.count(var_name)) {
skip_eager_vars_.emplace_back(var_name);
}
};
std::for_each(input_var_names.begin(), input_var_names.end(),
add_skip_var_fn);
std::for_each(output_var_names.begin(), output_var_names.end(),
add_skip_var_fn);
VLOG(4) << string::Sprintf(
"Distribution of variables in the graph compiled:"
"input[%lu],internal[%lu],output[%lu],"
"outer_eager_deletion[%lu],skip_eager_deletion[%lu],"
"initialized_beforehand[%lu]",
input_var_names.size(), internal_var_names_.size(),
output_var_names.size(), outer_varinfo.size(), skip_eager_vars_.size(),
initialized_beforehand_vars_.size());
} }
void CinnLaunchContext::BuildVarNameMap( void CinnLaunchContext::BuildVarNameMap(
...@@ -288,6 +310,7 @@ framework::ProgramDesc CinnLaunchContext::BuildCompiledProgram( ...@@ -288,6 +310,7 @@ framework::ProgramDesc CinnLaunchContext::BuildCompiledProgram(
// are set by values of the corresponding compiled tensors, // are set by values of the corresponding compiled tensors,
// including the in/out variables where the equiality between their tensors // including the in/out variables where the equiality between their tensors
// and the CINN compiled ones is verified in corresponding cinn_launch_op. // and the CINN compiled ones is verified in corresponding cinn_launch_op.
std::unordered_set<std::string> has_refer_vars;
for (auto&& arg : cinn_argument_names_) { for (auto&& arg : cinn_argument_names_) {
const std::string& var_name = cinn2paddle_varmap_.at(arg); const std::string& var_name = cinn2paddle_varmap_.at(arg);
framework::VarDesc* var_desc = block->Var(var_name); framework::VarDesc* var_desc = block->Var(var_name);
...@@ -298,6 +321,7 @@ framework::ProgramDesc CinnLaunchContext::BuildCompiledProgram( ...@@ -298,6 +321,7 @@ framework::ProgramDesc CinnLaunchContext::BuildCompiledProgram(
auto* ori_desc = res->second; auto* ori_desc = res->second;
var_desc->SetPersistable(ori_desc->Persistable()); var_desc->SetPersistable(ori_desc->Persistable());
var_desc->SetIsParameter(ori_desc->IsParameter()); var_desc->SetIsParameter(ori_desc->IsParameter());
has_refer_vars.insert(var_name);
} }
auto cinn_tensor = GetCinnTensorOfVar(var_name); auto cinn_tensor = GetCinnTensorOfVar(var_name);
...@@ -331,6 +355,12 @@ framework::ProgramDesc CinnLaunchContext::BuildCompiledProgram( ...@@ -331,6 +355,12 @@ framework::ProgramDesc CinnLaunchContext::BuildCompiledProgram(
auto* ins = instructions.at(ins_idx).get(); auto* ins = instructions.at(ins_idx).get();
auto in_args = trans_and_pack_args_fn(ins->GetInArgs()); auto in_args = trans_and_pack_args_fn(ins->GetInArgs());
auto out_args = trans_and_pack_args_fn(ins->GetOutArgs()); auto out_args = trans_and_pack_args_fn(ins->GetOutArgs());
for (auto&& var_name : in_args) {
if (!has_refer_vars.count(var_name)) {
initialized_beforehand_vars_.emplace_back(var_name);
}
}
has_refer_vars.insert(out_args.begin(), out_args.end());
auto* op_desc = block->AppendOp(); auto* op_desc = block->AppendOp();
op_desc->SetType("cinn_instruction_run"); op_desc->SetType("cinn_instruction_run");
...@@ -348,16 +378,26 @@ ParallelExecutor* CinnLaunchContext::InitializePE(const platform::Place& place, ...@@ -348,16 +378,26 @@ ParallelExecutor* CinnLaunchContext::InitializePE(const platform::Place& place,
framework::Scope* scope) { framework::Scope* scope) {
if (!parallel_executor_) { if (!parallel_executor_) {
framework::details::ExecutionStrategy exec_strategy; framework::details::ExecutionStrategy exec_strategy;
exec_strategy.num_threads_ = 1;
exec_strategy.use_device_ = platform::Place2DeviceType(place);
framework::details::BuildStrategy build_strategy; framework::details::BuildStrategy build_strategy;
parallel_executor_ = std::make_unique<ParallelExecutor>( parallel_executor_ = std::make_unique<ParallelExecutor>(
place, scope, exec_strategy, build_strategy, runtime_graph_.get()); place, scope, exec_strategy, build_strategy, runtime_graph_.get());
} }
// update the scope bound to an OpHandle and rebuild temporary variables // update the scope bound to an OpHandle and rebuild temporary variables
VLOG(4) << "Reset scope and initialize temporary variables";
std::unordered_map<Scope*, Scope*> scope_map = { std::unordered_map<Scope*, Scope*> scope_map = {
{parallel_executor_->GetLocalScopes().front(), scope}}; {parallel_executor_->GetLocalScopes().front(), scope}};
parallel_executor_->ResetOpHandleScopeMapOfGraphs(scope_map); parallel_executor_->ResetOpHandleScopeMapOfGraphs(scope_map);
parallel_executor_->PrepareVariables(scope); parallel_executor_->PrepareVariables(scope);
for (auto&& var_name : initialized_beforehand_vars_) {
auto* var = scope->GetVar(var_name);
auto* buffer = GetCinnBufferOfVar(var_name);
auto dim = framework::DDim(buffer->dims, buffer->dimensions);
var->GetMutable<LoDTensor>()->Resize(dim);
var->GetMutable<LoDTensor>()->mutable_data<float>(place);
}
return parallel_executor_.get(); return parallel_executor_.get();
} }
......
...@@ -86,6 +86,11 @@ class CinnLaunchContext { ...@@ -86,6 +86,11 @@ class CinnLaunchContext {
void CheckTensorEquivalent(const std::string& var_name, void CheckTensorEquivalent(const std::string& var_name,
const framework::LoDTensor& paddle_tensor); const framework::LoDTensor& paddle_tensor);
// Return the name list of variables skipped eager deletion
const std::vector<std::string>& GetSkipEagerVars() const {
return skip_eager_vars_;
}
// Return internal variable names list // Return internal variable names list
const std::unordered_set<std::string>& GetInternalVarNames() const { const std::unordered_set<std::string>& GetInternalVarNames() const {
return internal_var_names_; return internal_var_names_;
...@@ -143,6 +148,9 @@ class CinnLaunchContext { ...@@ -143,6 +148,9 @@ class CinnLaunchContext {
std::unordered_set<std::string> internal_var_names_; std::unordered_set<std::string> internal_var_names_;
// the names of the cinn arguments used in compiled executable program // the names of the cinn arguments used in compiled executable program
std::unordered_set<std::string> cinn_argument_names_; std::unordered_set<std::string> cinn_argument_names_;
// TODO(CtfGo): remove this list after fixing batch_norm bug
// due to duplicate association in the same variable.
std::vector<std::string> initialized_beforehand_vars_;
// the variable scope compiled from cinn // the variable scope compiled from cinn
const std::shared_ptr<CinnScope> cinn_scope_; const std::shared_ptr<CinnScope> cinn_scope_;
...@@ -150,6 +158,8 @@ class CinnLaunchContext { ...@@ -150,6 +158,8 @@ class CinnLaunchContext {
std::unique_ptr<framework::ir::Graph> runtime_graph_; std::unique_ptr<framework::ir::Graph> runtime_graph_;
// a ParallelExecutor to execute the runtime graph // a ParallelExecutor to execute the runtime graph
std::unique_ptr<framework::ParallelExecutor> parallel_executor_; std::unique_ptr<framework::ParallelExecutor> parallel_executor_;
// the name list of skip_eager_vars in runtime
std::vector<std::string> skip_eager_vars_;
// because a cinn_pod_value_t does not own a cinn_buffer_t object, // because a cinn_pod_value_t does not own a cinn_buffer_t object,
// an extra stroage is necessary to keep those objects and they can // an extra stroage is necessary to keep those objects and they can
......
...@@ -103,8 +103,8 @@ class CinnLaunchOpKernel : public framework::OpKernel<T> { ...@@ -103,8 +103,8 @@ class CinnLaunchOpKernel : public framework::OpKernel<T> {
details::DebugCinnCompiledResult(cinn_compiled_object); details::DebugCinnCompiledResult(cinn_compiled_object);
auto* launch_context = cinn_compiled_object.launch_context.get(); auto* launch_context = cinn_compiled_object.launch_context.get();
// Step 3. Prepare arguments needed for the compiled executable program. // Step 3. check the computational consistency of the subgraph
launch_context->UpdateCapturedEnv(scope, place); // before and after the compilation
// 3.1 Input variables: tensors of input variables have // 3.1 Input variables: tensors of input variables have
// been initialized before graph compiled, just check the // been initialized before graph compiled, just check the
// equiality between tensors of paddle and cinn. // equiality between tensors of paddle and cinn.
...@@ -120,20 +120,15 @@ class CinnLaunchOpKernel : public framework::OpKernel<T> { ...@@ -120,20 +120,15 @@ class CinnLaunchOpKernel : public framework::OpKernel<T> {
*inputs_name2tensor.at(var_name)); *inputs_name2tensor.at(var_name));
} }
// 3.2 Output variables: the output variables will be initialized
// and allocated buffer in callbacks which are defined in the
// external_malloc/free interface of cinn_buffer_t
// in their corresponding arguments.
// 3.3 Internal variables: A temporary scope is created in
// UpdateCapturedEnv to keep the internal variables and
// they are also initialized through callbacks
// Step 4. Set CINN runtime FLAGS, such as FLAGS_cinn_cudnn_deterministic. // Step 4. Set CINN runtime FLAGS, such as FLAGS_cinn_cudnn_deterministic.
details::SetCinnRuntimeFlags(); details::SetCinnRuntimeFlags();
// Step 5. Launch CINN to execute the compiled executable program // Step 5. use PE to execute the compiled CINN instructions
VLOG(4) << "Run Cinn compiled executable program with stream: " << stream; // in nodes of the runtime graph
details::LaunchCinnExecution(cinn_compiled_object, *launch_context, stream); VLOG(4) << "Execute the runtime graph by PE";
framework::Scope& exec_scope = scope.NewScope();
auto* pe = launch_context->InitializePE(place, &exec_scope);
pe->RunWithoutFetch(launch_context->GetSkipEagerVars());
VLOG(4) << "CinnLaunchOp launch execution done."; VLOG(4) << "CinnLaunchOp launch execution done.";
} }
}; };
......
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include <mutex> #include <mutex>
#include <random> #include <random>
#include <string> #include <string>
#include "gflags/gflags.h"
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/paddle2cinn/cinn_compiler.h" #include "paddle/fluid/framework/paddle2cinn/cinn_compiler.h"
...@@ -27,7 +28,9 @@ limitations under the License. */ ...@@ -27,7 +28,9 @@ limitations under the License. */
#include "paddle/phi/core/ddim.h" #include "paddle/phi/core/ddim.h"
USE_OP(cinn_launch); USE_OP(cinn_launch);
USE_OP(cinn_instruction_run);
USE_OP_ITSELF(elementwise_add); USE_OP_ITSELF(elementwise_add);
DECLARE_double(eager_delete_tensor_gb);
namespace paddle::operators { namespace paddle::operators {
...@@ -61,6 +64,7 @@ TEST(CinnLaunchOpTest, TestWithElementwiseAdd) { ...@@ -61,6 +64,7 @@ TEST(CinnLaunchOpTest, TestWithElementwiseAdd) {
CompareOpResult<float>(scope.GetVar(test_op_out_name), CompareOpResult<float>(scope.GetVar(test_op_out_name),
scope.GetVar(add_op_out_name)); scope.GetVar(add_op_out_name));
}; };
FLAGS_eager_delete_tensor_gb = -1;
// CPU // CPU
run_and_check_fn(platform::CPUPlace()); run_and_check_fn(platform::CPUPlace());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册