From a06205f569564c44e2aded4b82ee0d0f7ae93f9a Mon Sep 17 00:00:00 2001 From: Yang Yu Date: Wed, 24 Jan 2018 16:06:13 +0800 Subject: [PATCH] Add demo for parallel.do Unify the recognize_digits --- paddle/operators/parallel_do_op.cc | 35 ++++- .../paddle/v2/fluid/tests/book/CMakeLists.txt | 26 +++- python/paddle/v2/fluid/tests/book/__init__.py | 13 ++ .../tests/book/test_fit_a_line_parallel_do.py | 57 ------- .../fluid/tests/book/test_recognize_digits.py | 140 ++++++++++++++++++ .../tests/book/test_recognize_digits_conv.py | 74 --------- .../tests/book/test_recognize_digits_mlp.py | 96 ------------ 7 files changed, 208 insertions(+), 233 deletions(-) create mode 100644 python/paddle/v2/fluid/tests/book/__init__.py delete mode 100644 python/paddle/v2/fluid/tests/book/test_fit_a_line_parallel_do.py create mode 100644 python/paddle/v2/fluid/tests/book/test_recognize_digits.py delete mode 100644 python/paddle/v2/fluid/tests/book/test_recognize_digits_conv.py delete mode 100644 python/paddle/v2/fluid/tests/book/test_recognize_digits_mlp.py diff --git a/paddle/operators/parallel_do_op.cc b/paddle/operators/parallel_do_op.cc index 09e808902f8..f9cf40bb57f 100644 --- a/paddle/operators/parallel_do_op.cc +++ b/paddle/operators/parallel_do_op.cc @@ -17,6 +17,7 @@ limitations under the License. */ #include "paddle/framework/executor.h" #include "paddle/framework/op_registry.h" #include "paddle/framework/threadpool.h" +#include "paddle/operators/detail/safe_ref.h" namespace paddle { namespace operators { @@ -39,8 +40,10 @@ static void SplitTensorAndMoveTensorToScopes( const std::vector &names) { size_t num_sub_scopes = 0; for (auto &argu : names) { - auto *var = scope.FindVar(argu); - const auto &tensor = var->Get(); + const auto &tensor = + detail::Ref(scope.FindVar(argu), + "Cannot find variable %s in the parent scope", argu) + .Get(); auto lod_tensors = tensor.SplitLoDTensor(places); for (auto &lod : lod_tensors) { @@ -60,7 +63,9 @@ static void SplitTensorAndMoveTensorToScopes( } for (size_t i = 0; i < lod_tensors.size(); ++i) { - *(*sub_scopes)[i]->Var(argu)->GetMutable() = lod_tensors[i]; + *detail::Ref(sub_scopes->at(i)->Var(argu), + "Cannot find variable in the sub-scope", argu) + .GetMutable() = lod_tensors[i]; } } } @@ -287,6 +292,17 @@ class ParallelDoGradOpDescMaker : public framework::SingleGradOpDescMaker { this->InputGrad(input_param, false)); } } + auto *g_block = this->grad_block_[0]; + + // All variable name that needed by gradient operators + std::unordered_set all_inputs_in_grad_blocks; + + for (size_t i = 0; i < g_block->OpSize(); ++i) { + auto *op = g_block->Op(i); + for (auto &var_name : op->InputArgumentNames()) { + all_inputs_in_grad_blocks.insert(var_name); + } + } for (auto &output_param : this->OutputNames()) { if (output_param == kParallelScopes) { @@ -295,8 +311,17 @@ class ParallelDoGradOpDescMaker : public framework::SingleGradOpDescMaker { this->Output(output_param)); } else { grad->SetInput(output_param, this->Output(output_param)); - grad->SetInput(framework::GradVarName(output_param), - this->OutputGrad(output_param)); + std::vector og_names; + for (auto &og_name : this->OutputGrad(output_param)) { + if (all_inputs_in_grad_blocks.count(og_name) != 0) { + // there is some gradient operator needs the og, make this og as the + // input of parallel.do + // if there is no operator need this og, just do not make this og as + // input. + og_names.push_back(og_name); + } + } + grad->SetInput(framework::GradVarName(output_param), og_names); } } grad->SetAttrMap(this->Attrs()); diff --git a/python/paddle/v2/fluid/tests/book/CMakeLists.txt b/python/paddle/v2/fluid/tests/book/CMakeLists.txt index a35abe3e0c4..dda02c03fd5 100644 --- a/python/paddle/v2/fluid/tests/book/CMakeLists.txt +++ b/python/paddle/v2/fluid/tests/book/CMakeLists.txt @@ -1,9 +1,33 @@ file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py") string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") -list(REMOVE_ITEM TEST_OPS test_image_classification_train) +list(REMOVE_ITEM TEST_OPS test_image_classification_train test_recognize_digits) py_test(test_image_classification_train_resnet SRCS test_image_classification_train.py ARGS resnet) py_test(test_image_classification_train_vgg SRCS test_image_classification_train.py ARGS vgg) +py_test(test_recognize_digits_mlp_cpu + SRCS test_recognize_digits.py + ARGS mlp) +py_test(test_recognize_digits_mlp_cuda + SRCS test_recognize_digits.py + ARGS mlp --use_cuda) +py_test(test_recognize_digits_conv_cpu + SRCS test_recognize_digits.py + ARGS conv) +py_test(test_recognize_digits_conv_cuda + SRCS test_recognize_digits.py + ARGS conv --use_cuda) +py_test(test_recognize_digits_mlp_cpu_parallel + SRCS test_recognize_digits.py + ARGS mlp --parallel) +py_test(test_recognize_digits_mlp_cuda_parallel + SRCS test_recognize_digits.py + ARGS mlp --use_cuda --parallel) +py_test(test_recognize_digits_conv_cpu_parallel + SRCS test_recognize_digits.py + ARGS conv --parallel) +py_test(test_recognize_digits_conv_cuda_parallel + SRCS test_recognize_digits.py + ARGS conv --use_cuda --parallel) # default test foreach(src ${TEST_OPS}) diff --git a/python/paddle/v2/fluid/tests/book/__init__.py b/python/paddle/v2/fluid/tests/book/__init__.py new file mode 100644 index 00000000000..b94a21a7e40 --- /dev/null +++ b/python/paddle/v2/fluid/tests/book/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/python/paddle/v2/fluid/tests/book/test_fit_a_line_parallel_do.py b/python/paddle/v2/fluid/tests/book/test_fit_a_line_parallel_do.py deleted file mode 100644 index 9693b26d549..00000000000 --- a/python/paddle/v2/fluid/tests/book/test_fit_a_line_parallel_do.py +++ /dev/null @@ -1,57 +0,0 @@ -# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import paddle.v2 as paddle -import paddle.v2.fluid as fluid - -x = fluid.layers.data(name='x', shape=[13], dtype='float32') -y = fluid.layers.data(name='y', shape=[1], dtype='float32') - -places = fluid.layers.get_places() -pd = fluid.layers.ParallelDo(places=places) -with pd.do(): - x_ = pd.read_input(x) - y_ = pd.read_input(y) - y_predict = fluid.layers.fc(input=x_, size=1, act=None) - cost = fluid.layers.square_error_cost(input=y_predict, label=y_) - pd.write_output(fluid.layers.mean(x=cost)) - -avg_cost = fluid.layers.mean(x=pd()) - -sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001) -sgd_optimizer.minimize(avg_cost) - -BATCH_SIZE = 20 - -train_reader = paddle.batch( - paddle.reader.shuffle( - paddle.dataset.uci_housing.train(), buf_size=500), - batch_size=BATCH_SIZE) - -place = fluid.CPUPlace() -feeder = fluid.DataFeeder(place=place, feed_list=[x, y]) -exe = fluid.Executor(place) - -exe.run(fluid.default_startup_program()) - -PASS_NUM = 100 -for pass_id in range(PASS_NUM): - for data in train_reader(): - avg_loss_value, = exe.run(fluid.default_main_program(), - feed=feeder.feed(data), - fetch_list=[avg_cost]) - print(avg_loss_value) - if avg_loss_value[0] < 10.0: - exit(0) # if avg cost less than 10.0, we think our code is good. -exit(1) diff --git a/python/paddle/v2/fluid/tests/book/test_recognize_digits.py b/python/paddle/v2/fluid/tests/book/test_recognize_digits.py new file mode 100644 index 00000000000..4ecdcdc6327 --- /dev/null +++ b/python/paddle/v2/fluid/tests/book/test_recognize_digits.py @@ -0,0 +1,140 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import print_function +import argparse +import paddle.v2.fluid as fluid +import paddle.v2 as paddle +import sys + + +def parse_arg(): + parser = argparse.ArgumentParser() + parser.add_argument( + "nn_type", + help="The neural network type, in ['mlp', 'conv']", + type=str, + choices=['mlp', 'conv']) + parser.add_argument( + "--parallel", + help='Run in parallel or not', + default=False, + action="store_true") + parser.add_argument( + "--use_cuda", + help="Run the program by using CUDA", + default=False, + action="store_true") + return parser.parse_args() + + +BATCH_SIZE = 64 + + +def loss_net(hidden, label): + prediction = fluid.layers.fc(input=hidden, size=10, act='softmax') + loss = fluid.layers.cross_entropy(input=prediction, label=label) + return fluid.layers.mean(x=loss), fluid.layers.accuracy( + input=prediction, label=label) + + +def mlp(img, label): + hidden = fluid.layers.fc(input=img, size=200, act='tanh') + hidden = fluid.layers.fc(input=hidden, size=200, act='tanh') + return loss_net(hidden, label) + + +def conv_net(img, label): + conv_pool_1 = fluid.nets.simple_img_conv_pool( + input=img, + filter_size=5, + num_filters=20, + pool_size=2, + pool_stride=2, + act="relu") + conv_pool_2 = fluid.nets.simple_img_conv_pool( + input=conv_pool_1, + filter_size=5, + num_filters=50, + pool_size=2, + pool_stride=2, + act="relu") + return loss_net(conv_pool_2, label) + + +def main(): + args = parse_arg() + print("recognize digits with args: {0}".format(" ".join(sys.argv[1:]))) + + img = fluid.layers.data(name='img', shape=[1, 28, 28], dtype='float32') + label = fluid.layers.data(name='label', shape=[1], dtype='int64') + + if args.nn_type == 'mlp': + net_conf = mlp + else: + net_conf = conv_net + + if args.parallel: + places = fluid.layers.get_places() + pd = fluid.layers.ParallelDo(places) + with pd.do(): + img_ = pd.read_input(img) + label_ = pd.read_input(label) + for o in net_conf(img_, label_): + pd.write_output(o) + + avg_loss, acc = pd() + # get mean loss and acc through every devices. + avg_loss = fluid.layers.mean(x=avg_loss) + acc = fluid.layers.mean(x=acc) + else: + avg_loss, acc = net_conf(img, label) + + optimizer = fluid.optimizer.Adam(learning_rate=0.001) + optimizer.minimize(avg_loss) + + place = fluid.CUDAPlace(0) if args.use_cuda else fluid.CPUPlace() + + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + + train_reader = paddle.batch( + paddle.reader.shuffle( + paddle.dataset.mnist.train(), buf_size=500), + batch_size=BATCH_SIZE) + feeder = fluid.DataFeeder(feed_list=[img, label], place=place) + + PASS_NUM = 100 + for pass_id in range(PASS_NUM): + for batch_id, data in enumerate(train_reader()): + need_check = (batch_id + 1) % 10 == 0 + + if need_check: + fetch_list = [avg_loss, acc] + else: + fetch_list = [] + + outs = exe.run(feed=feeder.feed(data), fetch_list=fetch_list) + if need_check: + avg_loss_np, acc_np = outs + if float(acc_np) > 0.9: + exit(0) + else: + print( + 'PassID {0:1}, BatchID {1:04}, Loss {2:2.2}, Acc {3:2.2}'. + format(pass_id, batch_id + 1, + float(avg_loss_np), float(acc_np))) + + +if __name__ == '__main__': + main() diff --git a/python/paddle/v2/fluid/tests/book/test_recognize_digits_conv.py b/python/paddle/v2/fluid/tests/book/test_recognize_digits_conv.py deleted file mode 100644 index 4710d16c24e..00000000000 --- a/python/paddle/v2/fluid/tests/book/test_recognize_digits_conv.py +++ /dev/null @@ -1,74 +0,0 @@ -# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import print_function -import numpy as np -import paddle.v2 as paddle -import paddle.v2.fluid as fluid - -images = fluid.layers.data(name='pixel', shape=[1, 28, 28], dtype='float32') -label = fluid.layers.data(name='label', shape=[1], dtype='int64') -conv_pool_1 = fluid.nets.simple_img_conv_pool( - input=images, - filter_size=5, - num_filters=20, - pool_size=2, - pool_stride=2, - act="relu") -conv_pool_2 = fluid.nets.simple_img_conv_pool( - input=conv_pool_1, - filter_size=5, - num_filters=50, - pool_size=2, - pool_stride=2, - act="relu") - -predict = fluid.layers.fc(input=conv_pool_2, size=10, act="softmax") -cost = fluid.layers.cross_entropy(input=predict, label=label) -avg_cost = fluid.layers.mean(x=cost) -optimizer = fluid.optimizer.Adam(learning_rate=0.01) -optimizer.minimize(avg_cost) - -accuracy = fluid.evaluator.Accuracy(input=predict, label=label) - -BATCH_SIZE = 50 -PASS_NUM = 3 -train_reader = paddle.batch( - paddle.reader.shuffle( - paddle.dataset.mnist.train(), buf_size=500), - batch_size=BATCH_SIZE) - -place = fluid.CPUPlace() -exe = fluid.Executor(place) -feeder = fluid.DataFeeder(feed_list=[images, label], place=place) -exe.run(fluid.default_startup_program()) - -for pass_id in range(PASS_NUM): - accuracy.reset(exe) - for data in train_reader(): - loss, acc = exe.run(fluid.default_main_program(), - feed=feeder.feed(data), - fetch_list=[avg_cost] + accuracy.metrics) - pass_acc = accuracy.eval(exe) - print("pass_id=" + str(pass_id) + " acc=" + str(acc) + " pass_acc=" + - str(pass_acc)) - # print loss, acc - if loss < 10.0 and pass_acc > 0.9: - # if avg cost less than 10.0 and accuracy is larger than 0.9, we think our code is good. - exit(0) - - pass_acc = accuracy.eval(exe) - print("pass_id=" + str(pass_id) + " pass_acc=" + str(pass_acc)) - -exit(1) diff --git a/python/paddle/v2/fluid/tests/book/test_recognize_digits_mlp.py b/python/paddle/v2/fluid/tests/book/test_recognize_digits_mlp.py deleted file mode 100644 index 8776a65bf80..00000000000 --- a/python/paddle/v2/fluid/tests/book/test_recognize_digits_mlp.py +++ /dev/null @@ -1,96 +0,0 @@ -# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import print_function -import numpy as np -import paddle.v2 as paddle -import paddle.v2.fluid as fluid - -BATCH_SIZE = 128 -image = fluid.layers.data(name='x', shape=[784], dtype='float32') - -regularizer = fluid.regularizer.L2Decay(0.0005 * BATCH_SIZE) - -hidden1 = fluid.layers.fc(input=image, - size=128, - act='relu', - param_attr=fluid.ParamAttr( - regularizer=regularizer, - gradient_clip=fluid.clip.ClipByValue(10))) - -hidden2 = fluid.layers.fc(input=hidden1, - size=64, - act='relu', - param_attr=regularizer) - -predict = fluid.layers.fc(input=hidden2, - size=10, - act='softmax', - param_attr=regularizer) - -label = fluid.layers.data(name='y', shape=[1], dtype='int64') - -cost = fluid.layers.cross_entropy(input=predict, label=label) -avg_cost = fluid.layers.mean(x=cost) - -optimizer = fluid.optimizer.Momentum(learning_rate=0.001, momentum=0.9) -opts = optimizer.minimize(avg_cost) - -accuracy = fluid.evaluator.Accuracy(input=predict, label=label) - -inference_program = fluid.default_main_program().clone() -with fluid.program_guard(inference_program): - test_accuracy = fluid.evaluator.Accuracy(input=predict, label=label) - test_target = [avg_cost] + test_accuracy.metrics + test_accuracy.states - inference_program = fluid.io.get_inference_program(test_target) - -train_reader = paddle.batch( - paddle.reader.shuffle( - paddle.dataset.mnist.train(), buf_size=8192), - batch_size=BATCH_SIZE) - -test_reader = paddle.batch(paddle.dataset.mnist.test(), batch_size=128) - -place = fluid.CPUPlace() -exe = fluid.Executor(place) -feeder = fluid.DataFeeder(feed_list=[image, label], place=place) -exe.run(fluid.default_startup_program()) - -PASS_NUM = 100 -for pass_id in range(PASS_NUM): - accuracy.reset(exe) - for data in train_reader(): - out, acc = exe.run(fluid.default_main_program(), - feed=feeder.feed(data), - fetch_list=[avg_cost] + accuracy.metrics) - pass_acc = accuracy.eval(exe) - - test_accuracy.reset(exe) - for data in test_reader(): - out, acc = exe.run(inference_program, - feed=feeder.feed(data), - fetch_list=[avg_cost] + test_accuracy.metrics) - - test_pass_acc = test_accuracy.eval(exe) - print("pass_id=" + str(pass_id) + " train_cost=" + str( - out) + " train_acc=" + str(acc) + " train_pass_acc=" + str(pass_acc) - + " test_acc=" + str(test_pass_acc)) - - if test_pass_acc > 0.7: - fluid.io.save_inference_model( - "./recognize_digits_mlp.inference.model/", ["x"], [predict], - exe) - exit(0) - -exit(1) -- GitLab