提交 cb40c331 编写于 作者: Y Yu Yang

Update unittest

上级 ee97687f
......@@ -33,7 +33,7 @@ void ComputationOpHandle::RunImpl() {
}
}
op_->Run(*scope_, place_);
op_->Run(*scope_->FindVar("@TMP_SCOPE@")->Get<Scope *>(), place_);
}
std::string ComputationOpHandle::Name() const { return op_->Type(); }
......
......@@ -112,6 +112,12 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
ready_ops.clear();
};
// Create local scopes.
for (auto &scope : local_scopes_) {
auto &local_scope = scope->NewScope();
*scope->Var("@TMP_SCOPE@")->GetMutable<Scope *>() = &local_scope;
}
// Step 3. Execution
while (!pending_vars.empty()) {
// 1. Run All Ready ops
......@@ -156,9 +162,32 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
// Keep loop until all vars are ready.
}
++computation_count_;
auto sync_computation = [&] {
computation_count_ = 0;
// Wait All computational streams
for (auto p : this->places_) {
platform::DeviceContextPool::Instance().Get(p)->Wait();
}
// NOTE: the temp scope can be dropped lazily if needed.
// Drop tmp scopes;
for (auto &scope : local_scopes_) {
auto &kid = *scope->Var("@TMP_SCOPE@")->GetMutable<Scope *>();
kid = nullptr;
scope->DropKids();
}
};
// Wait FetchOps.
for (auto &fetch_op : fetch_ops) {
fetch_op.WaitAndMergeCPUTensors();
sync_computation();
}
if (computation_count_ == max_async_computation) {
sync_computation();
}
return fetch_data;
......
......@@ -48,6 +48,9 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
platform::DeviceContextPool fetch_ctxs_;
const bool use_event_;
std::unique_ptr<platform::EnforceNotMet> exception_;
size_t computation_count_{0};
size_t max_async_computation{100};
};
} // namespace details
......
......@@ -178,7 +178,32 @@ def SE_ResNeXt152():
return loss
class ParallelExecutor(unittest.TestCase):
class TestParallelExecutorBase(unittest.TestCase):
def check_network_convergence(self, method, memory_opt=True, iter=10):
main = fluid.Program()
startup = fluid.Program()
with fluid.program_guard(main, startup):
loss = method()
adam = fluid.optimizer.Adam()
adam.minimize(loss)
if memory_opt:
fluid.memory_optimize(main)
exe = fluid.ParallelExecutor(loss_name=loss.name, use_cuda=True)
first_loss, = exe.run([loss.name])
first_loss = numpy.array(first_loss)
for i in xrange(iter):
exe.run([])
last_loss, = exe.run([loss.name])
last_loss = numpy.array(last_loss)
print first_loss, last_loss
self.assertGreater(first_loss[0], last_loss[0])
class TestMNIST(TestParallelExecutorBase):
@classmethod
def setUpClass(cls):
# Convert mnist to recordio file
......@@ -195,6 +220,16 @@ class ParallelExecutor(unittest.TestCase):
fluid.recordio_writer.convert_reader_to_recordio_file(
'./mnist.recordio', reader, feeder)
def test_simple_fc(self):
self.check_network_convergence(simple_fc_net)
def test_batchnorm_fc(self):
self.check_network_convergence(fc_with_batchnorm)
class TestResnet(TestParallelExecutorBase):
@classmethod
def setUpClass(cls):
with fluid.program_guard(fluid.Program(), fluid.Program()):
reader = paddle.batch(flowers.train(), batch_size=4)
feeder = fluid.DataFeeder(
......@@ -208,34 +243,5 @@ class ParallelExecutor(unittest.TestCase):
fluid.recordio_writer.convert_reader_to_recordio_file(
"./flowers.recordio", reader, feeder)
def test_simple_fc(self):
self.check_network_convergence(simple_fc_net)
def test_batchnorm_fc(self):
self.check_network_convergence(fc_with_batchnorm)
def check_network_convergence(self, method, memory_opt=True, iter=10):
main = fluid.Program()
startup = fluid.Program()
with fluid.program_guard(main, startup):
loss = method()
adam = fluid.optimizer.Adam()
adam.minimize(loss)
if memory_opt:
fluid.memory_optimize(main)
exe = fluid.ParallelExecutor(loss_name=loss.name, use_cuda=True)
first_loss, = exe.run([loss.name])
first_loss = numpy.array(first_loss)
for i in xrange(iter):
exe.run([])
last_loss, = exe.run([loss.name])
last_loss = numpy.array(last_loss)
print first_loss, last_loss
self.assertGreater(first_loss[0], last_loss[0])
def test_resnet(self):
self.check_network_convergence(SE_ResNeXt152, iter=20)
self.check_network_convergence(SE_ResNeXt152, iter=200)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册