提交 69742652 编写于 作者: X Xin Pan

support offline train

上级 eb1aeb17
...@@ -13,3 +13,5 @@ if(WITH_INFERENCE) ...@@ -13,3 +13,5 @@ if(WITH_INFERENCE)
# NOTE: please add subdirectory inference at last. # NOTE: please add subdirectory inference at last.
add_subdirectory(inference) add_subdirectory(inference)
endif() endif()
add_subdirectory(train)
function(train_test TARGET_NAME)
set(options "")
set(oneValueArgs "")
set(multiValueArgs ARGS)
cmake_parse_arguments(train_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
set(PYTHON_TESTS_DIR ${PADDLE_BINARY_DIR}/python/paddle/fluid/tests)
set(arg_list "")
if(train_test_ARGS)
foreach(arg ${train_test_ARGS})
list(APPEND arg_list "_${arg}")
endforeach()
else()
list(APPEND arg_list "_")
endif()
foreach(arg ${arg_list})
string(REGEX REPLACE "^_$" "" arg "${arg}")
cc_test(test_train_${TARGET_NAME}${arg}
SRCS test_train_${TARGET_NAME}.cc
DEPS paddle_fluid_origin
ARGS --dirname=${PYTHON_TESTS_DIR}/book/${TARGET_NAME}${arg}.train.model/)
set_tests_properties(test_train_${TARGET_NAME}${arg}
PROPERTIES DEPENDS test_${TARGET_NAME})
endforeach()
endfunction(train_test)
train_test(recognize_digits ARGS mlp conv)
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <time.h>
#include <fstream>
#include "gflags/gflags.h"
#include "gtest/gtest.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/inference/io.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/init.h"
#include "paddle/fluid/platform/place.h"
DEFINE_string(dirname, "", "Directory of the train model.");
namespace paddle {
void Train() {
CHECK(!FLAGS_dirname.empty());
framework::InitDevices(false);
const auto cpu_place = platform::CPUPlace();
framework::Executor executor(cpu_place);
framework::Scope scope;
auto train_program = inference::Load(
&executor, &scope, FLAGS_dirname + "__model_combined__.main_program",
FLAGS_dirname + "__params_combined__");
std::string loss_name = "";
for (auto op_desc : train_program->Block(0).AllOps()) {
if (op_desc->Type() == "mean") {
loss_name = op_desc->Output("Out")[0];
break;
}
}
PADDLE_ENFORCE_NE(loss_name, "", "loss not found");
// init all parameters
// prepare data
auto x_var = scope.Var("img");
auto x_tensor = x_var->GetMutable<framework::LoDTensor>();
x_tensor->Resize({64, 1, 28, 28});
auto x_data = x_tensor->mutable_data<float>(cpu_place);
for (int i = 0; i < 64 * 28 * 28; ++i) {
x_data[i] = 1.0;
}
auto y_var = scope.Var("label");
auto y_tensor = y_var->GetMutable<framework::LoDTensor>();
y_tensor->Resize({64, 1});
auto y_data = y_tensor->mutable_data<int64_t>(cpu_place);
for (int i = 0; i < 64 * 1; ++i) {
y_data[i] = static_cast<int64_t>(1);
}
auto loss_var = scope.Var(loss_name);
float first_loss = 0.0;
float last_loss = 0.0;
for (int i = 0; i < 100; ++i) {
executor.Run(*train_program.get(), &scope, 0, false, true);
if (i == 0) {
first_loss = loss_var->Get<framework::LoDTensor>().data<float>()[0];
} else if (i == 99) {
last_loss = loss_var->Get<framework::LoDTensor>().data<float>()[0];
}
}
EXPECT_LT(last_loss, first_loss);
}
TEST(train, recognize_digits) { Train(); }
} // namespace paddle
...@@ -600,7 +600,7 @@ def save_inference_model(dirname, ...@@ -600,7 +600,7 @@ def save_inference_model(dirname,
""" """
if isinstance(feeded_var_names, six.string_types): if isinstance(feeded_var_names, six.string_types):
feeded_var_names = [feeded_var_names] feeded_var_names = [feeded_var_names]
else: elif export_for_deployment:
if len(feeded_var_names) > 0: if len(feeded_var_names) > 0:
# TODO(paddle-dev): polish these code blocks # TODO(paddle-dev): polish these code blocks
if not (bool(feeded_var_names) and all( if not (bool(feeded_var_names) and all(
...@@ -610,61 +610,60 @@ def save_inference_model(dirname, ...@@ -610,61 +610,60 @@ def save_inference_model(dirname,
if isinstance(target_vars, Variable): if isinstance(target_vars, Variable):
target_vars = [target_vars] target_vars = [target_vars]
else: elif export_for_deployment:
if not (bool(target_vars) and all( if not (bool(target_vars) and all(
isinstance(var, Variable) for var in target_vars)): isinstance(var, Variable) for var in target_vars)):
raise ValueError("'target_vars' should be a list of Variable.") raise ValueError("'target_vars' should be a list of Variable.")
if main_program is None: if main_program is None:
main_program = default_main_program() main_program = default_main_program()
copy_program = main_program.clone()
if params_filename is not None:
params_filename = os.path.basename(params_filename)
save_persistables(executor, dirname, main_program, params_filename)
# if there is lookup table, the trainer 0 will notify all pserver to save.
if main_program._is_distributed and main_program._is_chief and main_program._distributed_lookup_table:
lookup_table_filename = os.path.join(dirname, "__lookup_table__")
_save_lookup_tables_by_notify(executor, lookup_table_filename,
main_program._distributed_lookup_table,
main_program._endpoints)
if not os.path.isdir(dirname): if not os.path.isdir(dirname):
os.makedirs(dirname) os.makedirs(dirname)
if model_filename is not None:
model_basename = os.path.basename(model_filename)
else:
model_basename = "__model__"
model_basename = os.path.join(dirname, model_basename)
# When export_for_deployment is true, we modify the program online so that # When export_for_deployment is true, we modify the program online so that
# it can only be loaded for inference directly. If it's false, the whole # it can only be loaded for inference directly. If it's false, the whole
# original program and related meta are saved so that future usage can be # original program and related meta are saved so that future usage can be
# more flexible. # more flexible.
if export_for_deployment: if export_for_deployment:
global_block = copy_program.global_block() main_program = main_program.clone()
global_block = main_program.global_block()
for i, op in enumerate(global_block.ops): for i, op in enumerate(global_block.ops):
op.desc.set_is_target(False) op.desc.set_is_target(False)
if op.type == "feed" or op.type == "fetch": if op.type == "feed" or op.type == "fetch":
global_block._remove_op(i) global_block._remove_op(i)
copy_program.desc.flush() main_program.desc.flush()
pruned_program = copy_program._prune(targets=target_vars) main_program = main_program._prune(targets=target_vars)
saved_program = pruned_program._inference_optimize(prune_read_op=True) main_program = main_program._inference_optimize(prune_read_op=True)
fetch_var_names = [v.name for v in target_vars] fetch_var_names = [v.name for v in target_vars]
prepend_feed_ops(saved_program, feeded_var_names) prepend_feed_ops(main_program, feeded_var_names)
append_fetch_ops(saved_program, fetch_var_names) append_fetch_ops(main_program, fetch_var_names)
with open(model_basename, "wb") as f:
f.write(main_program.desc.serialize_to_string())
else: else:
# TODO(panyx0718): Save more information so that it can also be used # TODO(panyx0718): Save more information so that it can also be used
# for training and more flexible post-processing. # for training and more flexible post-processing.
saved_program = copy_program with open(model_basename + ".main_program", "wb") as f:
f.write(main_program.desc.serialize_to_string())
if model_filename is not None:
model_filename = os.path.basename(model_filename)
else:
model_filename = "__model__"
model_filename = os.path.join(dirname, model_filename)
if params_filename is not None:
params_filename = os.path.basename(params_filename)
with open(model_filename, "wb") as f:
f.write(saved_program.desc.serialize_to_string())
save_persistables(executor, dirname, saved_program, params_filename)
# if there is lookup table, the trainer 0 will notify all pserver to save.
if main_program._is_distributed and main_program._is_chief and main_program._distributed_lookup_table:
lookup_table_filename = os.path.join(dirname, "__lookup_table__")
_save_lookup_tables_by_notify(executor, lookup_table_filename,
main_program._distributed_lookup_table,
main_program._endpoints)
def load_inference_model(dirname, def load_inference_model(dirname,
......
...@@ -67,6 +67,7 @@ def train(nn_type, ...@@ -67,6 +67,7 @@ def train(nn_type,
use_cuda, use_cuda,
parallel, parallel,
save_dirname=None, save_dirname=None,
save_full_dirname=None,
model_filename=None, model_filename=None,
params_filename=None, params_filename=None,
is_local=True): is_local=True):
...@@ -143,6 +144,15 @@ def train(nn_type, ...@@ -143,6 +144,15 @@ def train(nn_type,
exe, exe,
model_filename=model_filename, model_filename=model_filename,
params_filename=params_filename) params_filename=params_filename)
if save_full_dirname is not None:
fluid.io.save_inference_model(
save_full_dirname,
None,
None,
exe,
model_filename=model_filename,
params_filename=params_filename,
export_for_deployment=False)
return return
else: else:
print( print(
...@@ -214,10 +224,12 @@ def infer(use_cuda, ...@@ -214,10 +224,12 @@ def infer(use_cuda,
def main(use_cuda, parallel, nn_type, combine): def main(use_cuda, parallel, nn_type, combine):
save_dirname = None save_dirname = None
save_full_dirname = None
model_filename = None model_filename = None
params_filename = None params_filename = None
if not use_cuda and not parallel: if not use_cuda and not parallel:
save_dirname = "recognize_digits_" + nn_type + ".inference.model" save_dirname = "recognize_digits_" + nn_type + ".inference.model"
save_full_dirname = "recognize_digits_" + nn_type + ".train.model"
if combine == True: if combine == True:
model_filename = "__model_combined__" model_filename = "__model_combined__"
params_filename = "__params_combined__" params_filename = "__params_combined__"
...@@ -228,6 +240,7 @@ def main(use_cuda, parallel, nn_type, combine): ...@@ -228,6 +240,7 @@ def main(use_cuda, parallel, nn_type, combine):
use_cuda=use_cuda, use_cuda=use_cuda,
parallel=parallel, parallel=parallel,
save_dirname=save_dirname, save_dirname=save_dirname,
save_full_dirname=save_full_dirname,
model_filename=model_filename, model_filename=model_filename,
params_filename=params_filename) params_filename=params_filename)
infer( infer(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册