diff --git a/paddle/fluid/operators/distributed/parameter_prefetch.cc b/paddle/fluid/operators/distributed/parameter_prefetch.cc index 5a67b358ddabb12566cd4ffe00cb12c65a185099..a9378d61c3ca39bd43b558633cc4d04c40175cac 100644 --- a/paddle/fluid/operators/distributed/parameter_prefetch.cc +++ b/paddle/fluid/operators/distributed/parameter_prefetch.cc @@ -110,7 +110,7 @@ void prefetch_core( int pservers = context.Attr("pserver_num"); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); - auto &actual_ctx = *pool.Get(context.GetPlace()); + auto &actual_ctx = *pool.Get(platform::CPUPlace()); std::unique_ptr local_scope = scope.NewTmpScope(); @@ -144,7 +144,6 @@ void prefetch_core( VLOG(3) << "don't send no-initialied variable: " << out_var_names[i]; } } - for (size_t i = 0; i < rets.size(); i++) { PADDLE_ENFORCE_NE(rets[i]->Wait(), 0U, platform::errors::ExecutionTimeout( "internal error in RPCClient")); @@ -167,6 +166,7 @@ void prefetch_core( for (int64_t i = 0; i < dims[0]; ++i) { auto origin_id = ids_in_this_section[i]; std::vector vecs(row_numel); + std::copy_n(out_var_data + i * row_numel, row_numel, vecs.begin()); (*recved_vec_map)[origin_id] = vecs; } @@ -213,18 +213,18 @@ void prefetchs(const std::vector &id_var_names, const auto place = scope.FindVar(id_var_names[0])->Get().place(); - if (!platform::is_cpu_place(place)) { - PADDLE_THROW("multi prefetch only support CPU currently"); - } - + std::vector> ids_group; std::vector ids_union; + std::vector ids_lods; TableAndEndpoints tables; for (auto &id_name : id_var_names) { - auto *in_var = scope.FindVar(id_name); - auto &id_tensor = in_var->Get(); - std::copy_n(id_tensor.data(), id_tensor.numel(), - back_inserter(ids_union)); + auto &id_tensor = scope.FindVar(id_name)->Get(); + std::vector ids; + TensorToVector(id_tensor, context.device_context(), &ids); + ids_union.insert(ids_union.end(), ids.begin(), ids.end()); + ids_group.push_back(ids); + ids_lods.push_back(id_tensor.lod()); } std::unordered_set s(ids_union.begin(), ids_union.end()); @@ -258,25 +258,48 @@ void prefetchs(const std::vector &id_var_names, } for (size_t i = 0; i < out_var_names.size(); i++) { - auto *in_var = scope.FindVar(id_var_names[i]); - auto &id_tensor = in_var->Get(); - auto ids_size = id_tensor.dims()[0]; - const auto *id_data = id_tensor.data(); - + std::vector ids = ids_group[i]; + auto ids_size = ids.size(); auto *out_t = scope.FindVar(out_var_names[i])->GetMutable(); - out_t->set_lod(id_tensor.lod()); - out_t->Resize(framework::make_ddim({ids_size, vec_dim_1})); + out_t->set_lod(ids_lods[i]); + out_t->Resize( + framework::make_ddim({static_cast(ids_size), vec_dim_1})); auto *out_d = out_t->mutable_data(place); - for (auto idx = 0; idx < static_cast(ids_size); idx++) { - const auto &id = id_data[idx]; - if (padding_idx != distributed::kNoPadding && id == padding_idx) { - memset(out_d + idx * vec_dim_1, 0, sizeof(float) * vec_dim_1); - } else { - std::copy_n(recved_vec_map[id].begin(), vec_dim_1, - out_d + idx * vec_dim_1); + if (platform::is_cpu_place(out_t->place())) { + for (auto idx = 0; idx < static_cast(ids_size); idx++) { + const auto &id = ids[idx]; + if (padding_idx != distributed::kNoPadding && id == padding_idx) { + memset(out_d + idx * vec_dim_1, 0, sizeof(float) * vec_dim_1); + } else { + std::copy_n(recved_vec_map[id].begin(), vec_dim_1, + out_d + idx * vec_dim_1); + } + } + } else { +#ifdef PADDLE_WITH_CUDA + for (auto idx = 0; idx < static_cast(ids_size); idx++) { + const auto &id = ids[idx]; + auto stream = context.cuda_device_context().stream(); + if (padding_idx != distributed::kNoPadding && id == padding_idx) { + platform::GpuMemsetAsync(out_d + idx * vec_dim_1, 0, + sizeof(float) * vec_dim_1, stream); + } else { + auto &cpu_place = + BOOST_GET_CONST(platform::CPUPlace, + paddle::platform::CPUDeviceContext().GetPlace()); + auto &gpu_place = + BOOST_GET_CONST(platform::CUDAPlace, out_t->place()); + memory::Copy(gpu_place, out_d + idx * vec_dim_1, cpu_place, + &recved_vec_map[id][0], sizeof(float) * vec_dim_1, + stream); + } } +#else + PADDLE_ENFORCE(true, platform::errors::PermissionDenied( + "Paddle is not compiled with GPU!")); +#endif } } } diff --git a/paddle/fluid/operators/distributed_ops/distributed_lookup_table_op.cc b/paddle/fluid/operators/distributed_ops/distributed_lookup_table_op.cc index 3037a63b0d7b4e8812e67fdfb776f89ea43eb546..8c093d12585981ee681ae13f0d2e493197c6b9b3 100644 --- a/paddle/fluid/operators/distributed_ops/distributed_lookup_table_op.cc +++ b/paddle/fluid/operators/distributed_ops/distributed_lookup_table_op.cc @@ -1,11 +1,8 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -17,6 +14,7 @@ limitations under the License. */ #include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/distributed/parameter_prefetch.h" +#include "paddle/fluid/operators/distributed_ops/distributed_lookup_table_op.h" #include "paddle/fluid/operators/math/math_function.h" namespace paddle { @@ -75,47 +73,6 @@ class DistributedLookupTableOp : public framework::OperatorWithKernel { } }; -template -class DistributedLookupTableKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &context) const override { - auto ids_vars = context.MultiInputVar("Ids"); - auto emb_vars = context.MultiOutput("Embeddings"); - - auto id_names = context.InputNames("Ids"); - auto embedding_name = context.InputNames("W").front(); - auto out_names = context.OutputNames("Outputs"); - auto lookup_tables = context.Attr>("table_names"); - auto endpoints = context.Attr>("endpoints"); - auto is_distributed = context.Attr("is_distributed"); - - auto lookup_table_version = - context.Attr("lookup_table_version"); - - operators::distributed::prefetchs(id_names, out_names, embedding_name, - is_distributed, lookup_tables, endpoints, - context, context.scope()); - - if (lookup_table_version == "lookup_table_v2") { - auto &scope = context.scope(); - auto emb_dim = - scope.FindVar(embedding_name)->Get().dims()[1]; - - for (size_t i = 0; i < id_names.size(); ++i) { - auto *id_var = scope.FindVar(id_names[i]); - auto *out_var = scope.FindVar(out_names[i]); - auto *id_tensor = id_var->GetMutable(); - auto *out_tensor = out_var->GetMutable(); - - auto id_dims = id_tensor->dims(); - out_tensor->Resize(framework::make_ddim( - {static_cast(id_dims[0]), static_cast(id_dims[1]), - static_cast(emb_dim)})); - } - } - } -}; - class DistributedLookupTableOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { @@ -170,15 +127,12 @@ class DistributedLookupTableOpMaker : public framework::OpProtoAndCheckerMaker { AddComment(R"DOC( Lookup Tablel Prefetch Operator. - This operator is used to perform lookup on parameter W, then concatenated into a sparse tensor. - The type of Ids(Input) is SelectedRows, the rows of Ids contains the ids to be looked up in W; if the Id is not in the sparse table, this operator will return a random value and set the value into the table for the next looking up. - )DOC"); } }; @@ -191,4 +145,5 @@ REGISTER_OPERATOR(distributed_lookup_table, ops::DistributedLookupTableOp, ops::DistributedLookupTableOpMaker); REGISTER_OP_CPU_KERNEL(distributed_lookup_table, - ops::DistributedLookupTableKernel); + ops::DistributedLookupTableKernel< + paddle::platform::CPUDeviceContext, float>); diff --git a/paddle/fluid/operators/distributed_ops/distributed_lookup_table_op.cu.cc b/paddle/fluid/operators/distributed_ops/distributed_lookup_table_op.cu.cc new file mode 100644 index 0000000000000000000000000000000000000000..54c894415096e869f363eda6a1de2a473e839263 --- /dev/null +++ b/paddle/fluid/operators/distributed_ops/distributed_lookup_table_op.cu.cc @@ -0,0 +1,22 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/fluid/operators/distributed_ops/distributed_lookup_table_op.h" + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_CUDA_KERNEL( + distributed_lookup_table, + ops::DistributedLookupTableKernel); diff --git a/paddle/fluid/operators/distributed_ops/distributed_lookup_table_op.h b/paddle/fluid/operators/distributed_ops/distributed_lookup_table_op.h new file mode 100644 index 0000000000000000000000000000000000000000..a71451c78a870b71c05b41bdcfb34a85b3e2213b --- /dev/null +++ b/paddle/fluid/operators/distributed_ops/distributed_lookup_table_op.h @@ -0,0 +1,45 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once +#include +#include +#include +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/distributed/parameter_prefetch.h" +#include "paddle/fluid/operators/math/math_function.h" + +namespace paddle { +namespace operators { + +template +class DistributedLookupTableKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &context) const override { + auto ids_vars = context.MultiInputVar("Ids"); + auto emb_vars = context.MultiOutput("Embeddings"); + + auto id_names = context.InputNames("Ids"); + auto embedding_name = context.InputNames("W").front(); + auto out_names = context.OutputNames("Outputs"); + auto lookup_tables = context.Attr>("table_names"); + auto endpoints = context.Attr>("endpoints"); + auto is_distributed = context.Attr("is_distributed"); + + operators::distributed::prefetchs(id_names, out_names, embedding_name, + is_distributed, lookup_tables, endpoints, + context, context.scope()); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/distributed_ops/recv_save_op.cc b/paddle/fluid/operators/distributed_ops/recv_save_op.cc index ccc30d1ea082a6f69b71059631247144c931116e..d194fcda36a474fa208f5d5a67e425ba5a5a3303 100644 --- a/paddle/fluid/operators/distributed_ops/recv_save_op.cc +++ b/paddle/fluid/operators/distributed_ops/recv_save_op.cc @@ -44,7 +44,7 @@ class RecvSaveOp : public framework::OperatorWithKernel { const framework::ExecutionContext &ctx) const override { return framework::OpKernelType( framework::proto::VarType::Type(ctx.Attr("dtype")), - ctx.GetPlace()); + platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/strided_memcpy.h b/paddle/fluid/operators/strided_memcpy.h index 7528422fdc09b7894898bdee94eaa11ad2cba311..f20bada8ab288fe74fd8ca82a73522a22b234191 100644 --- a/paddle/fluid/operators/strided_memcpy.h +++ b/paddle/fluid/operators/strided_memcpy.h @@ -1,11 +1,8 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/__init__.py b/python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/__init__.py index d2c7397c85f8df155444d9272c7b75596f0fe169..d555e9876a08cecb7c32057a61a92e0cc9dacf7b 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/__init__.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/__init__.py @@ -393,6 +393,12 @@ class FleetTranspiler(Fleet): "in fleet.save_inference_model() function, executor must be as Executor type" ) + # Todo(MrChengmo): support recv&save GPU-Kernel for ps-gpu model save + if not isinstance(executor.place, fluid.CPUPlace): + save_executor = Executor(fluid.CPUPlace()) + else: + save_executor = executor + if main_program is not None: if isinstance(main_program, CompiledProgram): raise TypeError( @@ -670,6 +676,11 @@ if you would like to save all variables in a raise TypeError( "in fleet.save_persistables() function, executor must be as Executor type" ) + # Todo(MrChengmo): support recv&save GPU-Kernel for ps-gpu model save + if not isinstance(executor.place, fluid.CPUPlace): + save_executor = Executor(fluid.CPUPlace()) + else: + save_executor = executor if main_program is None: main_program = self.main_program @@ -679,7 +690,8 @@ if you would like to save all variables in a "in fleet.save_persistables() function, main_program must be as Program type, CompiledProgram is not allowed" ) - self._save_distributed_persistables(executor, dirname, main_program) + self._save_distributed_persistables(save_executor, dirname, + main_program) @staticmethod def __exclude_vars(exclude_var_names=[]): diff --git a/python/paddle/fluid/tests/unittests/dist_fleet_ctr_ps_gpu.py b/python/paddle/fluid/tests/unittests/dist_fleet_ctr_ps_gpu.py new file mode 100644 index 0000000000000000000000000000000000000000..03d0fa447daf3e3a502e7d77491045f92695496c --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dist_fleet_ctr_ps_gpu.py @@ -0,0 +1,152 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Distribute CTR model for test fleet api +""" + +from __future__ import print_function + +import shutil +import tempfile +import time + +import paddle +import paddle.fluid as fluid +import os +import numpy as np + +import ctr_dataset_reader +from test_dist_fleet_base import runtime_main, FleetDistRunnerBase +from dist_fleet_ctr import TestDistCTR2x2, fake_ctr_reader +from paddle.distributed.fleet.base.util_factory import fleet_util + +# Fix seed for test +fluid.default_startup_program().random_seed = 1 +fluid.default_main_program().random_seed = 1 + + +class TestDistGpuPsCTR2x2(TestDistCTR2x2): + """ + For test CTR model, using Fleet api & PS-GPU + """ + + def check_model_right(self, dirname): + model_filename = os.path.join(dirname, "__model__") + + with open(model_filename, "rb") as f: + program_desc_str = f.read() + + program = fluid.Program.parse_from_string(program_desc_str) + with open(os.path.join(dirname, "__model__.proto"), "w") as wn: + wn.write(str(program)) + + def do_pyreader_training(self, fleet): + """ + do training using dataset, using fetch handler to catch variable + Args: + fleet(Fleet api): the fleet object of Parameter Server, define distribute training role + """ + device_id = int(os.getenv("FLAGS_selected_gpus", "0")) + place = fluid.CUDAPlace(device_id) + exe = fluid.Executor(place) + fleet.init_worker() + exe.run(fleet.startup_program) + + batch_size = 4 + train_reader = paddle.batch(fake_ctr_reader(), batch_size=batch_size) + self.reader.decorate_sample_list_generator(train_reader) + + for epoch_id in range(1): + self.reader.start() + try: + pass_start = time.time() + while True: + loss_val = exe.run(program=fleet.main_program, + fetch_list=[self.avg_cost.name]) + loss_val = np.mean(loss_val) + reduce_output = fleet_util.all_reduce( + np.array(loss_val), mode="sum") + loss_all_trainer = fleet_util.all_gather(float(loss_val)) + loss_val = float(reduce_output) / len(loss_all_trainer) + message = "TRAIN ---> pass: {} loss: {}\n".format(epoch_id, + loss_val) + fleet_util.print_on_rank(message, 0) + + pass_time = time.time() - pass_start + except fluid.core.EOFException: + self.reader.reset() + + model_dir = tempfile.mkdtemp() + fleet.save_inference_model( + exe, model_dir, [feed.name for feed in self.feeds], self.avg_cost) + self.check_model_right(model_dir) + if fleet.is_first_worker(): + fleet.save_persistables(executor=exe, dirname=model_dir) + shutil.rmtree(model_dir) + fleet.stop_worker() + + def do_dataset_training(self, fleet): + dnn_input_dim, lr_input_dim, train_file_path = ctr_dataset_reader.prepare_data( + ) + + device_id = int(os.getenv("FLAGS_selected_gpus", "0")) + place = fluid.CUDAPlace(device_id) + exe = fluid.Executor(place) + + fleet.init_worker() + exe.run(fleet.startup_program) + + thread_num = 2 + batch_size = 128 + filelist = [] + for _ in range(thread_num): + filelist.append(train_file_path) + + # config dataset + dataset = paddle.fleet.DatasetFactory().create_dataset() + dataset.set_batch_size(batch_size) + dataset.set_use_var(self.feeds) + pipe_command = 'python ctr_dataset_reader.py' + dataset.set_pipe_command(pipe_command) + + dataset.set_filelist(filelist) + dataset.set_thread(thread_num) + + for epoch_id in range(1): + pass_start = time.time() + dataset.set_filelist(filelist) + exe.train_from_dataset( + program=fleet.main_program, + dataset=dataset, + fetch_list=[self.avg_cost], + fetch_info=["cost"], + print_period=2, + debug=int(os.getenv("Debug", "0"))) + pass_time = time.time() - pass_start + + if os.getenv("SAVE_MODEL") == "1": + model_dir = tempfile.mkdtemp() + fleet.save_inference_model(exe, model_dir, + [feed.name for feed in self.feeds], + self.avg_cost) + self.check_model_right(model_dir) + if fleet.is_first_worker(): + fleet.save_persistables(executor=exe, dirname=model_dir) + shutil.rmtree(model_dir) + + fleet.stop_worker() + + +if __name__ == "__main__": + runtime_main(TestDistGpuPsCTR2x2) diff --git a/python/paddle/fluid/tests/unittests/test_dist_fleet_base.py b/python/paddle/fluid/tests/unittests/test_dist_fleet_base.py index f72850f949715caa7282e6ca85f148a75ebc1074..8c700ff3217b550ef3857ec7631348cfbe4e7b7f 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_fleet_base.py +++ b/python/paddle/fluid/tests/unittests/test_dist_fleet_base.py @@ -278,6 +278,23 @@ class TestFleetBase(unittest.TestCase): tr0_ret = tr0.returncode tr1_ret = tr0.returncode + if tr0_ret != 0: + print( + "========================Error tr0_err begin===========================" + ) + os.system("cat {}".format(tempfile.gettempdir() + "/tr0_err.log")) + print( + "========================Error tr0_err end===========================" + ) + + if tr1_ret != 0: + print( + "========================Error tr1_err begin===========================" + ) + os.system("cat {}".format(tempfile.gettempdir() + "/tr1_err.log")) + print( + "========================Error tr1_err end===========================" + ) self.assertEqual(tr0_ret, 0, "something wrong in tr0, please check") self.assertEqual(tr1_ret, 0, "something wrong in tr1, please check") diff --git a/python/paddle/fluid/tests/unittests/test_dist_fleet_ctr.py b/python/paddle/fluid/tests/unittests/test_dist_fleet_ctr.py index 18629c4f996a6d068339bd6cad494e8e8d21123f..e16e91192e1fefb6ecbcb440971bd6b5eb8c820e 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_fleet_ctr.py +++ b/python/paddle/fluid/tests/unittests/test_dist_fleet_ctr.py @@ -156,5 +156,40 @@ class TestDistCtrHalfAsync2x2(TestFleetBase): "dist_fleet_ctr.py", delta=1e-5, check_error_log=True) +class TestDistCtrPsGpuPyreaderAsync2x2(TestFleetBase): + def _setup_config(self): + self._mode = "async" + self._reader = "pyreader" + + def check_with_place(self, + model_file, + delta=1e-3, + check_error_log=False, + need_envs={}): + required_envs = { + "PATH": os.getenv("PATH", ""), + "PYTHONPATH": os.getenv("PYTHONPATH", ""), + "LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""), + "FLAGS_rpc_deadline": "30000", # 5sec to fail fast + "http_proxy": "", + "FLAGS_communicator_send_queue_size": "2", + "FLAGS_communicator_max_merge_var_num": "2", + "CPU_NUM": "2", + "SAVE_MODEL": "1" + } + + required_envs.update(need_envs) + + if check_error_log: + required_envs["GLOG_v"] = "3" + required_envs["GLOG_logtostderr"] = "1" + + tr0_losses, tr1_losses = self._run_cluster(model_file, required_envs) + + def test_dist_train(self): + self.check_with_place( + "dist_fleet_ctr_ps_gpu.py", delta=1e-5, check_error_log=True) + + if __name__ == "__main__": unittest.main()