diff --git a/paddle/fluid/inference/analysis/data_flow_graph.cc b/paddle/fluid/inference/analysis/data_flow_graph.cc index 8a3af0a8ebd5bad7be7046fa399cca4920da3d71..7f64bc75ae8ad40a268739cdc36051e76af9f49a 100644 --- a/paddle/fluid/inference/analysis/data_flow_graph.cc +++ b/paddle/fluid/inference/analysis/data_flow_graph.cc @@ -337,6 +337,34 @@ ExtractInputAndOutputOfSubGraph(std::vector &graph) { // NOLINT std::vector(outputs.begin(), outputs.end())); } +void FilterRedundantOutputOfSubGraph(DataFlowGraph *graph) { + std::vector op_nodes; + for (auto &node : GraphTraits(graph).nodes_in_TS()) { + if (node.type() == Node::Type::kValue || node.deleted()) { + continue; + } + op_nodes.push_back(&node); + } + size_t op_num = op_nodes.size(); + for (size_t i = 0; i < op_num; i++) { + if (op_nodes[i]->type() == Node::Type::kFunction) continue; + std::unordered_set follow_up_input_names; + for (size_t j = i + 1; j < op_num; j++) { + for (auto *in : op_nodes[j]->inlinks) { + follow_up_input_names.insert(in->name()); + } + } + std::vector filtered_subgraph_outlinks; + for (auto *out : op_nodes[i]->outlinks) { + if (follow_up_input_names.count(out->name())) { + filtered_subgraph_outlinks.push_back(out); + } + } + PADDLE_ENFORCE_GE(filtered_subgraph_outlinks.size(), 1UL); + op_nodes[i]->outlinks = filtered_subgraph_outlinks; + } +} + } // namespace analysis } // namespace inference } // namespace paddle diff --git a/paddle/fluid/inference/analysis/data_flow_graph.h b/paddle/fluid/inference/analysis/data_flow_graph.h index 16aeae4d35e7bd54646053190da7f47eaca69aa0..bb3ec6bbc1d9555386aba8837b019d2511653258 100644 --- a/paddle/fluid/inference/analysis/data_flow_graph.h +++ b/paddle/fluid/inference/analysis/data_flow_graph.h @@ -178,6 +178,7 @@ struct GraphTraits { std::pair, std::vector> ExtractInputAndOutputOfSubGraph(std::vector &graph); // NOLINT +void FilterRedundantOutputOfSubGraph(DataFlowGraph *graph); } // namespace analysis } // namespace inference } // namespace paddle diff --git a/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc b/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc index aaf7ca67011fb7bd4a74f6d8f57317594c528ca4..18c32fa09199003f17183207828cdfe4e627ae1a 100644 --- a/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc +++ b/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc @@ -52,6 +52,7 @@ bool DataFlowGraphToFluidPass::Initialize(Argument *argument) { bool DataFlowGraphToFluidPass::Finalize() { return true; } void DataFlowGraphToFluidPass::Run(DataFlowGraph *graph) { + FilterRedundantOutputOfSubGraph(graph); LOG(INFO) << "graph.inputs " << graph->inputs.size(); for (auto &node : GraphTraits(graph).nodes_in_TS()) { if (node.deleted()) continue; diff --git a/paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.cc b/paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.cc index a6f85484756417e103cbb60bcb664e8b800b9f28..c05b0e5d4690d0a447edf63a149903704bc2c9be 100644 --- a/paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.cc +++ b/paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.cc @@ -46,9 +46,9 @@ std::string DFG_GraphvizDrawPass::Draw(DataFlowGraph *graph) { for (size_t i = 0; i < graph->nodes.size(); i++) { const Node &node = graph->nodes.Get(i); if (!config_.display_deleted_node && node.deleted()) continue; - for (auto &in : node.inlinks) { - if (!config_.display_deleted_node && in->deleted()) continue; - dot.AddEdge(in->repr(), node.repr(), {}); + for (auto &out : node.outlinks) { + if (!config_.display_deleted_node && out->deleted()) continue; + dot.AddEdge(node.repr(), out->repr(), {}); } } return dot.Build(); diff --git a/paddle/fluid/inference/api/api.cc b/paddle/fluid/inference/api/api.cc index e74f23ff969f5a8f58a71da337c16dcbc14f10c0..63c3f0d7b3f5c2b9246e2b041796caf5eb562826 100644 --- a/paddle/fluid/inference/api/api.cc +++ b/paddle/fluid/inference/api/api.cc @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include #include "paddle/fluid/inference/api/paddle_inference_api.h" namespace paddle { @@ -40,19 +41,36 @@ PaddleBuf::PaddleBuf(PaddleBuf&& other) PaddleBuf::PaddleBuf(const PaddleBuf& other) { *this = other; } PaddleBuf& PaddleBuf::operator=(const PaddleBuf& other) { + if (!other.memory_owned_) { + data_ = other.data_; + length_ = other.length_; + memory_owned_ = other.memory_owned_; + } else { + Resize(other.length()); + memcpy(data_, other.data(), other.length()); + length_ = other.length(); + memory_owned_ = true; + } + return *this; +} + +PaddleBuf& PaddleBuf::operator=(PaddleBuf&& other) { // only the buffer with external memory can be copied - assert(!other.memory_owned_); data_ = other.data_; length_ = other.length_; memory_owned_ = other.memory_owned_; + other.data_ = nullptr; + other.length_ = 0; + other.memory_owned_ = false; return *this; } void PaddleBuf::Resize(size_t length) { // Only the owned memory can be reset, the external memory can't be changed. if (length_ == length) return; - assert(memory_owned_); - Free(); + if (memory_owned_) { + Free(); + } data_ = new char[length]; length_ = length; memory_owned_ = true; @@ -68,7 +86,7 @@ void PaddleBuf::Reset(void* data, size_t length) { void PaddleBuf::Free() { if (memory_owned_ && data_) { assert(length_ > 0); - delete static_cast(data_); + delete[] static_cast(data_); data_ = nullptr; length_ = 0; } diff --git a/paddle/fluid/inference/api/paddle_inference_api.h b/paddle/fluid/inference/api/paddle_inference_api.h index 59b0df7968cce137843ba8cad38a62fdb8d3bfc1..b24414e8245b1a4d90acce4fa1ad5690e06b47dd 100644 --- a/paddle/fluid/inference/api/paddle_inference_api.h +++ b/paddle/fluid/inference/api/paddle_inference_api.h @@ -40,11 +40,12 @@ class PaddleBuf { // Copy only available when memory is managed externally. explicit PaddleBuf(const PaddleBuf&); PaddleBuf& operator=(const PaddleBuf&); + PaddleBuf& operator=(PaddleBuf&&); // Do not own the memory. PaddleBuf(void* data, size_t length) : data_(data), length_(length), memory_owned_{false} {} // Own memory. - explicit PaddleBuf(size_t length) + PaddleBuf(size_t length) : data_(new char[length]), length_(length), memory_owned_(true) {} // Resize to `length` bytes. void Resize(size_t length); diff --git a/paddle/fluid/operators/elementwise_op_function.h b/paddle/fluid/operators/elementwise_op_function.h index eb8272e90c32c3a0be2c0ce1bc679571af876317..bc3e95e904f8b6c2cdd2ae6685bf67580178e6b6 100644 --- a/paddle/fluid/operators/elementwise_op_function.h +++ b/paddle/fluid/operators/elementwise_op_function.h @@ -534,8 +534,8 @@ void ElemwiseGradCompute(const framework::ExecutionContext& ctx, const framework::Tensor& dout, int axis, framework::Tensor* dx, framework::Tensor* dy, DX_OP dx_op, DY_OP dy_op) { - const framework::DDim x_dim = x.dims(); - const framework::DDim y_dim = y.dims(); + const framework::DDim& x_dim = x.dims(); + const framework::DDim& y_dim = y.dims(); if (x.dims() == y.dims()) { ElemwiseGradComputeNoBroadcast( ctx, x_dim, y_dim, x, y, out, dout, axis, dx, dy, dx_op, dy_op); @@ -558,19 +558,19 @@ void ElemwiseExplicitGradCompute(const framework::ExecutionContext& ctx, framework::Tensor* dx, framework::Tensor* dy, DX_OP dx_op, DY_OP dy_op) { if (dy == nullptr) { - const framework::DDim dx_dims = dout.dims(); + const framework::DDim& dx_dims = dout.dims(); auto dy_dims = dx_dims; ElemwiseGradComputeNoBroadcast( ctx, dx_dims, dy_dims, x, y, out, dout, axis, dx, dy, dx_op, dy_op); } else { if (dout.dims() == dy->dims()) { - const framework::DDim dx_dims = dout.dims(); - const framework::DDim dy_dims = dy->dims(); + const framework::DDim& dx_dims = dout.dims(); + const framework::DDim& dy_dims = dy->dims(); ElemwiseGradComputeNoBroadcast( ctx, dx_dims, dy_dims, x, y, out, dout, axis, dx, dy, dx_op, dy_op); } else { // Y is a scalar auto dx_dims = dout.dims(); - const framework::DDim dy_dims = dy->dims(); + const framework::DDim& dy_dims = dy->dims(); ElemwiseGradComputeWithBroadcast( ctx, dx_dims, dy_dims, x, y, out, dout, axis, dx, dy, dx_op, dy_op); } diff --git a/paddle/fluid/operators/softmax_with_cross_entropy_op.cu b/paddle/fluid/operators/softmax_with_cross_entropy_op.cu index 8f7840cee1dd95a828fd4ac8815e335a5db47e3d..a559b01ed32a48e3befb37c2ae8935b4f3a4acb0 100644 --- a/paddle/fluid/operators/softmax_with_cross_entropy_op.cu +++ b/paddle/fluid/operators/softmax_with_cross_entropy_op.cu @@ -1,4 +1,4 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -14,6 +14,8 @@ limitations under the License. */ #define EIGEN_USE_GPU +#include +#include "paddle/fluid/operators/math/cross_entropy.h" #include "paddle/fluid/operators/softmax_with_cross_entropy_op.h" namespace paddle { @@ -53,8 +55,196 @@ __global__ void SoftCrossEntropyGradientKernel(T* logit_grad, logit_grad[ids] = loss_grad[row_ids] * (logit_grad[ids] - labels[ids]); } } + } // namespace +static __device__ __forceinline__ float real_exp(float x) { return expf(x); } +static __device__ __forceinline__ double real_exp(double x) { return exp(x); } +static __device__ __forceinline__ float real_log(float x) { + return math::TolerableValue()(logf(x)); +} +static __device__ __forceinline__ double real_log(double x) { + return math::TolerableValue()(log(x)); +} + +/** In the following codes, 3 CUDA kernels are implemented to calculate softmax + * and loss **/ +/* + Supposing the x is `logits` and y is `labels`, the equations are as +followings: + + cross\_entropy_i = \sum_{j}[- y_i_j * log({e^{x_i_j}/\sum_{j}e^{x_i_j}})] + = \sum_{j}[- y_i_j * log({e^{x_i_j - max_i}/\sum_{j}e^{x_i_j-max_i}})] + = \sum_{j}[-y_i_j * (x_i_j - max_i - log\sum_{j}e^{x_i_j - max_i})] + = \sum_{j}[-y_i_j * (x_i_j - max_i - logDiffMaxSum_i)] + = \sum_{j}(-y_i_j * tmp_i_j) + + softmax_i_j = e^{tmp_i_j} + +where: + max_i = \max_{j}{x_i_j} + logDiffMaxSum_i = log\sum_{j}e^{x_i_j - max_i} + tmp_i_j = x_i_j - max_i - logDiffMaxSum_i + +Therefore, the calculation can be separated into 3 steps: +Step 1: row-wise operation to calculate max_i +Step 2: row-wise operation to calculate logDiffMaxSum_i +Step 3: caculate tmp_i_j, and finally get softmax_i_j and cross\_entropy_i + +To save memory, we can share memory among max_i, logDiffMaxSum_i and +cross\_entropy_i. +In this way, the 3 steps should be changed to: +Step 1 (RowReductionForMax): row-wise operation to calculate max_i +Step 2 (RowReductionForDiffMaxSum): calculate immediate result of softmax'_i_j = +x_i_j - max_i, and row-wise operation to calculate logDiffMaxSum_i +Step 3 (RowReductionForSoftmaxAndCrossEntropy): calculate tmp_i_j = softmax'_i_j +- logDiffMaxSum_i, and finally get softmax_i_j and cross\_entropy_i +*/ + +// There are 3 kinds of reduce algorithms in cub: +// BLOCK_REDUCE_RAKING_COMMUTATIVE_ONLY +// BLOCK_REDUCE_RAKING +// BLOCK_REDUCE_WARP_REDUCTIONS (default) +template +using BlockReduce = + cub::BlockReduce; + +template +using BlockReduceTempStorage = typename BlockReduce::TempStorage; + +// Make sure that BlockDim <= feature_size +// This kernel is used to calculate the max element of each row +template +__global__ void RowReductionForMax(const T* logits_data, T* max_data, + int feature_size) { + __shared__ BlockReduceTempStorage temp_storage; + + auto beg_idx = feature_size * blockIdx.x + threadIdx.x; + auto end_idx = feature_size * (blockIdx.x + 1); + + T cur_max = logits_data[beg_idx]; + beg_idx += BlockDim; + while (beg_idx < end_idx) { + if (cur_max < logits_data[beg_idx]) { + cur_max = logits_data[beg_idx]; + } + beg_idx += BlockDim; + } + + cur_max = BlockReduce(temp_storage).Reduce(cur_max, cub::Max()); + + if (threadIdx.x == 0) { + max_data[blockIdx.x] = cur_max < -64 ? -64 : cur_max; + } +} + +// Make sure that BlockDim <= feature_size +template +__global__ void RowReductionForDiffMaxSum(const T* logits_data, T* max_data, + T* softmax, int feature_size) { + __shared__ BlockReduceTempStorage temp_storage; + + auto beg_idx = feature_size * blockIdx.x + threadIdx.x; + auto end_idx = feature_size * (blockIdx.x + 1); + + auto block_max = max_data[blockIdx.x]; + + softmax[beg_idx] = logits_data[beg_idx] - block_max; + T diff_max_sum = real_exp(softmax[beg_idx]); + beg_idx += BlockDim; + while (beg_idx < end_idx) { + softmax[beg_idx] = logits_data[beg_idx] - block_max; + diff_max_sum += real_exp(softmax[beg_idx]); + beg_idx += BlockDim; + } + + diff_max_sum = + BlockReduce(temp_storage).Reduce(diff_max_sum, cub::Sum()); + if (threadIdx.x == 0) max_data[blockIdx.x] = real_log(diff_max_sum); +} + +// Make sure that BlockDim <= feature_size +template +__global__ void RowReductionForSoftmaxAndCrossEntropy(const T* logits_data, + const T* labels_data, + T* loss_data, T* softmax, + int feature_size) { + __shared__ BlockReduceTempStorage temp_storage; + + auto beg_idx = feature_size * blockIdx.x + threadIdx.x; + auto end_idx = feature_size * (blockIdx.x + 1); + + // log_diff_max_sum shares memory with loss + auto block_log_diff_max_sum = loss_data[blockIdx.x]; + auto tmp = softmax[beg_idx] - block_log_diff_max_sum; + softmax[beg_idx] = real_exp(tmp); + auto loss = -labels_data[beg_idx] * tmp; + beg_idx += BlockDim; + while (beg_idx < end_idx) { + tmp = softmax[beg_idx] - block_log_diff_max_sum; + softmax[beg_idx] = real_exp(tmp); + loss -= (labels_data[beg_idx] * tmp); + beg_idx += BlockDim; + } + + loss = BlockReduce(temp_storage).Reduce(loss, cub::Sum()); + if (threadIdx.x == 0) loss_data[blockIdx.x] = loss; +} + +template +__global__ void SetSoftmaxToOneWhenFeatureSizeIsOne(T* out, int batch_size) { + auto idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx < batch_size) out[idx] = static_cast(1); +} + +template +static void SoftmaxWithCrossEntropyFusedKernel(const T* logits_data, + const T* labels_data, + T* softmax_data, T* loss_data, + int batch_size, int feature_size, + cudaStream_t stream) { + constexpr int kMaxBlockDim = 512; + int block_dim = feature_size >= kMaxBlockDim + ? kMaxBlockDim + : (1 << static_cast(std::log2(feature_size))); + +#define CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(BlockDim) \ + case BlockDim: \ + RowReductionForMax<<>>( \ + logits_data, loss_data, feature_size); \ + RowReductionForDiffMaxSum<<>>( \ + logits_data, loss_data, softmax_data, feature_size); \ + RowReductionForSoftmaxAndCrossEntropy< \ + T, BlockDim><<>>( \ + logits_data, labels_data, loss_data, softmax_data, feature_size); \ + break + + switch (block_dim) { + CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(512); + CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(256); + CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(128); + CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(64); + CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(32); + CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(16); + CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(8); + CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(4); + CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(2); + case 1: + SetSoftmaxToOneWhenFeatureSizeIsOne<<<(batch_size + kMaxBlockDim - 1) / + kMaxBlockDim, + kMaxBlockDim, 0, stream>>>( + softmax_data, batch_size); + cudaMemsetAsync(loss_data, 0, batch_size, stream); + break; + default: + PADDLE_THROW("BlockDim must be 2^n in softmax_with_cross_entropy_op"); + break; + } + +#undef CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL +} + template class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel { public: @@ -66,14 +256,24 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel { Tensor* softmax = context.Output("Softmax"); Tensor* loss = context.Output("Loss"); - softmax->mutable_data(context.GetPlace()); - loss->mutable_data(context.GetPlace()); - - math::SoftmaxFunctor()( - context.cuda_device_context(), logits, softmax); - math::CrossEntropyFunctor()( - context.cuda_device_context(), loss, softmax, labels, - context.Attr("soft_label")); + auto* softmax_data = softmax->mutable_data(context.GetPlace()); + auto* loss_data = loss->mutable_data(context.GetPlace()); + + auto soft_label = context.Attr("soft_label"); + if (soft_label) { + int batch_size = logits->dims()[0]; + int feature_size = logits->dims()[1]; + auto* logits_data = logits->data(); + auto* labels_data = labels->data(); + SoftmaxWithCrossEntropyFusedKernel( + logits_data, labels_data, softmax_data, loss_data, batch_size, + feature_size, context.cuda_device_context().stream()); + } else { + math::SoftmaxCUDNNFunctor()(context.cuda_device_context(), logits, + softmax); + math::CrossEntropyFunctor()( + context.cuda_device_context(), loss, softmax, labels, false); + } } }; diff --git a/python/paddle/dataset/conll05.py b/python/paddle/dataset/conll05.py index 4e94ce89892f8e6822c15fdc510805e75dfca988..a7c3c5402ec305797293adcc2ec00b78d8f83e95 100644 --- a/python/paddle/dataset/conll05.py +++ b/python/paddle/dataset/conll05.py @@ -29,13 +29,13 @@ __all__ = ['test, get_dict', 'get_embedding', 'convert'] DATA_URL = 'http://www.cs.upc.edu/~srlconll/conll05st-tests.tar.gz' DATA_MD5 = '387719152ae52d60422c016e92a742fc' -WORDDICT_URL = 'http://paddlepaddle.bj.bcebos.com/demo/srl_dict_and_embedding/wordDict.txt' +WORDDICT_URL = 'http://paddlemodels.bj.bcebos.com/conll05st%2FwordDict.txt' WORDDICT_MD5 = 'ea7fb7d4c75cc6254716f0177a506baa' -VERBDICT_URL = 'http://paddlepaddle.bj.bcebos.com/demo/srl_dict_and_embedding/verbDict.txt' +VERBDICT_URL = 'http://paddlemodels.bj.bcebos.com/conll05st%2FverbDict.txt' VERBDICT_MD5 = '0d2977293bbb6cbefab5b0f97db1e77c' -TRGDICT_URL = 'http://paddlepaddle.bj.bcebos.com/demo/srl_dict_and_embedding/targetDict.txt' +TRGDICT_URL = 'http://paddlemodels.bj.bcebos.com/conll05st%2FtargetDict.txt' TRGDICT_MD5 = 'd8c7f03ceb5fc2e5a0fa7503a4353751' -EMB_URL = 'http://paddlepaddle.bj.bcebos.com/demo/srl_dict_and_embedding/emb' +EMB_URL = 'http://paddlemodels.bj.bcebos.com/conll05st%2Femb' EMB_MD5 = 'bf436eb0faa1f6f9103017f8be57cdb7' UNK_IDX = 0 diff --git a/python/paddle/dataset/wmt14.py b/python/paddle/dataset/wmt14.py index f0908c737874fa7335cca5b5f0cba83190c9f90f..7a157e3497eb16dbf004469dcd906c97e41f3378 100644 --- a/python/paddle/dataset/wmt14.py +++ b/python/paddle/dataset/wmt14.py @@ -40,7 +40,7 @@ URL_TRAIN = ('http://paddlepaddle.cdn.bcebos.com/demo/' 'wmt_shrinked_data/wmt14.tgz') MD5_TRAIN = '0791583d57d5beb693b9414c5b36798c' # BLEU of this trained model is 26.92 -URL_MODEL = 'http://paddlepaddle.bj.bcebos.com/demo/wmt_14/wmt14_model.tar.gz' +URL_MODEL = 'http://paddlemodels.bj.bcebos.com/wmt%2Fwmt14.tgz' MD5_MODEL = '0cb4a5366189b6acba876491c8724fa3' START = "" diff --git a/python/paddle/fluid/tests/unittests/test_dist_transpiler.py b/python/paddle/fluid/tests/unittests/test_dist_transpiler.py index b24036326d51aa56220d46cba202a0d4b93cdd7c..abd372126848c5779cf7d989dc03e421dc94b1cf 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_transpiler.py +++ b/python/paddle/fluid/tests/unittests/test_dist_transpiler.py @@ -51,17 +51,17 @@ class TranspilerTest(unittest.TestCase): self.origin_prog = main.clone() return main - def get_trainer(self, config=None): - t = self._transpiler_instance(config) + def get_trainer(self, config=None, sync_mode=True): + t = self._transpiler_instance(config, sync_mode) return t.get_trainer_program() - def get_pserver(self, ep, config=None): - t = self._transpiler_instance(config) + def get_pserver(self, ep, config=None, sync_mode=True): + t = self._transpiler_instance(config, sync_mode) pserver = t.get_pserver_program(ep) startup = t.get_startup_program(ep, pserver) return pserver, startup - def _transpiler_instance(self, config=None): + def _transpiler_instance(self, config=None, sync_mode=True): if not self.transpiler: main = self.get_main_program() self.transpiler = fluid.DistributeTranspiler(config=config) @@ -69,7 +69,8 @@ class TranspilerTest(unittest.TestCase): self.trainer_id, program=main, pservers=self.pserver_eps, - trainers=self.trainers) + trainers=self.trainers, + sync_mode=sync_mode) return self.transpiler @@ -464,5 +465,76 @@ class TestDistLookupTable(TestDistLookupTableBase): self.assertEqual([op.type for op in trainer.blocks[0].ops], ops) +class TestAsyncLocalLookupTable(TestDistLookupTableBase): + def net_conf(self): + self.network_with_table(is_sparse=True, is_distributed=False) + + def transpiler_test_impl(self): + config = fluid.DistributeTranspilerConfig() + pserver1, startup1 = self.get_pserver(self.pserver1_ep, config, False) + + self.assertEqual(len(pserver1.blocks), 3) + # 0 listen_and_serv + # 1 optimize for fc_w or fc_b adam + self.assertEqual([op.type for op in pserver1.blocks[1].ops], + ["adam", "scale", "scale"]) + # 2 optimize for table adam + # NOTE: if param is not selected rows, the grad will scaled to grad / trainer_num + self.assertEqual([op.type for op in pserver1.blocks[2].ops], + ["adam", "scale", "scale"]) + + trainer = self.get_trainer(config) + self.assertEqual(len(trainer.blocks), 1) + ops = [ + 'lookup_table', 'sequence_pool', 'lookup_table', 'sequence_pool', + 'concat', 'mul', 'elementwise_add', 'cross_entropy', 'mean', + 'fill_constant', 'mean_grad', 'cross_entropy_grad', + 'elementwise_add_grad', 'send', 'mul_grad', 'send', 'concat_grad', + 'sequence_pool_grad', 'lookup_table_grad', 'sequence_pool_grad', + 'lookup_table_grad', 'sum', 'split_selected_rows', 'send', 'recv', + 'recv', 'recv', 'concat' + ] + self.assertEqual([op.type for op in trainer.blocks[0].ops], ops) + + +class TestAsyncDistLookupTable(TestDistLookupTableBase): + def net_conf(self): + self.network_with_table(is_sparse=True, is_distributed=True) + + def transpiler_test_impl(self): + config = fluid.DistributeTranspilerConfig() + + pserver1, startup1 = self.get_pserver(self.pserver1_ep, config, False) + + self.assertEqual(len(pserver1.blocks), 6) + # 0 listen_and_serv + # 1 optimize for fc_w or fc_b adam + self.assertEqual([op.type for op in pserver1.blocks[1].ops], + ["adam", "scale", "scale"]) + # 2 optimize for table sgd + self.assertEqual([op.type for op in pserver1.blocks[2].ops], ["sgd"]) + # 3 prefetch -> lookup_sparse_table for data0 + self.assertEqual([op.type for op in pserver1.blocks[3].ops], + ["lookup_sparse_table"]) + # 4 prefetch -> lookup_sparse_table for data1 + self.assertEqual([op.type for op in pserver1.blocks[4].ops], + ["lookup_sparse_table"]) + # 5 save table + self.assertEqual([op.type for op in pserver1.blocks[5].ops], ["save"]) + + trainer = self.get_trainer(config) + self.assertEqual(len(trainer.blocks), 1) + ops = [ + 'split_ids', 'prefetch', 'merge_ids', 'sequence_pool', 'split_ids', + 'prefetch', 'merge_ids', 'sequence_pool', 'concat', 'mul', + 'elementwise_add', 'cross_entropy', 'mean', 'fill_constant', + 'mean_grad', 'cross_entropy_grad', 'elementwise_add_grad', 'send', + 'mul_grad', 'send', 'concat_grad', 'sequence_pool_grad', + 'lookup_table_grad', 'sequence_pool_grad', 'lookup_table_grad', + 'sum', 'split_ids', 'send', 'recv', 'recv' + ] + self.assertEqual([op.type for op in trainer.blocks[0].ops], ops) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index b0a100e1db34ad2971eadabff09fa5d0ce3f51dc..820509bbcc4679cadb06554476798c76e6869eb5 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -293,14 +293,15 @@ class DistributeTranspiler(object): RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE }) - program.global_block().append_op( - type="fetch_barrier", - inputs={}, - outputs={}, - attrs={ - "endpoints": pserver_endpoints, - RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE - }) + if self.sync_mode: + program.global_block().append_op( + type="fetch_barrier", + inputs={}, + outputs={}, + attrs={ + "endpoints": pserver_endpoints, + RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE + }) for varname, splited_var in self.param_var_mapping.iteritems(): if len(splited_var) <= 1: diff --git a/python/paddle/v2/dataset/conll05.py b/python/paddle/v2/dataset/conll05.py index 0d544efac9cd20157f87b5cd3b68f97ab5ed2dbc..8312900dc43fdd64cc1a205ab846b6f1deaecf5d 100644 --- a/python/paddle/v2/dataset/conll05.py +++ b/python/paddle/v2/dataset/conll05.py @@ -29,13 +29,13 @@ __all__ = ['test, get_dict', 'get_embedding', 'convert'] DATA_URL = 'http://www.cs.upc.edu/~srlconll/conll05st-tests.tar.gz' DATA_MD5 = '387719152ae52d60422c016e92a742fc' -WORDDICT_URL = 'http://paddlepaddle.bj.bcebos.com/demo/srl_dict_and_embedding/wordDict.txt' +WORDDICT_URL = 'http://paddlemodels.bj.bcebos.com/conll05st%2FwordDict.txt' WORDDICT_MD5 = 'ea7fb7d4c75cc6254716f0177a506baa' -VERBDICT_URL = 'http://paddlepaddle.bj.bcebos.com/demo/srl_dict_and_embedding/verbDict.txt' +VERBDICT_URL = 'http://paddlemodels.bj.bcebos.com/conll05st%2FverbDict.txt' VERBDICT_MD5 = '0d2977293bbb6cbefab5b0f97db1e77c' -TRGDICT_URL = 'http://paddlepaddle.bj.bcebos.com/demo/srl_dict_and_embedding/targetDict.txt' +TRGDICT_URL = 'http://paddlemodels.bj.bcebos.com/conll05st%2FtargetDict.txt' TRGDICT_MD5 = 'd8c7f03ceb5fc2e5a0fa7503a4353751' -EMB_URL = 'http://paddlepaddle.bj.bcebos.com/demo/srl_dict_and_embedding/emb' +EMB_URL = 'http://paddlemodels.bj.bcebos.com/conll05st%2Femb' EMB_MD5 = 'bf436eb0faa1f6f9103017f8be57cdb7' UNK_IDX = 0 diff --git a/python/paddle/v2/dataset/wmt14.py b/python/paddle/v2/dataset/wmt14.py index 5104e29051e4480f3a7eb18421f1b519841b009b..1ec210f265049c8b62cd99cd218f25a9f846ef43 100644 --- a/python/paddle/v2/dataset/wmt14.py +++ b/python/paddle/v2/dataset/wmt14.py @@ -41,7 +41,7 @@ URL_TRAIN = ('http://paddlepaddle.cdn.bcebos.com/demo/' 'wmt_shrinked_data/wmt14.tgz') MD5_TRAIN = '0791583d57d5beb693b9414c5b36798c' # BLEU of this trained model is 26.92 -URL_MODEL = 'http://paddlepaddle.bj.bcebos.com/demo/wmt_14/wmt14_model.tar.gz' +URL_MODEL = 'http://paddlemodels.bj.bcebos.com/wmt%2Fwmt14.tgz' MD5_MODEL = '0cb4a5366189b6acba876491c8724fa3' START = ""