diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 67073350d5a8aa3fdddf270f1b9e1f5be27d0eda..6e57b829ade4edd9d1a3edfee9bf8c28ea4b3cb3 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -351,7 +351,7 @@ target_link_libraries(executor while_op_helper executor_gc_helper recurrent_op_h cc_library(parallel_executor SRCS parallel_executor.cc DEPS threaded_ssa_graph_executor scope_buffered_ssa_graph_executor parallel_ssa_graph_executor async_ssa_graph_executor graph build_strategy bind_threaded_ssa_graph_executor collective_helper - fast_threaded_ssa_graph_executor variable_helper) + fast_threaded_ssa_graph_executor variable_helper cinn_runner) cc_library(executor_cache SRCS executor_cache.cc DEPS parallel_executor) if(WITH_PSCORE) diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index 72f7f0e6011c1bdbf50482c8e35b6c1207f5aa73..ad81b48847af9f3501697a3e71dd44b7110af8ee 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -139,7 +139,7 @@ set(IR_PASS_DEPS graph_viz_pass multi_devices_graph_pass coalesce_grad_tensor_pass fuse_all_reduce_op_pass backward_optimizer_op_deps_pass fuse_adam_op_pass fuse_sgd_op_pass fuse_momentum_op_pass sync_batch_norm_pass runtime_context_cache_pass graph_to_program_pass - fix_op_run_order_pass) + paddle_to_cinn_pass fix_op_run_order_pass) if(NOT APPLE AND NOT WIN32 AND (WITH_GPU OR WITH_ROCM)) set(IR_PASS_DEPS ${IR_PASS_DEPS} fusion_group_pass) endif() diff --git a/paddle/fluid/framework/details/build_strategy.cc b/paddle/fluid/framework/details/build_strategy.cc index 0d55882953db352b906920387d49afeee00f194f..a55b809055f3e799d4eb4903f9a2894da75badb0 100644 --- a/paddle/fluid/framework/details/build_strategy.cc +++ b/paddle/fluid/framework/details/build_strategy.cc @@ -19,8 +19,9 @@ limitations under the License. */ #include "paddle/fluid/framework/ir/graph_printer.h" #include "paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.h" -DECLARE_bool(use_mkldnn); DECLARE_bool(convert_all_blocks); +DECLARE_bool(use_cinn); +DECLARE_bool(use_mkldnn); namespace paddle { namespace framework { @@ -71,6 +72,10 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { // Note: This pass is used to check whether the multi_device_graph is right. AppendPass("multi_devices_check_pass"); + // Note: This pass is used to enable cinn. + if (FLAGS_use_cinn) { + AppendPass("paddle_to_cinn_pass"); + } SetCollectiveContext(); } diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 99c691e6cf6f7ae7ca5dd9f42071e7bac2429849..6f5f27400752dd9edf679a1ae249e77ed9fbbe89 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -59,6 +59,7 @@ cc_library(placement_pass_base SRCS placement_pass_base.cc DEPS pass) cc_library(coalesce_grad_tensor_pass SRCS coalesce_grad_tensor_pass.cc DEPS graph graph_helper) pass_library(graph_to_program_pass base) +pass_library(paddle_to_cinn_pass base DEPS cinn_runner) pass_library(graph_viz_pass base) pass_library(lock_free_optimize_pass base DEPS string_helper) pass_library(fc_fuse_pass inference) @@ -142,6 +143,7 @@ cc_test(pass_test SRCS pass_test.cc DEPS graph pass graph_helper) cc_test(graph_test SRCS graph_test.cc DEPS graph graph_helper op_registry) cc_test(graph_helper_test SRCS graph_helper_test.cc DEPS graph graph_helper op_registry) cc_test(graph_to_program_pass_test SRCS graph_to_program_pass_test.cc DEPS graph_to_program_pass) +cc_test(paddle_to_cinn_pass_test SRCS paddle_to_cinn_pass_test.cc DEPS paddle_to_cinn_pass proto_desc) cc_test(cost_model_test SRCS cost_model_test.cc DEPS cost_model op_registry) cc_test(test_graph_pattern_detector SRCS graph_pattern_detector_tester.cc DEPS graph_pattern_detector) cc_test(test_op_compat_sensible_pass SRCS op_compat_sensible_pass_tester.cc DEPS op_compat_sensible_pass) diff --git a/paddle/fluid/framework/ir/paddle_to_cinn_pass.cc b/paddle/fluid/framework/ir/paddle_to_cinn_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..fbf2cfb8d41d6a587dedb9b3cae6923e4085fc89 --- /dev/null +++ b/paddle/fluid/framework/ir/paddle_to_cinn_pass.cc @@ -0,0 +1,31 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/ir/paddle_to_cinn_pass.h" + +#include "paddle/fluid/framework/paddle2cinn/cinn_runner.h" + +namespace paddle { +namespace framework { +namespace ir { + +void PaddleToCinnPass::ApplyImpl(ir::Graph* graph) const { + paddle2cinn::CinnRunner::GetInstance()->ReplaceWithCinn(graph); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(paddle_to_cinn_pass, paddle::framework::ir::PaddleToCinnPass); diff --git a/paddle/fluid/framework/ir/paddle_to_cinn_pass.h b/paddle/fluid/framework/ir/paddle_to_cinn_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..f3b9bd21ebf9cab29359ee519e272b2e2c4eee98 --- /dev/null +++ b/paddle/fluid/framework/ir/paddle_to_cinn_pass.h @@ -0,0 +1,30 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/fluid/framework/ir/pass.h" + +namespace paddle { +namespace framework { +namespace ir { + +class PaddleToCinnPass : public Pass { + protected: + void ApplyImpl(ir::Graph* graph) const override; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/paddle_to_cinn_pass_test.cc b/paddle/fluid/framework/ir/paddle_to_cinn_pass_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..49d2ce295f3852429bccc7ab36d2ff0874e6533c --- /dev/null +++ b/paddle/fluid/framework/ir/paddle_to_cinn_pass_test.cc @@ -0,0 +1,40 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/ir/paddle_to_cinn_pass.h" + +#include "gtest/gtest.h" + +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/program_desc.h" + +namespace paddle { +namespace framework { +namespace ir { + +TEST(PaddleToCinnPassTest, TodoTest) { + ProgramDesc program; + Graph graph(program); + + auto pass = paddle::framework::ir::PassRegistry::Instance().Get( + "paddle_to_cinn_pass"); + + pass->Apply(&graph); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +USE_PASS(paddle_to_cinn_pass); diff --git a/paddle/fluid/framework/paddle2cinn/cinn_runner.cc b/paddle/fluid/framework/paddle2cinn/cinn_runner.cc index de5af910c99add7f2947d9ea13c119dd60f6f3de..ba90095cae6799b91b5f14a904f4cd960083d524 100644 --- a/paddle/fluid/framework/paddle2cinn/cinn_runner.cc +++ b/paddle/fluid/framework/paddle2cinn/cinn_runner.cc @@ -15,6 +15,8 @@ #include "paddle/fluid/framework/paddle2cinn/cinn_runner.h" #include +#include +#include #include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/scope.h" @@ -26,6 +28,19 @@ namespace paddle2cinn { using ir::Graph; +std::once_flag CinnRunner::get_instance_once_flag_; +std::shared_ptr CinnRunner::instance_; + +std::shared_ptr CinnRunner::GetInstance() { + std::call_once(get_instance_once_flag_, + [&]() { instance_.reset(new CinnRunner()); }); + return instance_; +} + +void CinnRunner::ReplaceWithCinn(Graph* graph) { + // TODO(zhhsplendid): call CINN Api when it is ready +} + std::map CinnRunner::Run( const Graph& graph, Scope* scope, std::map* feed_targets) { diff --git a/paddle/fluid/framework/paddle2cinn/cinn_runner.h b/paddle/fluid/framework/paddle2cinn/cinn_runner.h index 5f63d64545ff75440e68ca88e5892d2e16291d26..23d9565d2f3926de33bab4a3c7fa5ac320763840 100644 --- a/paddle/fluid/framework/paddle2cinn/cinn_runner.h +++ b/paddle/fluid/framework/paddle2cinn/cinn_runner.h @@ -16,6 +16,7 @@ #include #include +#include #include #include "paddle/fluid/framework/ir/graph.h" @@ -36,15 +37,24 @@ namespace paddle2cinn { // cache. class CinnRunner { public: - CinnRunner() {} ~CinnRunner() {} + // Singleton + static std::shared_ptr GetInstance(); + + // Replace Paddle graph with some CINN subgraphs/ops + void ReplaceWithCinn(ir::Graph* graph); + // Feed LoDTensors to tun CINN compiled object and return fetched result std::map Run( const ir::Graph& graph, Scope* scope, std::map* feed_targets); private: + CinnRunner() {} + + static std::once_flag get_instance_once_flag_; + static std::shared_ptr instance_; std::unordered_map, CinnCacheKey::Hash> cache_; diff --git a/paddle/fluid/framework/paddle2cinn/cinn_runner_test.cc b/paddle/fluid/framework/paddle2cinn/cinn_runner_test.cc index 88aca0bd66b375d14a511fa8baa91b05fb00da6f..c02b994c147ca11518e7d0f3a2cd7a2e1e875f94 100644 --- a/paddle/fluid/framework/paddle2cinn/cinn_runner_test.cc +++ b/paddle/fluid/framework/paddle2cinn/cinn_runner_test.cc @@ -12,11 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "gtest/gtest.h" +#include "paddle/fluid/framework/paddle2cinn/cinn_runner.h" + +#include +#include "gtest/gtest.h" #include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/lod_tensor.h" -#include "paddle/fluid/framework/paddle2cinn/cinn_runner.h" #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/scope.h" @@ -32,8 +34,9 @@ TEST(CinnRunnerTest, TodoTest) { Scope empty_scope; std::map empty_feed; - CinnRunner cinn_runner; - cinn_runner.Run(empty_graph, &empty_scope, &empty_feed); + std::shared_ptr cinn_runner = CinnRunner::GetInstance(); + cinn_runner->ReplaceWithCinn(&empty_graph); + cinn_runner->Run(empty_graph, &empty_scope, &empty_feed); } } // namespace paddle2cinn diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index d19ac0b65f4d1e30de0b60e01d593f2e3a01c448..3b80e9c78677d1c754935614caf6d11cb4467507 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -34,6 +34,7 @@ limitations under the License. */ #include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimization_var_info.h" #include "paddle/fluid/framework/ir/memory_optimize_pass/reference_count_pass_helper.h" #include "paddle/fluid/framework/ir/multi_devices_graph_pass/set_reader_device_info_utils.h" +#include "paddle/fluid/framework/paddle2cinn/cinn_runner.h" #include "paddle/fluid/framework/variable_helper.h" #include "paddle/fluid/platform/cuda_graph_with_memory_pool.h" #include "paddle/fluid/platform/event.h" @@ -43,6 +44,7 @@ limitations under the License. */ #include "paddle/fluid/platform/cuda_device_guard.h" #endif +DECLARE_bool(use_cinn); DECLARE_double(eager_delete_tensor_gb); #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) @@ -943,6 +945,40 @@ void ParallelExecutor::RunWithoutFetch( member_->executor_->Run(/*fetch_tensors*/ {}, /*return_merged*/ false); } +FetchResultType ParallelExecutor::RunFromCinn( + const std::unordered_map &feed_tensors, + const std::vector &fetch_names) { + // Feed tensor to scope, now only support 1 scope + // TODO(zhhsplendid): handle multiple scope + size_t scope_id = 0; + std::map cinn_input_tensors; + for (auto &name_tensor_pair : feed_tensors) { + bool is_persistable = member_->IsPersistable(name_tensor_pair.first); + if (!is_persistable) { + member_->SetSkipMemoryReuse(scope_id, name_tensor_pair.first); + } + Scope *feed_scope = is_persistable ? member_->local_scopes_[scope_id] + : member_->local_exec_scopes_[scope_id]; + Variable *feed_var = feed_scope->Var(name_tensor_pair.first); + LoDTensor *trg = feed_var->GetMutable(); + trg->ShareDataWith(name_tensor_pair.second); + trg->set_lod(name_tensor_pair.second.lod()); + + cinn_input_tensors[name_tensor_pair.first] = trg; + } + + // TODO(zhhsplendid): get correct API after CINN API is ready + // now only return empty fetch result; + std::shared_ptr cinn_runner = + paddle2cinn::CinnRunner::GetInstance(); + + cinn_runner->Run(Graph(), member_->local_exec_scopes_[scope_id], + &cinn_input_tensors); + + paddle::framework::FetchResultType fetches = FetchList(fetch_names.size()); + return fetches; +} + void ParallelExecutor::SkipMemoryReuse( size_t scope_idx, const std::vector &skip_vars) { for (auto &var_name : skip_vars) { diff --git a/paddle/fluid/framework/parallel_executor.h b/paddle/fluid/framework/parallel_executor.h index 78774f048963895081384cb53499b49ffc63c37f..f908ce3f013937d8b9050442bd9ffb960387ea32 100644 --- a/paddle/fluid/framework/parallel_executor.h +++ b/paddle/fluid/framework/parallel_executor.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include #include #include #include @@ -92,6 +93,10 @@ class ParallelExecutor { void RunWithoutFetch(const std::vector &skip_eager_vars); + FetchResultType RunFromCinn( + const std::unordered_map &feed_tensors, + const std::vector &fetch_names); + void ResetOpHandleScopeMapOfGraphs( const std::unordered_map &scope_map); diff --git a/paddle/fluid/platform/flags.cc b/paddle/fluid/platform/flags.cc index 7a7666665511fa78c33891e568b744b0cd2b3b19..18636f6f8427854a778506edff2d2299fbae2e87 100644 --- a/paddle/fluid/platform/flags.cc +++ b/paddle/fluid/platform/flags.cc @@ -681,6 +681,16 @@ PADDLE_DEFINE_EXPORTED_bool( apply_pass_to_program, false, "It controls whether to apply IR pass to program when using Fleet APIs"); +/** + * CINN related FLAG + * Name: FLAGS_use_cinn + * Since Version: 2.3 + * Value Range: bool, default=false + * Example: FLAGS_use_cinn=true would run PaddlePaddle using CINN + */ +PADDLE_DEFINE_EXPORTED_bool( + use_cinn, false, "It controls whether to run PaddlePaddle using CINN"); + DEFINE_int32(record_pool_max_size, 2000000, "SlotRecordDataset slot record pool max size"); DEFINE_int32(slotpool_thread_num, 1, "SlotRecordDataset slot pool thread num"); diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index f58c2a5db381c76d4dfea4fa0a55a95c27dbd12e..80350abb4fe2199c8e5a3e49d4409a3189c4e7fe 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -3293,6 +3293,18 @@ All parameter, weight, gradient are variables in Paddle. BOOST_GET(paddle::framework::FetchUnmergedList, ret))); } }) + .def("run_from_cinn", + [](ParallelExecutor &self, + const std::unordered_map &feed_tensors, + const std::vector &fetch_names) -> py::object { + paddle::framework::FetchResultType ret; + { + pybind11::gil_scoped_release release; + ret = self.RunFromCinn(feed_tensors, fetch_names); + } + return py::cast( + std::move(BOOST_GET(paddle::framework::FetchList, ret))); + }) .def("device_count", &ParallelExecutor::DeviceCount); BindFleetWrapper(&m); diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index 8c118f31cbe87a379dc2281b1d08f9e43702525a..bea5b29ecafa6523ed16a677bbd8445bae1b3bf1 100644 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -23,7 +23,8 @@ import numpy as np from .wrapped_decorator import signature_safe_contextmanager import six from .data_feeder import convert_dtype -from .framework import Program, default_main_program, Variable, Operator, convert_np_dtype_to_dtype_ +from .framework import Program, default_main_program, Variable, Operator +from .framework import convert_np_dtype_to_dtype_, get_flags from . import core from . import unique_name from . import compiler @@ -1016,7 +1017,16 @@ class Executor(object): check_feed_shape_type(var, feed_tensor, exe.device_count()) feed_tensor_dict[feed_name] = feed_tensor - exe.feed_and_split_tensor_into_local_scopes(feed_tensor_dict) + #TODO(zhhsplendid): handle other feed data format case for CINN + use_cinn = get_flags("FLAGS_use_cinn")["FLAGS_use_cinn"] + if use_cinn: + fetch_var_names = list(map(_to_name_str, fetch_list)) + fetch_tensors = exe.run_from_cinn( + feed_tensor_dict, fetch_var_names)._move_to_list() + return as_numpy( + fetch_tensors) if return_numpy else fetch_tensors + else: + exe.feed_and_split_tensor_into_local_scopes(feed_tensor_dict) elif isinstance(feed, list) or isinstance(feed, tuple): res = list() for i, each in enumerate(feed): @@ -1036,6 +1046,8 @@ class Executor(object): check_feed_shape_type(var, tensor) res_dict[feed_name] = tensor res.append(res_dict) + + use_cinn = get_flags("FLAGS_use_cinn")["FLAGS_use_cinn"] exe.feed_tensors_into_local_scopes(res) if hasattr(program._program, 'lr_sheduler'): diff --git a/python/paddle/fluid/tests/unittests/test_parallel_executor_run_cinn.py b/python/paddle/fluid/tests/unittests/test_parallel_executor_run_cinn.py new file mode 100644 index 0000000000000000000000000000000000000000..e8b1d838261f45b0987554c3d734fd8a6d63905a --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_parallel_executor_run_cinn.py @@ -0,0 +1,56 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import paddle +import unittest + +paddle.enable_static() + + +class TestParallelExecutorRunCinn(unittest.TestCase): + def test_run_from_cinn(self): + paddle.set_flags({'FLAGS_use_cinn': True}) + + main_program = paddle.static.Program() + startup_program = paddle.static.Program() + with paddle.static.program_guard(main_program, startup_program): + data = paddle.static.data( + name='X', shape=[None, 1], dtype='float32') + prediction = paddle.static.nn.fc(data, 2) + loss = paddle.mean(prediction) + adam = paddle.optimizer.Adam() + adam.minimize(loss) + + place = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda( + ) else paddle.CPUPlace() + exe = paddle.static.Executor(place) + exe.run(startup_program) + compiled_program = paddle.static.CompiledProgram( + main_program).with_data_parallel(loss_name=loss.name) + + batch_size = 16 + x = np.random.random(size=(batch_size, 1)).astype('float32') + fetch = exe.run(compiled_program, + feed={'X': x}, + fetch_list=[prediction.name], + return_merged=False) + + paddle.set_flags({'FLAGS_use_cinn': False}) + + +if __name__ == '__main__': + unittest.main()