From 20e23e1b811c0c28437e6ed1a5ab98367d2ee91f Mon Sep 17 00:00:00 2001 From: Yuang Liu Date: Tue, 25 Jan 2022 15:55:51 +0800 Subject: [PATCH] [fleet_executor] Dist model run method Implementation (#39194) --- .../distributed/fleet_executor/dist_model.cc | 293 +++++++++++++++++- .../distributed/fleet_executor/dist_model.h | 15 +- .../dist_model_tensor_wrapper.h | 2 +- paddle/fluid/pybind/bind_fleet_executor.cc | 7 +- .../fluid/tests/unittests/CMakeLists.txt | 3 +- .../test_fleet_exe_dist_model_run.py | 86 +++++ ...py => test_fleet_exe_dist_model_tensor.py} | 2 +- 7 files changed, 391 insertions(+), 17 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/test_fleet_exe_dist_model_run.py rename python/paddle/fluid/tests/unittests/{test_dist_model_tensor.py => test_fleet_exe_dist_model_tensor.py} (97%) diff --git a/paddle/fluid/distributed/fleet_executor/dist_model.cc b/paddle/fluid/distributed/fleet_executor/dist_model.cc index 6454a349505..4b848330237 100644 --- a/paddle/fluid/distributed/fleet_executor/dist_model.cc +++ b/paddle/fluid/distributed/fleet_executor/dist_model.cc @@ -13,11 +13,13 @@ // limitations under the License. #include +#include // NOLINT #include "paddle/fluid/distributed/fleet_executor/dist_model.h" #include "paddle/fluid/distributed/fleet_executor/fleet_executor.h" #include "paddle/fluid/distributed/fleet_executor/task_node.h" #include "paddle/fluid/framework/block_desc.h" +#include "paddle/fluid/framework/feed_fetch_method.h" #include "paddle/fluid/framework/naive_executor.h" #include "paddle/fluid/framework/op_proto_maker.h" #include "paddle/fluid/framework/program_desc.h" @@ -37,10 +39,110 @@ bool IsPersistable(const framework::VarDesc *var) { } return false; } + +bool LoadDataFromDistModelTensor(const DistModelTensor &input_data, + framework::LoDTensor *input_tensor, + const platform::Place &place) { + VLOG(3) << "Loading data from DistModelTensor for " << input_data.name; + framework::DDim dims = framework::make_ddim(input_data.shape); + void *input_tensor_ptr; + if (input_data.dtype == DistModelDataType::INT64) { + input_tensor_ptr = input_tensor->mutable_data(dims, place); + } else if (input_data.dtype == DistModelDataType::FLOAT32) { + input_tensor_ptr = input_tensor->mutable_data(dims, place); + } else if (input_data.dtype == DistModelDataType::INT32) { + input_tensor_ptr = input_tensor->mutable_data(dims, place); + } else { + // Q(fleet exe dev): for input/output, should we support fp16 + LOG(ERROR) << "unsupported feed type " << input_data.dtype; + return false; + } + + PADDLE_ENFORCE_NOT_NULL( + input_tensor_ptr, + paddle::platform::errors::Fatal( + "LoDTensor creation failed. DistModel loaded data failed.")); + PADDLE_ENFORCE_NOT_NULL(input_data.data.data(), + paddle::platform::errors::InvalidArgument( + "DistModelTensor contains no data.")); + + if (platform::is_cpu_place(place)) { + VLOG(3) << "Loading data for CPU."; + std::memcpy(static_cast(input_tensor_ptr), input_data.data.data(), + input_data.data.length()); + } else if (platform::is_gpu_place(place)) { + VLOG(3) << "Loading data for GPU."; +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto *dev_ctx = + dynamic_cast(pool.Get(place)); + auto gpu_place = place; + memory::Copy(gpu_place, static_cast(input_tensor_ptr), + platform::CPUPlace(), input_data.data.data(), + input_data.data.length(), dev_ctx->stream()); +#else + PADDLE_THROW(paddle::platform::errors::Fatal( + "Paddle wasn't compiled with CUDA, but place is GPU.")); +#endif + } else { + PADDLE_THROW(paddle::platform::errors::InvalidArgument( + "DistModel only supports CPU and GPU.")); + } + + framework::LoD dst_lod; + for (auto &src_lod : input_data.lod) { + dst_lod.emplace_back(src_lod); + } + input_tensor->set_lod(dst_lod); + return true; +} + +std::string DistModelDTypeToString(DistModelDataType dtype) { + switch (dtype) { + case DistModelDataType::FLOAT32: + return "float32"; + case DistModelDataType::FLOAT16: + return "float16"; + case DistModelDataType::INT64: + return "int64"; + case DistModelDataType::INT32: + return "int32"; + case DistModelDataType::INT8: + return "int8"; + } + return "NOT SUPPORT DTYPE"; +} + +bool IsPPFirstStage(const DistModelConfig &config) { + return config.local_rank - config.mp_degree < 0; +} + +bool IsPPLastStage(const DistModelConfig &config) { + return config.local_rank + config.mp_degree >= config.nranks; +} + +class DistModelTimer { + public: + void tic() { tic_time = std::chrono::high_resolution_clock::now(); } + double toc() { + std::chrono::high_resolution_clock::time_point toc_time = + std::chrono::high_resolution_clock::now(); + std::chrono::duration time_elapse = + std::chrono::duration_cast>(toc_time - + tic_time); + double time_elapse_in_ms = + static_cast(time_elapse.count()) * 1000.0; + return time_elapse_in_ms; + } + + private: + std::chrono::high_resolution_clock::time_point tic_time; +}; + } // namespace bool DistModel::Init() { - /* TODO(fleet exe dev): implement this funct */ + carrier_id_ = "inference"; bool init_method = (!config_.model_dir.empty() || config_.program_desc); PADDLE_ENFORCE_EQ(init_method, true, platform::errors::InvalidArgument( @@ -127,10 +229,9 @@ bool DistModel::CommInit() { InsertCommOp("mp_comm_id", mp_group_nranks, mp_group_rank, peer_endpoints, comm_init_block, config_.mp_ring_id); } - if (config_.pp_degree) { - // NOTE: the last pp stage doesn't need init pp comm + if (config_.pp_degree > 1) { VLOG(3) << "Init comm group for pp."; - if (config_.local_rank - config_.mp_degree >= 0) { + if (!IsPPFirstStage(config_)) { PADDLE_ENFORCE_EQ(config_.pp_upstream_ring_id >= 0, true, platform::errors::InvalidArgument( "pp upstream ring id must be provided for " @@ -143,7 +244,7 @@ bool DistModel::CommInit() { comm_init_block, config_.pp_upstream_ring_id); } - if (config_.local_rank + config_.mp_degree < config_.nranks) { + if (!IsPPLastStage(config_)) { PADDLE_ENFORCE_EQ(config_.pp_downstream_ring_id >= 0, true, platform::errors::InvalidArgument( "pp downstream ring id must be provided for " @@ -326,7 +427,7 @@ bool DistModel::PrepareFleetExe() { id_to_rank.insert({i, i}); } fleet_exe.reset(new FleetExecutor(executor_desc_)); - fleet_exe->Init("inference", *(program_.get()), scope_.get(), place_, 1, + fleet_exe->Init(carrier_id_, *(program_.get()), scope_.get(), place_, 1, {task_node_.get()}, id_to_rank); return true; } @@ -340,8 +441,27 @@ bool DistModel::PrepareFeedAndFetch() { feeds_.resize(idx + 1); } feeds_[idx] = op; - feed_names_[op->Output("Out")[0]] = idx; - idx_to_feeds_[idx] = op->Output("Out")[0]; + std::string var_name = op->Output("Out")[0]; + feed_names_[var_name] = idx; + idx_to_feeds_[idx] = var_name; + framework::VarDesc *real_var = program_->Block(0).FindVar(var_name); + if (!real_var) { + LOG(ERROR) + << "The output of feed ops [" << var_name + << "] cannot be found in the program. Check the inference program."; + return false; + } + if (real_var->GetDataType() == framework::proto::VarType::FP32) { + feeds_to_dtype_.insert({var_name, DistModelDataType::FLOAT32}); + } else if (real_var->GetDataType() == framework::proto::VarType::INT32) { + feeds_to_dtype_.insert({var_name, DistModelDataType::INT32}); + } else if (real_var->GetDataType() == framework::proto::VarType::INT64) { + feeds_to_dtype_.insert({var_name, DistModelDataType::INT64}); + } else { + LOG(ERROR) << "Don't support feed var dtype for: " + << real_var->GetDataType(); + return false; + } } else if (op->Type() == "fetch") { VLOG(3) << "fetch op with fetch var: " << op->Input("X")[0]; int idx = BOOST_GET_CONST(int, op->GetAttr("col")); @@ -349,15 +469,166 @@ bool DistModel::PrepareFeedAndFetch() { fetches_.resize(idx + 1); } fetches_[idx] = op; - id_to_fetches_[idx] = op->Input("X")[0]; + idx_to_fetches_[idx] = op->Input("X")[0]; + } + } + + if (config_.pp_degree == 1) { + if (feeds_.size() == 0) { + LOG(ERROR) << "No feed ops in the inf program, please check the program."; + return false; + } + if (fetches_.size() == 0) { + LOG(ERROR) << "No fetch op in the inf program, please check the program."; + return false; + } + } else { + if (IsPPFirstStage(config_)) { + if (feeds_.size() == 0) { + LOG(ERROR) << "Feed ops are needed for the first pp stage."; + return false; + } else { + LOG(WARNING) << "No feed ops in non-first pp stage."; + } + } else if (feeds_.size() > 0) { + LOG(WARNING) << "Feed op is found in the non-first stage of pp."; + } + if (IsPPLastStage(config_)) { + if (fetches_.size() == 0) { + LOG(ERROR) << "Fetch op is needed for the last pp stage."; + return false; + } else { + LOG(WARNING) << "No fetch op in non-last pp stage."; + } + } else if (fetches_.size() > 0) { + LOG(WARNING) << "Fetch op is found in the non-last stage of pp."; + } + } + return true; +} + +bool DistModel::FeedData(const std::vector &input_data, + framework::Scope *scope) { + VLOG(3) << "DistModel is feeding data."; + if (input_data.size() != feeds_.size()) { + LOG(ERROR) << "Should provide " << feeds_.size() << " feeds, but got " + << input_data.size() << " data."; + return false; + } + feed_tensors_.resize(feeds_.size()); + for (size_t i = 0; i < input_data.size(); ++i) { + // feed each data separately + framework::LoDTensor *input_tensor = &(feed_tensors_[i]); + if (!LoadDataFromDistModelTensor(input_data[i], input_tensor, place_)) { + LOG(ERROR) << "Fail to load data from tensor " << input_data[i].name; + return false; + } + std::string target_name = input_data[i].name; + if (feed_names_.find(target_name) == feed_names_.end()) { + LOG(ERROR) << "The input name [" << target_name + << "] cannot be found in the program." + << " DistModel loads data failed."; + return false; + } + if (input_data[i].dtype != feeds_to_dtype_[target_name]) { + LOG(ERROR) << "Feed var [" << target_name << "] expected dtype is: " + << DistModelDTypeToString(feeds_to_dtype_[target_name]) + << ". But received dtype is: " + << DistModelDTypeToString(input_data[i].dtype) << "."; + return false; + } + int feed_idx = feed_names_[target_name]; + framework::SetFeedVariable(scope, *input_tensor, "feed", feed_idx); + } + return true; +} + +bool DistModel::FetchResults(std::vector *output_data, + framework::Scope *scope) { + VLOG(3) << "DistModel is fetch results."; + output_data->resize(fetches_.size()); + for (size_t i = 0; i < fetches_.size(); ++i) { + int idx = BOOST_GET_CONST(int, fetches_[i]->GetAttr("col")); + VLOG(3) << "Fetching data for [" << idx_to_fetches_[idx] << "]"; + PADDLE_ENFORCE_EQ( + static_cast(idx), i, + platform::errors::InvalidArgument( + "Fetch op's col attr(%d) should be equal to the index(%d)", idx, + i)); + framework::FetchType &fetch_var = + framework::GetFetchVariable(*scope, "fetch", idx); + auto &fetch = BOOST_GET(framework::LoDTensor, fetch_var); + auto type = fetch.type(); + auto output = &(output_data->at(i)); + output->name = idx_to_fetches_[idx]; + bool rst = false; + if (type == framework::proto::VarType::FP32) { + rst = FetchResult(fetch, output); + output->dtype = DistModelDataType::FLOAT32; + } else if (type == framework::proto::VarType::INT64) { + rst = FetchResult(fetch, output); + output->dtype = DistModelDataType::INT64; + } else if (type == framework::proto::VarType::INT32) { + rst = FetchResult(fetch, output); + output->dtype = DistModelDataType::INT32; + } else { + LOG(ERROR) << "DistModel meets unknown fetch data type. DistModel only " + "supports float32, int64 and int32 fetch type for now."; + } + if (!rst) { + LOG(ERROR) << "DistModel fails to fetch result " << idx_to_fetches_[idx]; + return false; } } return true; } -void DistModel::Run(const std::vector &input_data, +template +bool DistModel::FetchResult(const framework::LoDTensor &fetch, + DistModelTensor *output_data) { + auto shape = framework::vectorize(fetch.dims()); + output_data->shape.assign(shape.begin(), shape.end()); + const T *data = fetch.data(); + int64_t num_elems = fetch.numel(); + output_data->data.Resize(num_elems * sizeof(T)); + // The output of fetch op is always on the cpu, no need switch on place + memcpy(output_data->data.data(), data, num_elems * sizeof(T)); + output_data->lod.clear(); + for (auto &level : fetch.lod()) { + output_data->lod.emplace_back(level.begin(), level.end()); + } + return true; +} + +bool DistModel::Run(const std::vector &input_data, std::vector *output_data) { - /* TODO(fleet exe dev): implement this funct */ + // TODO(fleet exe dev): support pipeline inf mode + VLOG(3) << "DistModel run for once."; + + DistModelTimer timer; + timer.tic(); + + if (!FeedData(input_data, scope_.get())) { + LOG(ERROR) << "DistModel failed at feeding data."; + return false; + } + double feed_elapse = timer.toc(); + VLOG(3) << "Finish loading data, cost " << feed_elapse << "ms."; + + fleet_exe->Run(carrier_id_); + double fleet_exe_elapse = timer.toc(); + VLOG(3) << "Finish FleetExe running, cost " << fleet_exe_elapse - feed_elapse + << "ms."; + + if (!FetchResults(output_data, scope_.get())) { + LOG(ERROR) << "DistModel failed at fetching result."; + return false; + } + double fetch_elapse = timer.toc(); + VLOG(3) << "Finish fetching data, cost " << fetch_elapse - fleet_exe_elapse + << "ms."; + VLOG(3) << "DistModel finish inf, cost " << fetch_elapse << "ms"; + return true; } } // namespace distributed diff --git a/paddle/fluid/distributed/fleet_executor/dist_model.h b/paddle/fluid/distributed/fleet_executor/dist_model.h index 96e9c018074..e6ad94e266a 100644 --- a/paddle/fluid/distributed/fleet_executor/dist_model.h +++ b/paddle/fluid/distributed/fleet_executor/dist_model.h @@ -19,6 +19,7 @@ #include "paddle/fluid/distributed/fleet_executor/dist_model_tensor_wrapper.h" #include "paddle/fluid/distributed/fleet_executor/fleet_executor_desc.pb.h" +#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/platform/macros.h" #include "paddle/fluid/platform/place.h" @@ -57,7 +58,7 @@ class DistModel { public: explicit DistModel(const DistModelConfig& config) : config_(config) {} bool Init(); - void Run(const std::vector& input_data, + bool Run(const std::vector& input_data, std::vector* output_data); ~DistModel() = default; @@ -75,12 +76,22 @@ class DistModel { void InsertCommOp(std::string tmp_var_name, int nranks, int rank, const std::vector& peer_endpoints, framework::BlockDesc* block, int ring_id); + bool FeedData(const std::vector& input_data, + framework::Scope* scope); + bool FetchResults(std::vector* output_data, + framework::Scope* scope); + template + bool FetchResult(const framework::LoDTensor& fetch, + DistModelTensor* output_data); + std::string carrier_id_; + std::vector feed_tensors_; std::vector feeds_; std::map feed_names_; std::map idx_to_feeds_; + std::map feeds_to_dtype_; std::vector fetches_; - std::map id_to_fetches_; + std::map idx_to_fetches_; DistModelConfig config_; FleetExecutorDesc executor_desc_; std::shared_ptr fleet_exe; diff --git a/paddle/fluid/distributed/fleet_executor/dist_model_tensor_wrapper.h b/paddle/fluid/distributed/fleet_executor/dist_model_tensor_wrapper.h index 4a04633388a..6bdd858d6cf 100644 --- a/paddle/fluid/distributed/fleet_executor/dist_model_tensor_wrapper.h +++ b/paddle/fluid/distributed/fleet_executor/dist_model_tensor_wrapper.h @@ -62,7 +62,7 @@ class DistModelDataBuf { void Free(); void* data_{nullptr}; size_t length_{0}; - bool memory_owned_{false}; + bool memory_owned_{true}; }; struct DistModelTensor { diff --git a/paddle/fluid/pybind/bind_fleet_executor.cc b/paddle/fluid/pybind/bind_fleet_executor.cc index 450939dd0ff..72ee451fe7c 100644 --- a/paddle/fluid/pybind/bind_fleet_executor.cc +++ b/paddle/fluid/pybind/bind_fleet_executor.cc @@ -162,7 +162,12 @@ void BindFleetExecutor(py::module* m) { py::class_(*m, "DistModel") .def(py::init()) .def("init", &DistModel::Init) - .def("run", &DistModel::Run, py::call_guard()); + .def("run", + [](DistModel& self, const std::vector& inputs) { + std::vector outputs; + self.Run(inputs, &outputs); + return outputs; + }); py::class_(*m, "DistModelDataBuf") .def(py::init()) diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 2ac5e9404c1..2e35277d70c 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -156,7 +156,8 @@ if(((NOT WITH_ROCM) AND (NOT WITH_GPU)) OR WIN32) LIST(REMOVE_ITEM TEST_OPS test_fleet_executor_origin_scheduler) LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_mapper) LIST(REMOVE_ITEM TEST_OPS test_fleet_executor_task_node) - LIST(REMOVE_ITEM TEST_OPS test_dist_model_tensor) + LIST(REMOVE_ITEM TEST_OPS test_fleet_exe_dist_model_run) + LIST(REMOVE_ITEM TEST_OPS test_fleet_exe_dist_model_tensor) endif() # Temporally disable test_deprecated_decorator diff --git a/python/paddle/fluid/tests/unittests/test_fleet_exe_dist_model_run.py b/python/paddle/fluid/tests/unittests/test_fleet_exe_dist_model_run.py new file mode 100644 index 00000000000..544fe4dd43e --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_fleet_exe_dist_model_run.py @@ -0,0 +1,86 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import paddle +import numpy as np +import os +from paddle.fluid import core + +paddle.enable_static() + + +class TestDistModelRun(unittest.TestCase): + def test_dist_model_run(self): + # step 0: declare folder to save the model and params + folder = './dist_model_run_test/' + file = 'inf' + path_prefix = folder + file + + # step 1: saving the inference model and params + x = paddle.static.data(name='x', shape=[28, 28], dtype='float32') + y = paddle.static.data(name='y', shape=[28, 1], dtype='int64') + predict = paddle.static.nn.fc(x, 10, activation='softmax') + loss = paddle.nn.functional.cross_entropy(predict, y) + avg_loss = paddle.tensor.stat.mean(loss) + exe = paddle.static.Executor(paddle.CUDAPlace(0)) + exe.run(paddle.static.default_startup_program()) + x_data = np.random.randn(28, 28).astype('float32') + y_data = np.random.randint(0, 9, size=[28, 1]).astype('int64') + exe.run(paddle.static.default_main_program(), + feed={'x': x_data, + 'y': y_data}, + fetch_list=[avg_loss]) + paddle.static.save_inference_model(path_prefix, [x, y], [avg_loss], exe) + print('save model to', path_prefix) + + # step 2: prepare fake data for the inference + x_tensor = np.random.randn(28, 28).astype('float32') + y_tensor = np.random.randint(0, 9, size=[28, 1]).astype('int64') + + # step 3: init the dist model to inference with fake data + config = core.DistModelConfig() + config.model_dir = path_prefix + config.place = 'GPU' + dist = core.DistModel(config) + dist.init() + dist_x = core.DistModelTensor(x_tensor, 'x') + dist_y = core.DistModelTensor(y_tensor, 'y') + input_data = [dist_x, dist_y] + output_rst = dist.run(input_data) + dist_model_rst = output_rst[0].as_ndarray().ravel().tolist() + print("dist model rst:", dist_model_rst) + + # step 4: use framework's api to inference with fake data + [inference_program, feed_target_names, fetch_targets] = ( + paddle.static.load_inference_model(path_prefix, exe)) + results = exe.run(inference_program, + feed={'x': x_tensor, + 'y': y_tensor}, + fetch_list=fetch_targets) + load_inference_model_rst = results[0] + print("load inference model api rst:", load_inference_model_rst) + + # step 5: compare two results + self.assertTrue(np.allclose(dist_model_rst, load_inference_model_rst)) + + # step 6: clean up the env, delete the saved model and params + os.remove(path_prefix + '.pdiparams') + os.remove(path_prefix + '.pdmodel') + os.rmdir(folder) + print('cleaned up the env') + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_dist_model_tensor.py b/python/paddle/fluid/tests/unittests/test_fleet_exe_dist_model_tensor.py similarity index 97% rename from python/paddle/fluid/tests/unittests/test_dist_model_tensor.py rename to python/paddle/fluid/tests/unittests/test_fleet_exe_dist_model_tensor.py index da25550c4f4..a74b4f0d224 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_model_tensor.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_exe_dist_model_tensor.py @@ -1,4 +1,4 @@ -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. -- GitLab