diff --git a/paddle/fluid/framework/details/all_reduce_op_handle.cc b/paddle/fluid/framework/details/all_reduce_op_handle.cc index ed75b48090b27a9c430afe067467a1a39d711938..61276efedeeca76a8818c15ddab73b3c53725c4b 100644 --- a/paddle/fluid/framework/details/all_reduce_op_handle.cc +++ b/paddle/fluid/framework/details/all_reduce_op_handle.cc @@ -53,6 +53,10 @@ AllReduceOpHandle::AllReduceOpHandle(ir::Node *node, this->SetDeviceContext(p, nccl_ctxs_->DevCtx(p)); } } + // TODO(gongwb) :polish them! + if (is_encoded) { + VLOG(1) << "Use dgc allreduce mode"; + } } #else AllReduceOpHandle::AllReduceOpHandle(ir::Node *node, @@ -86,7 +90,7 @@ void AllReduceOpHandle::RunImplEncoded() { paddle::framework::GradOriginalVarName(in_var_handles[i]->name()); auto encode_var_name = original_name + g_dgc_encoded; auto *in_var = local_scope->FindVar(encode_var_name); - PADDLE_ENFORCE_NOT_NULL(in_var); + PADDLE_ENFORCE_NOT_NULL(in_var, "%s should not be null", encode_var_name); auto &in = in_var->Get(); ins.emplace_back(&in); diff --git a/paddle/fluid/framework/details/build_strategy.cc b/paddle/fluid/framework/details/build_strategy.cc index f8bf43bcb48226b4d1317a78ade7179741097378..afe5078bf80d00b595789a5f45d91a5e7a8dfce6 100644 --- a/paddle/fluid/framework/details/build_strategy.cc +++ b/paddle/fluid/framework/details/build_strategy.cc @@ -142,6 +142,14 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { AppendPass("memory_optimize_pass"); } + // runtime_context_cache pass should be the last pass to enable the attr of + // all original and fused operators. But no operators can be enabled this + // attr if putting it after MultiDevPass. + if (strategy_.cache_runtime_context_) { + VLOG(10) << "Add runtime_context_cache_pass"; + AppendPass("runtime_context_cache_pass"); + } + AppendMultiDevPass(strategy_); if (strategy_.fuse_all_reduce_ops_) { @@ -243,7 +251,7 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph, CreatePassesFromStrategy(false); for (std::shared_ptr &pass : pass_builder_->AllPasses()) { - VLOG(3) << "apply " << pass->Type(); + VLOG(3) << "BuildStrategy::Apply pass:" << pass->Type(); if (IsMultiDevPass(pass->Type())) { pass->Erase(kPlaces); pass->SetNotOwned>(kPlaces, &places); @@ -328,3 +336,4 @@ USE_PASS(graph_to_program_pass); USE_PASS(fuse_adam_op_pass); USE_PASS(fuse_sgd_op_pass); USE_PASS(fuse_all_reduce_op_pass); +USE_PASS(runtime_context_cache_pass); diff --git a/paddle/fluid/framework/details/build_strategy.h b/paddle/fluid/framework/details/build_strategy.h index cc48c51e924039d93b2e1e18bea752611e7bef92..8aa444a30c0f7f1f5c19d54cf248f86c3e3b3cf3 100644 --- a/paddle/fluid/framework/details/build_strategy.h +++ b/paddle/fluid/framework/details/build_strategy.h @@ -107,6 +107,8 @@ struct BuildStrategy { std::vector trainers_endpoints_; bool remove_unnecessary_lock_{true}; + bool cache_runtime_context_{false}; + // NOTE: // Before you add new options, think if it's a general strategy that works // with other strategy. If not, the strategy should be created through diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index ba1d7379c56d953a0f37d03deed6c47e46cbf129..16fc1721eb6f5d2517ad45289f2415ef41749df2 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -68,6 +68,7 @@ pass_library(transpose_flatten_concat_fuse_pass inference) pass_library(identity_scale_op_clean_pass base) pass_library(sync_batch_norm_pass base) pass_library(runtime_context_cache_pass base) +pass_library(expected_kernel_cache_pass base) pass_library(quant_conv2d_dequant_fuse_pass inference) pass_library(fillconstant_elementwisemul_fuse inference) diff --git a/paddle/fluid/framework/ir/expected_kernel_cache_pass.cc b/paddle/fluid/framework/ir/expected_kernel_cache_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..ee67af0aff5c90a9da0ece8f197d9a0c0a8e5b9c --- /dev/null +++ b/paddle/fluid/framework/ir/expected_kernel_cache_pass.cc @@ -0,0 +1,37 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/ir/expected_kernel_cache_pass.h" +#include +#include "paddle/fluid/framework/operator.h" + +namespace paddle { +namespace framework { +namespace ir { + +void ExpectedKernelCachePass::ApplyImpl(ir::Graph* graph) const { + VLOG(3) << "Applies Expected Kernel Cache strategy."; + for (const Node* n : graph->Nodes()) { + if (n->IsOp()) { + n->Op()->SetAttr(kEnableCacheExpectedKernel, true); + } + } +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(expected_kernel_cache_pass, + paddle::framework::ir::ExpectedKernelCachePass); diff --git a/paddle/fluid/framework/ir/expected_kernel_cache_pass.h b/paddle/fluid/framework/ir/expected_kernel_cache_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..bf0907d3feb7bccd163363da65505e0af3fb9bf6 --- /dev/null +++ b/paddle/fluid/framework/ir/expected_kernel_cache_pass.h @@ -0,0 +1,31 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include "paddle/fluid/framework/ir/pass.h" + +namespace paddle { +namespace framework { +namespace ir { + +class ExpectedKernelCachePass : public Pass { + protected: + void ApplyImpl(ir::Graph* graph) const override; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/runtime_context_cache_pass.cc b/paddle/fluid/framework/ir/runtime_context_cache_pass.cc index c7cf9b0dc342bbfaa80b622d7dcd0f6348f78d42..566b654f237cbd71e1983c971374ee13d7b36805 100644 --- a/paddle/fluid/framework/ir/runtime_context_cache_pass.cc +++ b/paddle/fluid/framework/ir/runtime_context_cache_pass.cc @@ -23,7 +23,7 @@ namespace ir { void RuntimeContextCachePass::ApplyImpl(ir::Graph* graph) const { VLOG(3) << "Applies Runtime Context Cache strategy."; for (const Node* n : graph->Nodes()) { - if (n->IsOp()) { + if (n->IsOp() && n->Op()) { n->Op()->SetAttr(kEnableCacheRuntimeContext, true); } } diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 8dd5bc6dd8ed49b242d23cf7747288e0d86e5c49..040f669f2b4da50d005dd52ee7cc49f9424eff20 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -906,50 +906,23 @@ void OperatorWithKernel::RunImpl(const Scope& scope, platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); auto* dev_ctx = pool.Get(place); - // check if op[type] has kernel registered. - auto& all_op_kernels = AllOpKernels(); - auto kernels_iter = all_op_kernels.find(type_); - if (kernels_iter == all_op_kernels.end()) { - PADDLE_THROW( - "There are no kernels which are registered in the %s operator.", type_); + if (!HasAttr(kEnableCacheExpectedKernel) || !kernel_type_) { + ChooseKernel(*runtime_ctx, scope, place); } - OpKernelMap& kernels = kernels_iter->second; - - auto expected_kernel_key = this->GetExpectedKernelType( - ExecutionContext(*this, scope, *dev_ctx, *runtime_ctx, nullptr)); - VLOG(3) << "expected_kernel_key:" << expected_kernel_key; - - auto kernel_iter = kernels.find(expected_kernel_key); -#ifdef PADDLE_WITH_MKLDNN - // workaround for missing MKLDNN kernel when FLAGS_use_mkldnn env var is set - if (kernel_iter == kernels.end() && - expected_kernel_key.library_type_ == LibraryType::kMKLDNN) { - VLOG(3) << "missing MKLDNN kernel: fallbacking to PLAIN one"; - expected_kernel_key.library_type_ = LibraryType::kPlain; - expected_kernel_key.data_layout_ = DataLayout::kAnyLayout; - kernel_iter = kernels.find(expected_kernel_key); - } -#endif - if (kernel_iter == kernels.end()) { - PADDLE_THROW("op %s does not have kernel for %s", type_, - KernelTypeToString(expected_kernel_key)); - } - - std::vector* kernel_configs = - GetKernelConfig(expected_kernel_key); + std::vector* kernel_configs = GetKernelConfig(*kernel_type_); // do data transformScope &transfer_scope; std::vector transfered_inplace_vars; - auto* transfer_scope = PrepareData(scope, expected_kernel_key, - &transfered_inplace_vars, runtime_ctx); + auto* transfer_scope = + PrepareData(scope, *kernel_type_, &transfered_inplace_vars, runtime_ctx); // exec scope is the scope that kernel actually executed on. const Scope& exec_scope = (transfer_scope == nullptr ? scope : *transfer_scope); - if (!(expected_kernel_key.place_ == dev_ctx->GetPlace())) { - dev_ctx = pool.Get(expected_kernel_key.place_); + if (!(kernel_type_->place_ == dev_ctx->GetPlace())) { + dev_ctx = pool.Get(kernel_type_->place_); } if (!all_kernels_must_compute_runtime_shape) { @@ -958,8 +931,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope, } // TODO(panyx0718): ExecutionContext should only depend on RuntimeContext // not Scope. Imperative mode only pass inputs and get outputs. - kernel_iter->second(ExecutionContext(*this, exec_scope, *dev_ctx, - *runtime_ctx, kernel_configs)); + (*kernel_func_)(ExecutionContext(*this, exec_scope, *dev_ctx, *runtime_ctx, + kernel_configs)); if (!transfered_inplace_vars.empty()) { // there is inplace variable has been transfered. @@ -985,6 +958,46 @@ void OperatorWithKernel::RunImpl(const Scope& scope, } } +void OperatorWithKernel::ChooseKernel(const RuntimeContext& ctx, + const Scope& scope, + const platform::Place& place) const { + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + auto* dev_ctx = pool.Get(place); + + // check if op[type] has kernel registered. + auto& all_op_kernels = AllOpKernels(); + auto kernels_iter = all_op_kernels.find(type_); + if (kernels_iter == all_op_kernels.end()) { + PADDLE_THROW( + "There are no kernels which are registered in the %s operator.", type_); + } + + OpKernelMap& kernels = kernels_iter->second; + + auto expected_kernel_key = this->GetExpectedKernelType( + ExecutionContext(*this, scope, *dev_ctx, ctx, nullptr)); + VLOG(3) << "expected_kernel_key:" << expected_kernel_key; + + auto kernel_iter = kernels.find(expected_kernel_key); +#ifdef PADDLE_WITH_MKLDNN + // workaround for missing MKLDNN kernel when FLAGS_use_mkldnn env var is set + if (kernel_iter == kernels.end() && + expected_kernel_key.library_type_ == LibraryType::kMKLDNN) { + VLOG(3) << "missing MKLDNN kernel: fallbacking to PLAIN one"; + expected_kernel_key.library_type_ = LibraryType::kPlain; + expected_kernel_key.data_layout_ = DataLayout::kAnyLayout; + kernel_iter = kernels.find(expected_kernel_key); + } +#endif + if (kernel_iter == kernels.end()) { + PADDLE_THROW("op %s does not have kernel for %s", type_, + KernelTypeToString(expected_kernel_key)); + } + + kernel_type_.reset(new OpKernelType(expected_kernel_key)); + kernel_func_.reset(new OpKernelFunc(kernel_iter->second)); +} + void OperatorWithKernel::TransferInplaceVarsBack( const Scope& scope, const std::vector& inplace_vars, const Scope& transfer_scope) const { diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index 108094d81ba429dca1e7a75fcbce43733ab125dd..2e733badb8ab17826f7a12e9e17771f3b8dd3fb8 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -70,6 +70,12 @@ constexpr char kNewGradSuffix[] = "@NEWGRAD@"; /// this Op's execution to save the elapsed time. constexpr char kEnableCacheRuntimeContext[] = "@ENABLE_CACHE_RUNTIME_CONTEXT@"; +/// If an Op has attribtue kEnableCacheExpectedKernel, it means that in a same +/// name scope and same place, since the expected kerenl of this Op does not +/// change in the execution, it could be recorded only at the first iteration of +/// this Op's execution to save the elapsed time. +constexpr char kEnableCacheExpectedKernel[] = "@ENABLE_CACHE_EXPECTED_KERNEL@"; + /// If an Op has this attribute, all its kernels should calculate output /// variable's shape in the corresponding Compute() function. And /// OperatorWithKernel::RunImpl() would skip call this Op's InferShape() @@ -491,8 +497,13 @@ class OperatorWithKernel : public OperatorBase { const std::vector& inplace_vars, const Scope& exec_scope) const; + void ChooseKernel(const RuntimeContext& ctx, const Scope& scope, + const platform::Place& place) const; + protected: mutable OpKernelConfigsMap kernel_configs_map_; + mutable std::unique_ptr kernel_type_; + mutable std::unique_ptr kernel_func_; mutable std::unique_ptr runtime_ctx_; mutable const Scope* pre_scope_ = nullptr; mutable bool enable_cache_runtime_context = false; diff --git a/paddle/fluid/imperative/nccl_context.cc b/paddle/fluid/imperative/nccl_context.cc index 04364e68342810f6babee1df0c5eb3b476de022a..f96c83936df590e5bd3abe89b7e7c2a6ddf92d01 100644 --- a/paddle/fluid/imperative/nccl_context.cc +++ b/paddle/fluid/imperative/nccl_context.cc @@ -33,8 +33,7 @@ void NCCLParallelContext::RecvNCCLID(const std::string &ep, // creating socket fd if ((server_fd = socket(AF_INET, SOCK_STREAM, 0)) == 0) PADDLE_THROW("create server fd failed"); - if (setsockopt(server_fd, SOL_SOCKET, SO_REUSEADDR | SO_REUSEPORT, &opt, - sizeof(opt))) + if (setsockopt(server_fd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt))) PADDLE_THROW("set socket opt failed"); address.sin_family = AF_INET; diff --git a/paddle/fluid/inference/api/analysis_config.cc b/paddle/fluid/inference/api/analysis_config.cc index 0109b4a4fa7617880f642f5a39639bca38475515..b54ea269ff250f02b6331807237e10ee65b0b0b4 100644 --- a/paddle/fluid/inference/api/analysis_config.cc +++ b/paddle/fluid/inference/api/analysis_config.cc @@ -231,6 +231,7 @@ void AnalysisConfig::Update() { pass_builder()->InsertPass(3, "tensorrt_subgraph_pass"); } pass_builder()->DeletePass("runtime_context_cache_pass"); + pass_builder()->DeletePass("expected_kernel_cache_pass"); } if (use_mkldnn_) { diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 1a87ac378a232f153a994c6b11ff37f8f5419a36..de564dbb40b3fed8cb165e34877a8cdc3ee5e349 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -99,6 +99,7 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) { "conv_elementwise_add_fuse_pass", // #endif // "transpose_flatten_concat_fuse_pass", + "expected_kernel_cache_pass", // }); use_gpu_ = true; @@ -136,6 +137,7 @@ CpuPassStrategy::CpuPassStrategy() : PassStrategy({}) { "conv_bn_fuse_pass", // "conv_eltwiseadd_bn_fuse_pass", // "is_test_pass", // + "expected_kernel_cache_pass", // }); use_gpu_ = false; diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 8c34e3efe2a07cadde5aa06669fda88be7661db1..0d1a1ca07771b991df78743501907401cf56f933 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -1356,6 +1356,10 @@ All parameter, weight, gradient are variables in Paddle. "fuse_all_reduce_ops", [](const BuildStrategy &self) { return self.fuse_all_reduce_ops_; }, [](BuildStrategy &self, bool b) { self.fuse_all_reduce_ops_ = b; }) + .def_property( + "cache_runtime_context", + [](const BuildStrategy &self) { return self.cache_runtime_context_; }, + [](BuildStrategy &self, bool b) { self.cache_runtime_context_ = b; }) .def("_finalize_strategy_and_create_passes", [](BuildStrategy &self) -> std::shared_ptr { return self.CreatePassesFromStrategy(true); diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 7e6e37116fe23f26eb14dd0573dbe031aec98dd8..94bc3d0854d5b17de4811aca07c35fbaa0e382ca 100644 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -752,7 +752,7 @@ class DGCMomentumOptimizer(MomentumOptimizer): force_cpu=True) for param_var, grad_var in param_and_grads: - var_numel = reduce(lambda x, y: x * y, param_var.shape) + var_numel = abs(reduce(lambda x, y: x * y, param_var.shape)) if var_numel < 16384 or \ param_var.type == core.VarDesc.VarType.SELECTED_ROWS or \ grad_var.type == core.VarDesc.VarType.SELECTED_ROWS or \ diff --git a/python/paddle/fluid/parallel_executor.py b/python/paddle/fluid/parallel_executor.py index 6b88e7a99fd78f6a7670ba55bc678e85d229ddf4..092cd5aea7d2f3ae7e5ba927261921fbe28f51bf 100644 --- a/python/paddle/fluid/parallel_executor.py +++ b/python/paddle/fluid/parallel_executor.py @@ -104,6 +104,7 @@ class ParallelExecutor(object): self._scope = scope if scope is not None else executor.global_scope() if main_program is not None and main_program._enable_dgc: + assert num_trainers > 1 assert build_strategy.reduce_strategy == BuildStrategy.ReduceStrategy.AllReduce assert num_trainers * len( self._places) > 1, "dgc is not useful for single card training" @@ -123,6 +124,11 @@ class ParallelExecutor(object): exec_strategy=exec_strategy, share_vars_from=share_vars_from._compiled_program if share_vars_from else None) + + # FIXME(gongwb): I will move dgc from dist mode to allreduce mode in next pr. + if main_program._enable_dgc: + self._compiled_program._build_strategy.is_distribution = True + self._place = core.CUDAPlace(0) if use_cuda else core.CPUPlace() self._exe = executor.Executor(self._place) self._compiled_program._compile(place=self._place, scope=self._scope)