diff --git a/paddle/fluid/framework/details/nccl_all_reduce_op_handle.cc b/paddle/fluid/framework/details/nccl_all_reduce_op_handle.cc index 0611ec6376d2097c85dad1e5d7430c7b0713a385..1e48f75958a3ada4d1cd5c8d0f920da4fed2157e 100644 --- a/paddle/fluid/framework/details/nccl_all_reduce_op_handle.cc +++ b/paddle/fluid/framework/details/nccl_all_reduce_op_handle.cc @@ -14,6 +14,8 @@ #include "paddle/fluid/framework/details/nccl_all_reduce_op_handle.h" +#include + namespace paddle { namespace framework { namespace details { @@ -27,6 +29,32 @@ NCCLAllReduceOpHandle::NCCLAllReduceOpHandle( } } +struct ReduceLoDTensor { + const std::vector &src_tensors_; + LoDTensor &dst_tensor_; + + ReduceLoDTensor(const std::vector &src, LoDTensor *dst) + : src_tensors_(src), dst_tensor_(*dst) {} + + template + void operator()() const { + PADDLE_ENFORCE(!src_tensors_.empty()); + auto &t0 = src_tensors_[0]; + PADDLE_ENFORCE_NE(t0.numel(), 0); + dst_tensor_.Resize(t0.dims()); + T *dst = dst_tensor_.mutable_data(platform::CPUPlace()); + std::copy(t0.data(), t0.data() + t0.numel(), dst); + + for (size_t i = 1; i < src_tensors_.size(); ++i) { + auto &t = src_tensors_[i]; + PADDLE_ENFORCE_EQ(t.dims(), t0.dims()); + PADDLE_ENFORCE_EQ(t.type(), t0.type()); + std::transform(t.data(), t.data() + t.numel(), dst, dst, + [](T a, T b) -> T { return a + b; }); + } + } +}; + void NCCLAllReduceOpHandle::RunImpl() { if (inputs_.size() == 1) { return; // No need to all reduce when GPU count = 1; @@ -41,40 +69,67 @@ void NCCLAllReduceOpHandle::RunImpl() { int dtype = -1; size_t numel = 0; - std::vector> all_reduce_calls; + std::vector lod_tensors; for (size_t i = 0; i < local_scopes_.size(); ++i) { - auto &p = places_[i]; auto *s = local_scopes_[i]; - int dev_id = boost::get(p).device; auto &lod_tensor = s->FindVar(var_name)->Get(); - void *buffer = const_cast(lod_tensor.data()); + lod_tensors.emplace_back(lod_tensor); + } - if (dtype == -1) { - dtype = platform::ToNCCLDataType(lod_tensor.type()); - } + if (platform::is_gpu_place(lod_tensors[0].place())) { + std::vector> all_reduce_calls; + for (size_t i = 0; i < local_scopes_.size(); ++i) { + auto &p = places_[i]; + auto &lod_tensor = lod_tensors[i]; + void *buffer = const_cast(lod_tensor.data()); - if (numel == 0) { - numel = static_cast(lod_tensor.numel()); - } + if (dtype == -1) { + dtype = platform::ToNCCLDataType(lod_tensor.type()); + } - auto &nccl_ctx = nccl_ctxs_.at(dev_id); - auto stream = nccl_ctx.stream(); - auto comm = nccl_ctx.comm_; - all_reduce_calls.emplace_back([=] { - PADDLE_ENFORCE(platform::dynload::ncclAllReduce( - buffer, buffer, numel, static_cast(dtype), ncclSum, - comm, stream)); + if (numel == 0) { + numel = static_cast(lod_tensor.numel()); + } + + int dev_id = boost::get(p).device; + auto &nccl_ctx = nccl_ctxs_.at(dev_id); + auto stream = nccl_ctx.stream(); + auto comm = nccl_ctx.comm_; + all_reduce_calls.emplace_back([=] { + PADDLE_ENFORCE(platform::dynload::ncclAllReduce( + buffer, buffer, numel, static_cast(dtype), + ncclSum, comm, stream)); + }); + } + this->RunAndRecordEvent([&] { + platform::NCCLGroupGuard guard; + for (auto &call : all_reduce_calls) { + call(); + } }); - } + } else { // Special handle CPU only Operator's gradient. Like CRF + auto &trg = + *this->local_scopes_[0]->Var()->GetMutable(); + + // Reduce All Tensor to trg in CPU + ReduceLoDTensor func(lod_tensors, &trg); + VisitDataType(ToDataType(lod_tensors[0].type()), func); - this->RunAndRecordEvent([&] { - platform::NCCLGroupGuard guard; - for (auto &call : all_reduce_calls) { - call(); + for (size_t i = 0; i < local_scopes_.size(); ++i) { + auto &scope = local_scopes_[i]; + auto &p = places_[i]; + auto *var = scope->FindVar(var_name); + auto *dev_ctx = dev_ctxes_[p]; + + RunAndRecordEvent(p, [&trg, var, dev_ctx, p] { + auto &tensor_gpu = *var->GetMutable(); + auto &tensor_cpu = trg; + TensorCopy(tensor_cpu, p, *dev_ctx, &tensor_gpu); + }); } - }); + } } } diff --git a/paddle/fluid/framework/details/op_handle_base.cc b/paddle/fluid/framework/details/op_handle_base.cc index 846bc21be27cf6c889ae34b967b8bff3c60ab743..534d77860f87be08c8834efd373d90eb199ed6a2 100644 --- a/paddle/fluid/framework/details/op_handle_base.cc +++ b/paddle/fluid/framework/details/op_handle_base.cc @@ -107,6 +107,22 @@ void OpHandleBase::RunAndRecordEvent(const std::function &callback) { #endif } +void OpHandleBase::RunAndRecordEvent(platform::Place p, + const std::function &callback) { +#ifdef PADDLE_WITH_CUDA + if (platform::is_cpu_place(p) || events_.empty()) { + callback(); + } else { + auto *ctx = dev_ctxes_.at(p); + auto *cuda_ctx = static_cast(ctx); + cuda_ctx->RecordEvent(events_.at(boost::get(p).device), + callback); + } +#else + callback(); +#endif +} + } // namespace details } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/details/op_handle_base.h b/paddle/fluid/framework/details/op_handle_base.h index 1aacba5a4c3c959b6f584aa5f4dcdc5c0dc43e76..a9a6c8d39cf8741f7d9c91579a650ad742cec381 100644 --- a/paddle/fluid/framework/details/op_handle_base.h +++ b/paddle/fluid/framework/details/op_handle_base.h @@ -64,6 +64,9 @@ class OpHandleBase { protected: void RunAndRecordEvent(const std::function &callback); + void RunAndRecordEvent(platform::Place p, + const std::function &callback); + virtual void RunImpl() = 0; }; diff --git a/python/paddle/fluid/tests/book/test_label_semantic_roles.py b/python/paddle/fluid/tests/book/test_label_semantic_roles.py index 4d8bca4d2430a248ccf421572bdafdffc3a3003a..d9cd76952e31f8185512ab45f9f3ab2ce7d9da48 100644 --- a/python/paddle/fluid/tests/book/test_label_semantic_roles.py +++ b/python/paddle/fluid/tests/book/test_label_semantic_roles.py @@ -12,17 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +import contextlib import math - import numpy as np +import os +import time +import unittest + import paddle import paddle.dataset.conll05 as conll05 import paddle.fluid as fluid -from paddle.fluid.initializer import init_on_cpu -import contextlib -import time -import unittest -import os word_dict, verb_dict, label_dict = conll05.get_dict() word_dict_len = len(word_dict) diff --git a/python/paddle/fluid/tests/unittests/test_parallel_executor.py b/python/paddle/fluid/tests/unittests/test_parallel_executor.py index 95845ea4de54ad43754ec5811d28ed52a8a3ae86..83d22fd799eea55eedb58f93421b275985edb50b 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_executor.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_executor.py @@ -505,3 +505,148 @@ class ParallelExecutorTestingDuringTraining(unittest.TestCase): train_loss, test_loss, atol=1e-8), "Train loss: " + str(train_loss) + "\n Test loss:" + str(test_loss)) + + +import paddle.dataset.conll05 as conll05 +import paddle.fluid as fluid + +word_dict, verb_dict, label_dict = conll05.get_dict() +word_dict_len = len(word_dict) +label_dict_len = len(label_dict) +pred_dict_len = len(verb_dict) +mark_dict_len = 2 +word_dim = 32 +mark_dim = 5 +hidden_dim = 512 +depth = 8 +mix_hidden_lr = 1e-3 +embedding_name = 'emb' + + +def db_lstm(word, predicate, ctx_n2, ctx_n1, ctx_0, ctx_p1, ctx_p2, mark, + **ignored): + # 8 features + predicate_embedding = fluid.layers.embedding( + input=predicate, + size=[pred_dict_len, word_dim], + dtype='float32', + param_attr='vemb') + + mark_embedding = fluid.layers.embedding( + input=mark, size=[mark_dict_len, mark_dim], dtype='float32') + + word_input = [word, ctx_n2, ctx_n1, ctx_0, ctx_p1, ctx_p2] + emb_layers = [ + fluid.layers.embedding( + size=[word_dict_len, word_dim], + input=x, + param_attr=fluid.ParamAttr( + name=embedding_name, trainable=False)) for x in word_input + ] + emb_layers.append(predicate_embedding) + emb_layers.append(mark_embedding) + + hidden_0_layers = [ + fluid.layers.fc(input=emb, size=hidden_dim, act='tanh') + for emb in emb_layers + ] + + hidden_0 = fluid.layers.sums(input=hidden_0_layers) + + lstm_0 = fluid.layers.dynamic_lstm( + input=hidden_0, + size=hidden_dim, + candidate_activation='relu', + gate_activation='sigmoid', + cell_activation='sigmoid') + + # stack L-LSTM and R-LSTM with direct edges + input_tmp = [hidden_0, lstm_0] + + for i in range(1, depth): + mix_hidden = fluid.layers.sums(input=[ + fluid.layers.fc(input=input_tmp[0], size=hidden_dim, act='tanh'), + fluid.layers.fc(input=input_tmp[1], size=hidden_dim, act='tanh') + ]) + + lstm = fluid.layers.dynamic_lstm( + input=mix_hidden, + size=hidden_dim, + candidate_activation='relu', + gate_activation='sigmoid', + cell_activation='sigmoid', + is_reverse=((i % 2) == 1)) + + input_tmp = [mix_hidden, lstm] + + feature_out = fluid.layers.sums(input=[ + fluid.layers.fc(input=input_tmp[0], size=label_dict_len, act='tanh'), + fluid.layers.fc(input=input_tmp[1], size=label_dict_len, act='tanh') + ]) + + return feature_out + + +class TestCRFModel(unittest.TestCase): + def test_all(self): + main = fluid.Program() + startup = fluid.Program() + with fluid.program_guard(main, startup): + word = fluid.layers.data( + name='word_data', shape=[1], dtype='int64', lod_level=1) + predicate = fluid.layers.data( + name='verb_data', shape=[1], dtype='int64', lod_level=1) + ctx_n2 = fluid.layers.data( + name='ctx_n2_data', shape=[1], dtype='int64', lod_level=1) + ctx_n1 = fluid.layers.data( + name='ctx_n1_data', shape=[1], dtype='int64', lod_level=1) + ctx_0 = fluid.layers.data( + name='ctx_0_data', shape=[1], dtype='int64', lod_level=1) + ctx_p1 = fluid.layers.data( + name='ctx_p1_data', shape=[1], dtype='int64', lod_level=1) + ctx_p2 = fluid.layers.data( + name='ctx_p2_data', shape=[1], dtype='int64', lod_level=1) + mark = fluid.layers.data( + name='mark_data', shape=[1], dtype='int64', lod_level=1) + feature_out = db_lstm(**locals()) + target = fluid.layers.data( + name='target', shape=[1], dtype='int64', lod_level=1) + crf_cost = fluid.layers.linear_chain_crf( + input=feature_out, + label=target, + param_attr=fluid.ParamAttr( + name='crfw', learning_rate=1e-1)) + avg_cost = fluid.layers.mean(crf_cost) + + sgd_optimizer = fluid.optimizer.SGD( + learning_rate=fluid.layers.exponential_decay( + learning_rate=0.01, + decay_steps=100000, + decay_rate=0.5, + staircase=True)) + sgd_optimizer.minimize(avg_cost) + + train_data = paddle.batch( + paddle.reader.shuffle( + paddle.dataset.conll05.test(), buf_size=8192), + batch_size=16) + + place = fluid.CUDAPlace(0) + exe = fluid.Executor(place) + exe.run(startup) + + pe = fluid.ParallelExecutor(use_cuda=True, loss_name=avg_cost.name) + + feeder = fluid.DataFeeder( + feed_list=[ + word, ctx_n2, ctx_n1, ctx_0, ctx_p1, ctx_p2, predicate, + mark, target + ], + place=fluid.CPUPlace()) + + data = train_data() + for i in xrange(10): + cur_batch = next(data) + print map(numpy.array, + pe.run(feed_dict=feeder.feed(cur_batch), + fetch_list=[avg_cost.name]))[0]