提交 2e5d44f1 编写于 作者: C chengduoZH

fix fetch op

上级 99acf1da
...@@ -49,7 +49,7 @@ void FetchOpHandle::RunImpl() { ...@@ -49,7 +49,7 @@ void FetchOpHandle::RunImpl() {
platform::DeviceContextPool::Instance().Get(platform::CPUPlace()); platform::DeviceContextPool::Instance().Get(platform::CPUPlace());
for (auto *input : inputs_) { for (auto *input : inputs_) {
auto *var = static_cast<VarHandle *>(input); auto *var = static_cast<VarHandle *>(input);
var->generated_op_->Wait(cpu_ctx); if (var->generated_op_) var->generated_op_->Wait(cpu_ctx);
} }
tensors_.resize(inputs_.size()); tensors_.resize(inputs_.size());
auto *var_handle = static_cast<VarHandle *>(inputs_[0]); auto *var_handle = static_cast<VarHandle *>(inputs_[0]);
...@@ -61,9 +61,14 @@ void FetchOpHandle::RunImpl() { ...@@ -61,9 +61,14 @@ void FetchOpHandle::RunImpl() {
auto &scope = scopes[i]; auto &scope = scopes[i];
auto *var = auto *var =
scope->FindVar(kLocalExecScopeName)->Get<Scope *>()->FindVar(var_name); scope->FindVar(kLocalExecScopeName)->Get<Scope *>()->FindVar(var_name);
if (var == nullptr) {
scope->FindVar(var_name);
}
PADDLE_ENFORCE_NOT_NULL(var, "Cannot find variable %s in execution scope", PADDLE_ENFORCE_NOT_NULL(var, "Cannot find variable %s in execution scope",
var_name); var_name);
auto &t = var->Get<framework::LoDTensor>(); auto &t = var->Get<framework::LoDTensor>();
if (platform::is_gpu_place(t.place())) { if (platform::is_gpu_place(t.place())) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
TensorCopySync(t, cpu, &tensors_[i]); TensorCopySync(t, cpu, &tensors_[i]);
......
...@@ -721,3 +721,83 @@ class TestCRFModel(unittest.TestCase): ...@@ -721,3 +721,83 @@ class TestCRFModel(unittest.TestCase):
def test_update_dense_parameter(self): def test_update_dense_parameter(self):
self.check_network_convergence(is_sparse=False) self.check_network_convergence(is_sparse=False)
# test fetch op
import paddle.dataset.flowers as flowers
def lenet(data, class_dim):
conv1 = fluid.layers.conv2d(data, 32, 5, 1, act=None)
bn1 = fluid.layers.batch_norm(conv1, act='relu')
pool1 = fluid.layers.pool2d(bn1, 2, 'max', 2)
conv2 = fluid.layers.conv2d(pool1, 50, 5, 1, act=None)
bn2 = fluid.layers.batch_norm(conv2, act='relu')
pool2 = fluid.layers.pool2d(bn2, 2, 'max', 2)
fc1 = fluid.layers.fc(pool2, size=500, act='relu')
fc2 = fluid.layers.fc(fc1, size=class_dim, act='softmax')
return fc2
class TestFetchOp(unittest.TestCase):
def parallel_exe(self, train_inputs, seed):
main = fluid.Program()
startup = fluid.Program()
startup.random_seed = seed
with fluid.program_guard(main, startup):
data = fluid.layers.data(
name='image', shape=[3, 224, 224], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
out = lenet(data, class_dim=102)
loss = fluid.layers.cross_entropy(input=out, label=label)
loss = fluid.layers.mean(loss)
opt = fluid.optimizer.Momentum(
learning_rate=0.1,
momentum=0.9,
regularization=fluid.regularizer.L2Decay(1e-4))
opt.minimize(loss)
# TODO(zcd): I found that onece the memory optimizer is open,
# parallel_exe doesn't fetch some variable, such as conv2d_0.b_0@GRAD, conv2d_1.b_0@GRAD.
# fluid.memory_optimize(main)
place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
exe.run(startup)
feeder = fluid.DataFeeder(place=place, feed_list=[data, label])
pe = fluid.ParallelExecutor(
use_cuda=True, loss_name=loss.name, main_program=main)
fetch_list = []
for data in train_inputs:
all_vars = main.global_block().vars
for k, v in all_vars.iteritems():
if v.persistable and 'velocity' not in k:
fetch_list.append(k)
ret = pe.run(fetch_list, feed=feeder.feed(data))
result = {}
for i in range(len(fetch_list)):
result[fetch_list[i]] = np.sum(ret[i])
def test_update_sparse_parameter(self):
tst_reader = paddle.batch(flowers.test(use_xmap=False), batch_size=16)
tst_reader_iter = tst_reader()
seed = 100
iters = 4
train_inputs = []
for i in range(iters):
train_inputs.append(tst_reader_iter.next())
self.parallel_exe(train_inputs, seed)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册