diff --git a/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc b/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc index 7e08d883625ad372803929df2409208ee3db5aed..c15744fc1650db2048688bb183c9cff30d779a0c 100644 --- a/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc +++ b/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc @@ -34,7 +34,7 @@ limitations under the License. */ #include "paddle/fluid/framework/ir/subgraph_detector.h" #include "paddle/fluid/framework/op_proto_maker.h" #include "paddle/fluid/framework/paddle2cinn/cinn_compiler.h" -#include "paddle/fluid/operators/cinn_launch_op.h" +#include "paddle/fluid/operators/cinn/cinn_launch_op.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/errors.h" diff --git a/paddle/fluid/framework/paddle2cinn/build_cinn_pass_test.cc b/paddle/fluid/framework/paddle2cinn/build_cinn_pass_test.cc index 1649cd24b341ab3407acc6e0e3403cc4efffee55..586b59a05ecef82b30f5df3c3f2122c683dd5412 100644 --- a/paddle/fluid/framework/paddle2cinn/build_cinn_pass_test.cc +++ b/paddle/fluid/framework/paddle2cinn/build_cinn_pass_test.cc @@ -27,7 +27,7 @@ limitations under the License. */ #include "paddle/fluid/framework/paddle2cinn/cinn_compiler.h" #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/var_desc.h" -#include "paddle/fluid/operators/cinn_launch_op.h" +#include "paddle/fluid/operators/cinn/cinn_launch_op.h" namespace paddle { namespace framework { diff --git a/paddle/fluid/framework/paddle2cinn/cinn_compiler_test.cc b/paddle/fluid/framework/paddle2cinn/cinn_compiler_test.cc index 7d339f53f0fa4048ad243e93c924d72c621edd64..db20e423c4a40f51184f921ee3e8b9be0ad276ac 100644 --- a/paddle/fluid/framework/paddle2cinn/cinn_compiler_test.cc +++ b/paddle/fluid/framework/paddle2cinn/cinn_compiler_test.cc @@ -34,7 +34,7 @@ #include "paddle/fluid/framework/paddle2cinn/build_cinn_pass.h" #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/scope.h" -#include "paddle/fluid/operators/cinn_launch_op.h" +#include "paddle/fluid/operators/cinn/cinn_launch_op.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/place.h" diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index f0621af9bbda5ce9cdc7cce0a7be40a091e95cec..6b8567589872ffd7ce8788b5af17dc5ead42043e 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -52,6 +52,10 @@ if (WITH_LITE) add_subdirectory(lite) endif() +if(WITH_CINN) + add_subdirectory(cinn) +endif() + SET(OP_HEADER_DEPS xxhash executor) if (WITH_GPU) @@ -82,7 +86,7 @@ endif() set(OP_HEADER_DEPS ${OP_HEADER_DEPS} pten pten_api_utils) register_operators(EXCLUDES py_layer_op py_func_op warpctc_op dgc_op load_combine_op lstm_op run_program_op eye_op - recurrent_op save_combine_op sparse_attention_op sync_batch_norm_op spectral_op cinn_launch_op ${OP_MKL_DEPS} DEPS ${OP_HEADER_DEPS}) + recurrent_op save_combine_op sparse_attention_op sync_batch_norm_op spectral_op ${OP_MKL_DEPS} DEPS ${OP_HEADER_DEPS}) op_library(run_program_op SRCS run_program_op.cc run_program_op.cu.cc DEPS executor_cache ${OP_HEADER_DEPS}) op_library(save_combine_op DEPS string_array) @@ -167,14 +171,6 @@ if (WITH_ASCEND_CL) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} npu_op_runner) endif() -if (WITH_CINN) - op_library(cinn_launch_op SRCS cinn_launch_op.cc cinn_launch_op.cu.cc DEPS transform_desc cinn_compiler cinn ${OP_HEADER_DEPS}) - if (WITH_TESTING) - cc_test(cinn_launch_op_test SRCS cinn_launch_op_test.cc DEPS cinn_compiler cinn_launch_op elementwise_add_op) - set_tests_properties(cinn_launch_op_test PROPERTIES ENVIRONMENT OMP_NUM_THREADS=1) - endif() -endif() - # FIXME(typhoonzero): operator deps may not needed. # op_library(lod_tensor_to_array_op DEPS lod_rank_table_op) # op_library(array_to_lod_tensor_op DEPS lod_rank_table_op) @@ -205,7 +201,7 @@ elseif(WITH_ROCM) else() cc_test(test_leaky_relu_grad_grad_functor SRCS test_leaky_relu_grad_grad_functor.cc DEPS tensor device_context eigen3) endif() -cc_test(share_buffer_op_cpp_test SRCS share_buffer_op_test.cc DEPS lod_tensor device_context share_buffer_op) +cc_test(share_buffer_op_cpp_test SRCS share_buffer_op_test.cc DEPS lod_tensor device_context share_buffer_op) cc_library(tensor_formatter SRCS tensor_formatter.cc DEPS ${OP_HEADER_DEPS}) if (WITH_PYTHON) diff --git a/paddle/fluid/operators/cinn/CMakeLists.txt b/paddle/fluid/operators/cinn/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..1a7c6daf3f380dd6e5251bd451175cd9ea32c0b2 --- /dev/null +++ b/paddle/fluid/operators/cinn/CMakeLists.txt @@ -0,0 +1,13 @@ +include(operators) +register_operators(EXCLUDES cinn_launch_op) + +cc_library(cinn_launch_context SRCS cinn_launch_context.cc DEPS cinn) +op_library(cinn_launch_op SRCS cinn_launch_op.cc cinn_launch_op.cu.cc DEPS cinn cinn_compiler cinn_launch_context) + +if (WITH_TESTING) + cc_test(cinn_launch_context_test SRCS cinn_launch_context_test.cc DEPS scope lod_tensor cinn_launch_context) + set_tests_properties(cinn_launch_context_test PROPERTIES LABELS "RUN_TYPE=CINN") + + cc_test(cinn_launch_op_test SRCS cinn_launch_op_test.cc DEPS cinn_compiler cinn_launch_op elementwise_add_op) + set_tests_properties(cinn_launch_op_test PROPERTIES ENVIRONMENT "OMP_NUM_THREADS=1;runtime_include_dir=${PADDLE_BINARY_DIR}/third_party/CINN/src/external_cinn/cinn/runtime/cuda") +endif() diff --git a/paddle/fluid/operators/cinn_launch_op.cc b/paddle/fluid/operators/cinn/cinn_launch_context.cc similarity index 54% rename from paddle/fluid/operators/cinn_launch_op.cc rename to paddle/fluid/operators/cinn/cinn_launch_context.cc index f0ad5b3c3bf996c13aad6986920fd45e05287184..90a4ca73399cf55f29398591e9f65aab70dffbda 100644 --- a/paddle/fluid/operators/cinn_launch_op.cc +++ b/paddle/fluid/operators/cinn/cinn_launch_context.cc @@ -12,76 +12,18 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/operators/cinn_launch_op.h" - +#include "paddle/fluid/operators/cinn/cinn_launch_context.h" #include #include -#include "paddle/fluid/string/string_helper.h" - -DECLARE_bool(cudnn_deterministic); - namespace paddle { namespace operators { - namespace details { -const ::cinn::common::Target& PlaceToCinnTarget(const platform::Place& place) { - if (platform::is_cpu_place(place)) { - return ::cinn::common::DefaultHostTarget(); - } else if (platform::is_gpu_place(place)) { - return ::cinn::common::DefaultNVGPUTarget(); - } - - PADDLE_THROW(platform::errors::InvalidArgument( - "CINN is not supported on current place:%s", place)); - return ::cinn::common::UnkTarget(); -} - -void DebugCinnCompiledResult(const CinnCompiledObject& result) { - if (!VLOG_IS_ON(4)) { - return; - } - const auto& cinn_runtime_program = result.runtime_program; - const auto& cinn_scope = *(result.scope); - const auto& paddle2cinn_varmap = result.paddle2cinn_varmap; - - VLOG(4) << "Compiled runtime_program instrunction size:[" - << cinn_runtime_program->size() << "]"; - - std::vector infos; - auto cinn_var_names = cinn_scope.var_names(); - infos.reserve(cinn_var_names.size()); - std::transform(cinn_var_names.begin(), cinn_var_names.end(), - std::back_inserter(infos), - [](const auto& name_view) { return name_view.data(); }); - VLOG(4) << "Compiled scope variable names:[" - << string::join_strings(infos, ',') << "]"; - - infos.clear(); - infos.reserve(paddle2cinn_varmap.size()); - std::transform(paddle2cinn_varmap.begin(), paddle2cinn_varmap.end(), - std::back_inserter(infos), [](const auto& paddle2cinn) { - return paddle2cinn.first + "->" + paddle2cinn.second; - }); - VLOG(4) << "Compiled paddle2cinn_varmap:[" << string::join_strings(infos, ',') - << "]"; -} - -void LaunchCinnExecution(const CinnCompiledObject& compiled_obj, - const CinnLaunchContext& context, void* stream) { - compiled_obj.runtime_program->Execute(&context.FinalizeArguments(), stream); -} - -void SetCinnRuntimeFlags() { - VLOG(4) << "Set FLAGS_cinn_cudnn_deterministic to " - << FLAGS_cudnn_deterministic; - ::cinn::runtime::SetCinnCudnnDeterministic(FLAGS_cudnn_deterministic); -} - -CinnLaunchContext::CinnLaunchContext(const CinnCompiledObject& compiled_obj) - : paddle2cinn_varmap_(compiled_obj.paddle2cinn_varmap), - cinn_scope_(compiled_obj.scope) { +CinnLaunchContext::CinnLaunchContext( + const std::unordered_map& paddle2cinn_varmap, + const std::shared_ptr& cinn_scope) + : paddle2cinn_varmap_(paddle2cinn_varmap), cinn_scope_(cinn_scope) { auto var_names = cinn_scope_->var_names(); cinn_variable_names_.reserve(var_names.size()); std::transform( @@ -221,90 +163,5 @@ CinnLaunchContext::FinalizeArguments() const { } } // namespace details - -class CinnLaunchOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInputs(kX), "Input", kX, "CinnLaunchOp"); - OP_INOUT_CHECK(ctx->HasOutputs(kOutputs), "Output", kOutputs, - "CinnLaunchOp"); - } - - protected: - /* [Why use single type kernel]: - * - * This op is similar to a control flow op, it doses not need - * a op kernel, but in order to make it execute under dynamic - * graph mode, implement it with op kernel. - * - * So whether the kernel data type is int, float or other type, - * which has no effect on its execution logic, so directly - * specified a data type here. - * - * Of course, the data type here is also not important. - */ - - framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(framework::proto::VarType::FP32, - ctx.GetPlace()); - } -}; - -class CinnLaunchOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput(kX, - "(vector)" - "which are the input of graph inside the CinnLaunchOp.") - .AsDuplicable(); - AddOutput(kOutputs, - "(vector)" - "which are the output of graph inside the CinnLaunchOp.") - .AsDuplicable(); - AddAttr( - kCompilationKey, - "(string)" - "a hash key used to get the graph object or its computation result."); - AddComment(R"DOC( -CinnLaunch Operator. - -This operator is used to launch CINN(https://github.com/PaddlePaddle/CINN/blob/develop/README.md) -to compile a graph and execute the compiled object. - -Both input and output of this operator are a set of variables -which are input and output of the graph respectively that will be -compiled and executed in this operator. -In addition, there is an attribute named 'compilation_key' should be -set necessarily to get corresponding ir::Graph object of the graph -or its computation result. - -It accomplishes the computation of graph following several steps: - 1. Fetch ir::Graph object from CinnCompiler using kCompilationKey - 2. Compile the graph to a compiled object, and insert it to the - global cache so that we can directly query it from this cache next time - when shape of input variables are not changed at all. - 3. Create and instantiate all variables used to execute compiled runtime program - if necessary according to the info(type,shape) included in the return scope. - 4. Pack each tensor buffer of all above variables as execution arguments. - 5. Launch execution of the runtime program with above arguments, then - the result would be output by writing value on underlying buffer address. - -)DOC"); - } -}; - } // namespace operators } // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OPERATOR( - cinn_launch, ops::CinnLaunchOp, ops::CinnLaunchOpMaker, - paddle::framework::EmptyGradOpMaker, - paddle::framework::EmptyGradOpMaker); -/* see [Why use single type kernel] */ -REGISTER_OP_CPU_KERNEL( - cinn_launch, - ops::CinnLaunchOpKernel); diff --git a/paddle/fluid/operators/cinn/cinn_launch_context.h b/paddle/fluid/operators/cinn/cinn_launch_context.h new file mode 100644 index 0000000000000000000000000000000000000000..c990255d68253d048d8c5b3806dbabf80efe746d --- /dev/null +++ b/paddle/fluid/operators/cinn/cinn_launch_context.h @@ -0,0 +1,104 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include "cinn/hlir/framework/scope.h" +#include "cinn/hlir/framework/tensor.h" +#include "cinn/runtime/cinn_runtime.h" +#include "paddle/fluid/framework/ddim.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/platform/place.h" + +namespace paddle { +namespace operators { +namespace details { + +using LoDTensor = framework::LoDTensor; +using CinnTensor = ::cinn::hlir::framework::Tensor; +using CinnScope = ::cinn::hlir::framework::Scope; + +class CinnLaunchContext { + public: + explicit CinnLaunchContext( + const std::unordered_map& paddle2cinn_varmap, + const std::shared_ptr& cinn_scope); + + // Return whether a Paddle variable used on compiled kernels + bool IsVariableUsed(const std::string& var_name); + + // Assign tensor buffer to input or output variables + void AssignExternalVariable(const std::string& var_name, + const platform::Place& place, LoDTensor* tensor); + + // Assign tensor buffer to internal variables + void AssignInternalVariable(const std::string& var_name, + const platform::Place& place, LoDTensor* tensor); + + // Extract internal variable names from CinnScope + // by excluding used input and output variables + std::unordered_set GetInternalVariableNames(); + + // Finalize all execution arguments and return them + const std::map& FinalizeArguments() const; + + std::vector> HandoverBuffers() { + return std::move(hold_buffers_); + } + + private: + // Get CinnTensor with CINN variable name + CinnTensor GetCinnTensor(const std::string& var_name); + + // Check whether tensors from Paddle and CINN of the same variable + // are equivalent in type and dimension + void CheckTensorEquivalent(const std::string& var_name, + const LoDTensor& paddle_tensor, + const CinnTensor& cinn_tensor); + + // Share the buffer of a Paddle tensor to CINN by delivering memory address + // to a cinn_buffer_t object + std::unique_ptr ShareTensorWithCinnBuffer( + const platform::Place& place, bool free_mem_callback, LoDTensor* tensor); + + // Set an argument with (cinn name)->(paddle tensor) pair + void SetArgument(const std::string& cinn_name, const platform::Place& place, + bool free_mem_callback, LoDTensor* paddle_tensor); + + private: + // a variable name map from paddle to cinn + const std::unordered_map& paddle2cinn_varmap_; + // the variable scope of cinn + const std::shared_ptr cinn_scope_; + + // all variables used by compiled executable program + std::unordered_set cinn_variable_names_; + + // because a cinn_pod_value_t does not own the cinn_buffer_t object, + // an extra stroage is necessary to keep the object and it can + // not be released until runtime program finish execution. + std::vector> hold_buffers_; + + // name to execution argument + std::map name2argument_; +}; + +} // namespace details +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/cinn/cinn_launch_context_test.cc b/paddle/fluid/operators/cinn/cinn_launch_context_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..d922e8355b44c5e15214160f15dec22cf70de719 --- /dev/null +++ b/paddle/fluid/operators/cinn/cinn_launch_context_test.cc @@ -0,0 +1,134 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/cinn/cinn_launch_context.h" +#include "gtest/gtest.h" +#include "paddle/fluid/framework/ddim.h" +#include "paddle/fluid/framework/scope.h" + +namespace paddle { +namespace operators { +namespace details { + +using CinnShape = ::cinn::hlir::framework::Shape; + +std::unique_ptr CreateDefaultLaunchContext() { + static std::once_flag initialized; + static std::unordered_map paddle2cinn_varmap; + static std::shared_ptr cinn_scope; + std::call_once(initialized, [&paddle2cinn_varmap, &cinn_scope]() { + auto& scope = cinn_scope; + scope = std::make_shared(); + + scope->Var("cinn_var1"); + scope->GetTensor("cinn_var1")->Resize(CinnShape({3, 4})); + scope->Var("cinn_var2"); + scope->GetTensor("cinn_var2")->Resize(CinnShape({6, 7, 8})); + scope->Var("cinn_var3"); + scope->GetTensor("cinn_var3")->Resize(CinnShape({10, 16})); + + paddle2cinn_varmap = { + {"var1", "cinn_var1"}, {"var3", "cinn_var3"}, {"var4", "cinn_var4"}}; + }); + + return std::make_unique(paddle2cinn_varmap, cinn_scope); +} + +TEST(CinnLaunchContextTest, TestIsVariableUsed) { + auto launch_context = CreateDefaultLaunchContext(); + + ASSERT_EQ(launch_context->IsVariableUsed("var1"), true); + ASSERT_EQ(launch_context->IsVariableUsed("var4"), false); +} + +TEST(CinnLaunchContextTest, TestGetInternalVariableNames) { + auto launch_context = CreateDefaultLaunchContext(); + auto internal_variable_names = launch_context->GetInternalVariableNames(); + ASSERT_EQ(internal_variable_names.size(), 3); + EXPECT_NE(internal_variable_names.find("cinn_var2"), + internal_variable_names.end()); +} + +TEST(CinnLaunchContextTest, TestCheckTensorEquivalent) { + auto launch_context = CreateDefaultLaunchContext(); + platform::CPUPlace place; + framework::Scope scope; + auto* tensor1 = scope.Var("var1")->GetMutable(); + + // CheckTensorEquivalent: tensor dimension not equivalent + tensor1->mutable_data(framework::make_ddim({3, 5}), place); + ASSERT_THROW(launch_context->AssignExternalVariable("var1", place, tensor1), + paddle::platform::EnforceNotMet); +} + +TEST(CinnLaunchContextTest, TestAssignVariablePreCondition) { + auto launch_context = CreateDefaultLaunchContext(); + platform::CPUPlace place; + framework::Scope scope; + auto* tensor4 = scope.Var("var4")->GetMutable(); + + // not used + ASSERT_THROW(launch_context->AssignExternalVariable("var4", place, tensor4), + paddle::platform::EnforceNotMet); + // not found + ASSERT_THROW( + launch_context->AssignExternalVariable("cinn_var4", place, tensor4), + paddle::platform::EnforceNotMet); +} + +TEST(CinnLaunchContextTest, TestSetArgument) { + auto launch_context = CreateDefaultLaunchContext(); + + platform::CPUPlace place; + framework::Scope scope; + auto* tensor1 = scope.Var("var1")->GetMutable(); + float* data1 = + tensor1->mutable_data(framework::make_ddim({3, 4}), place); + data1[0] = 9.99f; + data1[10] = 19.99f; + + // assign external variable + ASSERT_NO_THROW( + launch_context->AssignExternalVariable("var1", place, tensor1)); + auto* tensor2 = scope.Var("var2")->GetMutable(); + tensor2->mutable_data(framework::make_ddim({6, 7, 8}), place); + ASSERT_NO_THROW( + launch_context->AssignInternalVariable("cinn_var2", place, tensor2)); + // FinalizeArguments not missed check + ASSERT_THROW(launch_context->FinalizeArguments(), + paddle::platform::EnforceNotMet); + auto* tensor3 = scope.Var("var3")->GetMutable(); + tensor3->mutable_data(framework::make_ddim({10, 16}), place); + ASSERT_NO_THROW( + launch_context->AssignExternalVariable("var3", place, tensor3)); + + auto name2argument = launch_context->FinalizeArguments(); + ASSERT_EQ(name2argument.size(), 3); + ASSERT_EQ(name2argument.count("cinn_var1"), 1); + // check ShareTensorWithCinnBuffer + auto* cinn_buffer = + static_cast(name2argument.at("cinn_var1")); + + ASSERT_EQ(cinn_buffer->memory, nullptr); + cinn_buffer->external_malloc->operator()(nullptr, cinn_buffer); + ASSERT_NE(cinn_buffer->memory, nullptr); + ASSERT_EQ(cinn_buffer->num_elements(), 12); + auto* shadow_data = reinterpret_cast(cinn_buffer->memory); + EXPECT_FLOAT_EQ(shadow_data[0], 9.99f); + EXPECT_FLOAT_EQ(shadow_data[10], 19.99f); +} + +} // namespace details +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/cinn/cinn_launch_op.cc b/paddle/fluid/operators/cinn/cinn_launch_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..4b86d9b2b3d41540d70122362215bd6a77ef0184 --- /dev/null +++ b/paddle/fluid/operators/cinn/cinn_launch_op.cc @@ -0,0 +1,169 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/cinn/cinn_launch_op.h" + +#include +#include + +#include "paddle/fluid/string/string_helper.h" + +DECLARE_bool(cudnn_deterministic); + +namespace paddle { +namespace operators { + +namespace details { + +const ::cinn::common::Target& PlaceToCinnTarget(const platform::Place& place) { + if (platform::is_cpu_place(place)) { + return ::cinn::common::DefaultHostTarget(); + } else if (platform::is_gpu_place(place)) { + return ::cinn::common::DefaultNVGPUTarget(); + } + + PADDLE_THROW(platform::errors::InvalidArgument( + "CINN is not supported on current place:%s", place)); + return ::cinn::common::UnkTarget(); +} + +void DebugCinnCompiledResult(const CinnCompiledObject& result) { + if (!VLOG_IS_ON(4)) { + return; + } + const auto& cinn_runtime_program = result.runtime_program; + const auto& cinn_scope = *(result.scope); + const auto& paddle2cinn_varmap = result.paddle2cinn_varmap; + + VLOG(4) << "Compiled runtime_program instrunction size:[" + << cinn_runtime_program->size() << "]"; + + std::vector infos; + auto cinn_var_names = cinn_scope.var_names(); + infos.reserve(cinn_var_names.size()); + std::transform(cinn_var_names.begin(), cinn_var_names.end(), + std::back_inserter(infos), + [](const auto& name_view) { return name_view.data(); }); + VLOG(4) << "Compiled scope variable names:[" + << string::join_strings(infos, ',') << "]"; + + infos.clear(); + infos.reserve(paddle2cinn_varmap.size()); + std::transform(paddle2cinn_varmap.begin(), paddle2cinn_varmap.end(), + std::back_inserter(infos), [](const auto& paddle2cinn) { + return paddle2cinn.first + "->" + paddle2cinn.second; + }); + VLOG(4) << "Compiled paddle2cinn_varmap:[" << string::join_strings(infos, ',') + << "]"; +} + +void LaunchCinnExecution(const CinnCompiledObject& compiled_obj, + const CinnLaunchContext& context, void* stream) { + compiled_obj.runtime_program->Execute(&context.FinalizeArguments(), stream); +} + +void SetCinnRuntimeFlags() { + VLOG(4) << "Set FLAGS_cinn_cudnn_deterministic to " + << FLAGS_cudnn_deterministic; + ::cinn::runtime::SetCinnCudnnDeterministic(FLAGS_cudnn_deterministic); +} + +} // namespace details + +class CinnLaunchOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInputs(kX), "Input", kX, "CinnLaunchOp"); + OP_INOUT_CHECK(ctx->HasOutputs(kOutputs), "Output", kOutputs, + "CinnLaunchOp"); + } + + protected: + /* [Why use single type kernel]: + * + * This op is similar to a control flow op, it doses not need + * a op kernel, but in order to make it execute under dynamic + * graph mode, implement it with op kernel. + * + * So whether the kernel data type is int, float or other type, + * which has no effect on its execution logic, so directly + * specified a data type here. + * + * Of course, the data type here is also not important. + */ + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(framework::proto::VarType::FP32, + ctx.GetPlace()); + } +}; + +class CinnLaunchOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput(kX, + "(vector)" + "which are the input of graph inside the CinnLaunchOp.") + .AsDuplicable(); + AddOutput(kOutputs, + "(vector)" + "which are the output of graph inside the CinnLaunchOp.") + .AsDuplicable(); + AddAttr( + kCompilationKey, + "(string)" + "a hash key used to get the graph object or its computation result."); + AddComment(R"DOC( +CinnLaunch Operator. + +This operator is used to launch CINN(https://github.com/PaddlePaddle/CINN/blob/develop/README.md) +to compile a graph and execute the compiled object. + +Both input and output of this operator are a set of variables +which are input and output of the graph respectively that will be +compiled and executed in this operator. +In addition, there is an attribute named 'compilation_key' should be +set necessarily to get corresponding ir::Graph object of the graph +or its computation result. + +It accomplishes the computation of graph following several steps: + 1. Fetch ir::Graph object from CinnCompiler using kCompilationKey + 2. Compile the graph to a compiled object, and insert it to the + global cache so that we can directly query it from this cache next time + when shape of input variables are not changed at all. + 3. Create and instantiate all variables used to execute compiled runtime program + if necessary according to the info(type,shape) included in the return scope. + 4. Pack each tensor buffer of all above variables as execution arguments. + 5. Launch execution of the runtime program with above arguments, then + the result would be output by writing value on underlying buffer address. + +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR( + cinn_launch, ops::CinnLaunchOp, ops::CinnLaunchOpMaker, + paddle::framework::EmptyGradOpMaker, + paddle::framework::EmptyGradOpMaker); +/* see [Why use single type kernel] */ +REGISTER_OP_CPU_KERNEL( + cinn_launch, + ops::CinnLaunchOpKernel); diff --git a/paddle/fluid/operators/cinn_launch_op.cu.cc b/paddle/fluid/operators/cinn/cinn_launch_op.cu.cc similarity index 97% rename from paddle/fluid/operators/cinn_launch_op.cu.cc rename to paddle/fluid/operators/cinn/cinn_launch_op.cu.cc index fae2d6ddb487d9237088ab7072a20ce429ed9d6d..813e7b1152f87eeed18380aaa98cd507d28121b5 100644 --- a/paddle/fluid/operators/cinn_launch_op.cu.cc +++ b/paddle/fluid/operators/cinn/cinn_launch_op.cu.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/cinn_launch_op.h" +#include "paddle/fluid/operators/cinn/cinn_launch_op.h" #include #include #include "cinn/runtime/cinn_runtime.h" diff --git a/paddle/fluid/operators/cinn_launch_op.h b/paddle/fluid/operators/cinn/cinn_launch_op.h similarity index 74% rename from paddle/fluid/operators/cinn_launch_op.h rename to paddle/fluid/operators/cinn/cinn_launch_op.h index 2b1bf89197dffb7812beb2406e7fd9278987e8a5..3a272916332beace3b163414d717eb50c79d3a85 100644 --- a/paddle/fluid/operators/cinn_launch_op.h +++ b/paddle/fluid/operators/cinn/cinn_launch_op.h @@ -26,6 +26,7 @@ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/paddle2cinn/cinn_compiler.h" +#include "paddle/fluid/operators/cinn/cinn_launch_context.h" namespace paddle { namespace operators { @@ -42,69 +43,6 @@ using CinnCompiledObject = framework::paddle2cinn::CinnCompiledObject; namespace details { -class CinnLaunchContext { - public: - explicit CinnLaunchContext(const CinnCompiledObject& compiled_obj); - - // Return whether a Paddle variable used on compiled kernels - bool IsVariableUsed(const std::string& var_name); - - // Assign tensor buffer to input or output variables - void AssignExternalVariable(const std::string& var_name, - const platform::Place& place, LoDTensor* tensor); - - // Assign tensor buffer to internal variables - void AssignInternalVariable(const std::string& var_name, - const platform::Place& place, LoDTensor* tensor); - - // Extract internal variable names from CinnScope - // by excluding used input and output variables - std::unordered_set GetInternalVariableNames(); - - // Finalize all execution arguments and return them - const std::map& FinalizeArguments() const; - - std::vector> HandoverBuffers() { - return std::move(hold_buffers_); - } - - private: - // Get CinnTensor with CINN variable name - CinnTensor GetCinnTensor(const std::string& var_name); - - // Check whether tensors from Paddle and CINN of the same variable - // are equivalent in type and dimension - void CheckTensorEquivalent(const std::string& var_name, - const LoDTensor& paddle_tensor, - const CinnTensor& cinn_tensor); - - // Share the buffer of a Paddle tensor to CINN by delivering memory address - // to a cinn_buffer_t object - std::unique_ptr ShareTensorWithCinnBuffer( - const platform::Place& place, bool free_mem_callback, LoDTensor* tensor); - - // Set an argument with (cinn name)->(paddle tensor) pair - void SetArgument(const std::string& cinn_name, const platform::Place& place, - bool free_mem_callback, LoDTensor* paddle_tensor); - - private: - // a variable name map from paddle to cinn - const std::unordered_map& paddle2cinn_varmap_; - // the variable scope of cinn - const std::shared_ptr cinn_scope_; - - // all variables used by compiled executable program - std::unordered_set cinn_variable_names_; - - // because a cinn_pod_value_t does not own the cinn_buffer_t object, - // an extra stroage is necessary to keep the object and it can - // not be released until runtime program finish execution. - std::vector> hold_buffers_; - - // name to execution argument - std::map name2argument_; -}; - // Tranform Paddle place to CINN target const ::cinn::common::Target& PlaceToCinnTarget(const platform::Place& place); @@ -178,8 +116,8 @@ class CinnLaunchOpKernel : public framework::OpKernel { compilation_key, inputs_name2tensor, target, stream); details::DebugCinnCompiledResult(cinn_compiled_object); - auto launch_context = - std::make_unique(cinn_compiled_object); + auto launch_context = std::make_unique( + cinn_compiled_object.paddle2cinn_varmap, cinn_compiled_object.scope); // Step 3. Prepare arguments needed for the compiled executable program. VLOG(4) << "CinnLaunchOp prepare arguments"; diff --git a/paddle/fluid/operators/cinn_launch_op_test.cc b/paddle/fluid/operators/cinn/cinn_launch_op_test.cc similarity index 59% rename from paddle/fluid/operators/cinn_launch_op_test.cc rename to paddle/fluid/operators/cinn/cinn_launch_op_test.cc index 5e0b87d06afeaffb60e73543c9325afb30f5cc41..02373c38184fca3722312dc23e92ac51be86d1ea 100644 --- a/paddle/fluid/operators/cinn_launch_op_test.cc +++ b/paddle/fluid/operators/cinn/cinn_launch_op_test.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/cinn_launch_op.h" +#include "paddle/fluid/operators/cinn/cinn_launch_op.h" #include #include #include @@ -134,16 +134,10 @@ TEST(CinnLaunchOpTest, TestElementwiseAddPass) { elementwise_add_op->Run(scope, place); LoDTensor test_out, expected_out; - if (platform::is_cpu_place(place)) { - test_out.ShareDataWith(scope.Var(test_out_name)->Get()); - expected_out.ShareDataWith( - scope.Var(expected_out_name)->Get()); - } else { - TensorCopySync(scope.Var(test_out_name)->Get(), - platform::CPUPlace(), &test_out); - TensorCopySync(scope.Var(expected_out_name)->Get(), - platform::CPUPlace(), &expected_out); - } + TensorCopySync(scope.Var(test_out_name)->Get(), + platform::CPUPlace(), &test_out); + TensorCopySync(scope.Var(expected_out_name)->Get(), + platform::CPUPlace(), &expected_out); ASSERT_TRUE(test_out.IsInitialized()); ASSERT_TRUE(expected_out.IsInitialized()); @@ -189,116 +183,6 @@ TEST(CinnLaunchOpHelperTest, TestPlaceToCinnTarget) { paddle::platform::EnforceNotMet); } -const CinnCompiledObject& GetDefaultCompiledObj() { - static std::once_flag initialized; - static CinnCompiledObject compiled_object; - std::call_once(initialized, [&compiled_object]() { - auto& scope = compiled_object.scope; - scope = std::make_shared(); - - scope->Var("cinn_var1"); - scope->GetTensor("cinn_var1")->Resize(CinnShape({3, 4})); - scope->Var("cinn_var2"); - scope->GetTensor("cinn_var2")->Resize(CinnShape({6, 7, 8})); - scope->Var("cinn_var3"); - scope->GetTensor("cinn_var3")->Resize(CinnShape({10, 16})); - - auto& varmap = compiled_object.paddle2cinn_varmap; - varmap = { - {"var1", "cinn_var1"}, {"var3", "cinn_var3"}, {"var4", "cinn_var4"}}; - }); - return compiled_object; -} - -TEST(CinnLaunchContextTest, TestIsVariableUsed) { - auto launch_context = - std::make_unique(GetDefaultCompiledObj()); - - ASSERT_EQ(launch_context->IsVariableUsed("var1"), true); - ASSERT_EQ(launch_context->IsVariableUsed("var4"), false); -} - -TEST(CinnLaunchContextTest, TestGetInternalVariableNames) { - auto launch_context = - std::make_unique(GetDefaultCompiledObj()); - auto internal_variable_names = launch_context->GetInternalVariableNames(); - ASSERT_EQ(internal_variable_names.size(), 3); - EXPECT_NE(internal_variable_names.find("cinn_var2"), - internal_variable_names.end()); -} - -TEST(CinnLaunchContextTest, TestCheckTensorEquivalent) { - auto launch_context = - std::make_unique(GetDefaultCompiledObj()); - platform::CPUPlace place; - framework::Scope scope; - auto* tensor1 = scope.Var("var1")->GetMutable(); - - // CheckTensorEquivalent: tensor dimension not equivalent - tensor1->mutable_data(framework::make_ddim({3, 5}), place); - ASSERT_THROW(launch_context->AssignExternalVariable("var1", place, tensor1), - paddle::platform::EnforceNotMet); -} - -TEST(CinnLaunchContextTest, TestAssignVariablePreCondition) { - auto launch_context = - std::make_unique(GetDefaultCompiledObj()); - platform::CPUPlace place; - framework::Scope scope; - auto* tensor4 = scope.Var("var4")->GetMutable(); - - // not used - ASSERT_THROW(launch_context->AssignExternalVariable("var4", place, tensor4), - paddle::platform::EnforceNotMet); - // not found - ASSERT_THROW( - launch_context->AssignExternalVariable("cinn_var4", place, tensor4), - paddle::platform::EnforceNotMet); -} - -TEST(CinnLaunchContextTest, TestSetArgument) { - auto launch_context = - std::make_unique(GetDefaultCompiledObj()); - - platform::CPUPlace place; - framework::Scope scope; - auto* tensor1 = scope.Var("var1")->GetMutable(); - float* data1 = - tensor1->mutable_data(framework::make_ddim({3, 4}), place); - data1[0] = 9.99f; - data1[10] = 19.99f; - - // assign external variable - ASSERT_NO_THROW( - launch_context->AssignExternalVariable("var1", place, tensor1)); - auto* tensor2 = scope.Var("var2")->GetMutable(); - tensor2->mutable_data(framework::make_ddim({6, 7, 8}), place); - ASSERT_NO_THROW( - launch_context->AssignInternalVariable("cinn_var2", place, tensor2)); - // FinalizeArguments not missed check - ASSERT_THROW(launch_context->FinalizeArguments(), - paddle::platform::EnforceNotMet); - auto* tensor3 = scope.Var("var3")->GetMutable(); - tensor3->mutable_data(framework::make_ddim({10, 16}), place); - ASSERT_NO_THROW( - launch_context->AssignExternalVariable("var3", place, tensor3)); - - auto name2argument = launch_context->FinalizeArguments(); - ASSERT_EQ(name2argument.size(), 3); - ASSERT_EQ(name2argument.count("cinn_var1"), 1); - // check ShareTensorWithCinnBuffer - auto* cinn_buffer = - static_cast(name2argument.at("cinn_var1")); - - ASSERT_EQ(cinn_buffer->memory, nullptr); - cinn_buffer->external_malloc->operator()(nullptr, cinn_buffer); - ASSERT_NE(cinn_buffer->memory, nullptr); - ASSERT_EQ(cinn_buffer->num_elements(), 12); - auto* shadow_data = reinterpret_cast(cinn_buffer->memory); - EXPECT_FLOAT_EQ(shadow_data[0], 9.99f); - EXPECT_FLOAT_EQ(shadow_data[10], 19.99f); -} - } // namespace details } // namespace operators } // namespace paddle