未验证 提交 c249556d 编写于 作者: Z Zhen Wang 提交者: GitHub

Pass the stream created by Paddle to CINN. (#37337)

上级 a4ef88ed
...@@ -66,7 +66,7 @@ CinnCompiler* CinnCompiler::GetInstance() { ...@@ -66,7 +66,7 @@ CinnCompiler* CinnCompiler::GetInstance() {
const CinnCompiledObject& CinnCompiler::Compile( const CinnCompiledObject& CinnCompiler::Compile(
const Graph& graph, const Graph& graph,
const std::map<std::string, const LoDTensor*>& input_tensors, const std::map<std::string, const LoDTensor*>& input_tensors,
const Target& target) { const Target& target, void* stream) {
VLOG(1) << "-- The graph to be compiled is:\n" << VizGraph(graph); VLOG(1) << "-- The graph to be compiled is:\n" << VizGraph(graph);
CinnCacheKey cur_key(graph, input_tensors, target.arch_str()); CinnCacheKey cur_key(graph, input_tensors, target.arch_str());
bool exist = false; bool exist = false;
...@@ -77,7 +77,7 @@ const CinnCompiledObject& CinnCompiler::Compile( ...@@ -77,7 +77,7 @@ const CinnCompiledObject& CinnCompiler::Compile(
if (!exist) { if (!exist) {
std::int64_t compiled_num = real_compiled_num_.fetch_add(1); std::int64_t compiled_num = real_compiled_num_.fetch_add(1);
auto compiled_res = auto compiled_res =
CompileGraph(graph, input_tensors, target, compiled_num); CompileGraph(graph, input_tensors, target, compiled_num, stream);
AutoWRLock w_guard{&rwlock_}; AutoWRLock w_guard{&rwlock_};
if (!cache_.count(cur_key)) { if (!cache_.count(cur_key)) {
cache_[cur_key] = std::move(compiled_res); cache_[cur_key] = std::move(compiled_res);
...@@ -91,9 +91,9 @@ const CinnCompiledObject& CinnCompiler::Compile( ...@@ -91,9 +91,9 @@ const CinnCompiledObject& CinnCompiler::Compile(
const CinnCompiledObject& CinnCompiler::Compile( const CinnCompiledObject& CinnCompiler::Compile(
const std::string& compilation_key, const std::string& compilation_key,
const std::map<std::string, const LoDTensor*>& input_tensors, const std::map<std::string, const LoDTensor*>& input_tensors,
const Target& target) { const Target& target, void* stream) {
const auto& graph = FindGraph(compilation_key); const auto& graph = FindGraph(compilation_key);
return Compile(graph, input_tensors, target); return Compile(graph, input_tensors, target, stream);
} }
std::string CinnCompiler::AddGraph(std::unique_ptr<Graph> graph) { std::string CinnCompiler::AddGraph(std::unique_ptr<Graph> graph) {
...@@ -189,7 +189,7 @@ void CinnCompiler::Clear() { ...@@ -189,7 +189,7 @@ void CinnCompiler::Clear() {
std::unique_ptr<CinnCompiledObject> CinnCompiler::CompileGraph( std::unique_ptr<CinnCompiledObject> CinnCompiler::CompileGraph(
const ir::Graph& graph, const ir::Graph& graph,
const std::map<std::string, const LoDTensor*>& input_tensors, const std::map<std::string, const LoDTensor*>& input_tensors,
const Target& target, std::int64_t compiled_num) const { const Target& target, std::int64_t compiled_num, void* stream) const {
CinnGraphSymbolization symbol{compiled_num, graph, target, input_tensors}; CinnGraphSymbolization symbol{compiled_num, graph, target, input_tensors};
auto frontend_program = symbol(); auto frontend_program = symbol();
ProgramPass::Apply(&frontend_program, target, {"Decomposer"}); ProgramPass::Apply(&frontend_program, target, {"Decomposer"});
...@@ -209,7 +209,8 @@ std::unique_ptr<CinnCompiledObject> CinnCompiler::CompileGraph( ...@@ -209,7 +209,8 @@ std::unique_ptr<CinnCompiledObject> CinnCompiler::CompileGraph(
std::make_unique<GraphCompiler>(target, scope, cinn_graph); std::make_unique<GraphCompiler>(target, scope, cinn_graph);
GraphCompiler::CompileOptions options; GraphCompiler::CompileOptions options;
options.with_instantiate_variables = false; options.with_instantiate_variables = false;
auto compiled_res = graph_compiler->Build(options, std::move(fetch_ids)); auto compiled_res =
graph_compiler->Build(options, std::move(fetch_ids), stream);
auto compiled_obj = std::make_unique<CinnCompiledObject>(); auto compiled_obj = std::make_unique<CinnCompiledObject>();
*compiled_obj = {std::move(graph_compiler), *compiled_obj = {std::move(graph_compiler),
std::move(compiled_res.runtime_program), scope, std::move(compiled_res.runtime_program), scope,
......
...@@ -55,12 +55,12 @@ class CinnCompiler { ...@@ -55,12 +55,12 @@ class CinnCompiler {
const CinnCompiledObject& Compile( const CinnCompiledObject& Compile(
const ir::Graph& graph, const ir::Graph& graph,
const std::map<std::string, const LoDTensor*>& input_tensors, const std::map<std::string, const LoDTensor*>& input_tensors,
const ::cinn::common::Target& target); const ::cinn::common::Target& target, void* stream = nullptr);
const CinnCompiledObject& Compile( const CinnCompiledObject& Compile(
const std::string& compilation_key, const std::string& compilation_key,
const std::map<std::string, const LoDTensor*>& input_tensors, const std::map<std::string, const LoDTensor*>& input_tensors,
const ::cinn::common::Target& target); const ::cinn::common::Target& target, void* stream = nullptr);
std::string AddGraph(std::unique_ptr<ir::Graph> graph); std::string AddGraph(std::unique_ptr<ir::Graph> graph);
...@@ -83,7 +83,8 @@ class CinnCompiler { ...@@ -83,7 +83,8 @@ class CinnCompiler {
std::unique_ptr<CinnCompiledObject> CompileGraph( std::unique_ptr<CinnCompiledObject> CompileGraph(
const ir::Graph& graph, const ir::Graph& graph,
const std::map<std::string, const LoDTensor*>& input_tensors, const std::map<std::string, const LoDTensor*>& input_tensors,
const ::cinn::common::Target& target, std::int64_t compiled_num) const; const ::cinn::common::Target& target, std::int64_t compiled_num,
void* stream = nullptr) const;
std::unordered_map<std::string, std::unique_ptr<ir::Graph>> graphs_; std::unordered_map<std::string, std::unique_ptr<ir::Graph>> graphs_;
std::unordered_map<CinnCacheKey, std::unique_ptr<CinnCompiledObject>, std::unordered_map<CinnCacheKey, std::unique_ptr<CinnCompiledObject>,
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/cinn_launch_op.h" #include "paddle/fluid/operators/cinn_launch_op.h"
#include <vector>
#include "paddle/fluid/string/string_helper.h" #include "paddle/fluid/string/string_helper.h"
DECLARE_bool(cudnn_deterministic); DECLARE_bool(cudnn_deterministic);
...@@ -65,8 +66,8 @@ void DebugCinnCompiledResult(const CinnCompiledObject& result) { ...@@ -65,8 +66,8 @@ void DebugCinnCompiledResult(const CinnCompiledObject& result) {
} }
void LaunchCinnExecution(const CinnCompiledObject& compiled_obj, void LaunchCinnExecution(const CinnCompiledObject& compiled_obj,
const CinnLaunchContext& context) { const CinnLaunchContext& context, void* stream) {
compiled_obj.runtime_program->Execute(&context.FinalizeArguments()); compiled_obj.runtime_program->Execute(&context.FinalizeArguments(), stream);
} }
void SetCinnRuntimeFlags() { void SetCinnRuntimeFlags() {
......
...@@ -13,6 +13,56 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,56 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/cinn_launch_op.h" #include "paddle/fluid/operators/cinn_launch_op.h"
#include <memory>
#include <vector>
#include "cinn/runtime/cinn_runtime.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/type_defs.h"
#ifdef PADDLE_WITH_CUDA
#include <cuda_runtime.h>
#endif
namespace paddle {
namespace operators {
namespace details {
#ifdef PADDLE_WITH_CUDA
void CUDART_CB ReleaseScope(void* data) {
auto* temp_scope = static_cast<framework::Scope*>(data);
delete temp_scope;
}
void CUDART_CB ReleaseBuffers(void* data) {
auto* buffers =
static_cast<std::vector<std::unique_ptr<cinn_buffer_t>>*>(data);
delete buffers;
}
template <>
void ReleaseResource<platform::CUDADeviceContext>(
const std::vector<void*>& resources, void* stream) {
PADDLE_ENFORCE_CUDA_SUCCESS(cudaLaunchHostFunc(
static_cast<gpuStream_t>(stream), ReleaseScope, resources[0]));
PADDLE_ENFORCE_CUDA_SUCCESS(cudaLaunchHostFunc(
static_cast<gpuStream_t>(stream), ReleaseBuffers, resources[1]));
}
template <>
void* GetStream<platform::CUDADeviceContext>(
const framework::ExecutionContext& ctx) {
const auto& dev_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
return dev_ctx.stream();
}
#endif
} // namespace details
} // namespace operators
} // namespace paddle
/* see [Why use single type kernel] */ /* see [Why use single type kernel] */
REGISTER_OP_CUDA_KERNEL(cinn_launch, REGISTER_OP_CUDA_KERNEL(cinn_launch,
......
...@@ -67,6 +67,10 @@ class CinnLaunchContext { ...@@ -67,6 +67,10 @@ class CinnLaunchContext {
// Finalize all execution arguments and return them // Finalize all execution arguments and return them
const std::map<std::string, cinn_pod_value_t>& FinalizeArguments() const; const std::map<std::string, cinn_pod_value_t>& FinalizeArguments() const;
std::vector<std::unique_ptr<cinn_buffer_t>> HandoverBuffers() {
return std::move(hold_buffers_);
}
private: private:
// Get CinnTensor with CINN variable name // Get CinnTensor with CINN variable name
CinnTensor GetCinnTensor(const std::string& var_name); CinnTensor GetCinnTensor(const std::string& var_name);
...@@ -110,10 +114,35 @@ void DebugCinnCompiledResult(const CinnCompiledObject& result); ...@@ -110,10 +114,35 @@ void DebugCinnCompiledResult(const CinnCompiledObject& result);
// Launch cinn to execute compiled executable program and wait done // Launch cinn to execute compiled executable program and wait done
void LaunchCinnExecution(const CinnCompiledObject& compiled_obj, void LaunchCinnExecution(const CinnCompiledObject& compiled_obj,
const CinnLaunchContext& context); const CinnLaunchContext& context, void* stream);
// Set cinn FLAGS (such as FLAGS_cinn_cudnn_deterministic) with paddle's FLAGS. // Set cinn FLAGS (such as FLAGS_cinn_cudnn_deterministic) with paddle's FLAGS.
void SetCinnRuntimeFlags(); void SetCinnRuntimeFlags();
template <typename DeviceContext>
void ReleaseResource(const std::vector<void*>& resources, void* stream) {
auto* temp_scope = static_cast<framework::Scope*>(resources[0]);
auto* buffers =
static_cast<std::vector<std::unique_ptr<cinn_buffer_t>>*>(resources[1]);
delete temp_scope;
delete buffers;
}
template <typename DeviceContext>
void* GetStream(const framework::ExecutionContext& ctx) {
return nullptr;
}
#ifdef PADDLE_WITH_CUDA
template <>
void ReleaseResource<platform::CUDADeviceContext>(
const std::vector<void*>& resources, void* stream);
template <>
void* GetStream<platform::CUDADeviceContext>(
const framework::ExecutionContext& ctx);
#endif
} // namespace details } // namespace details
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
...@@ -122,6 +151,7 @@ class CinnLaunchOpKernel : public framework::OpKernel<T> { ...@@ -122,6 +151,7 @@ class CinnLaunchOpKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
const auto& scope = ctx.scope(); const auto& scope = ctx.scope();
const auto& place = ctx.GetPlace(); const auto& place = ctx.GetPlace();
void* stream = details::GetStream<DeviceContext>(ctx);
// Step 1. Find graph object and prepare input // Step 1. Find graph object and prepare input
PADDLE_ENFORCE_EQ(ctx.HasAttr(kCompilationKey), true, PADDLE_ENFORCE_EQ(ctx.HasAttr(kCompilationKey), true,
platform::errors::NotFound( platform::errors::NotFound(
...@@ -146,7 +176,7 @@ class CinnLaunchOpKernel : public framework::OpKernel<T> { ...@@ -146,7 +176,7 @@ class CinnLaunchOpKernel : public framework::OpKernel<T> {
// Step 2. Get compilation result of the graph // Step 2. Get compilation result of the graph
auto target = details::PlaceToCinnTarget(place); auto target = details::PlaceToCinnTarget(place);
const auto& cinn_compiled_object = CinnCompiler::GetInstance()->Compile( const auto& cinn_compiled_object = CinnCompiler::GetInstance()->Compile(
compilation_key, inputs_name2tensor, target); compilation_key, inputs_name2tensor, target, stream);
details::DebugCinnCompiledResult(cinn_compiled_object); details::DebugCinnCompiledResult(cinn_compiled_object);
auto launch_context = auto launch_context =
...@@ -199,7 +229,7 @@ class CinnLaunchOpKernel : public framework::OpKernel<T> { ...@@ -199,7 +229,7 @@ class CinnLaunchOpKernel : public framework::OpKernel<T> {
// names, because they will not be used outside the graph // names, because they will not be used outside the graph
// and should be destructed after computation finished. // and should be destructed after computation finished.
auto internal_variable_names = launch_context->GetInternalVariableNames(); auto internal_variable_names = launch_context->GetInternalVariableNames();
auto temp_scope = scope.NewTmpScope(); framework::Scope* temp_scope = scope.NewTmpScope().release();
for (const auto& var_name : internal_variable_names) { for (const auto& var_name : internal_variable_names) {
auto* tensor = temp_scope->Var(var_name)->GetMutable<LoDTensor>(); auto* tensor = temp_scope->Var(var_name)->GetMutable<LoDTensor>();
launch_context->MutableTensorData(var_name, place, tensor, true); launch_context->MutableTensorData(var_name, place, tensor, true);
...@@ -210,8 +240,15 @@ class CinnLaunchOpKernel : public framework::OpKernel<T> { ...@@ -210,8 +240,15 @@ class CinnLaunchOpKernel : public framework::OpKernel<T> {
details::SetCinnRuntimeFlags(); details::SetCinnRuntimeFlags();
// Step 5. Launch CINN to execute the compiled executable program // Step 5. Launch CINN to execute the compiled executable program
details::LaunchCinnExecution(cinn_compiled_object, *launch_context); VLOG(4) << "Run Cinn compiled executable program with stream: " << stream;
details::LaunchCinnExecution(cinn_compiled_object, *launch_context, stream);
VLOG(4) << "CinnLaunchOp launch execution done."; VLOG(4) << "CinnLaunchOp launch execution done.";
// Step 6. Release some resources, such as `temp_scope` and cinn_buffers.
auto* buffers_holder = new std::vector<std::unique_ptr<cinn_buffer_t>>{
launch_context->HandoverBuffers()};
details::ReleaseResource<DeviceContext>({temp_scope, buffers_holder},
stream);
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册