From 71168dad57402e3c15531cfb5342e94f4a456cf6 Mon Sep 17 00:00:00 2001 From: chengduo <30176695+chengduoZH@users.noreply.github.com> Date: Wed, 21 Aug 2019 10:03:05 +0800 Subject: [PATCH] [Cherry Pick] Bug fix and speedup dygraph multi-cards on v1.5 (#19298) * add warning info for CPU_NUM test=develop * update dygraph parallel.py test=develop * prune the feed op in compiler test=release/1.5 * remove compile from PE test=develop * test CUDAPinnedPlace in reader test=release/1.5 --- paddle/fluid/framework/tensor_util.cc | 4 + paddle/fluid/operators/assign_op.cc | 62 +++++++++----- .../fluid/operators/reader/buffered_reader.cc | 13 ++- paddle/fluid/pybind/pybind.cc | 4 +- python/paddle/fluid/compiler.py | 12 ++- python/paddle/fluid/dygraph/parallel.py | 68 ++++++++++++++- python/paddle/fluid/framework.py | 13 +-- python/paddle/fluid/parallel_executor.py | 2 +- ...arallel_executor_run_load_infer_program.py | 85 +++++++++++++++++++ 9 files changed, 226 insertions(+), 37 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/test_parallel_executor_run_load_infer_program.py diff --git a/paddle/fluid/framework/tensor_util.cc b/paddle/fluid/framework/tensor_util.cc index 33ef3b9186..3bc24bc1e5 100644 --- a/paddle/fluid/framework/tensor_util.cc +++ b/paddle/fluid/framework/tensor_util.cc @@ -99,6 +99,8 @@ void TensorCopy(const Tensor& src, const platform::Place& dst_place, PADDLE_THROW("ctx is not belong to dst_gpu_place or src_gpu_place."); } } + } else { + PADDLE_THROW("Copy from %s to %s is not supported.", src_place, dst_place); } #endif } @@ -166,6 +168,8 @@ void TensorCopySync(const Tensor& src, const platform::Place& dst_place, auto dst_gpu_place = boost::get(dst_place); memory::Copy(dst_gpu_place, dst_ptr, src_pinned_place, src_ptr, size, nullptr); + } else { + PADDLE_THROW("Copy from %s to %s is not supported.", src_place, dst_place); } #endif } diff --git a/paddle/fluid/operators/assign_op.cc b/paddle/fluid/operators/assign_op.cc index d9294048a9..871dfe6734 100644 --- a/paddle/fluid/operators/assign_op.cc +++ b/paddle/fluid/operators/assign_op.cc @@ -66,27 +66,47 @@ class AssignFunctor { const platform::DeviceContext &dev_ctx_; }; -class AssignOp : public framework::OperatorBase { +class AssignOp : public framework::OperatorWithKernel { public: AssignOp(const std::string &type, const framework::VariableNameMap &inputs, const framework::VariableNameMap &outputs, const framework::AttributeMap &attrs) - : OperatorBase(type, inputs, outputs, attrs) {} + : OperatorWithKernel(type, inputs, outputs, attrs) {} - private: - void RunImpl(const framework::Scope &scope, - const platform::Place &place) const override { - auto *x = scope.FindVar(Input("X")); + void InferShape(framework::InferShapeContext *ctx) const override { + if (ctx->HasInput("X")) { + auto type = ctx->GetInputsVarType("X")[0]; + if (type == framework::proto::VarType::SELECTED_ROWS || + type == framework::proto::VarType::LOD_TENSOR) { + ctx->SetOutputDim("Out", ctx->GetInputDim("X")); + if (type == framework::proto::VarType::LOD_TENSOR) { + ctx->ShareLoD("X", /*->*/ "Out"); + } + } + } + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.device_context()); + } +}; + +class AssignKernel { + public: + void operator()(const framework::ExecutionContext &ctx) const { + auto *x = ctx.InputVar("X"); if (x == nullptr) { return; } - auto *out = scope.FindVar(Output("Out")); + auto *out = ctx.OutputVar("Out"); PADDLE_ENFORCE( out != nullptr, "The Output(Out) should not be null if the Input(X) is set."); - platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); - auto &dev_ctx = *pool.Get(place); + auto &dev_ctx = *pool.Get(ctx.GetPlace()); framework::VisitVarType(*x, AssignFunctor(out, dev_ctx)); } @@ -110,19 +130,6 @@ raise error if the type is not listed above. } }; -class AssignInferShape : public framework::InferShapeBase { - public: - void operator()(framework::InferShapeContext *context) const override { - if (context->HasInput("X")) { - auto type = context->GetInputsVarType("X")[0]; - if (type == framework::proto::VarType::SELECTED_ROWS || - type == framework::proto::VarType::LOD_TENSOR) { - context->SetOutputDim("Out", context->GetInputDim("X")); - } - } - } -}; - class AssignGradMaker : public framework::SingleGradOpDescMaker { public: using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; @@ -142,4 +149,13 @@ class AssignGradMaker : public framework::SingleGradOpDescMaker { namespace ops = paddle::operators; REGISTER_OPERATOR(assign, ops::AssignOp, ops::AssignGradMaker, - ops::AssignInferShape, ops::AssignOpProtoMaker); + ops::AssignOpProtoMaker); +REGISTER_OP_CPU_KERNEL_FUNCTOR(assign, float, ops::AssignKernel, double, + ops::AssignKernel, int, ops::AssignKernel, + int64_t, ops::AssignKernel); + +#ifdef PADDLE_WITH_CUDA +REGISTER_OP_CUDA_KERNEL_FUNCTOR(assign, float, ops::AssignKernel, double, + ops::AssignKernel, int, ops::AssignKernel, + int64_t, ops::AssignKernel); +#endif diff --git a/paddle/fluid/operators/reader/buffered_reader.cc b/paddle/fluid/operators/reader/buffered_reader.cc index 16cb08f419..b332450c25 100644 --- a/paddle/fluid/operators/reader/buffered_reader.cc +++ b/paddle/fluid/operators/reader/buffered_reader.cc @@ -128,9 +128,18 @@ void BufferedReader::ReadAsync(size_t i) { boost::get(cpu_place), cpu_ptr, size, stream_); } else { + platform::CUDAPinnedPlace cuda_pinned_place; + framework::LoDTensor cuda_pinned_tensor; + cuda_pinned_tensor.Resize(cpu[i].dims()); + auto cuda_pinned_ptr = + cuda_pinned_tensor.mutable_data(cuda_pinned_place, cpu[i].type()); + memory::Copy(cuda_pinned_place, cuda_pinned_ptr, + boost::get(cpu_place), cpu_ptr, + size); memory::Copy(boost::get(place_), gpu_ptr, - boost::get(cpu_place), cpu_ptr, size, - stream_); + cuda_pinned_place, cuda_pinned_ptr, size, stream_); + PADDLE_ENFORCE(cudaStreamSynchronize(stream_), + "cuda stream sync error."); } gpu[i].set_lod(cpu[i].lod()); } diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 9738601bd6..2886d9786b 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -64,13 +64,13 @@ limitations under the License. */ #ifndef _WIN32 #include "paddle/fluid/pybind/nccl_wrapper_py.h" #endif +#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/pybind/protobuf.h" #include "paddle/fluid/pybind/pybind.h" // NOLINT #include "paddle/fluid/pybind/reader_py.h" #include "paddle/fluid/pybind/recordio.h" #include "paddle/fluid/pybind/tensor_py.h" #include "paddle/fluid/string/to_string.h" - #ifdef PADDLE_WITH_CUDA #ifndef _WIN32 #include "paddle/fluid/operators/nccl/nccl_gpu_common.h" @@ -1118,6 +1118,8 @@ All parameter, weight, gradient are variables in Paddle. return std::shared_ptr(std::move(pass)); }); + m.def("size_of_dtype", framework::SizeOfType); + py::class_> pass(m, "Pass"); pass.def(py::init()) .def("has", &ir::Pass::Has) diff --git a/python/paddle/fluid/compiler.py b/python/paddle/fluid/compiler.py index 7ffecb69a9..d47976844f 100644 --- a/python/paddle/fluid/compiler.py +++ b/python/paddle/fluid/compiler.py @@ -46,6 +46,15 @@ def _is_pserver_mode(main_program): return False +def _prune_feed_ops(program): + # prune the feed ops in the program. + pop_idx = [] + for i, op in enumerate(program.global_block().ops): + if op.type == "feed": pop_idx.append(i) + for index in pop_idx[::-1]: + program.global_block()._remove_op(index) + + class CompiledProgram(object): """ Compiles to Graph for execution. @@ -101,6 +110,7 @@ class CompiledProgram(object): # don't not create a new program here. self._program = None elif isinstance(program_or_graph, framework.Program): + _prune_feed_ops(program_or_graph) self._graph = core.Graph(program_or_graph.desc) self._program = program_or_graph else: @@ -274,8 +284,6 @@ class CompiledProgram(object): "share_vars_from is not compiled and run, so there is no " "var to share.") self._local_scopes = self._share_vars_from._executor.local_scopes() - # drop the local_exe_scopes of the previous parallel_executor - self._share_vars_from._executor.drop_local_exe_scopes() else: assert scope is not None, "" self._local_scopes = [] diff --git a/python/paddle/fluid/dygraph/parallel.py b/python/paddle/fluid/dygraph/parallel.py index 37716cea14..c17cfc73de 100644 --- a/python/paddle/fluid/dygraph/parallel.py +++ b/python/paddle/fluid/dygraph/parallel.py @@ -14,7 +14,7 @@ import os import six import numpy as np - +from collections import OrderedDict from .. import core from . import layers from . import parallel_helper @@ -36,7 +36,7 @@ def prepare_context(strategy=None): strategy.current_endpoint = Env().current_endpoint if strategy.nranks < 2: return - assert framework.in_dygraph_mode() is True,\ + assert framework.in_dygraph_mode() is True, \ "dygraph.parallel.prepare_context should be used with dygrahp mode." place = framework._current_expected_place() assert place is not None, \ @@ -168,6 +168,37 @@ class DataParallel(layers.Layer): loss = loss / loss_scale return loss + def _coalesce_tensors(self, var_groups): + from ..layers import nn + coalesced_grads_and_grad_vars = [] + for group_id, grad_vars in var_groups.items(): + flattened_vars = [] + g_var_shapes = [] + for g_var in grad_vars: + g_var_shapes.append(g_var.shape) + flattened_vars.append( + nn.reshape( + x=g_var, shape=[np.prod(g_var.shape)], inplace=True)) + coalesced_grad = nn.concat(flattened_vars) + coalesced_grads_and_grad_vars.append( + [coalesced_grad, grad_vars, g_var_shapes]) + return coalesced_grads_and_grad_vars + + def _split_tensors(self, coalesced_grads_and_grad_vars): + from ..layers import nn + for coalesced_grad, origin_grad_vars, grad_shapes in coalesced_grads_and_grad_vars: + grad_var_len = [np.prod(g_shape) for g_shape in grad_shapes] + splited_vars = nn.split( + coalesced_grad, num_or_sections=grad_var_len, dim=0) + reshaped_grad_vars = [] + for g_var, g_shape in zip(splited_vars, grad_shapes): + reshaped_grad_vars.append( + nn.reshape( + x=g_var, shape=g_shape, inplace=True)) + for origin_g_var, reshaped_g_var in zip(origin_grad_vars, + reshaped_grad_vars): + nn.assign(input=reshaped_g_var, output=origin_g_var) + def apply_collective_grads(self): """ AllReduce the Parameters' gradient. @@ -175,6 +206,8 @@ class DataParallel(layers.Layer): if not self._is_data_parallel_mode(): return + grad_var_set = set() + grad_vars = [] for param in self._layers.parameters(): # NOTE(zcd): The grad_ivar maybe no generated. if param.trainable and param._ivar._grad_ivar(): @@ -183,7 +216,36 @@ class DataParallel(layers.Layer): name=param._ivar._grad_name(), stop_gradient=True, ivar=param._ivar._grad_ivar()) - collective._allreduce(g_var, g_var, sync_mode=True) + grad_vars.append(g_var) + assert g_var not in grad_var_set + grad_var_set.add(g_var) + + # FIXME(zcd): the type of the var should be LoDTensor, i.e + # the gradients should be dense, otherwise, the following + # logic should be updated. + # 128 MB as a group + mega_bytes = 128 * 1024 * 1024 + group_idx = 0 + memory_counter = 0 + grad_var_groups = OrderedDict() + dtype = grad_vars[0].dtype + for g_var in grad_vars: + # Note: the dtype of the same group should be the same. + bytes = np.prod(g_var.shape) * core.size_of_dtype(g_var.dtype) + if memory_counter < mega_bytes and dtype == g_var.dtype: + memory_counter += bytes + else: + memory_counter = bytes + group_idx += 1 + grad_var_groups.setdefault(group_idx, []).append(g_var) + + coalesced_grads_and_vars = self._coalesce_tensors(grad_var_groups) + + for coalesced_grad, g_vars, g_shapes in coalesced_grads_and_vars: + collective._allreduce( + coalesced_grad, coalesced_grad, sync_mode=False) + + self._split_tensors(coalesced_grads_and_vars) def _is_data_parallel_mode(self): return self._strategy.nranks > 1 diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 10a53c84af..572a7c1e91 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -85,11 +85,14 @@ def _current_expected_place(): def _cpu_num(): if "CPU_NUM" not in os.environ.keys(): - sys.stderr.write( - 'The CPU_NUM is not specified, you should set CPU_NUM in ' - 'the environment variable list, i.e export CPU_NUM=1. CPU_NUM ' - 'indicates that how many CPUPlace are used in the current task.\n' - '!!! The default number of CPUPlaces is 1.\n\n') + if multiprocessing.cpu_count() > 1: + sys.stderr.write( + '!!! The CPU_NUM is not specified, you should set CPU_NUM in the environment variable list.\n' + 'CPU_NUM indicates that how many CPUPlace are used in the current task.\n' + 'And if this parameter are set as N (equal to the number of physical CPU core) the program may be faster.\n\n' + 'export CPU_NUM={} # for example, set CPU_NUM as number of physical CPU core which is {}.\n\n' + '!!! The default number of CPU_NUM=1.\n'.format( + multiprocessing.cpu_count(), multiprocessing.cpu_count())) os.environ['CPU_NUM'] = str(1) cpu_num = os.environ.get('CPU_NUM') return int(cpu_num) diff --git a/python/paddle/fluid/parallel_executor.py b/python/paddle/fluid/parallel_executor.py index d4a1041a4b..576ce2a95c 100644 --- a/python/paddle/fluid/parallel_executor.py +++ b/python/paddle/fluid/parallel_executor.py @@ -163,6 +163,7 @@ class ParallelExecutor(object): assert isinstance( share_vars_from, ParallelExecutor ), "The share_vars_from should be ParallelExecutor." + self._compiled_program.with_data_parallel( loss_name=loss_name, build_strategy=build_strategy, @@ -172,7 +173,6 @@ class ParallelExecutor(object): self._place = core.CUDAPlace(0) if use_cuda else core.CPUPlace() self._exe = executor.Executor(self._place) - self._compiled_program._compile(place=self._place, scope=self._scope) def run(self, fetch_list, feed=None, feed_dict=None, return_numpy=True): """ diff --git a/python/paddle/fluid/tests/unittests/test_parallel_executor_run_load_infer_program.py b/python/paddle/fluid/tests/unittests/test_parallel_executor_run_load_infer_program.py new file mode 100644 index 0000000000..fc76f5d152 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_parallel_executor_run_load_infer_program.py @@ -0,0 +1,85 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest + +import paddle.fluid as fluid +from simple_nets import simple_fc_net, init_data + + +class TestMNIST(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.save_dirname = "./" + cls.model_filename = "test_parallel_executor_run_load_infer_program_model" + cls.params_filename = "test_parallel_executor_run_load_infer_program_parameter" + cls.place = fluid.CPUPlace() + cls.exe = fluid.Executor(cls.place) + img, label = init_data() + cls.batch_data = [] + for img, label in zip(img, label): + cls.batch_data.append([img, label]) + + def test_simple_fc(self): + exe_loss = self.run_with_executor() + + [inference_program, feed_target_names, + fetch_targets] = fluid.io.load_inference_model( + self.save_dirname, self.exe, self.model_filename, + self.params_filename) + + train_exe = fluid.ParallelExecutor( + use_cuda=False, main_program=inference_program) + feed_vars = [ + inference_program.global_block().var(var_name) + for var_name in ["image", "label"] + ] + feeder = fluid.DataFeeder(place=self.place, feed_list=feed_vars) + + pe_loss = train_exe.run(feed=feeder.feed(self.batch_data), + fetch_list=[fetch_targets[0].name]) + assert exe_loss == pe_loss + + def run_with_executor(self): + main = fluid.Program() + startup = fluid.Program() + with fluid.program_guard(main, startup): + loss = simple_fc_net() + + feed_vars = [ + main.global_block().var(var_name) + for var_name in ["image", "label"] + ] + feeder = fluid.DataFeeder(place=self.place, feed_list=feed_vars) + + self.exe.run(startup) + + loss_data = self.exe.run(main, + feed=feeder.feed(self.batch_data), + fetch_list=[loss.name]) + + fluid.io.save_inference_model( + self.save_dirname, ["image", "label"], [loss], + self.exe, + model_filename=self.model_filename, + params_filename=self.params_filename, + main_program=main) + + return loss_data + + +if __name__ == '__main__': + unittest.main() -- GitLab