未验证 提交 d5b5004b 编写于 作者: T tianshuo78520a 提交者: GitHub

Delete legacy C++ training user-interface (#31949)

* delete include framework.pb.h

* fix error

* delete fluid_train
上级 b05f6142
......@@ -9,4 +9,3 @@ add_subdirectory(pybind)
# NOTE: please add subdirectory inference at last.
add_subdirectory(inference)
add_subdirectory(train)
function(train_test TARGET_NAME)
set(options "")
set(oneValueArgs "")
set(multiValueArgs ARGS)
cmake_parse_arguments(train_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
if (NOT APPLE AND NOT WIN32)
cc_test(test_train_${TARGET_NAME}
SRCS test_train_${TARGET_NAME}.cc
DEPS paddle_inference_shared
ARGS --dirname=${PYTHON_TESTS_DIR}/book/)
else()
cc_test(test_train_${TARGET_NAME}
SRCS test_train_${TARGET_NAME}.cc
DEPS paddle_inference_io
ARGS --dirname=${PYTHON_TESTS_DIR}/book/)
endif()
if(TEST test_train_${TARGET_NAME})
set_tests_properties(test_train_${TARGET_NAME}
PROPERTIES FIXTURES_REQUIRED test_${TARGET_NAME}_infer_model)
if(NOT WIN32 AND NOT APPLE)
set_tests_properties(test_train_${TARGET_NAME}
PROPERTIES TIMEOUT 150)
endif()
endif()
endfunction(train_test)
if(WITH_TESTING)
train_test(recognize_digits)
endif()
cmake_minimum_required(VERSION 3.0)
project(cpp_train_demo CXX C)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")
if(NOT DEFINED PADDLE_LIB)
message(FATAL_ERROR "please set PADDLE_LIB with -DPADDLE_LIB=/paddle/lib/dir")
endif()
option(WITH_MKLDNN "Compile PaddlePaddle with MKLDNN" OFF)
option(WITH_MKL "Compile PaddlePaddle with MKL support, default use openblas." OFF)
include_directories("${PADDLE_LIB}")
include_directories("${PADDLE_LIB}/third_party/install/protobuf/include")
include_directories("${PADDLE_LIB}/third_party/install/glog/include")
include_directories("${PADDLE_LIB}/third_party/install/gflags/include")
include_directories("${PADDLE_LIB}/third_party/install/xxhash/include")
include_directories("${PADDLE_LIB}/third_party/install/zlib/include")
include_directories("${PADDLE_LIB}/third_party/boost")
include_directories("${PADDLE_LIB}/third_party/eigen3")
include_directories("${PADDLE_LIB}/third_party/threadpool")
include_directories("${PADDLE_LIB}/third_party/dlpack")
link_directories("${PADDLE_LIB}/third_party/install/protobuf/lib")
link_directories("${PADDLE_LIB}/third_party/install/glog/lib")
link_directories("${PADDLE_LIB}/third_party/install/gflags/lib")
link_directories("${PADDLE_LIB}/third_party/install/xxhash/lib")
link_directories("${PADDLE_LIB}/third_party/install/zlib/lib")
add_executable(demo_trainer demo_trainer.cc)
if(WITH_MKLDNN)
add_definitions(-DPADDLE_WITH_MKLDNN)
include_directories("${PADDLE_LIB}/third_party/install/mkldnn/include")
if(WIN32)
set(MKLDNN_LIB ${PADDLE_LIB}/third_party/install/mkldnn/lib/mkldnn.lib)
else(WIN32)
set(MKLDNN_LIB ${PADDLE_LIB}/third_party/install/mkldnn/lib/libmkldnn.so.0)
endif(WIN32)
endif(WITH_MKLDNN)
if(WITH_MKL)
include_directories("${PADDLE_LIB}/third_party/install/mklml/include")
if(WIN32)
set(MATH_LIB ${PADDLE_LIB}/third_party/install/mklml/lib/mklml.lib)
else(WIN32)
set(MATH_LIB ${PADDLE_LIB}/third_party/install/mklml/lib/libmklml_intel.so)
endif(WIN32)
else()
if(APPLE)
set(MATH_LIB cblas)
elseif(WIN32)
set(MATH_LIB ${PADDLE_LIB}/third_party/install/openblas/lib/libopenblas.lib)
else()
set(MATH_LIB ${PADDLE_LIB}/third_party/install/openblas/lib/libopenblas.a)
endif(APPLE)
endif()
if(APPLE)
set(MACOS_LD_FLAGS "-undefined dynamic_lookup -Wl,-all_load -framework CoreFoundation -framework Security")
else(APPLE)
set(ARCHIVE_START "-Wl,--whole-archive")
set(ARCHIVE_END "-Wl,--no-whole-archive")
set(EXTERNAL_LIB "-lrt -ldl -lpthread")
endif(APPLE)
target_link_libraries(demo_trainer
${MACOS_LD_FLAGS}
${ARCHIVE_START}
${PADDLE_LIB}/paddle/fluid/inference/libpaddle_inference.so
${ARCHIVE_END}
${MATH_LIB}
${MKLDNN_LIB}
glog gflags protobuf z xxhash
${EXTERNAL_LIB})
### step 1. build paddle lib
```
# WITH_MKL=ON|OFF
# WITH_MKLDNN=ON|OFF
PADDLE_LIB=/paddle/lib/dir
cmake .. -DPADDLE_INSTALL_DIR=$PADDLE_LIB \
-DCMAKE_BUILD_TYPE=Release \
-DWITH_GPU=OFF \
-DWITH_STYLE_CHECK=OFF \
-DWITH_MKL=OFF \
-DWITH_MKLDNN=OFF
make -j8
make -j8 fluid_lib_dist
```
### step 2. generate program desc
```
# please install paddle before run this scripe
pip install --upgrade paddlepaddle-*.whl
python demo_network.py
```
This will generate two program desc files:
- startup_program: used to init all parameters
- main_program: main logic of the network
### step 3. build demo_trainer and run it.
```
# Make a build dir at the same dir of this README.md document.
# The demo dir can be put anywhere.
mkdir build
cd build
# WITH_MKL=ON|OFF
# WITH_MKLDNN=ON|OFF
PADDLE_LIB=/paddle/lib/dir
# PADDLE_LIB is the same with PADDLE_INSTALL_DIR when building the lib
cmake .. -DPADDLE_LIB=$PADDLE_LIB \
-DWITH_MKLDNN=OFF \
-DWITH_MKL=OFF
make
# copy startup_program and main_program to this dir
cp ../startup_program .
cp ../main_program .
# run demo cpp trainer
./demo_trainer
```
The output will be:
```
step: 0 loss: 1069.02
step: 1 loss: 1069.02
step: 2 loss: 1069.02
....
```
#!/bin/bash
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
set -x
cd "$(dirname "$0")"
rm -rf build/
set +x
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
import paddle.fluid.framework as framework
def train_network(with_optimize):
x = fluid.layers.data(name='x', shape=[13], dtype='float32')
y_predict = fluid.layers.fc(input=x, size=1, act=None)
y = fluid.layers.data(name='y', shape=[1], dtype='float32')
cost = fluid.layers.square_error_cost(input=y_predict, label=y)
avg_cost = fluid.layers.mean(cost)
if with_optimize:
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.00001)
sgd_optimizer.minimize(avg_cost)
else:
fluid.backward.append_backward(avg_cost)
def save_program_desc(network_func):
startup_program = framework.Program()
train_program = framework.Program()
with framework.program_guard(train_program, startup_program):
network_func(with_optimize=False)
with open("startup_program", "w") as f:
f.write(startup_program.desc.serialize_to_string())
with open("main_program", "w") as f:
f.write(train_program.desc.serialize_to_string())
save_program_desc(train_network)
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <time.h>
#include <fstream>
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/init.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
namespace train {
void ReadBinaryFile(const std::string& filename, std::string* contents) {
std::ifstream fin(filename, std::ios::in | std::ios::binary);
PADDLE_ENFORCE_EQ(
fin.is_open(), true,
platform::errors::Unavailable("Failed to open file %s.", filename));
fin.seekg(0, std::ios::end);
contents->clear();
contents->resize(fin.tellg());
fin.seekg(0, std::ios::beg);
fin.read(&(contents->at(0)), contents->size());
fin.close();
}
std::unique_ptr<paddle::framework::ProgramDesc> Load(
paddle::framework::Executor* executor, const std::string& model_filename) {
VLOG(3) << "loading model from " << model_filename;
std::string program_desc_str;
ReadBinaryFile(model_filename, &program_desc_str);
std::unique_ptr<paddle::framework::ProgramDesc> main_program(
new paddle::framework::ProgramDesc(program_desc_str));
return main_program;
}
} // namespace train
} // namespace paddle
int main() {
paddle::framework::InitDevices();
const auto cpu_place = paddle::platform::CPUPlace();
paddle::framework::Executor executor(cpu_place);
paddle::framework::Scope scope;
auto startup_program = paddle::train::Load(&executor, "startup_program");
auto train_program = paddle::train::Load(&executor, "main_program");
std::string loss_name = "";
for (auto op_desc : train_program->Block(0).AllOps()) {
if (op_desc->Type() == "mean") {
loss_name = op_desc->Output("Out")[0];
break;
}
}
PADDLE_ENFORCE_NE(loss_name, "",
platform::errors::NotFound("Loss name is not found."));
// init all parameters
executor.Run(*startup_program, &scope, 0);
// prepare data
auto x_var = scope.Var("x");
auto x_tensor = x_var->GetMutable<paddle::framework::LoDTensor>();
x_tensor->Resize({2, 13});
auto x_data = x_tensor->mutable_data<float>(cpu_place);
for (int i = 0; i < 2 * 13; ++i) {
x_data[i] = static_cast<float>(i);
}
auto y_var = scope.Var("y");
auto y_tensor = y_var->GetMutable<paddle::framework::LoDTensor>();
y_tensor->Resize({2, 1});
auto y_data = y_tensor->mutable_data<float>(cpu_place);
for (int i = 0; i < 2 * 1; ++i) {
y_data[i] = static_cast<float>(i);
}
auto loss_var = scope.Var(loss_name);
paddle::platform::ProfilerState pf_state;
pf_state = paddle::platform::ProfilerState::kCPU;
paddle::platform::EnableProfiler(pf_state);
clock_t t1 = clock();
for (int i = 0; i < 10; ++i) {
executor.Run(*train_program, &scope, 0, false, true);
std::cout << "step: " << i << " loss: "
<< loss_var->Get<paddle::framework::LoDTensor>().data<float>()[0]
<< std::endl;
}
clock_t t2 = clock();
paddle::platform::DisableProfiler(paddle::platform::EventSortingKey::kTotal,
"run_paddle_op_profiler");
std::cout << "run_time = " << t2 - t1 << std::endl;
return 0;
}
#!/bin/bash
set -x
PADDLE_ROOT=$1
TURN_ON_MKL=$2 # use MKL or Openblas
# download models
function download() {
wget -q http://paddle-tar.bj.bcebos.com/train_demo/LR-1-7/main_program
wget -q http://paddle-tar.bj.bcebos.com/train_demo/LR-1-7/startup_program
}
download
# build demo trainer
paddle_install_dir=${PADDLE_ROOT}/build/paddle_install_dir
mkdir -p build
cd build
rm -rf *
cmake .. -DPADDLE_LIB=$paddle_install_dir \
-DWITH_MKLDNN=$TURN_ON_MKL \
-DWITH_MKL=$TURN_ON_MKL
make
cd ..
# run demo trainer
build/demo_trainer
cmake_minimum_required(VERSION 3.0)
project(cpp_imdb_train_demo CXX C)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")
if(NOT DEFINED PADDLE_LIB)
message(FATAL_ERROR "please set PADDLE_LIB with -DPADDLE_LIB=/paddle/lib/dir")
endif()
option(WITH_MKLDNN "Compile PaddlePaddle with MKLDNN" OFF)
option(WITH_MKL "Compile PaddlePaddle with MKL support, default use openblas." OFF)
include_directories("${PADDLE_LIB}")
include_directories("${PADDLE_LIB}/third_party/install/protobuf/include")
include_directories("${PADDLE_LIB}/third_party/install/glog/include")
include_directories("${PADDLE_LIB}/third_party/install/gflags/include")
include_directories("${PADDLE_LIB}/third_party/install/xxhash/include")
include_directories("${PADDLE_LIB}/third_party/install/zlib/include")
include_directories("${PADDLE_LIB}/third_party/boost")
include_directories("${PADDLE_LIB}/third_party/eigen3")
include_directories("${PADDLE_LIB}/third_party/threadpool")
include_directories("${PADDLE_LIB}/third_party/dlpack")
link_directories("${PADDLE_LIB}/third_party/install/protobuf/lib")
link_directories("${PADDLE_LIB}/third_party/install/glog/lib")
link_directories("${PADDLE_LIB}/third_party/install/gflags/lib")
link_directories("${PADDLE_LIB}/third_party/install/xxhash/lib")
link_directories("${PADDLE_LIB}/third_party/install/zlib/lib")
add_executable(demo_trainer save_model.cc demo_trainer.cc)
if(WITH_MKLDNN)
include_directories("${PADDLE_LIB}/third_party/install/mkldnn/include")
if(WIN32)
set(MKLDNN_LIB ${PADDLE_LIB}/third_party/install/mkldnn/lib/mkldnn.lib)
else(WIN32)
set(MKLDNN_LIB ${PADDLE_LIB}/third_party/install/mkldnn/lib/libmkldnn.so.0)
endif(WIN32)
endif(WITH_MKLDNN)
if(WITH_MKL)
include_directories("${PADDLE_LIB}/third_party/install/mklml/include")
if(WIN32)
set(MATH_LIB ${PADDLE_LIB}/third_party/install/mklml/lib/mklml.lib)
else(WIN32)
set(MATH_LIB ${PADDLE_LIB}/third_party/install/mklml/lib/libmklml_intel.so)
endif(WIN32)
else()
if(APPLE)
set(MATH_LIB cblas)
elseif(WIN32)
set(MATH_LIB ${PADDLE_LIB}/third_party/install/openblas/lib/libopenblas.lib)
else()
set(MATH_LIB ${PADDLE_LIB}/third_party/install/openblas/lib/libopenblas.a)
endif(APPLE)
endif()
if(APPLE)
set(MACOS_LD_FLAGS "-undefined dynamic_lookup -Wl,-all_load -framework CoreFoundation -framework Security")
else(APPLE)
set(ARCHIVE_START "-Wl,--whole-archive")
set(ARCHIVE_END "-Wl,--no-whole-archive")
set(EXTERNAL_LIB "-lrt -ldl -lpthread")
endif(APPLE)
target_link_libraries(demo_trainer
${MACOS_LD_FLAGS}
${ARCHIVE_START}
${PADDLE_LIB}/paddle/fluid/inference/libpaddle_inference.so
${ARCHIVE_END}
${MATH_LIB}
${MKLDNN_LIB}
glog gflags protobuf z xxhash
${EXTERNAL_LIB})
# Train with C++ inference API
What is C++ inference API and how to install it:
see: [PaddlePaddle Fluid 提供了 C++ API 来支持模型的部署上线](https://paddlepaddle.org.cn/documentation/docs/zh/1.5/advanced_usage/deploy/inference/index_cn.html)
After downloading the source code of Paddle, you can build your own inference lib:
```shell
PADDLE_ROOT=./Paddle
cd Paddle
mkdir build
cd build
cmake -DPADDLE_INFERENCE_INSTALL_DIR=$PADDLE_ROOT \
-DCMAKE_BUILD_TYPE=Release \
-DWITH_PYTHON=OFF \
-DWITH_MKL=OFF \
-DWITH_GPU=OFF \
-DON_INFER=ON \
..
make
make inference_lib_dist
```
## IMDB task
see: [IMDB Dataset of 50K Movie Reviews | Kaggle](https://www.kaggle.com/lakshmi25npathi/imdb-dataset-of-50k-movie-reviews)
## Quick Start
### prepare data
```shell
wget https://fleet.bj.bcebos.com/text_classification_data.tar.gz
tar -zxvf text_classification_data.tar.gz
```
### build
```shell
mkdir build
cd build
rm -rf *
PADDLE_LIB=path/to/Paddle/build/paddle_install_dir
cmake .. -DPADDLE_LIB=$PADDLE_LIB -DWITH_MKLDNN=OFF -DWITH_MKL=OFF
make
```
### generate program description
```
python generate_program.py bow
```
### run
```shell
# After editing train.cfg
sh run.sh
```
## results
Below are training logs on BOW model, the losses go down as expected.
```
WARNING: Logging before InitGoogleLogging() is written to STDERR
I0731 22:39:06.974232 10965 demo_trainer.cc:130] Start training...
I0731 22:39:57.395229 10965 demo_trainer.cc:164] epoch: 0; average loss: 0.405706
I0731 22:40:50.262344 10965 demo_trainer.cc:164] epoch: 1; average loss: 0.110746
I0731 22:41:49.731079 10965 demo_trainer.cc:164] epoch: 2; average loss: 0.0475805
I0731 22:43:31.398355 10965 demo_trainer.cc:164] epoch: 3; average loss: 0.0233249
I0731 22:44:58.744391 10965 demo_trainer.cc:164] epoch: 4; average loss: 0.00701507
I0731 22:46:30.451735 10965 demo_trainer.cc:164] epoch: 5; average loss: 0.00258187
I0731 22:48:14.396687 10965 demo_trainer.cc:164] epoch: 6; average loss: 0.00113157
I0731 22:49:56.242744 10965 demo_trainer.cc:164] epoch: 7; average loss: 0.000698234
I0731 22:51:11.585919 10965 demo_trainer.cc:164] epoch: 8; average loss: 0.000510136
I0731 22:52:50.573947 10965 demo_trainer.cc:164] epoch: 9; average loss: 0.000400932
I0731 22:54:02.686152 10965 demo_trainer.cc:164] epoch: 10; average loss: 0.000329259
I0731 22:54:55.233342 10965 demo_trainer.cc:164] epoch: 11; average loss: 0.000278644
I0731 22:56:15.496256 10965 demo_trainer.cc:164] epoch: 12; average loss: 0.000241055
I0731 22:57:45.015926 10965 demo_trainer.cc:164] epoch: 13; average loss: 0.000212085
I0731 22:59:18.419997 10965 demo_trainer.cc:164] epoch: 14; average loss: 0.000189109
I0731 23:00:15.409077 10965 demo_trainer.cc:164] epoch: 15; average loss: 0.000170465
I0731 23:01:38.795770 10965 demo_trainer.cc:164] epoch: 16; average loss: 0.000155051
I0731 23:02:57.289487 10965 demo_trainer.cc:164] epoch: 17; average loss: 0.000142106
I0731 23:03:48.032507 10965 demo_trainer.cc:164] epoch: 18; average loss: 0.000131089
I0731 23:04:51.195230 10965 demo_trainer.cc:164] epoch: 19; average loss: 0.000121605
I0731 23:06:27.008040 10965 demo_trainer.cc:164] epoch: 20; average loss: 0.00011336
I0731 23:07:56.568284 10965 demo_trainer.cc:164] epoch: 21; average loss: 0.000106129
I0731 23:09:23.948290 10965 demo_trainer.cc:164] epoch: 22; average loss: 9.97393e-05
I0731 23:10:56.062590 10965 demo_trainer.cc:164] epoch: 23; average loss: 9.40532e-05
I0731 23:12:23.014047 10965 demo_trainer.cc:164] epoch: 24; average loss: 8.89622e-05
I0731 23:13:21.439818 10965 demo_trainer.cc:164] epoch: 25; average loss: 8.43784e-05
I0731 23:14:56.171597 10965 demo_trainer.cc:164] epoch: 26; average loss: 8.02322e-05
I0731 23:16:01.513542 10965 demo_trainer.cc:164] epoch: 27; average loss: 7.64629e-05
I0731 23:17:18.709139 10965 demo_trainer.cc:164] epoch: 28; average loss: 7.30239e-05
I0731 23:18:41.421555 10965 demo_trainer.cc:164] epoch: 29; average loss: 6.98716e-05
```
I trained a Bow model and a CNN model on IMDB dataset using the trainer. At the same time, I also trained the same models using traditional Python training methods.
Results show that the two methods achieve almost the same dev accuracy:
CNN:
<img src="https://user-images.githubusercontent.com/23031310/62356234-32217300-b543-11e9-89fd-a07614904a08.png" width="300">
BOW:
<img src="https://user-images.githubusercontent.com/23031310/62356253-39488100-b543-11e9-9fa2-a399fc1119d6.png" width="300">
I also recorded the training speed of the C++ Trainer and the python training methods, C++ trainer is quicker on CNN model:
<img src="https://user-images.githubusercontent.com/23031310/62356444-af4ce800-b543-11e9-88c8-f3bde1321ea1.png" width="300">
#TODO (mapingshuo): find the reason why C++ trainer is quicker on CNN model than python method.
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <time.h>
#include <fstream>
#include "include/save_model.h"
#include "paddle/fluid/framework/data_feed_factory.h"
#include "paddle/fluid/framework/dataset_factory.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/init.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/profiler.h"
#include "gflags/gflags.h"
DEFINE_string(filelist, "train_filelist.txt", "filelist for fluid dataset");
DEFINE_string(data_proto_desc, "data.proto", "data feed protobuf description");
DEFINE_string(startup_program_file, "startup_program",
"startup program description");
DEFINE_string(main_program_file, "", "main program description");
DEFINE_string(loss_name, "mean_0.tmp_0",
"loss tensor name in the main program");
DEFINE_string(save_dir, "cnn_model", "directory to save trained models");
DEFINE_int32(epoch_num, 30, "number of epochs to run when training");
namespace paddle {
namespace train {
void ReadBinaryFile(const std::string& filename, std::string* contents) {
std::ifstream fin(filename, std::ios::in | std::ios::binary);
PADDLE_ENFORCE_EQ(
fin.is_open(), true,
platform::errors::Unavailable("Failed to open file %s.", filename));
fin.seekg(0, std::ios::end);
contents->clear();
contents->resize(fin.tellg());
fin.seekg(0, std::ios::beg);
fin.read(&(contents->at(0)), contents->size());
fin.close();
}
std::unique_ptr<paddle::framework::ProgramDesc> LoadProgramDesc(
const std::string& model_filename) {
VLOG(3) << "loading model from " << model_filename;
std::string program_desc_str;
ReadBinaryFile(model_filename, &program_desc_str);
std::unique_ptr<paddle::framework::ProgramDesc> main_program(
new paddle::framework::ProgramDesc(program_desc_str));
return main_program;
}
bool IsPersistable(const paddle::framework::VarDesc* var) {
if (var->Persistable() &&
var->GetType() != paddle::framework::proto::VarType::FEED_MINIBATCH &&
var->GetType() != paddle::framework::proto::VarType::FETCH_LIST &&
var->GetType() != paddle::framework::proto::VarType::RAW) {
return true;
}
return false;
}
} // namespace train
} // namespace paddle
int main(int argc, char* argv[]) {
::GFLAGS_NAMESPACE::ParseCommandLineFlags(&argc, &argv, true);
std::cerr << "filelist: " << FLAGS_filelist << std::endl;
std::cerr << "data_proto_desc: " << FLAGS_data_proto_desc << std::endl;
std::cerr << "startup_program_file: " << FLAGS_startup_program_file
<< std::endl;
std::cerr << "main_program_file: " << FLAGS_main_program_file << std::endl;
std::cerr << "loss_name: " << FLAGS_loss_name << std::endl;
std::cerr << "save_dir: " << FLAGS_save_dir << std::endl;
std::cerr << "epoch_num: " << FLAGS_epoch_num << std::endl;
std::string filelist = std::string(FLAGS_filelist);
std::vector<std::string> file_vec;
std::ifstream fin(filelist);
if (fin) {
std::string filename;
while (fin >> filename) {
file_vec.push_back(filename);
}
}
PADDLE_ENFORCE_GE(
file_vec.size(), 1,
platform::errors::InvalidArgument(
"At least one file to train, but received number of file is %d.",
file_vec.size()));
paddle::framework::InitDevices();
const auto cpu_place = paddle::platform::CPUPlace();
paddle::framework::Executor executor(cpu_place);
paddle::framework::Scope scope;
auto startup_program =
paddle::train::LoadProgramDesc(std::string(FLAGS_startup_program_file));
auto main_program =
paddle::train::LoadProgramDesc(std::string(FLAGS_main_program_file));
executor.Run(*startup_program, &scope, 0);
std::string data_feed_desc_str;
paddle::train::ReadBinaryFile(std::string(FLAGS_data_proto_desc),
&data_feed_desc_str);
VLOG(3) << "load data feed desc done.";
std::unique_ptr<paddle::framework::Dataset> dataset_ptr;
dataset_ptr =
paddle::framework::DatasetFactory::CreateDataset("MultiSlotDataset");
VLOG(3) << "initialize dataset ptr done";
// find all params
std::vector<std::string> param_names;
const paddle::framework::BlockDesc& global_block = main_program->Block(0);
for (auto* var : global_block.AllVars()) {
if (paddle::train::IsPersistable(var)) {
VLOG(3) << "persistable variable's name: " << var->Name();
param_names.push_back(var->Name());
}
}
int epoch_num = FLAGS_epoch_num;
std::string loss_name = FLAGS_loss_name;
auto loss_var = scope.Var(loss_name);
LOG(INFO) << "Start training...";
for (int epoch = 0; epoch < epoch_num; ++epoch) {
VLOG(3) << "Epoch:" << epoch;
// get reader
dataset_ptr->SetFileList(file_vec);
VLOG(3) << "set file list done";
dataset_ptr->SetThreadNum(1);
VLOG(3) << "set thread num done";
dataset_ptr->SetDataFeedDesc(data_feed_desc_str);
VLOG(3) << "set data feed desc done";
dataset_ptr->CreateReaders();
const std::vector<paddle::framework::DataFeed*> readers =
dataset_ptr->GetReaders();
PADDLE_ENFORCE_EQ(readers.size(), 1,
platform::errors::InvalidArgument(
"Readers num(%d) should be equal to thread num(1).",
readers.size()));
readers[0]->SetPlace(paddle::platform::CPUPlace());
const std::vector<std::string>& input_feed_names =
readers[0]->GetUseSlotAlias();
for (auto name : input_feed_names) {
readers[0]->AddFeedVar(scope.Var(name), name);
}
VLOG(3) << "get reader done";
readers[0]->Start();
VLOG(3) << "start a reader";
VLOG(3) << "readers size: " << readers.size();
int step = 0;
std::vector<float> loss_vec;
while (readers[0]->Next() > 0) {
executor.Run(*main_program, &scope, 0, false, true);
loss_vec.push_back(
loss_var->Get<paddle::framework::LoDTensor>().data<float>()[0]);
}
float average_loss =
accumulate(loss_vec.begin(), loss_vec.end(), 0.0) / loss_vec.size();
LOG(INFO) << "epoch: " << epoch << "; average loss: " << average_loss;
dataset_ptr->DestroyReaders();
// save model
std::string save_dir_root = FLAGS_save_dir;
std::string save_dir =
save_dir_root + "/epoch" + std::to_string(epoch) + ".model";
paddle::framework::save_model(main_program, &scope, param_names, save_dir,
false);
}
}
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import paddle
import logging
import paddle.fluid as fluid
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger("fluid")
logger.setLevel(logging.INFO)
def load_vocab(filename):
vocab = {}
with open(filename) as f:
wid = 0
for line in f:
vocab[line.strip()] = wid
wid += 1
vocab["<unk>"] = len(vocab)
return vocab
if __name__ == "__main__":
vocab = load_vocab('imdb.vocab')
dict_dim = len(vocab)
model_name = sys.argv[1]
data = fluid.layers.data(
name="words", shape=[1], dtype="int64", lod_level=1)
label = fluid.layers.data(name="label", shape=[1], dtype="int64")
dataset = fluid.DatasetFactory().create_dataset()
dataset.set_batch_size(128)
dataset.set_pipe_command("python imdb_reader.py")
dataset.set_use_var([data, label])
desc = dataset.proto_desc
with open("data.proto", "w") as f:
f.write(dataset.desc())
from nets import *
if model_name == 'cnn':
logger.info("Generate program description of CNN net")
avg_cost, acc, prediction = cnn_net(data, label, dict_dim)
elif model_name == 'bow':
logger.info("Generate program description of BOW net")
avg_cost, acc, prediction = bow_net(data, label, dict_dim)
else:
logger.error("no such model: " + model_name)
exit(0)
# optimizer = fluid.optimizer.SGD(learning_rate=0.01)
optimizer = fluid.optimizer.Adagrad(learning_rate=0.01)
optimizer.minimize(avg_cost)
with open(model_name + "_main_program", "wb") as f:
f.write(fluid.default_main_program().desc.serialize_to_string())
with open(model_name + "_startup_program", "wb") as f:
f.write(fluid.default_startup_program().desc.serialize_to_string())
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import os
import paddle
import re
import paddle.fluid.incubate.data_generator as dg
class IMDBDataset(dg.MultiSlotDataGenerator):
def load_resource(self, dictfile):
self._vocab = {}
wid = 0
with open(dictfile) as f:
for line in f:
self._vocab[line.strip()] = wid
wid += 1
self._unk_id = len(self._vocab)
self._pattern = re.compile(r'(;|,|\.|\?|!|\s|\(|\))')
self.return_value = ("words", [1, 2, 3, 4, 5, 6]), ("label", [0])
def get_words_and_label(self, line):
send = '|'.join(line.split('|')[:-1]).lower().replace("<br />",
" ").strip()
label = [int(line.split('|')[-1])]
words = [x for x in self._pattern.split(send) if x and x != " "]
feas = [
self._vocab[x] if x in self._vocab else self._unk_id for x in words
]
return feas, label
def infer_reader(self, infer_filelist, batch, buf_size):
def local_iter():
for fname in infer_filelist:
with open(fname, "r") as fin:
for line in fin:
feas, label = self.get_words_and_label(line)
yield feas, label
import paddle
batch_iter = paddle.batch(
paddle.reader.shuffle(
local_iter, buf_size=buf_size),
batch_size=batch)
return batch_iter
def generate_sample(self, line):
def memory_iter():
for i in range(1000):
yield self.return_value
def data_iter():
feas, label = self.get_words_and_label(line)
yield ("words", feas), ("label", label)
return data_iter
if __name__ == "__main__":
imdb = IMDBDataset()
imdb.load_resource("imdb.vocab")
imdb.run_from_stdin()
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <fcntl.h>
#include <google/protobuf/io/zero_copy_stream_impl.h>
#include <google/protobuf/message.h>
#include <google/protobuf/text_format.h>
#include <stdio.h>
#include <string.h>
#include <unistd.h>
#include <fstream>
#include <iostream>
#include <memory>
#include <string>
#include <vector>
#include "gflags/gflags.h"
#include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/framework/lod_rank_table.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/prune.h"
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace framework {
void save_model(const std::unique_ptr<ProgramDesc>& main_program, Scope* scope,
const std::vector<std::string>& param_names,
const std::string& model_name, bool save_combine);
}
}
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import time
import numpy as np
import paddle
import paddle.fluid as fluid
def bow_net(data,
label,
dict_dim,
emb_dim=128,
hid_dim=128,
hid_dim2=96,
class_dim=2):
"""
bow net
"""
emb = fluid.layers.embedding(
input=data, size=[dict_dim, emb_dim], is_sparse=True)
bow = fluid.layers.sequence_pool(input=emb, pool_type='sum')
bow_tanh = fluid.layers.tanh(bow)
fc_1 = fluid.layers.fc(input=bow_tanh, size=hid_dim, act="tanh")
fc_2 = fluid.layers.fc(input=fc_1, size=hid_dim2, act="tanh")
prediction = fluid.layers.fc(input=[fc_2], size=class_dim, act="softmax")
cost = fluid.layers.cross_entropy(input=prediction, label=label)
avg_cost = fluid.layers.mean(x=cost)
acc = fluid.layers.accuracy(input=prediction, label=label)
return avg_cost, acc, prediction
def cnn_net(data,
label,
dict_dim,
emb_dim=128,
hid_dim=128,
hid_dim2=96,
class_dim=2,
win_size=3):
"""
conv net
"""
emb = fluid.layers.embedding(
input=data, size=[dict_dim, emb_dim], is_sparse=True)
conv_3 = fluid.nets.sequence_conv_pool(
input=emb,
num_filters=hid_dim,
filter_size=win_size,
act="tanh",
pool_type="max")
fc_1 = fluid.layers.fc(input=[conv_3], size=hid_dim2)
prediction = fluid.layers.fc(input=[fc_1], size=class_dim, act="softmax")
cost = fluid.layers.cross_entropy(input=prediction, label=label)
avg_cost = fluid.layers.mean(x=cost)
acc = fluid.layers.accuracy(input=prediction, label=label)
return avg_cost, acc, prediction
def lstm_net(data,
label,
dict_dim,
emb_dim=128,
hid_dim=128,
hid_dim2=96,
class_dim=2,
emb_lr=30.0):
"""
lstm net
"""
emb = fluid.layers.embedding(
input=data,
size=[dict_dim, emb_dim],
param_attr=fluid.ParamAttr(learning_rate=emb_lr),
is_sparse=True)
fc0 = fluid.layers.fc(input=emb, size=hid_dim * 4)
lstm_h, c = fluid.layers.dynamic_lstm(
input=fc0, size=hid_dim * 4, is_reverse=False)
lstm_max = fluid.layers.sequence_pool(input=lstm_h, pool_type='max')
lstm_max_tanh = fluid.layers.tanh(lstm_max)
fc1 = fluid.layers.fc(input=lstm_max_tanh, size=hid_dim2, act='tanh')
prediction = fluid.layers.fc(input=fc1, size=class_dim, act='softmax')
cost = fluid.layers.cross_entropy(input=prediction, label=label)
avg_cost = fluid.layers.mean(x=cost)
acc = fluid.layers.accuracy(input=prediction, label=label)
return avg_cost, acc, prediction
def gru_net(data,
label,
dict_dim,
emb_dim=128,
hid_dim=128,
hid_dim2=96,
class_dim=2,
emb_lr=400.0):
"""
gru net
"""
emb = fluid.layers.embedding(
input=data,
size=[dict_dim, emb_dim],
param_attr=fluid.ParamAttr(learning_rate=emb_lr))
fc0 = fluid.layers.fc(input=emb, size=hid_dim * 3)
gru_h = fluid.layers.dynamic_gru(input=fc0, size=hid_dim, is_reverse=False)
gru_max = fluid.layers.sequence_pool(input=gru_h, pool_type='max')
gru_max_tanh = fluid.layers.tanh(gru_max)
fc1 = fluid.layers.fc(input=gru_max_tanh, size=hid_dim2, act='tanh')
prediction = fluid.layers.fc(input=fc1, size=class_dim, act='softmax')
cost = fluid.layers.cross_entropy(input=prediction, label=label)
avg_cost = fluid.layers.mean(x=cost)
acc = fluid.layers.accuracy(input=prediction, label=label)
return avg_cost, acc, prediction
set -exu
build/demo_trainer --flagfile="train.cfg"
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "include/save_model.h"
#include <fcntl.h>
#include <google/protobuf/io/zero_copy_stream_impl.h>
#include <google/protobuf/message.h>
#include <google/protobuf/text_format.h>
#include <stdio.h>
#include <string.h>
#include <unistd.h>
#include <fstream>
#include <iostream>
#include "gflags/gflags.h"
#include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/framework/lod_rank_table.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/prune.h"
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/platform/place.h"
using std::unique_ptr;
namespace paddle {
namespace framework {
void save_model(const unique_ptr<ProgramDesc>& main_program, Scope* scope,
const std::vector<std::string>& param_names,
const std::string& model_name, bool save_combine) {
auto place = platform::CPUPlace();
const BlockDesc& global_block = main_program->Block(0);
std::vector<std::string> paralist;
for (auto* var : global_block.AllVars()) {
bool is_model_param = false;
for (auto param_name : param_names) {
if (var->Name() == param_name) {
is_model_param = true;
break;
}
}
if (!is_model_param) continue;
if (!save_combine) {
VLOG(3) << "model var name: %s" << var->Name().c_str();
paddle::framework::AttributeMap attrs;
attrs.insert({"file_path", model_name + "/" + var->Name()});
auto save_op = paddle::framework::OpRegistry::CreateOp(
"save", {{"X", {var->Name()}}}, {}, attrs);
save_op->Run(*scope, place);
} else {
paralist.push_back(var->Name());
}
}
if (save_combine) {
std::sort(paralist.begin(), paralist.end());
paddle::framework::AttributeMap attrs;
attrs.insert({"file_path", model_name});
auto save_op = paddle::framework::OpRegistry::CreateOp(
"save_combine", {{"X", paralist}}, {}, attrs);
save_op->Run(*scope, place);
}
}
} // namespace framework
} // namespace paddle
--filelist=train_filelist.txt
--data_proto_desc=data.proto
--loss_name=mean_0.tmp_0
--startup_program_file=bow_startup_program
--main_program_file=bow_main_program
--save_dir=bow_model
--epoch_num=30
train_data/part-0
train_data/part-1
train_data/part-10
train_data/part-11
train_data/part-2
train_data/part-3
train_data/part-4
train_data/part-5
train_data/part-6
train_data/part-7
train_data/part-8
train_data/part-9
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <time.h>
#include <fstream>
#include "gflags/gflags.h"
#include "gtest/gtest.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/io/fs.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/inference/io.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/init.h"
#include "paddle/fluid/platform/place.h"
DEFINE_string(dirname, "", "Directory of the train model.");
namespace paddle {
void Train(std::string model_dir) {
framework::InitDevices();
const auto cpu_place = platform::CPUPlace();
framework::Executor executor(cpu_place);
framework::Scope scope;
auto train_program = inference::Load(
&executor, &scope, model_dir + "__model_combined__.main_program",
model_dir + "__params_combined__");
std::string loss_name = "";
for (auto op_desc : train_program->Block(0).AllOps()) {
if (op_desc->Type() == "mean") {
loss_name = op_desc->Output("Out")[0];
break;
}
}
PADDLE_ENFORCE_NE(loss_name, "",
platform::errors::NotFound("Loss name is not found."));
// prepare data
auto x_var = scope.Var("img");
auto x_tensor = x_var->GetMutable<framework::LoDTensor>();
x_tensor->Resize({64, 1, 28, 28});
auto x_data = x_tensor->mutable_data<float>(cpu_place);
for (int i = 0; i < 64 * 28 * 28; ++i) {
x_data[i] = 1.0;
}
auto y_var = scope.Var("label");
auto y_tensor = y_var->GetMutable<framework::LoDTensor>();
y_tensor->Resize({64, 1});
auto y_data = y_tensor->mutable_data<int64_t>(cpu_place);
for (int i = 0; i < 64 * 1; ++i) {
y_data[i] = static_cast<int64_t>(1);
}
auto loss_var = scope.Var(loss_name);
float first_loss = 0.0;
float last_loss = 0.0;
for (int i = 0; i < 100; ++i) {
executor.Run(*train_program, &scope, 0, false, true,
{loss_name, "img", "label"});
if (i == 0) {
first_loss = loss_var->Get<framework::LoDTensor>().data<float>()[0];
} else if (i == 99) {
last_loss = loss_var->Get<framework::LoDTensor>().data<float>()[0];
}
}
EXPECT_LT(last_loss, first_loss);
}
TEST(train, recognize_digits) {
CHECK(!FLAGS_dirname.empty());
Train(FLAGS_dirname + "recognize_digits_mlp.train.model/");
Train(FLAGS_dirname + "recognize_digits_conv.train.model/");
}
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册