未验证 提交 17d62ab2 编写于 作者: C chengduo 提交者: GitHub

Enhance fuse optimization op pass (#19010)

* Enhance fuse optimization op pass
test=develop
上级 21440b4d
......@@ -32,19 +32,62 @@ class FuseAdamOpPass : public FuseOptimizerOpPass {
return {"Moment1", "Moment2", "Beta1Pow", "Beta2Pow"};
}
void FuseOptimizerOps(
ir::Node *FuseOptimizerOps(
const std::unordered_map<std::string, std::vector<std::string>>
&aux_var_set,
const std::unordered_map<std::string, std::string> &fused_vars_name,
const std::vector<ir::Node *> &adam_ops, ir::Graph *graph) const {
auto fused_adam_node =
FuseAdamOps(aux_var_set, fused_vars_name, adam_ops, graph);
auto fused_scale1 =
FuseScaleOps(aux_var_set.at("Beta1Pow"), fused_vars_name.at("Beta1Pow"),
adam_ops, graph);
auto fused_scale2 =
FuseScaleOps(aux_var_set.at("Beta2Pow"), fused_vars_name.at("Beta2Pow"),
adam_ops, graph);
RemoveCycleDepsBetweenOpNodes(graph, fused_scale1, fused_scale2);
return fused_adam_node;
}
void FuseAdamOps(
void RemoveCycleDepsBetweenOpNodes(Graph *graph, const Node *fused_scale1,
const Node *fused_scale2) const {
std::unordered_set<Node *> not_need_ctrl_var_nodes;
std::unordered_set<Node *> fused_scale2_in_nodes;
fused_scale2_in_nodes.insert(fused_scale2->inputs.begin(),
fused_scale2->inputs.end());
for (auto &out_node : fused_scale1->outputs) {
if (fused_scale2_in_nodes.count(out_node)) {
PADDLE_ENFORCE(out_node->IsCtrlVar(),
"The dependency var only should be ctrl var.");
not_need_ctrl_var_nodes.insert(out_node);
}
}
for (auto &node : not_need_ctrl_var_nodes) {
// remove this node from the input op node.
PADDLE_ENFORCE(!node->inputs.empty());
auto op_node = node->inputs.front();
PADDLE_ENFORCE(op_node->IsOp());
op_node->outputs.erase(
remove_if(
op_node->outputs.begin(), op_node->outputs.end(),
[&node](const Node *op_out_node) { return op_out_node == node; }),
op_node->outputs.end());
// remove this node from the output op nodes.
for (auto &out_op_node : node->outputs) {
out_op_node->inputs.erase(
remove_if(
out_op_node->inputs.begin(), out_op_node->inputs.end(),
[&node](const Node *op_in_node) { return op_in_node == node; }),
out_op_node->inputs.end());
}
graph->RemoveNode(node);
}
}
ir::Node *FuseAdamOps(
const std::unordered_map<std::string, std::vector<std::string>> &vars_set,
const std::unordered_map<std::string, std::string> &fused_vars_name,
const std::vector<ir::Node *> &adam_ops, ir::Graph *graph) const {
......@@ -102,13 +145,10 @@ class FuseAdamOpPass : public FuseOptimizerOpPass {
adam_desc.SetAttr("min_row_size_to_use_multithread",
min_row_size_to_use_multithread);
adam_desc.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), op_role);
auto adam_node = graph->CreateOpNode(&adam_desc);
InserInputAndOutputForOptOps(adam_ops, adam_node);
return graph->CreateOpNode(&adam_desc);
}
void FuseScaleOps(const std::vector<std::string> &beta_name,
ir::Node *FuseScaleOps(const std::vector<std::string> &beta_name,
const std::string &fused_var_name,
const std::vector<ir::Node *> &adam_ops,
ir::Graph *graph) const {
......@@ -139,7 +179,7 @@ class FuseAdamOpPass : public FuseOptimizerOpPass {
scale_ops.emplace_back(*scale_op_iter);
}
PADDLE_ENFORCE_EQ(scale_ops.size(), beta_name.size());
VLOG(7) << "The number of scale op is " << scale_ops.size() << ".";
// Check attributions
// NOTE: If new attribution is added, the following code maybe need change.
int op_role = boost::get<int>(
......@@ -175,29 +215,12 @@ class FuseAdamOpPass : public FuseOptimizerOpPass {
scale_desc.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), op_role);
auto scale_node = graph->CreateOpNode(&scale_desc);
for (auto scale_op : scale_ops) {
// set inputs
scale_node->inputs.insert(scale_node->inputs.begin(),
scale_op->inputs.begin(),
scale_op->inputs.end());
for (auto &input : scale_op->inputs) {
std::replace(input->outputs.begin(), input->outputs.end(), scale_op,
scale_node);
}
// set outputs
scale_node->outputs.insert(scale_node->outputs.begin(),
scale_op->outputs.begin(),
scale_op->outputs.end());
for (auto &output : scale_op->outputs) {
std::replace(output->inputs.begin(), output->inputs.end(), scale_op,
scale_node);
}
}
InsertInputAndOutputForFusedOpNode(scale_ops, graph, scale_node);
// Delete scale_ops
for (auto &scale_op : scale_ops) {
graph->RemoveNode(scale_op);
}
return scale_node;
}
};
} // namespace ir
......
......@@ -33,7 +33,7 @@ class FuseMomentumOpPass : public FuseOptimizerOpPass {
}
// Fuse Momentum Ops
virtual void FuseOptimizerOps(
virtual ir::Node *FuseOptimizerOps(
const std::unordered_map<std::string, std::vector<std::string>> &vars_set,
const std::unordered_map<std::string, std::string> &fused_vars_name,
const std::vector<ir::Node *> &momentum_ops, ir::Graph *graph) const {
......@@ -77,9 +77,7 @@ class FuseMomentumOpPass : public FuseOptimizerOpPass {
momentum_desc.SetAttr("use_nesterov", use_nesterov);
momentum_desc.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), op_role);
auto momentum_node = graph->CreateOpNode(&momentum_desc);
InserInputAndOutputForOptOps(momentum_ops, momentum_node);
return graph->CreateOpNode(&momentum_desc);
}
};
......
......@@ -14,6 +14,7 @@
#include "paddle/fluid/framework/ir/fuse_optimizer_ops_pass/fuse_optimizer_op_pass.h"
#include <algorithm>
#include <set>
#include <unordered_set>
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/op_registry.h"
......@@ -59,6 +60,15 @@ void FuseOptimizerOpPass::ApplyImpl(ir::Graph *graph) const {
}
return;
}
// There should not have no-ctr-var between the op_nodes that link the op_node
// of op_nodes.
if (HasVarDepsBetweenOps(topo_nodes, opt_nodes)) {
VLOG(6) << "There are interdependent variables among these optimization "
"operators, which can not be handled well at present.";
return;
}
result.Set(details::kFusedOptType, new details::FusedOptType);
result.Get<details::FusedOptType>(details::kFusedOptType) = fuse_op_type;
if (!result.Has(details::kProgramDescs)) {
......@@ -158,14 +168,54 @@ void FuseOptimizerOpPass::ApplyImpl(ir::Graph *graph) const {
&result);
// Step 5: Fuse optimizer Ops and Scale Ops
auto *fused_opt_node =
FuseOptimizerOps(aux_var_set, fused_vars_name, opt_nodes, &result);
InsertInputAndOutputForFusedOpNode(opt_nodes, graph, fused_opt_node);
// Step 6: Remove optimizer Ops
for (auto &opt_op : opt_nodes) {
graph->RemoveNode(opt_op);
}
}
bool FuseOptimizerOpPass::HasVarDepsBetweenOps(
const std::vector<Node *> &topo_nodes,
const std::vector<Node *> &opt_nodes) const {
std::unordered_map<Node *, std::unordered_set<Node *>> preceding_ops;
std::unordered_map<Node *, std::unordered_set<Node *>> pending_ops;
for (auto &op : topo_nodes) {
preceding_ops[op];
pending_ops[op];
for (auto &var : op->outputs) {
if (var->IsCtrlVar()) continue;
for (auto &pending_op : var->outputs) {
preceding_ops[pending_op].insert(op);
pending_ops[op].insert(pending_op);
}
}
}
std::unordered_set<Node *> opt_node_set(opt_nodes.begin(), opt_nodes.end());
auto has_var_deps = [](const std::unordered_set<Node *> &op_set1,
const std::unordered_set<Node *> &op_set2) -> bool {
std::set<Node *> intersect_ops;
set_intersection(op_set1.begin(), op_set1.end(), op_set2.begin(),
op_set2.end(),
inserter(intersect_ops, intersect_ops.begin()));
return !intersect_ops.empty();
};
for (auto opt_node : opt_node_set) {
if (has_var_deps(preceding_ops.at(opt_node), opt_node_set)) {
return true;
}
if (has_var_deps(pending_ops.at(opt_node), opt_node_set)) {
return true;
}
}
return false;
}
void FuseOptimizerOpPass::GradientsFilter(
const std::vector<size_t> &new_grad_idx, std::vector<Node *> *opt_nodes,
std::unordered_map<std::string, std::vector<std::string>> *aux_var_set)
......@@ -338,26 +388,84 @@ void FuseOptimizerOpPass::AppendAllocContinuousSpace(
op_desc->SetAttr("check_name", check_name);
}
void FuseOptimizerOpPass::InserInputAndOutputForOptOps(
const std::vector<ir::Node *> &opt_nodes, ir::Node *opt_node) const {
void FuseOptimizerOpPass::InsertInputAndOutputForFusedOpNode(
const std::vector<ir::Node *> &op_nodes, ir::Graph *graph,
ir::Node *fused_opt_node) const {
std::unordered_set<ir::Node *> inputs;
std::unordered_set<ir::Node *> outputs;
for (auto opt_op : opt_nodes) {
// set inputs
for (auto opt_op : op_nodes) {
inputs.insert(opt_op->inputs.begin(), opt_op->inputs.end());
for (auto &input : opt_op->inputs) {
replace(input->outputs.begin(), input->outputs.end(), opt_op, opt_node);
replace(input->outputs.begin(), input->outputs.end(), opt_op,
fused_opt_node);
}
// set outputs
outputs.insert(opt_op->outputs.begin(), opt_op->outputs.end());
for (auto &output : opt_op->outputs) {
replace(output->inputs.begin(), output->inputs.end(), opt_op, opt_node);
replace(output->inputs.begin(), output->inputs.end(), opt_op,
fused_opt_node);
}
}
// Remove the dependence vars between op_nodes.
std::unordered_set<ir::Node *> out_dep_vars;
std::unordered_set<ir::Node *> not_useful_vars;
auto deal_with_ctrl_vars = [&out_dep_vars, &not_useful_vars,
&fused_opt_node](ir::Node *ctr_var_node) {
PADDLE_ENFORCE_EQ(ctr_var_node->inputs.size(), 1);
if (ctr_var_node->inputs.front() == fused_opt_node) {
PADDLE_ENFORCE_GT(ctr_var_node->outputs.size(), 0);
auto output_ops = ctr_var_node->outputs;
output_ops.erase(std::remove_if(output_ops.begin(), output_ops.end(),
[&fused_opt_node](const ir::Node *node) {
return node == fused_opt_node;
}),
output_ops.end());
if (!output_ops.empty()) {
out_dep_vars.insert(ctr_var_node);
}
not_useful_vars.insert(ctr_var_node);
}
};
for (auto *in_node : inputs) {
if (in_node->IsCtrlVar()) {
deal_with_ctrl_vars(in_node);
}
}
for (auto *out_node : outputs) {
if (out_node->IsCtrlVar()) {
deal_with_ctrl_vars(out_node);
}
}
for (auto &node : not_useful_vars) {
if (inputs.count(node)) {
inputs.erase(node);
}
if (outputs.count(node)) {
outputs.erase(node);
}
}
for (auto &dep_var : out_dep_vars) {
if (not_useful_vars.count(dep_var)) {
not_useful_vars.erase(dep_var);
}
dep_var->inputs.clear();
dep_var->inputs.emplace_back(fused_opt_node);
}
opt_node->inputs.insert(opt_node->inputs.begin(), inputs.begin(),
outputs.insert(out_dep_vars.begin(), out_dep_vars.end());
fused_opt_node->inputs.insert(fused_opt_node->inputs.begin(), inputs.begin(),
inputs.end());
opt_node->outputs.insert(opt_node->outputs.begin(), outputs.begin(),
outputs.end());
fused_opt_node->outputs.insert(fused_opt_node->outputs.begin(),
outputs.begin(), outputs.end());
for (auto &ctrl_var_node : not_useful_vars) {
graph->RemoveNode(ctrl_var_node);
}
}
} // namespace ir
} // namespace framework
......
......@@ -41,7 +41,8 @@ class FuseOptimizerOpPass : public ir::Pass {
std::unordered_map<std::string, std::vector<std::string>> *aux_var_set,
std::vector<ir::Node *> *ops) const;
void InserInputAndOutputForOptOps(const std::vector<ir::Node *> &opt_ops,
void InsertInputAndOutputForFusedOpNode(
const std::vector<ir::Node *> &opt_ops, ir::Graph *graph,
ir::Node *opt_node) const;
private:
......@@ -49,7 +50,7 @@ class FuseOptimizerOpPass : public ir::Pass {
virtual const std::vector<std::string> GetAuxiliaryVarNames() const = 0;
virtual void FuseOptimizerOps(
virtual ir::Node *FuseOptimizerOps(
const std::unordered_map<std::string, std::vector<std::string>> &vars_set,
const std::unordered_map<std::string, std::string> &fused_vars_name,
const std::vector<ir::Node *> &adam_ops, ir::Graph *graph) const = 0;
......@@ -91,6 +92,9 @@ class FuseOptimizerOpPass : public ir::Pass {
*aux_var_set) const;
bool IsLoDTensorType(const proto::VarType::Type &type) const;
bool HasVarDepsBetweenOps(const std::vector<Node *> &topo_nodes,
const std::vector<Node *> &opt_nodes) const;
};
} // namespace ir
......
......@@ -31,7 +31,7 @@ class FuseSgdOpPass : public FuseOptimizerOpPass {
}
// Fuse Sgd Ops
virtual void FuseOptimizerOps(
virtual ir::Node *FuseOptimizerOps(
const std::unordered_map<std::string, std::vector<std::string>> &vars_set,
const std::unordered_map<std::string, std::string> &fused_vars_name,
const std::vector<ir::Node *> &sgd_ops, ir::Graph *graph) const {
......@@ -56,9 +56,7 @@ class FuseSgdOpPass : public FuseOptimizerOpPass {
// NOTE: multi_devices_pass requires that every op should have a role.
Sgd_desc.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), op_role);
auto sgd_node = graph->CreateOpNode(&Sgd_desc);
InserInputAndOutputForOptOps(sgd_ops, sgd_node);
return graph->CreateOpNode(&Sgd_desc);
}
};
} // namespace ir
......
......@@ -124,7 +124,6 @@ class TestDistSaveLoad2x2(TestDistSimnetBow2x2):
strategy = fluid.ExecutionStrategy()
strategy.num_threads = 1
strategy.allow_op_delay = False
build_stra = fluid.BuildStrategy()
......
......@@ -36,10 +36,8 @@ class TestParallelExecutorBase(unittest.TestCase):
memory_opt=False,
iter=50,
batch_size=None,
allow_op_delay=False,
feed_dict=None,
get_data_from_feeder=None,
seed=None,
use_parallel_executor=True,
use_reduce=False,
use_ir_memory_optimize=True,
......@@ -57,51 +55,23 @@ class TestParallelExecutorBase(unittest.TestCase):
main = fluid.Program()
startup = fluid.Program()
startup.random_seed = 1 # Fix random seed
startup.random_seed = 1
main.random_seed = 1
with fluid.program_guard(main, startup):
if seed is not None:
startup.random_seed = seed
main.random_seed = seed
loss = method(use_feed=feed_dict is not None)
# NOTE(zjl): memory_optimize/inplace pass would not require
# that loss.persistable = True
loss.persistable = memory_opt
if optimizer:
optimizer().minimize(loss)
if memory_opt:
fluid.memory_optimize(main)
if get_data_from_feeder is not None:
assert feed_dict is None
feed_dict = get_data_from_feeder()
feed_dict, loss = cls.build_model(feed_dict, get_data_from_feeder,
main, memory_opt, method,
optimizer)
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(startup)
exec_strategy = fluid.ExecutionStrategy()
exec_strategy.allow_op_delay = allow_op_delay
if use_fast_executor:
exec_strategy.use_experimental_executor = True
build_strategy = fluid.BuildStrategy()
build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce \
if use_reduce else fluid.BuildStrategy.ReduceStrategy.AllReduce
build_strategy.fuse_elewise_add_act_ops = fuse_elewise_add_act_ops
build_strategy.fuse_relu_depthwise_conv = fuse_relu_depthwise_conv
build_strategy.fuse_all_optimizer_ops = fuse_all_optimizer_ops
build_strategy.fuse_all_reduce_ops = fuse_all_reduce_ops
build_strategy.memory_optimize = use_ir_memory_optimize
build_strategy.enable_inplace = enable_inplace
build_strategy.enable_sequential_execution = enable_sequential_execution
build_strategy, exec_strategy = cls.set_strategy(
enable_inplace, enable_sequential_execution, fuse_all_optimizer_ops,
fuse_all_reduce_ops, fuse_elewise_add_act_ops,
fuse_relu_depthwise_conv, use_fast_executor, use_ir_memory_optimize,
use_reduce, use_cuda)
if use_cuda and core.is_compiled_with_cuda():
build_strategy.remove_unnecessary_lock = True
if use_parallel_executor:
binary = compiler.CompiledProgram(main).with_data_parallel(
loss_name=loss.name,
......@@ -114,13 +84,12 @@ class TestParallelExecutorBase(unittest.TestCase):
batch_size *= fluid.core.get_cuda_device_count(
) if use_cuda else int(
os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
begin = time.time()
first_loss, = run_executor(
exe=exe, binary=binary, feed=feed_dict, fetch_list=[loss.name])
for i in range(iter):
for _ in range(iter):
run_executor(exe=exe, binary=binary, feed=feed_dict, fetch_list=[])
last_loss, = run_executor(
exe=exe, binary=binary, feed=feed_dict, fetch_list=[loss.name])
end = time.time()
......@@ -138,3 +107,85 @@ class TestParallelExecutorBase(unittest.TestCase):
print(first_loss, last_loss)
# self.assertGreater(first_loss[0], last_loss[0])
return first_loss, last_loss
@classmethod
def check_pass_conflict(cls,
method,
use_cuda=True,
memory_opt=False,
feed_dict=None,
get_data_from_feeder=None,
use_reduce=False,
use_ir_memory_optimize=True,
enable_inplace=True,
fuse_elewise_add_act_ops=False,
fuse_all_optimizer_ops=False,
fuse_all_reduce_ops=False,
fuse_relu_depthwise_conv=False,
optimizer=fluid.optimizer.Adam,
use_fast_executor=True,
enable_sequential_execution=False):
main = fluid.Program()
startup = fluid.Program()
with fluid.program_guard(main, startup):
feed_dict, loss = cls.build_model(feed_dict, get_data_from_feeder,
main, memory_opt, method,
optimizer)
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(startup)
build_strategy, exec_strategy = cls.set_strategy(
enable_inplace, enable_sequential_execution, fuse_all_optimizer_ops,
fuse_all_reduce_ops, fuse_elewise_add_act_ops,
fuse_relu_depthwise_conv, use_fast_executor, use_ir_memory_optimize,
use_reduce, use_cuda)
binary = compiler.CompiledProgram(main).with_data_parallel(
loss_name=loss.name,
build_strategy=build_strategy,
exec_strategy=exec_strategy)
exe.run(binary, feed=feed_dict, fetch_list=[loss.name])
@classmethod
def set_strategy(cls, enable_inplace, enable_sequential_execution,
fuse_all_optimizer_ops, fuse_all_reduce_ops,
fuse_elewise_add_act_ops, fuse_relu_depthwise_conv,
use_fast_executor, use_ir_memory_optimize, use_reduce,
use_cuda):
exec_strategy = fluid.ExecutionStrategy()
if use_fast_executor:
exec_strategy.use_experimental_executor = True
build_strategy = fluid.BuildStrategy()
build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce \
if use_reduce else fluid.BuildStrategy.ReduceStrategy.AllReduce
build_strategy.fuse_elewise_add_act_ops = fuse_elewise_add_act_ops
build_strategy.fuse_relu_depthwise_conv = fuse_relu_depthwise_conv
build_strategy.fuse_all_optimizer_ops = fuse_all_optimizer_ops
build_strategy.fuse_all_reduce_ops = fuse_all_reduce_ops
build_strategy.memory_optimize = use_ir_memory_optimize
build_strategy.enable_inplace = enable_inplace
build_strategy.enable_sequential_execution = enable_sequential_execution
if use_cuda and core.is_compiled_with_cuda():
build_strategy.remove_unnecessary_lock = True
return build_strategy, exec_strategy
@classmethod
def build_model(cls, feed_dict, get_data_from_feeder, main, memory_opt,
method, optimizer):
loss = method(use_feed=feed_dict is not None)
# NOTE(zjl): memory_optimize/inplace pass would not require
# that loss.persistable = True
loss.persistable = memory_opt
if optimizer:
optimizer().minimize(loss)
if memory_opt:
fluid.memory_optimize(main)
if get_data_from_feeder is not None:
assert feed_dict is None
feed_dict = get_data_from_feeder()
return feed_dict, loss
......@@ -165,7 +165,6 @@ class TestDistRunnerBase(object):
exec_strategy = fluid.ExecutionStrategy()
exec_strategy.num_threads = 1
exec_strategy.allow_op_delay = False
build_stra = fluid.BuildStrategy()
# FIXME force disable enable_inplace and memory_optimize
......
......@@ -74,12 +74,6 @@ class TestFuseAdamOps(TestFuseOptimizationOps):
def optimizer(self, learning_rate=1e-4):
return fluid.optimizer.Adam(learning_rate=learning_rate)
def test_simple_fc_with_fuse_op(self):
self._decorate_compare_fused_optimizer_ops(
simple_fc_net, True, optimizer=self.optimizer)
self._decorate_compare_fused_optimizer_ops(
simple_fc_net, False, optimizer=self.optimizer)
def test_batchnorm_fc_with_fuse_op(self):
self._decorate_compare_fused_optimizer_ops(
fc_with_batchnorm, True, optimizer=self.optimizer)
......@@ -142,5 +136,48 @@ class TestSpareFuseMomentumOps(TestSpareFuseAdamOps):
learning_rate=learning_rate, momentum=0.1)
class TestPassConflictBase(TestFuseAdamOps):
def _compare_fused_optimizer_ops(self,
model,
use_cuda,
feed_dict=None,
get_data_from_feeder=None,
optimizer=fluid.optimizer.Adam):
if use_cuda and not core.is_compiled_with_cuda():
return
self.check_pass_conflict(
model,
feed_dict=feed_dict,
get_data_from_feeder=get_data_from_feeder,
use_cuda=use_cuda,
fuse_all_optimizer_ops=True,
memory_opt=False, # avoid the gradient's name changed in Python side.
optimizer=optimizer,
enable_sequential_execution=True)
class TestFuseAdamOpsPassConflict(TestPassConflictBase):
def optimizer(self, learning_rate=1e-4):
return fluid.optimizer.Adam(learning_rate=learning_rate)
def test_batchnorm_fc_with_fuse_op(self):
self._decorate_compare_fused_optimizer_ops(
fc_with_batchnorm, True, optimizer=self.optimizer)
self._decorate_compare_fused_optimizer_ops(
fc_with_batchnorm, False, optimizer=self.optimizer)
class TestFuseSGDOpsPassConflict(TestFuseAdamOpsPassConflict):
def optimizer(self, learning_rate=1e-3):
return fluid.optimizer.SGD(learning_rate=learning_rate)
class TestFuseMomentumOpsPassConflict(TestFuseAdamOpsPassConflict):
def optimizer(self, learning_rate=1e-3):
return fluid.optimizer.Momentum(
learning_rate=learning_rate, momentum=0.1)
if __name__ == '__main__':
unittest.main()
......@@ -135,14 +135,12 @@ class TestMNIST(TestParallelExecutorBase):
single_first_loss, single_last_loss = self.check_network_convergence(
method=simple_fc_net,
seed=1,
feed_dict={"image": img,
"label": label},
use_cuda=use_cuda,
use_parallel_executor=False)
parallel_first_loss, parallel_last_loss = self.check_network_convergence(
method=simple_fc_net,
seed=1,
feed_dict={"image": img,
"label": label},
use_cuda=use_cuda,
......
......@@ -54,14 +54,12 @@ class TestMNIST(TestParallelExecutorBase):
img, label = init_data()
single_first_loss, single_last_loss = self.check_network_convergence(
method=simple_fc_net,
seed=1,
feed_dict={"image": img,
"label": label},
use_cuda=use_cuda,
use_parallel_executor=False)
parallel_first_loss, parallel_last_loss = self.check_network_convergence(
method=simple_fc_net,
seed=1,
feed_dict={"image": img,
"label": label},
use_cuda=use_cuda,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册