From 9d0baeabd70b8a107f2f03ba3a06f6bed271e66b Mon Sep 17 00:00:00 2001 From: TeFeng Chen Date: Tue, 15 Feb 2022 14:40:54 +0800 Subject: [PATCH] Add cinn_instruction_run_op for launching execution of a cinn instruction (#39435) * add cinn_instruction_run_op for launching execution of a cinn instruction * fix multi definition compilation error * update cmake * fix bug at infershape * fix compile error due to lacking header file --- .../framework/paddle2cinn/cinn_compiler.cc | 20 +++- .../framework/paddle2cinn/cinn_compiler.h | 11 +- .../paddle2cinn/cinn_compiler_test.cc | 7 ++ paddle/fluid/operators/cinn/CMakeLists.txt | 6 +- .../operators/cinn/cinn_instruction_run_op.cc | 109 ++++++++++++++++++ .../cinn/cinn_instruction_run_op.cu.cc | 26 +++++ .../operators/cinn/cinn_instruction_run_op.h | 76 ++++++++++++ .../operators/cinn/cinn_launch_context.cc | 33 ++++++ .../operators/cinn/cinn_launch_context.h | 13 ++- paddle/fluid/operators/cinn/cinn_launch_op.cc | 5 +- .../fluid/operators/cinn/cinn_launch_op.cu.cc | 29 ----- paddle/fluid/operators/cinn/cinn_launch_op.h | 24 +--- paddle/fluid/operators/cinn/cinn_op_helper.cc | 31 +++++ paddle/fluid/operators/cinn/cinn_op_helper.h | 47 ++++++++ 14 files changed, 370 insertions(+), 67 deletions(-) create mode 100644 paddle/fluid/operators/cinn/cinn_instruction_run_op.cc create mode 100644 paddle/fluid/operators/cinn/cinn_instruction_run_op.cu.cc create mode 100644 paddle/fluid/operators/cinn/cinn_instruction_run_op.h create mode 100644 paddle/fluid/operators/cinn/cinn_op_helper.cc create mode 100644 paddle/fluid/operators/cinn/cinn_op_helper.h diff --git a/paddle/fluid/framework/paddle2cinn/cinn_compiler.cc b/paddle/fluid/framework/paddle2cinn/cinn_compiler.cc index 535c9ab58e2..c62ece7f0dc 100644 --- a/paddle/fluid/framework/paddle2cinn/cinn_compiler.cc +++ b/paddle/fluid/framework/paddle2cinn/cinn_compiler.cc @@ -88,7 +88,7 @@ const CinnCompiledObject& CinnCompiler::Compile( if (cache_by_struct_.count(cur_key_by_struct) != 0) { exist = true; cache_by_address_[cur_key_by_address] = - cache_by_struct_.at(cur_key_by_struct).get(); + cache_by_struct_.at(cur_key_by_struct); } } } @@ -98,12 +98,13 @@ const CinnCompiledObject& CinnCompiler::Compile( CompileGraph(graph, input_tensors, target, compiled_num, stream); pten::AutoWRLock w_guard{&rwlock_}; if (!cache_by_struct_.count(cur_key_by_struct)) { - cache_by_address_[cur_key_by_address] = compiled_res.get(); - cache_by_struct_[cur_key_by_struct] = std::move(compiled_res); + cache_by_address_[cur_key_by_address] = compiled_num; + cache_by_struct_[cur_key_by_struct] = compiled_num; + index2cache_.emplace(compiled_num, std::move(compiled_res)); } } pten::AutoRDLock guard{&rwlock_}; - const auto& cached_boj = *cache_by_address_[cur_key_by_address]; + const auto& cached_boj = *index2cache_[cache_by_address_[cur_key_by_address]]; return cached_boj; } @@ -115,6 +116,15 @@ const CinnCompiledObject& CinnCompiler::Compile( return Compile(graph, input_tensors, target, stream); } +const CinnCompiledObject& CinnCompiler::GetCompiledObject( + int64_t cached_index) const { + auto res = index2cache_.find(cached_index); + PADDLE_ENFORCE_NE(res, index2cache_.end(), + platform::errors::InvalidArgument( + "Index(%ld) not found in cache", cached_index)); + return *res->second; +} + std::string CinnCompiler::AddGraph(std::unique_ptr graph) { std::string graph_key; ProgramDesc program; @@ -202,6 +212,7 @@ void CinnCompiler::Clear() { graphs_.clear(); cache_by_address_.clear(); cache_by_struct_.clear(); + index2cache_.clear(); } real_compiled_num_.store(0); } @@ -240,6 +251,7 @@ std::unique_ptr CinnCompiler::CompileGraph( compiled_obj->launch_context = std::make_unique( compiled_obj->paddle2cinn_varmap, compiled_obj->scope); + compiled_obj->cached_index = compiled_num; return compiled_obj; } diff --git a/paddle/fluid/framework/paddle2cinn/cinn_compiler.h b/paddle/fluid/framework/paddle2cinn/cinn_compiler.h index 91a7b4e5a11..d7ae743111e 100644 --- a/paddle/fluid/framework/paddle2cinn/cinn_compiler.h +++ b/paddle/fluid/framework/paddle2cinn/cinn_compiler.h @@ -53,6 +53,7 @@ struct CinnCompiledObject { std::shared_ptr<::cinn::hlir::framework::Scope> scope; std::unordered_map paddle2cinn_varmap; std::unique_ptr launch_context; + std::int64_t cached_index; }; // Entrance to use CINN. @@ -76,6 +77,8 @@ class CinnCompiler { const std::map& input_tensors, const ::cinn::common::Target& target, void* stream = nullptr); + const CinnCompiledObject& GetCompiledObject(int64_t cached_index) const; + std::string AddGraph(std::unique_ptr graph); const ir::Graph& FindGraph(const std::string& graph_key) const; @@ -101,12 +104,12 @@ class CinnCompiler { void* stream = nullptr) const; std::unordered_map> graphs_; - std::unordered_map + std::unordered_map cache_by_address_; - std::unordered_map, CinnCacheKey::Hash> + std::unordered_map cache_by_struct_; + std::unordered_map> + index2cache_; std::atomic_int64_t real_compiled_num_{0}; mutable pten::RWLock rwlock_; diff --git a/paddle/fluid/framework/paddle2cinn/cinn_compiler_test.cc b/paddle/fluid/framework/paddle2cinn/cinn_compiler_test.cc index 6769413d99b..be51c7b783a 100644 --- a/paddle/fluid/framework/paddle2cinn/cinn_compiler_test.cc +++ b/paddle/fluid/framework/paddle2cinn/cinn_compiler_test.cc @@ -270,13 +270,20 @@ TEST(CinnCompilerTest, Compile) { auto compile_fn = [&](const Target& target) { const auto& compiled_obj = cinn_compiler->Compile(compiling_graph, input_tensors, target); + ASSERT_NE(compiled_obj.compiler, nullptr); ASSERT_NE(compiled_obj.runtime_program, nullptr); ASSERT_NE(compiled_obj.scope, nullptr); ASSERT_FALSE(compiled_obj.paddle2cinn_varmap.empty()); + ASSERT_NE(compiled_obj.launch_context, nullptr); const auto& cached_obj = cinn_compiler->Compile(compilation_key, input_tensors, target); ASSERT_EQ(reinterpret_cast(&compiled_obj), reinterpret_cast(&cached_obj)); + ASSERT_EQ(cached_obj.cached_index + 1, cinn_compiler->real_compiled_num()); + const auto& ret_obj = + cinn_compiler->GetCompiledObject(cached_obj.cached_index); + ASSERT_EQ(reinterpret_cast(&compiled_obj), + reinterpret_cast(&ret_obj)); }; // GPU Compilation diff --git a/paddle/fluid/operators/cinn/CMakeLists.txt b/paddle/fluid/operators/cinn/CMakeLists.txt index ed3a7598bda..b80916616a1 100644 --- a/paddle/fluid/operators/cinn/CMakeLists.txt +++ b/paddle/fluid/operators/cinn/CMakeLists.txt @@ -1,8 +1,10 @@ include(operators) -register_operators(EXCLUDES cinn_launch_op) +cc_library(cinn_op_helper SRCS cinn_op_helper.cc DEPS operator device_context) cc_library(cinn_launch_context SRCS cinn_launch_context.cc DEPS ddim lod_tensor scope cinn) -op_library(cinn_launch_op SRCS cinn_launch_op.cc cinn_launch_op.cu.cc DEPS string_helper cinn cinn_compiler cinn_launch_context) + +SET(CINN_OP_DEPS string_helper cinn cinn_compiler cinn_op_helper cinn_launch_context) +register_operators(DEPS ${CINN_OP_DEPS}) if (WITH_TESTING) cc_test(cinn_launch_context_test SRCS cinn_launch_context_test.cc DEPS ddim lod_tensor scope cinn_launch_context) diff --git a/paddle/fluid/operators/cinn/cinn_instruction_run_op.cc b/paddle/fluid/operators/cinn/cinn_instruction_run_op.cc new file mode 100644 index 00000000000..edf854a9c95 --- /dev/null +++ b/paddle/fluid/operators/cinn/cinn_instruction_run_op.cc @@ -0,0 +1,109 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/cinn/cinn_instruction_run_op.h" +#include "paddle/fluid/framework/paddle2cinn/cinn_compiler.h" +#include "paddle/fluid/operators/cinn/cinn_launch_context.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle::operators { + +class CinnInstructionRunOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInputs(kX), "Input", kX, "CinnInstructionRun"); + OP_INOUT_CHECK(ctx->HasOutputs(kOutputs), "Output", kOutputs, + "CinnInstructionRun"); + const CinnCompiledObject& compiled_object = + CinnCompiler::GetInstance()->GetCompiledObject( + ctx->Attrs().Get(kCachedIndex)); + + details::CinnLaunchContext* launch_context = + compiled_object.launch_context.get(); + std::vector output_args = ctx->Outputs(kOutputs); + std::vector output_dims(output_args.size()); + std::transform(output_args.begin(), output_args.end(), output_dims.begin(), + [launch_context](const std::string& var_name) { + cinn_buffer_t* buffer = + launch_context->GetCinnBufferOfVar(var_name); + return framework::DDim(buffer->dims, buffer->dimensions); + }); + ctx->SetOutputsDim(kOutputs, output_dims); + } +}; + +class CinnInstructionRunOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput(kX, + "(vector)" + "which are the input arguments of this cinn instruction") + .AsDuplicable(); + AddOutput(kOutputs, + "(vector)" + "which are the output arguments of this cinn instruction") + .AsDuplicable(); + AddAttr( + kCachedIndex, + "(int64_t)" + "the stored index of the cached compilation result in CinnCompiler," + "which is used to fetch the CinnCompiledObject where this cinn " + "instruction is included"); + AddAttr( + kInstructionIndex, + "(int64_t)" + "the index of this instruction to the cinn runtime program"); + AddComment(R"DOC( +CinnInstructionRun Operator. + +This operator is used to launch a +CINN(https://github.com/PaddlePaddle/CINN/blob/develop/README.md) instruction execution + +Both the input and output of this operator are a set of variables +which are the input and output arguments of the bound cinn instruction respectively. +In addition, there is an attribute named 'cached_index' should be +set necessarily to get the CinnCompiledObject where the instruction is included +and 'instruction_index' is fetch the instruction object from complied runtime prograrm. + +It accomplishes the execution of the instruction according to the following steps: + 0. Set the shapes ot the output variables at InferShape function with + compilation result. + 1. Fetch the cinn instruction bound to this operator by 'cached_index' + and 'instruction_index' from CinnCompiler. + 2. Prepare the input and output variables of the instruction in Paddle and share + their buffers to CINN by setting 'memory' of according cinn_buffer_t. + 3. Launch CINN runtime to execute the instruction. + +)DOC"); + } +}; + +} // namespace paddle::operators + +namespace ops = paddle::operators; +using CPUDeviceContext = paddle::platform::CPUDeviceContext; +REGISTER_OPERATOR( + cinn_instruction_run, ops::CinnInstructionRunOp, + ops::CinnInstructionRunOpMaker, + paddle::framework::EmptyGradOpMaker, + paddle::framework::EmptyGradOpMaker); +REGISTER_OP_CPU_KERNEL( + cinn_instruction_run, + ops::CinnInstructionRunOpKernel, + ops::CinnInstructionRunOpKernel, + ops::CinnInstructionRunOpKernel, + ops::CinnInstructionRunOpKernel, + ops::CinnInstructionRunOpKernel); diff --git a/paddle/fluid/operators/cinn/cinn_instruction_run_op.cu.cc b/paddle/fluid/operators/cinn/cinn_instruction_run_op.cu.cc new file mode 100644 index 00000000000..a1b00a18206 --- /dev/null +++ b/paddle/fluid/operators/cinn/cinn_instruction_run_op.cu.cc @@ -0,0 +1,26 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/cinn/cinn_instruction_run_op.h" +#include "paddle/fluid/framework/op_registry.h" + +namespace ops = paddle::operators; +using CUDADeviceContext = paddle::platform::CUDADeviceContext; +REGISTER_OP_CUDA_KERNEL( + cinn_instruction_run, + ops::CinnInstructionRunOpKernel, + ops::CinnInstructionRunOpKernel, + ops::CinnInstructionRunOpKernel, + ops::CinnInstructionRunOpKernel, + ops::CinnInstructionRunOpKernel); diff --git a/paddle/fluid/operators/cinn/cinn_instruction_run_op.h b/paddle/fluid/operators/cinn/cinn_instruction_run_op.h new file mode 100644 index 00000000000..8847faa944b --- /dev/null +++ b/paddle/fluid/operators/cinn/cinn_instruction_run_op.h @@ -0,0 +1,76 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include "cinn/hlir/framework/graph_compiler.h" +#include "cinn/hlir/framework/instruction.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/paddle2cinn/cinn_compiler.h" +#include "paddle/fluid/operators/cinn/cinn_launch_context.h" +#include "paddle/fluid/operators/cinn/cinn_op_helper.h" + +namespace paddle::operators { + +using CinnInstruction = ::cinn::hlir::framework::Instruction; +using CinnCompiledObject = framework::paddle2cinn::CinnCompiledObject; +using CinnCompiler = framework::paddle2cinn::CinnCompiler; + +template +class CinnInstructionRunOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + // step 1: fetch the cinn instruction bound to this operator + auto cached_index = ctx.template Attr(kCachedIndex); + auto ins_index = ctx.template Attr(kInstructionIndex); + const CinnCompiledObject& compiled_object = + CinnCompiler::GetInstance()->GetCompiledObject(cached_index); + const std::vector>& instructions = + compiled_object.runtime_program->GetRunInstructions(); + PADDLE_ENFORCE_LT(ins_index, instructions.size(), + platform::errors::InvalidArgument( + "Index(%ld) > instructions.size(%ld).", ins_index, + instructions.size())); + auto&& instruction = instructions.at(ins_index); + + // step 2: prepare the input and output arguments of the instruction + details::CinnLaunchContext* launch_context = + compiled_object.launch_context.get(); + auto share_argument_buffer_fn = [launch_context, + &ctx](const std::string& var_name) { + cinn_buffer_t* buffer = launch_context->GetCinnBufferOfVar(var_name); + framework::Variable* var = ctx.scope().GetVar(var_name); + auto* tensor = var->template GetMutable(); + buffer->memory = + reinterpret_cast(tensor->mutable_data(ctx.GetPlace())); + }; + std::vector in_args = ctx.InputNames(kX); + std::for_each(in_args.begin(), in_args.end(), share_argument_buffer_fn); + std::vector out_args = ctx.OutputNames(kOutputs); + std::for_each(out_args.begin(), out_args.end(), share_argument_buffer_fn); + + // step 3: launch CINN runtime to execute the instruction + // TODO(CtfGo): simplify format of arguments package as a vector in CINN + // and update this usage call + instruction->Run(&launch_context->FinalizeArguments(), false, + details::GetStream(ctx)); + } +}; + +} // namespace paddle::operators diff --git a/paddle/fluid/operators/cinn/cinn_launch_context.cc b/paddle/fluid/operators/cinn/cinn_launch_context.cc index fa93bf00f2a..282a8f69e4e 100644 --- a/paddle/fluid/operators/cinn/cinn_launch_context.cc +++ b/paddle/fluid/operators/cinn/cinn_launch_context.cc @@ -24,12 +24,31 @@ CinnLaunchContext::CinnLaunchContext( const std::unordered_map& paddle2cinn_varmap, const std::shared_ptr& cinn_scope) : paddle2cinn_varmap_(paddle2cinn_varmap), cinn_scope_(cinn_scope) { + // generate all names of cinn used variables auto var_names = cinn_scope_->var_names(); cinn_variable_names_.reserve(var_names.size()); std::transform( var_names.begin(), var_names.end(), std::inserter(cinn_variable_names_, cinn_variable_names_.end()), [](const auto& name_view) { return std::string(name_view.data()); }); + // build the variable name map of cinn2paddle + for (const auto& x : paddle2cinn_varmap_) { + auto res = cinn2paddle_varmap_.emplace(x.second, x.first); + PADDLE_ENFORCE_EQ( + res.second, true, + platform::errors::InvalidArgument( + "Cinn variable(%s) maps to more than one paddle variable(%s,%s)", + x.second, res.first->second, x.first)); + } + // supplement the relations of the remain variables not appearing in above + // map, + // they are internal variables and here we use the name from cinn compiled. + for (const auto& var_name : cinn_variable_names_) { + if (!cinn2paddle_varmap_.count(var_name)) { + cinn2paddle_varmap_.emplace(var_name, var_name); + paddle2cinn_varmap_.emplace(var_name, var_name); + } + } } void CinnLaunchContext::UpdateCapturedEnv(const framework::Scope& scope, @@ -189,6 +208,20 @@ CinnLaunchContext::FinalizeArguments() const { return name2argument_; } +cinn_buffer_t* CinnLaunchContext::GetCinnBufferOfVar( + const std::string& paddle_var_name) { + auto res = paddle2cinn_varmap_.find(paddle_var_name); + PADDLE_ENFORCE_NE( + res, paddle2cinn_varmap_.end(), + platform::errors::InvalidArgument( + "Variable(%s) not found in compilation result", paddle_var_name)); + auto it = name2argument_.find(res->second); + PADDLE_ENFORCE_NE(it, name2argument_.end(), + platform::errors::InvalidArgument( + "Argument(%s) not be initialized", res->second)); + return static_cast(it->second); +} + } // namespace details } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/cinn/cinn_launch_context.h b/paddle/fluid/operators/cinn/cinn_launch_context.h index 7b71d77d8b8..71ddeb35420 100644 --- a/paddle/fluid/operators/cinn/cinn_launch_context.h +++ b/paddle/fluid/operators/cinn/cinn_launch_context.h @@ -64,6 +64,8 @@ class CinnLaunchContext { // Finalize all execution arguments and return them const std::map& FinalizeArguments() const; + cinn_buffer_t* GetCinnBufferOfVar(const std::string& paddle_var_name); + private: // Get CinnTensor with CINN variable name CinnTensor GetCinnTensor(const std::string& var_name); @@ -84,19 +86,22 @@ class CinnLaunchContext { std::unique_ptr cached_temp_scope_ = nullptr; // a variable name map from paddle to cinn - const std::unordered_map& paddle2cinn_varmap_; + std::unordered_map paddle2cinn_varmap_; + // a variable name map from cinn to paddle + std::unordered_map cinn2paddle_varmap_; // the variable scope of cinn const std::shared_ptr cinn_scope_; - // all variables used by compiled executable program + // all names of cinn variables used by compiled executable program std::unordered_set cinn_variable_names_; // because a cinn_pod_value_t does not own the cinn_buffer_t object, // an extra stroage is necessary to keep the object and it can - // not be released until runtime program finish execution. + // not be released until the runtime program finish execution. std::vector> hold_buffers_; - // name to execution argument + // this map saves all execution arguments with their cinn names as key, + // and it is passed to the Execute interface of a cinn runtime program. std::map name2argument_; }; diff --git a/paddle/fluid/operators/cinn/cinn_launch_op.cc b/paddle/fluid/operators/cinn/cinn_launch_op.cc index cd17c947228..d918b7216c4 100644 --- a/paddle/fluid/operators/cinn/cinn_launch_op.cc +++ b/paddle/fluid/operators/cinn/cinn_launch_op.cc @@ -13,10 +13,11 @@ // limitations under the License. #include "paddle/fluid/operators/cinn/cinn_launch_op.h" - #include #include - +#include "cinn/hlir/framework/graph_compiler.h" +#include "cinn/runtime/cinn_runtime.h" +#include "cinn/runtime/flags.h" #include "paddle/fluid/string/string_helper.h" DECLARE_bool(cudnn_deterministic); diff --git a/paddle/fluid/operators/cinn/cinn_launch_op.cu.cc b/paddle/fluid/operators/cinn/cinn_launch_op.cu.cc index ea36a19202e..9dfd53834e9 100644 --- a/paddle/fluid/operators/cinn/cinn_launch_op.cu.cc +++ b/paddle/fluid/operators/cinn/cinn_launch_op.cu.cc @@ -13,36 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/cinn/cinn_launch_op.h" -#include -#include -#include "cinn/runtime/cinn_runtime.h" #include "paddle/fluid/framework/operator.h" -#include "paddle/fluid/framework/scope.h" -#include "paddle/fluid/platform/device/gpu/gpu_types.h" -#include "paddle/fluid/platform/device_context.h" -#include "paddle/fluid/platform/enforce.h" - -#ifdef PADDLE_WITH_CUDA -#include -#endif - -namespace paddle { -namespace operators { -namespace details { - -#ifdef PADDLE_WITH_CUDA -template <> -void* GetStream( - const framework::ExecutionContext& ctx) { - const auto& dev_ctx = - ctx.template device_context(); - return dev_ctx.stream(); -} -#endif - -} // namespace details -} // namespace operators -} // namespace paddle /* see [Why use single type kernel] */ REGISTER_OP_CUDA_KERNEL(cinn_launch, diff --git a/paddle/fluid/operators/cinn/cinn_launch_op.h b/paddle/fluid/operators/cinn/cinn_launch_op.h index 23dfa9d84c0..bd9b30f559b 100644 --- a/paddle/fluid/operators/cinn/cinn_launch_op.h +++ b/paddle/fluid/operators/cinn/cinn_launch_op.h @@ -18,27 +18,18 @@ #include #include #include -#include "cinn/hlir/framework/graph_compiler.h" -#include "cinn/hlir/framework/scope.h" -#include "cinn/runtime/cinn_runtime.h" -#include "cinn/runtime/flags.h" +#include "cinn/common/target.h" #include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/paddle2cinn/cinn_compiler.h" #include "paddle/fluid/operators/cinn/cinn_launch_context.h" +#include "paddle/fluid/operators/cinn/cinn_op_helper.h" namespace paddle { namespace operators { -constexpr char kX[] = "X"; -constexpr char kNoNeedBufferX[] = "NoNeedBufferX"; -constexpr char kOutputs[] = "Out"; -constexpr char kCompilationKey[] = "compilation_key"; - using LoDTensor = framework::LoDTensor; -using CinnTensor = ::cinn::hlir::framework::Tensor; -using CinnScope = ::cinn::hlir::framework::Scope; using CinnCompiler = framework::paddle2cinn::CinnCompiler; using CinnCompiledObject = framework::paddle2cinn::CinnCompiledObject; @@ -57,17 +48,6 @@ void LaunchCinnExecution(const CinnCompiledObject& compiled_obj, // Set cinn FLAGS (such as FLAGS_cinn_cudnn_deterministic) with paddle's FLAGS. void SetCinnRuntimeFlags(); -template -void* GetStream(const framework::ExecutionContext& ctx) { - return nullptr; -} - -#ifdef PADDLE_WITH_CUDA -template <> -void* GetStream( - const framework::ExecutionContext& ctx); -#endif - } // namespace details template diff --git a/paddle/fluid/operators/cinn/cinn_op_helper.cc b/paddle/fluid/operators/cinn/cinn_op_helper.cc new file mode 100644 index 00000000000..3fb9c822c77 --- /dev/null +++ b/paddle/fluid/operators/cinn/cinn_op_helper.cc @@ -0,0 +1,31 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/cinn/cinn_op_helper.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/platform/device_context.h" + +namespace paddle::operators::details { + +#ifdef PADDLE_WITH_CUDA +template <> +void* GetStream( + const framework::ExecutionContext& ctx) { + const auto& dev_ctx = + ctx.template device_context(); + return dev_ctx.stream(); +} +#endif + +} // namespace paddle::operators::details diff --git a/paddle/fluid/operators/cinn/cinn_op_helper.h b/paddle/fluid/operators/cinn/cinn_op_helper.h new file mode 100644 index 00000000000..e542134b946 --- /dev/null +++ b/paddle/fluid/operators/cinn/cinn_op_helper.h @@ -0,0 +1,47 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/fluid/framework/operator.h" + +// We define some common names or utility functions +// for operators related to cinn in this file +namespace paddle::operators { + +// input params, output params and attributes +constexpr char kX[] = "X"; +constexpr char kNoNeedBufferX[] = "NoNeedBufferX"; +constexpr char kOutputs[] = "Out"; +constexpr char kCompilationKey[] = "compilation_key"; +constexpr char kCachedIndex[] = "cached_index"; +constexpr char kInstructionIndex[] = "instruction_index"; + +// utility functions +namespace details { + +template +void* GetStream(const framework::ExecutionContext& ctx) { + return nullptr; +} + +#ifdef PADDLE_WITH_CUDA +template <> +void* GetStream( + const framework::ExecutionContext& ctx); +#endif + +} // namespace details +} // namespace paddle::operators -- GitLab