提交 08295f98 编写于 作者: Y yuyang18

Add build strategy

上级 c06b4483
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
namespace paddle {
namespace framework {
namespace details {
struct BuildStrategy {
enum class ReduceStrategy { kAllReduce = 0, kReduce = 1 };
enum class GradientScaleStrategy {
kCoeffNumDevice = 0,
kOne = 1,
kCustomized = 2,
};
ReduceStrategy reduce_{ReduceStrategy::kReduce};
GradientScaleStrategy gradient_scale_{GradientScaleStrategy::kCoeffNumDevice};
};
} // namespace details
} // namespace framework
} // namespace paddle
...@@ -37,31 +37,26 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder( ...@@ -37,31 +37,26 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
const std::string &loss_var_name, const std::string &loss_var_name,
const std::unordered_set<std::string> &params, const std::unordered_set<std::string> &params,
const std::vector<Scope *> &local_scopes, const std::vector<Scope *> &local_scopes,
platform::NCCLContextMap *nccl_ctxs, bool use_default_grad_scale, platform::NCCLContextMap *nccl_ctxs, const BuildStrategy &strategy)
bool balance_parameter_opt_between_cards)
: loss_var_name_(loss_var_name), : loss_var_name_(loss_var_name),
places_(places), places_(places),
local_scopes_(local_scopes), local_scopes_(local_scopes),
nccl_ctxs_(nccl_ctxs), nccl_ctxs_(nccl_ctxs),
balance_parameter_opt_between_cards_( strategy_(strategy) {
balance_parameter_opt_between_cards) {
#else #else
MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder( MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
const std::string &loss_var_name, const std::string &loss_var_name,
const std::unordered_set<std::string> &params, const std::unordered_set<std::string> &params,
const std::vector<Scope *> &local_scopes, bool use_default_grad_scale, const std::vector<Scope *> &local_scopes, const BuildStrategy &strategy)
bool balance_parameter_opt_between_cards)
: loss_var_name_(loss_var_name), : loss_var_name_(loss_var_name),
places_(places), places_(places),
local_scopes_(local_scopes), local_scopes_(local_scopes),
balance_parameter_opt_between_cards_( strategy_(strategy) {
balance_parameter_opt_between_cards) {
#endif #endif
for (auto &p : params) { for (auto &p : params) {
grad_names_.insert(GradVarName(p)); grad_names_.insert(GradVarName(p));
} }
use_default_grad_scale_ = use_default_grad_scale;
} }
void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result, void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result,
...@@ -146,7 +141,8 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -146,7 +141,8 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
CreateComputationalOps(&result, *op, 1); CreateComputationalOps(&result, *op, 1);
} else if (IsScaleLossOp(*op)) { } else if (IsScaleLossOp(*op)) {
// user can customize loss@grad if not use_default_grad_scale_ // user can customize loss@grad if not use_default_grad_scale_
if (use_default_grad_scale_) { if (strategy_.gradient_scale_ !=
BuildStrategy::GradientScaleStrategy::kCustomized) {
CreateScaleLossGradOp(&result); CreateScaleLossGradOp(&result);
} }
is_forwarding = false; is_forwarding = false;
...@@ -165,19 +161,22 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -165,19 +161,22 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
// broadcast, and each gradient is only broadcast once. // broadcast, and each gradient is only broadcast once.
for (auto &og : op->OutputArgumentNames()) { for (auto &og : op->OutputArgumentNames()) {
if (IsParameterGradientOnce(og, &og_has_been_broadcast)) { if (IsParameterGradientOnce(og, &og_has_been_broadcast)) {
if (balance_parameter_opt_between_cards_) { switch (strategy_.reduce_) {
case BuildStrategy::ReduceStrategy::kReduce:
CreateReduceOp(&result, og, cur_device_id); CreateReduceOp(&result, og, cur_device_id);
var_name_on_devices[cur_device_id].emplace(og); var_name_on_devices[cur_device_id].emplace(og);
bcast_var_name_set[cur_device_id].emplace( bcast_var_name_set[cur_device_id].emplace(
og.substr(0, og.size() - strlen(kGradVarSuffix))); og.substr(0, og.size() - strlen(kGradVarSuffix)));
cur_device_id = (cur_device_id + 1) % places_.size(); cur_device_id = (cur_device_id + 1) % places_.size();
} else { break;
case BuildStrategy::ReduceStrategy::kAllReduce:
if (IsSparseGradient(var_types, og)) { if (IsSparseGradient(var_types, og)) {
CreateReduceOp(&result, og, 0); CreateReduceOp(&result, og, 0);
CreateBroadcastOp(&result, og, 0); CreateBroadcastOp(&result, og, 0);
} else { } else {
InsertNCCLAllReduceOp(&result, og); InsertNCCLAllReduceOp(&result, og);
} }
break;
} }
} }
} }
...@@ -303,7 +302,7 @@ bool MultiDevSSAGraphBuilder::IsParameterGradientOnce( ...@@ -303,7 +302,7 @@ bool MultiDevSSAGraphBuilder::IsParameterGradientOnce(
int MultiDevSSAGraphBuilder::GetOpDeviceID( int MultiDevSSAGraphBuilder::GetOpDeviceID(
const std::vector<std::unordered_set<std::string>> &var_name_on_devices, const std::vector<std::unordered_set<std::string>> &var_name_on_devices,
const OpDesc &op) const { const OpDesc &op) const {
if (!balance_parameter_opt_between_cards_) { if (strategy_.reduce_ != BuildStrategy::ReduceStrategy::kReduce) {
return -1; return -1;
} }
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/details/build_strategy.h"
#include "paddle/fluid/framework/details/ssa_graph_builder.h" #include "paddle/fluid/framework/details/ssa_graph_builder.h"
namespace paddle { namespace paddle {
...@@ -36,15 +37,13 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { ...@@ -36,15 +37,13 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
const std::unordered_set<std::string> &params, const std::unordered_set<std::string> &params,
const std::vector<Scope *> &local_scopes, const std::vector<Scope *> &local_scopes,
platform::NCCLContextMap *nccl_ctxs, platform::NCCLContextMap *nccl_ctxs,
bool use_default_grad_scale, const BuildStrategy &strategy);
bool balance_parameter_opt_between_cards);
#else #else
MultiDevSSAGraphBuilder(const std::vector<platform::Place> &places, MultiDevSSAGraphBuilder(const std::vector<platform::Place> &places,
const std::string &loss_var_name, const std::string &loss_var_name,
const std::unordered_set<std::string> &params, const std::unordered_set<std::string> &params,
const std::vector<Scope *> &local_scopes, const std::vector<Scope *> &local_scopes,
bool use_default_grad_scale, const BuildStrategy &strategy);
bool balance_parameter_opt_between_cards);
#endif #endif
std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const override; std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const override;
...@@ -62,8 +61,6 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { ...@@ -62,8 +61,6 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
platform::NCCLContextMap *nccl_ctxs_; platform::NCCLContextMap *nccl_ctxs_;
#endif #endif
bool balance_parameter_opt_between_cards_;
bool use_default_grad_scale_;
bool IsScaleLossOp(const OpDesc &op) const; bool IsScaleLossOp(const OpDesc &op) const;
...@@ -105,6 +102,9 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { ...@@ -105,6 +102,9 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
bool IsSparseGradient( bool IsSparseGradient(
const std::unordered_map<std::string, proto::VarType::Type> &var_types, const std::unordered_map<std::string, proto::VarType::Type> &var_types,
const std::string &og) const; const std::string &og) const;
private:
BuildStrategy strategy_;
}; };
} // namespace details } // namespace details
} // namespace framework } // namespace framework
......
...@@ -57,8 +57,7 @@ ParallelExecutor::ParallelExecutor( ...@@ -57,8 +57,7 @@ ParallelExecutor::ParallelExecutor(
const std::unordered_set<std::string> &bcast_vars, const std::unordered_set<std::string> &bcast_vars,
const ProgramDesc &main_program, const std::string &loss_var_name, const ProgramDesc &main_program, const std::string &loss_var_name,
Scope *scope, const std::vector<Scope *> &local_scopes, Scope *scope, const std::vector<Scope *> &local_scopes,
bool use_default_grad_scale, bool balance_parameter_opt_between_cards, const ExecutionStrategy &exec_strategy, const BuildStrategy &build_strategy)
const ExecutionStrategy &exec_strategy)
: member_(new ParallelExecutorPrivate(places)) { : member_(new ParallelExecutorPrivate(places)) {
member_->global_scope_ = scope; member_->global_scope_ = scope;
...@@ -93,12 +92,11 @@ ParallelExecutor::ParallelExecutor( ...@@ -93,12 +92,11 @@ ParallelExecutor::ParallelExecutor(
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
details::MultiDevSSAGraphBuilder builder( details::MultiDevSSAGraphBuilder builder(
member_->places_, loss_var_name, params, member_->local_scopes_, member_->places_, loss_var_name, params, member_->local_scopes_,
member_->nccl_ctxs_.get(), use_default_grad_scale, member_->nccl_ctxs_.get(), build_strategy);
balance_parameter_opt_between_cards);
#else #else
details::MultiDevSSAGraphBuilder builder( details::MultiDevSSAGraphBuilder builder(member_->places_, loss_var_name,
member_->places_, loss_var_name, params, member_->local_scopes_, params, member_->local_scopes_,
use_default_grad_scale, balance_parameter_opt_between_cards); build_strategy);
#endif #endif
auto graph = builder.Build(main_program); auto graph = builder.Build(main_program);
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include <paddle/fluid/framework/details/build_strategy.h>
#include <string> #include <string>
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
...@@ -29,6 +30,7 @@ namespace framework { ...@@ -29,6 +30,7 @@ namespace framework {
class ParallelExecutorPrivate; class ParallelExecutorPrivate;
using details::BuildStrategy;
using details::ExecutionStrategy; using details::ExecutionStrategy;
class ParallelExecutor { class ParallelExecutor {
...@@ -41,9 +43,8 @@ class ParallelExecutor { ...@@ -41,9 +43,8 @@ class ParallelExecutor {
const ProgramDesc &main_program, const ProgramDesc &main_program,
const std::string &loss_var_name, Scope *scope, const std::string &loss_var_name, Scope *scope,
const std::vector<Scope *> &local_scopes, const std::vector<Scope *> &local_scopes,
bool use_default_grad_scale, const ExecutionStrategy &exec_strategy,
bool balance_parameter_opt_between_cards, const BuildStrategy &build_strategy);
const ExecutionStrategy &exec_strategy);
~ParallelExecutor(); ~ParallelExecutor();
......
...@@ -50,7 +50,7 @@ class NCCLGroupGuard { ...@@ -50,7 +50,7 @@ class NCCLGroupGuard {
} }
inline ~NCCLGroupGuard() { inline ~NCCLGroupGuard() {
PADDLE_ENFORCE(dynload::ncclGroupEnd()); CHECK_EQ(dynload::ncclGroupEnd(), ncclSuccess);
NCCLMutex().unlock(); NCCLMutex().unlock();
} }
}; };
......
...@@ -494,6 +494,7 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -494,6 +494,7 @@ All parameter, weight, gradient are variables in Paddle.
m.def("disable_profiler", platform::DisableProfiler); m.def("disable_profiler", platform::DisableProfiler);
m.def("reset_profiler", platform::ResetProfiler); m.def("reset_profiler", platform::ResetProfiler);
// -- python binds for parallel executor.
py::class_<ParallelExecutor> pe(m, "ParallelExecutor"); py::class_<ParallelExecutor> pe(m, "ParallelExecutor");
py::class_<ExecutionStrategy>(pe, "ExecutionStrategy") py::class_<ExecutionStrategy>(pe, "ExecutionStrategy")
.def(py::init()) .def(py::init())
...@@ -515,12 +516,38 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -515,12 +516,38 @@ All parameter, weight, gradient are variables in Paddle.
[](ExecutionStrategy &self, bool allow_op_delay) { [](ExecutionStrategy &self, bool allow_op_delay) {
self.allow_op_delay_ = allow_op_delay; self.allow_op_delay_ = allow_op_delay;
}); });
py::class_<BuildStrategy> build_strategy(pe, "BuildStrategy");
py::enum_<BuildStrategy::ReduceStrategy>(build_strategy, "ReduceStrategy")
.value("Reduce", BuildStrategy::ReduceStrategy::kReduce)
.value("AllReduce", BuildStrategy::ReduceStrategy::kAllReduce);
py::enum_<BuildStrategy::GradientScaleStrategy>(build_strategy,
"GradientScaleStrategy")
.value("CoeffNumDevice",
BuildStrategy::GradientScaleStrategy::kCoeffNumDevice)
.value("One", BuildStrategy::GradientScaleStrategy::kOne)
.value("Customized", BuildStrategy::GradientScaleStrategy::kCustomized);
build_strategy.def(py::init())
.def_property(
"reduce_strategy",
[](const BuildStrategy &self) { return self.reduce_; },
[](BuildStrategy &self, BuildStrategy::ReduceStrategy strategy) {
self.reduce_ = strategy;
})
.def_property(
"gradient_scale_strategy",
[](const BuildStrategy &self) { return self.gradient_scale_; },
[](BuildStrategy &self,
BuildStrategy::GradientScaleStrategy strategy) {
self.gradient_scale_ = strategy;
});
pe.def(py::init<const std::vector<platform::Place> &, pe.def(py::init<const std::vector<platform::Place> &,
const std::unordered_set<std::string> &, const std::unordered_set<std::string> &,
const std::unordered_set<std::string> &, const ProgramDesc &, const std::unordered_set<std::string> &, const ProgramDesc &,
const std::string &, Scope *, std::vector<Scope *> &, bool, const std::string &, Scope *, std::vector<Scope *> &,
bool, const ExecutionStrategy &>()) const ExecutionStrategy &, const BuildStrategy &>())
.def("bcast_params", &ParallelExecutor::BCastParamsToGPUs) .def("bcast_params", &ParallelExecutor::BCastParamsToGPUs)
// NOTE: even we return a vec<Scope*>* to Python use reference policy. // NOTE: even we return a vec<Scope*>* to Python use reference policy.
// We still cannot get local_scope from this vector, since the element // We still cannot get local_scope from this vector, since the element
......
...@@ -52,12 +52,14 @@ import clip ...@@ -52,12 +52,14 @@ import clip
import profiler import profiler
import unique_name import unique_name
import recordio_writer import recordio_writer
from parallel_executor import ParallelExecutor, ExecutionStrategy import parallel_executor
from parallel_executor import *
Tensor = LoDTensor Tensor = LoDTensor
__all__ = framework.__all__ + executor.__all__ + concurrency.__all__ + \ __all__ = framework.__all__ + executor.__all__ + concurrency.__all__ + \
trainer.__all__ + inferencer.__all__ + transpiler.__all__ + [ trainer.__all__ + inferencer.__all__ + transpiler.__all__ + \
parallel_executor.__all__ + [
'io', 'io',
'initializer', 'initializer',
'layers', 'layers',
...@@ -79,8 +81,6 @@ __all__ = framework.__all__ + executor.__all__ + concurrency.__all__ + \ ...@@ -79,8 +81,6 @@ __all__ = framework.__all__ + executor.__all__ + concurrency.__all__ + \
'profiler', 'profiler',
'unique_name', 'unique_name',
'recordio_writer', 'recordio_writer',
'ParallelExecutor',
'ExecutionStrategy',
] ]
......
...@@ -19,9 +19,10 @@ import executor ...@@ -19,9 +19,10 @@ import executor
import warnings import warnings
import sys import sys
__all__ = ['ParallelExecutor', 'ExecutionStrategy'] __all__ = ['ParallelExecutor', 'ExecutionStrategy', 'BuildStrategy']
ExecutionStrategy = core.ParallelExecutor.ExecutionStrategy ExecutionStrategy = core.ParallelExecutor.ExecutionStrategy
BuildStrategy = core.ParallelExecutor.BuildStrategy
class ParallelExecutor(object): class ParallelExecutor(object):
...@@ -30,9 +31,8 @@ class ParallelExecutor(object): ...@@ -30,9 +31,8 @@ class ParallelExecutor(object):
loss_name=None, loss_name=None,
main_program=None, main_program=None,
share_vars_from=None, share_vars_from=None,
use_default_grad_scale=True,
balance_parameter_opt_between_cards=False,
exec_strategy=None, exec_strategy=None,
build_strategy=None,
**kwargs): **kwargs):
""" """
ParallelExecutor can run program in parallel. ParallelExecutor can run program in parallel.
...@@ -81,7 +81,16 @@ class ParallelExecutor(object): ...@@ -81,7 +81,16 @@ class ParallelExecutor(object):
"Setting {0} by constructor is deprecated. Use " \ "Setting {0} by constructor is deprecated. Use " \
"strategy=ExecutionStrategy(); strategy.{0}=xxx; " \ "strategy=ExecutionStrategy(); strategy.{0}=xxx; " \
"pe=ParallelExecutor(exec_strategy=strategy) " \ "pe=ParallelExecutor(exec_strategy=strategy) " \
"instead.\n " "instead.\n ".format(key)
elif key in dir(BuildStrategy):
err_msg += \
"Setting {0} by constructor is deprecated. Use " \
"strategy=BuildStrategy(); See help(" \
"paddle.fluid.ParallelExecutor.BuildStrategy) \n".format(
key)
else:
err_msg += "Setting {0} by constructor is deprecated. Use strategy.\n".format(
key)
raise ValueError(err_msg) raise ValueError(err_msg)
self._places = [] self._places = []
...@@ -116,6 +125,9 @@ class ParallelExecutor(object): ...@@ -116,6 +125,9 @@ class ParallelExecutor(object):
exec_strategy.num_threads = min( exec_strategy.num_threads = min(
len(self._places) * 2, multiprocessing.cpu_count()) len(self._places) * 2, multiprocessing.cpu_count())
if build_strategy is None:
build_strategy = BuildStrategy()
main = main_program main = main_program
main = main if main else framework.default_main_program() main = main if main else framework.default_main_program()
scope = executor.global_scope() scope = executor.global_scope()
...@@ -139,9 +151,8 @@ class ParallelExecutor(object): ...@@ -139,9 +151,8 @@ class ParallelExecutor(object):
p.name for p in main.global_block().iter_parameters() p.name for p in main.global_block().iter_parameters()
if not p.stop_gradient if not p.stop_gradient
]), ]),
set(self.persistable_vars), main.desc, loss_name set(self.persistable_vars), main.desc, loss_name if loss_name else
if loss_name else '', scope, local_scopes, use_default_grad_scale, '', scope, local_scopes, exec_strategy, build_strategy)
balance_parameter_opt_between_cards, exec_strategy)
self.scope = scope self.scope = scope
......
...@@ -234,12 +234,16 @@ class TestParallelExecutorBase(unittest.TestCase): ...@@ -234,12 +234,16 @@ class TestParallelExecutorBase(unittest.TestCase):
startup_exe.run(startup) startup_exe.run(startup)
exec_strategy = fluid.ExecutionStrategy() exec_strategy = fluid.ExecutionStrategy()
exec_strategy.allow_op_delay = allow_op_delay exec_strategy.allow_op_delay = allow_op_delay
build_strategy = fluid.BuildStrategy()
build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce if balance_parameter_opt_between_cards else fluid.BuildStrategy.ReduceStrategy.AllReduce
if use_parallel_executor: if use_parallel_executor:
exe = fluid.ParallelExecutor( exe = fluid.ParallelExecutor(
True, True,
loss_name=loss.name, loss_name=loss.name,
balance_parameter_opt_between_cards=balance_parameter_opt_between_cards, exec_strategy=exec_strategy,
exec_strategy=exec_strategy) build_strategy=build_strategy)
else: else:
exe = fluid.Executor(place=place) exe = fluid.Executor(place=place)
...@@ -548,7 +552,7 @@ class TestTransformer(TestParallelExecutorBase): ...@@ -548,7 +552,7 @@ class TestTransformer(TestParallelExecutorBase):
class ParallelExecutorTestingDuringTraining(unittest.TestCase): class ParallelExecutorTestingDuringTraining(unittest.TestCase):
def check_network_convergence(self, balance_parameter_opt_between_cards): def check_network_convergence(self, build_strategy=None):
main = fluid.Program() main = fluid.Program()
startup = fluid.Program() startup = fluid.Program()
with fluid.program_guard(main, startup): with fluid.program_guard(main, startup):
...@@ -571,15 +575,13 @@ class ParallelExecutorTestingDuringTraining(unittest.TestCase): ...@@ -571,15 +575,13 @@ class ParallelExecutorTestingDuringTraining(unittest.TestCase):
use_cuda=True, use_cuda=True,
loss_name=loss.name, loss_name=loss.name,
main_program=main, main_program=main,
balance_parameter_opt_between_cards=balance_parameter_opt_between_cards build_strategy=build_strategy)
)
test_exe = fluid.ParallelExecutor( test_exe = fluid.ParallelExecutor(
use_cuda=True, use_cuda=True,
main_program=test_program, main_program=test_program,
share_vars_from=train_exe, share_vars_from=train_exe,
balance_parameter_opt_between_cards=balance_parameter_opt_between_cards build_strategy=build_strategy)
)
for i in xrange(5): for i in xrange(5):
test_loss, = test_exe.run([loss.name], feed=feed_dict) test_loss, = test_exe.run([loss.name], feed=feed_dict)
...@@ -594,10 +596,14 @@ class ParallelExecutorTestingDuringTraining(unittest.TestCase): ...@@ -594,10 +596,14 @@ class ParallelExecutorTestingDuringTraining(unittest.TestCase):
str(test_loss)) str(test_loss))
def test_parallel_testing(self): def test_parallel_testing(self):
self.check_network_convergence(False) build_strategy = fluid.BuildStrategy()
build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.AllReduce
self.check_network_convergence(build_strategy)
def test_parallel_testing_with_new_strategy(self): def test_parallel_testing_with_new_strategy(self):
self.check_network_convergence(True) build_strategy = fluid.BuildStrategy()
build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce
self.check_network_convergence(build_strategy)
import paddle.dataset.conll05 as conll05 import paddle.dataset.conll05 as conll05
...@@ -617,7 +623,7 @@ embedding_name = 'emb' ...@@ -617,7 +623,7 @@ embedding_name = 'emb'
def db_lstm(word, predicate, ctx_n2, ctx_n1, ctx_0, ctx_p1, ctx_p2, mark, def db_lstm(word, predicate, ctx_n2, ctx_n1, ctx_0, ctx_p1, ctx_p2, mark,
is_sparse, balance_parameter_opt_between_cards, **ignored): is_sparse, **ignored):
# 8 features # 8 features
predicate_embedding = fluid.layers.embedding( predicate_embedding = fluid.layers.embedding(
input=predicate, input=predicate,
...@@ -686,9 +692,7 @@ def db_lstm(word, predicate, ctx_n2, ctx_n1, ctx_0, ctx_p1, ctx_p2, mark, ...@@ -686,9 +692,7 @@ def db_lstm(word, predicate, ctx_n2, ctx_n1, ctx_0, ctx_p1, ctx_p2, mark,
class TestCRFModel(unittest.TestCase): class TestCRFModel(unittest.TestCase):
def check_network_convergence(self, def check_network_convergence(self, is_sparse, build_strategy=None):
is_sparse,
balance_parameter_opt_between_cards=False):
main = fluid.Program() main = fluid.Program()
startup = fluid.Program() startup = fluid.Program()
with fluid.program_guard(main, startup): with fluid.program_guard(main, startup):
...@@ -739,8 +743,7 @@ class TestCRFModel(unittest.TestCase): ...@@ -739,8 +743,7 @@ class TestCRFModel(unittest.TestCase):
pe = fluid.ParallelExecutor( pe = fluid.ParallelExecutor(
use_cuda=True, use_cuda=True,
loss_name=avg_cost.name, loss_name=avg_cost.name,
balance_parameter_opt_between_cards=balance_parameter_opt_between_cards build_strategy=build_strategy)
)
feeder = fluid.DataFeeder( feeder = fluid.DataFeeder(
feed_list=[ feed_list=[
...@@ -756,19 +759,29 @@ class TestCRFModel(unittest.TestCase): ...@@ -756,19 +759,29 @@ class TestCRFModel(unittest.TestCase):
pe.run(feed=feeder.feed(cur_batch), pe.run(feed=feeder.feed(cur_batch),
fetch_list=[avg_cost.name]))[0] fetch_list=[avg_cost.name]))[0]
def test_update_sparse_parameter(self): def test_update_sparse_parameter_all_reduce(self):
self.check_network_convergence(is_sparse=True) build_strategy = fluid.BuildStrategy()
build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.AllReduce
self.check_network_convergence(
is_sparse=True, build_strategy=build_strategy)
def test_update_dense_parameter(self): def test_update_dense_parameter_all_reduce(self):
self.check_network_convergence(is_sparse=False) build_strategy = fluid.BuildStrategy()
build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.AllReduce
self.check_network_convergence(
is_sparse=False, build_strategy=build_strategy)
def test_update_sparse_parameter_with_new_strategy(self): def test_update_sparse_parameter_reduce(self):
build_strategy = fluid.BuildStrategy()
build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce
self.check_network_convergence( self.check_network_convergence(
is_sparse=False, balance_parameter_opt_between_cards=True) is_sparse=False, build_strategy=build_strategy)
def test_update_dense_parameter_with_new_strategy(self): def test_update_dense_parameter_reduce(self):
build_strategy = fluid.BuildStrategy()
build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce
self.check_network_convergence( self.check_network_convergence(
is_sparse=False, balance_parameter_opt_between_cards=True) is_sparse=False, build_strategy=build_strategy)
# test fetch all the variables of global_block # test fetch all the variables of global_block
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册