From 6974265292c6db0080e2bc4d7001cb0641d12d9c Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Fri, 21 Sep 2018 20:04:29 +0800 Subject: [PATCH] support offline train --- paddle/fluid/CMakeLists.txt | 2 + paddle/fluid/train/CMakeLists.txt | 28 ++++++ .../train/test_train_recognize_digits.cc | 91 +++++++++++++++++++ python/paddle/fluid/io.py | 61 ++++++------- .../fluid/tests/book/test_recognize_digits.py | 13 +++ 5 files changed, 164 insertions(+), 31 deletions(-) create mode 100644 paddle/fluid/train/CMakeLists.txt create mode 100644 paddle/fluid/train/test_train_recognize_digits.cc diff --git a/paddle/fluid/CMakeLists.txt b/paddle/fluid/CMakeLists.txt index ee1f655e25d..519a00fb073 100644 --- a/paddle/fluid/CMakeLists.txt +++ b/paddle/fluid/CMakeLists.txt @@ -13,3 +13,5 @@ if(WITH_INFERENCE) # NOTE: please add subdirectory inference at last. add_subdirectory(inference) endif() + +add_subdirectory(train) diff --git a/paddle/fluid/train/CMakeLists.txt b/paddle/fluid/train/CMakeLists.txt new file mode 100644 index 00000000000..9f10f73637c --- /dev/null +++ b/paddle/fluid/train/CMakeLists.txt @@ -0,0 +1,28 @@ +function(train_test TARGET_NAME) + set(options "") + set(oneValueArgs "") + set(multiValueArgs ARGS) + cmake_parse_arguments(train_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + + set(PYTHON_TESTS_DIR ${PADDLE_BINARY_DIR}/python/paddle/fluid/tests) + set(arg_list "") + if(train_test_ARGS) + foreach(arg ${train_test_ARGS}) + list(APPEND arg_list "_${arg}") + endforeach() + else() + list(APPEND arg_list "_") + endif() + foreach(arg ${arg_list}) + string(REGEX REPLACE "^_$" "" arg "${arg}") + cc_test(test_train_${TARGET_NAME}${arg} + SRCS test_train_${TARGET_NAME}.cc + DEPS paddle_fluid_origin + ARGS --dirname=${PYTHON_TESTS_DIR}/book/${TARGET_NAME}${arg}.train.model/) + set_tests_properties(test_train_${TARGET_NAME}${arg} + PROPERTIES DEPENDS test_${TARGET_NAME}) + endforeach() +endfunction(train_test) + + +train_test(recognize_digits ARGS mlp conv) diff --git a/paddle/fluid/train/test_train_recognize_digits.cc b/paddle/fluid/train/test_train_recognize_digits.cc new file mode 100644 index 00000000000..45997985ea8 --- /dev/null +++ b/paddle/fluid/train/test_train_recognize_digits.cc @@ -0,0 +1,91 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include + +#include "gflags/gflags.h" +#include "gtest/gtest.h" + +#include "paddle/fluid/framework/executor.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/inference/io.h" +#include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/init.h" +#include "paddle/fluid/platform/place.h" + +DEFINE_string(dirname, "", "Directory of the train model."); + +namespace paddle { + +void Train() { + CHECK(!FLAGS_dirname.empty()); + framework::InitDevices(false); + const auto cpu_place = platform::CPUPlace(); + framework::Executor executor(cpu_place); + framework::Scope scope; + + auto train_program = inference::Load( + &executor, &scope, FLAGS_dirname + "__model_combined__.main_program", + FLAGS_dirname + "__params_combined__"); + + std::string loss_name = ""; + for (auto op_desc : train_program->Block(0).AllOps()) { + if (op_desc->Type() == "mean") { + loss_name = op_desc->Output("Out")[0]; + break; + } + } + + PADDLE_ENFORCE_NE(loss_name, "", "loss not found"); + + // init all parameters + + // prepare data + auto x_var = scope.Var("img"); + auto x_tensor = x_var->GetMutable(); + x_tensor->Resize({64, 1, 28, 28}); + + auto x_data = x_tensor->mutable_data(cpu_place); + for (int i = 0; i < 64 * 28 * 28; ++i) { + x_data[i] = 1.0; + } + + auto y_var = scope.Var("label"); + auto y_tensor = y_var->GetMutable(); + y_tensor->Resize({64, 1}); + auto y_data = y_tensor->mutable_data(cpu_place); + for (int i = 0; i < 64 * 1; ++i) { + y_data[i] = static_cast(1); + } + + auto loss_var = scope.Var(loss_name); + float first_loss = 0.0; + float last_loss = 0.0; + for (int i = 0; i < 100; ++i) { + executor.Run(*train_program.get(), &scope, 0, false, true); + if (i == 0) { + first_loss = loss_var->Get().data()[0]; + } else if (i == 99) { + last_loss = loss_var->Get().data()[0]; + } + } + EXPECT_LT(last_loss, first_loss); +} + +TEST(train, recognize_digits) { Train(); } + +} // namespace paddle diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index e703e5ac794..01d5e7a11ea 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -600,7 +600,7 @@ def save_inference_model(dirname, """ if isinstance(feeded_var_names, six.string_types): feeded_var_names = [feeded_var_names] - else: + elif export_for_deployment: if len(feeded_var_names) > 0: # TODO(paddle-dev): polish these code blocks if not (bool(feeded_var_names) and all( @@ -610,61 +610,60 @@ def save_inference_model(dirname, if isinstance(target_vars, Variable): target_vars = [target_vars] - else: + elif export_for_deployment: if not (bool(target_vars) and all( isinstance(var, Variable) for var in target_vars)): raise ValueError("'target_vars' should be a list of Variable.") if main_program is None: main_program = default_main_program() - copy_program = main_program.clone() + + if params_filename is not None: + params_filename = os.path.basename(params_filename) + save_persistables(executor, dirname, main_program, params_filename) + + # if there is lookup table, the trainer 0 will notify all pserver to save. + if main_program._is_distributed and main_program._is_chief and main_program._distributed_lookup_table: + lookup_table_filename = os.path.join(dirname, "__lookup_table__") + _save_lookup_tables_by_notify(executor, lookup_table_filename, + main_program._distributed_lookup_table, + main_program._endpoints) if not os.path.isdir(dirname): os.makedirs(dirname) + if model_filename is not None: + model_basename = os.path.basename(model_filename) + else: + model_basename = "__model__" + model_basename = os.path.join(dirname, model_basename) # When export_for_deployment is true, we modify the program online so that # it can only be loaded for inference directly. If it's false, the whole # original program and related meta are saved so that future usage can be # more flexible. if export_for_deployment: - global_block = copy_program.global_block() + main_program = main_program.clone() + global_block = main_program.global_block() for i, op in enumerate(global_block.ops): op.desc.set_is_target(False) if op.type == "feed" or op.type == "fetch": global_block._remove_op(i) - copy_program.desc.flush() + main_program.desc.flush() - pruned_program = copy_program._prune(targets=target_vars) - saved_program = pruned_program._inference_optimize(prune_read_op=True) + main_program = main_program._prune(targets=target_vars) + main_program = main_program._inference_optimize(prune_read_op=True) fetch_var_names = [v.name for v in target_vars] - prepend_feed_ops(saved_program, feeded_var_names) - append_fetch_ops(saved_program, fetch_var_names) + prepend_feed_ops(main_program, feeded_var_names) + append_fetch_ops(main_program, fetch_var_names) + + with open(model_basename, "wb") as f: + f.write(main_program.desc.serialize_to_string()) else: # TODO(panyx0718): Save more information so that it can also be used # for training and more flexible post-processing. - saved_program = copy_program - - if model_filename is not None: - model_filename = os.path.basename(model_filename) - else: - model_filename = "__model__" - model_filename = os.path.join(dirname, model_filename) - - if params_filename is not None: - params_filename = os.path.basename(params_filename) - - with open(model_filename, "wb") as f: - f.write(saved_program.desc.serialize_to_string()) - - save_persistables(executor, dirname, saved_program, params_filename) - - # if there is lookup table, the trainer 0 will notify all pserver to save. - if main_program._is_distributed and main_program._is_chief and main_program._distributed_lookup_table: - lookup_table_filename = os.path.join(dirname, "__lookup_table__") - _save_lookup_tables_by_notify(executor, lookup_table_filename, - main_program._distributed_lookup_table, - main_program._endpoints) + with open(model_basename + ".main_program", "wb") as f: + f.write(main_program.desc.serialize_to_string()) def load_inference_model(dirname, diff --git a/python/paddle/fluid/tests/book/test_recognize_digits.py b/python/paddle/fluid/tests/book/test_recognize_digits.py index 4b4f3e40377..383f3772253 100644 --- a/python/paddle/fluid/tests/book/test_recognize_digits.py +++ b/python/paddle/fluid/tests/book/test_recognize_digits.py @@ -67,6 +67,7 @@ def train(nn_type, use_cuda, parallel, save_dirname=None, + save_full_dirname=None, model_filename=None, params_filename=None, is_local=True): @@ -143,6 +144,15 @@ def train(nn_type, exe, model_filename=model_filename, params_filename=params_filename) + if save_full_dirname is not None: + fluid.io.save_inference_model( + save_full_dirname, + None, + None, + exe, + model_filename=model_filename, + params_filename=params_filename, + export_for_deployment=False) return else: print( @@ -214,10 +224,12 @@ def infer(use_cuda, def main(use_cuda, parallel, nn_type, combine): save_dirname = None + save_full_dirname = None model_filename = None params_filename = None if not use_cuda and not parallel: save_dirname = "recognize_digits_" + nn_type + ".inference.model" + save_full_dirname = "recognize_digits_" + nn_type + ".train.model" if combine == True: model_filename = "__model_combined__" params_filename = "__params_combined__" @@ -228,6 +240,7 @@ def main(use_cuda, parallel, nn_type, combine): use_cuda=use_cuda, parallel=parallel, save_dirname=save_dirname, + save_full_dirname=save_full_dirname, model_filename=model_filename, params_filename=params_filename) infer( -- GitLab