diff --git a/paddle/fluid/inference/CMakeLists.txt b/paddle/fluid/inference/CMakeLists.txt index 83d91afa2549a068a01b774606558c19c6503125..d1db924e6b2161d7797dad1c3425188469ad573f 100644 --- a/paddle/fluid/inference/CMakeLists.txt +++ b/paddle/fluid/inference/CMakeLists.txt @@ -47,7 +47,7 @@ if (ANAKIN_FOUND) set(ANAKIN_SHARED_INFERENCE_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/api/api_anakin_engine.cc) endif() set(SHARED_INFERENCE_SRCS - io.cc ${CMAKE_CURRENT_SOURCE_DIR}/api/api.cc ${CMAKE_CURRENT_SOURCE_DIR}/api/api_impl.cc + io.cc ${CMAKE_CURRENT_SOURCE_DIR}/../framework/data_feed.cc ${CMAKE_CURRENT_SOURCE_DIR}/../framework/data_set.cc ${CMAKE_CURRENT_SOURCE_DIR}/../framework/data_feed_factory.cc ${CMAKE_CURRENT_SOURCE_DIR}/../framework/dataset_factory.cc ${CMAKE_CURRENT_SOURCE_DIR}/api/api.cc ${CMAKE_CURRENT_SOURCE_DIR}/api/api_impl.cc ${CMAKE_CURRENT_SOURCE_DIR}/api/analysis_predictor.cc ${mkldnn_quantizer_src} ${CMAKE_CURRENT_SOURCE_DIR}/api/details/zero_copy_tensor.cc diff --git a/paddle/fluid/train/imdb_demo/CMakeLists.txt b/paddle/fluid/train/imdb_demo/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..c97343780281f4fa1a9d0a54149499f99dae2226 --- /dev/null +++ b/paddle/fluid/train/imdb_demo/CMakeLists.txt @@ -0,0 +1,78 @@ +cmake_minimum_required(VERSION 3.0) + +project(cpp_imdb_train_demo CXX C) + +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11") + +if(NOT DEFINED PADDLE_LIB) + message(FATAL_ERROR "please set PADDLE_LIB with -DPADDLE_LIB=/paddle/lib/dir") +endif() + +option(WITH_MKLDNN "Compile PaddlePaddle with MKLDNN" OFF) +option(WITH_MKL "Compile PaddlePaddle with MKL support, default use openblas." OFF) + +include_directories("${PADDLE_LIB}") +include_directories("${PADDLE_LIB}/third_party/install/protobuf/include") +include_directories("${PADDLE_LIB}/third_party/install/glog/include") +include_directories("${PADDLE_LIB}/third_party/install/gflags/include") +include_directories("${PADDLE_LIB}/third_party/install/xxhash/include") +include_directories("${PADDLE_LIB}/third_party/install/snappy/include") +include_directories("${PADDLE_LIB}/third_party/install/snappystream/include") +include_directories("${PADDLE_LIB}/third_party/install/zlib/include") + +include_directories("${PADDLE_LIB}/third_party/boost") +include_directories("${PADDLE_LIB}/third_party/eigen3") + +link_directories("${PADDLE_LIB}/third_party/install/snappy/lib") +link_directories("${PADDLE_LIB}/third_party/install/snappystream/lib") +link_directories("${PADDLE_LIB}/third_party/install/protobuf/lib") +link_directories("${PADDLE_LIB}/third_party/install/glog/lib") +link_directories("${PADDLE_LIB}/third_party/install/gflags/lib") +link_directories("${PADDLE_LIB}/third_party/install/xxhash/lib") +link_directories("${PADDLE_LIB}/third_party/install/zlib/lib") + +add_executable(demo_trainer save_model.cc demo_trainer.cc) + +if(WITH_MKLDNN) + include_directories("${PADDLE_LIB}/third_party/install/mkldnn/include") + if(WIN32) + set(MKLDNN_LIB ${PADDLE_LIB}/third_party/install/mkldnn/lib/mkldnn.lib) + else(WIN32) + set(MKLDNN_LIB ${PADDLE_LIB}/third_party/install/mkldnn/lib/libmkldnn.so.0) + endif(WIN32) +endif(WITH_MKLDNN) + +if(WITH_MKL) + include_directories("${PADDLE_LIB}/third_party/install/mklml/include") + if(WIN32) + set(MATH_LIB ${PADDLE_LIB}/third_party/install/mklml/lib/mklml.lib) + else(WIN32) + set(MATH_LIB ${PADDLE_LIB}/third_party/install/mklml/lib/libmklml_intel.so) + endif(WIN32) +else() + if(APPLE) + set(MATH_LIB cblas) + elseif(WIN32) + set(MATH_LIB ${PADDLE_LIB}/third_party/install/openblas/lib/libopenblas.lib) + else() + set(MATH_LIB ${PADDLE_LIB}/third_party/install/openblas/lib/libopenblas.a) + endif(APPLE) +endif() + +if(APPLE) + set(MACOS_LD_FLAGS "-undefined dynamic_lookup -Wl,-all_load -framework CoreFoundation -framework Security") +else(APPLE) + set(ARCHIVE_START "-Wl,--whole-archive") + set(ARCHIVE_END "-Wl,--no-whole-archive") + set(EXTERNAL_LIB "-lrt -ldl -lpthread") +endif(APPLE) + +target_link_libraries(demo_trainer + ${MACOS_LD_FLAGS} + ${ARCHIVE_START} + ${PADDLE_LIB}/paddle/fluid/inference/libpaddle_fluid.so + ${ARCHIVE_END} + ${MATH_LIB} + ${MKLDNN_LIB} + glog gflags protobuf snappystream snappy z xxhash + ${EXTERNAL_LIB}) diff --git a/paddle/fluid/train/imdb_demo/README.md b/paddle/fluid/train/imdb_demo/README.md new file mode 100644 index 0000000000000000000000000000000000000000..3c75a4744aba54e3dd56e13b5b4a2fd6646ac45c --- /dev/null +++ b/paddle/fluid/train/imdb_demo/README.md @@ -0,0 +1,97 @@ +# Train with C++ inference API + +What is C++ inference API and how to install it: + +see: [PaddlePaddle Fluid 提供了 C++ API 来支持模型的部署上线](https://paddlepaddle.org.cn/documentation/docs/zh/1.5/advanced_usage/deploy/inference/index_cn.html) + +## IMDB task + +see: [IMDB Dataset of 50K Movie Reviews | Kaggle](https://www.kaggle.com/lakshmi25npathi/imdb-dataset-of-50k-movie-reviews) + +## Quick Start + +### prepare data + +```shell + wget https://fleet.bj.bcebos.com/text_classification_data.tar.gz + tar -zxvf text_classification_data.tar.gz +``` +### build + +```shell + mkdir build + cd build + rm -rf * + PADDLE_LIB=path/to/your/fluid_inference_install_dir/ + cmake .. -DPADDLE_LIB=$PADDLE_LIB -DWITH_MKLDNN=OFF -DWITH_MKL=OFF + make +``` + +### generate program description + +``` + python generate_program.py bow +``` + +### run + +```shell + # After editing train.cfg + sh run.sh +``` + +## results + +Below are training logs on BOW model, the losses go down as expected. + +``` +WARNING: Logging before InitGoogleLogging() is written to STDERR +I0731 22:39:06.974232 10965 demo_trainer.cc:130] Start training... +I0731 22:39:57.395229 10965 demo_trainer.cc:164] epoch: 0; average loss: 0.405706 +I0731 22:40:50.262344 10965 demo_trainer.cc:164] epoch: 1; average loss: 0.110746 +I0731 22:41:49.731079 10965 demo_trainer.cc:164] epoch: 2; average loss: 0.0475805 +I0731 22:43:31.398355 10965 demo_trainer.cc:164] epoch: 3; average loss: 0.0233249 +I0731 22:44:58.744391 10965 demo_trainer.cc:164] epoch: 4; average loss: 0.00701507 +I0731 22:46:30.451735 10965 demo_trainer.cc:164] epoch: 5; average loss: 0.00258187 +I0731 22:48:14.396687 10965 demo_trainer.cc:164] epoch: 6; average loss: 0.00113157 +I0731 22:49:56.242744 10965 demo_trainer.cc:164] epoch: 7; average loss: 0.000698234 +I0731 22:51:11.585919 10965 demo_trainer.cc:164] epoch: 8; average loss: 0.000510136 +I0731 22:52:50.573947 10965 demo_trainer.cc:164] epoch: 9; average loss: 0.000400932 +I0731 22:54:02.686152 10965 demo_trainer.cc:164] epoch: 10; average loss: 0.000329259 +I0731 22:54:55.233342 10965 demo_trainer.cc:164] epoch: 11; average loss: 0.000278644 +I0731 22:56:15.496256 10965 demo_trainer.cc:164] epoch: 12; average loss: 0.000241055 +I0731 22:57:45.015926 10965 demo_trainer.cc:164] epoch: 13; average loss: 0.000212085 +I0731 22:59:18.419997 10965 demo_trainer.cc:164] epoch: 14; average loss: 0.000189109 +I0731 23:00:15.409077 10965 demo_trainer.cc:164] epoch: 15; average loss: 0.000170465 +I0731 23:01:38.795770 10965 demo_trainer.cc:164] epoch: 16; average loss: 0.000155051 +I0731 23:02:57.289487 10965 demo_trainer.cc:164] epoch: 17; average loss: 0.000142106 +I0731 23:03:48.032507 10965 demo_trainer.cc:164] epoch: 18; average loss: 0.000131089 +I0731 23:04:51.195230 10965 demo_trainer.cc:164] epoch: 19; average loss: 0.000121605 +I0731 23:06:27.008040 10965 demo_trainer.cc:164] epoch: 20; average loss: 0.00011336 +I0731 23:07:56.568284 10965 demo_trainer.cc:164] epoch: 21; average loss: 0.000106129 +I0731 23:09:23.948290 10965 demo_trainer.cc:164] epoch: 22; average loss: 9.97393e-05 +I0731 23:10:56.062590 10965 demo_trainer.cc:164] epoch: 23; average loss: 9.40532e-05 +I0731 23:12:23.014047 10965 demo_trainer.cc:164] epoch: 24; average loss: 8.89622e-05 +I0731 23:13:21.439818 10965 demo_trainer.cc:164] epoch: 25; average loss: 8.43784e-05 +I0731 23:14:56.171597 10965 demo_trainer.cc:164] epoch: 26; average loss: 8.02322e-05 +I0731 23:16:01.513542 10965 demo_trainer.cc:164] epoch: 27; average loss: 7.64629e-05 +I0731 23:17:18.709139 10965 demo_trainer.cc:164] epoch: 28; average loss: 7.30239e-05 +I0731 23:18:41.421555 10965 demo_trainer.cc:164] epoch: 29; average loss: 6.98716e-05 +``` + +I trained a Bow model and a CNN model on IMDB dataset using the trainer. At the same time, I also trained the same models using traditional Python training methods. +Results show that the two methods achieve almost the same dev accuracy: + +CNN: + + + +BOW: + + + +I also recorded the training speed of the C++ Trainer and the python training methods, C++ trainer is quicker on CNN model: + + + +#TODO (mapingshuo): find the reason why C++ trainer is quicker on CNN model than python method. diff --git a/paddle/fluid/train/imdb_demo/demo_trainer.cc b/paddle/fluid/train/imdb_demo/demo_trainer.cc new file mode 100644 index 0000000000000000000000000000000000000000..e502635b00759238502acb45d288386c8abb0e84 --- /dev/null +++ b/paddle/fluid/train/imdb_demo/demo_trainer.cc @@ -0,0 +1,183 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "include/save_model.h" +#include "paddle/fluid/framework/data_feed_factory.h" +#include "paddle/fluid/framework/dataset_factory.h" +#include "paddle/fluid/framework/executor.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/framework/variable_helper.h" +#include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/init.h" +#include "paddle/fluid/platform/place.h" +#include "paddle/fluid/platform/profiler.h" + +#include "gflags/gflags.h" + +DEFINE_string(filelist, "train_filelist.txt", "filelist for fluid dataset"); +DEFINE_string(data_proto_desc, "data.proto", "data feed protobuf description"); +DEFINE_string(startup_program_file, "startup_program", + "startup program description"); +DEFINE_string(main_program_file, "", "main program description"); +DEFINE_string(loss_name, "mean_0.tmp_0", + "loss tensor name in the main program"); +DEFINE_string(save_dir, "cnn_model", "directory to save trained models"); +DEFINE_int32(epoch_num, 30, "number of epochs to run when training"); + +namespace paddle { +namespace train { + +void ReadBinaryFile(const std::string& filename, std::string* contents) { + std::ifstream fin(filename, std::ios::in | std::ios::binary); + PADDLE_ENFORCE(static_cast(fin), "Cannot open file %s", filename); + fin.seekg(0, std::ios::end); + contents->clear(); + contents->resize(fin.tellg()); + fin.seekg(0, std::ios::beg); + fin.read(&(contents->at(0)), contents->size()); + fin.close(); +} + +std::unique_ptr LoadProgramDesc( + const std::string& model_filename) { + VLOG(3) << "loading model from " << model_filename; + std::string program_desc_str; + ReadBinaryFile(model_filename, &program_desc_str); + std::unique_ptr main_program( + new paddle::framework::ProgramDesc(program_desc_str)); + return main_program; +} + +bool IsPersistable(const paddle::framework::VarDesc* var) { + if (var->Persistable() && + var->GetType() != paddle::framework::proto::VarType::FEED_MINIBATCH && + var->GetType() != paddle::framework::proto::VarType::FETCH_LIST && + var->GetType() != paddle::framework::proto::VarType::RAW) { + return true; + } + return false; +} + +} // namespace train +} // namespace paddle + +int main(int argc, char* argv[]) { + gflags::ParseCommandLineFlags(&argc, &argv, true); + + std::cerr << "filelist: " << FLAGS_filelist << std::endl; + std::cerr << "data_proto_desc: " << FLAGS_data_proto_desc << std::endl; + std::cerr << "startup_program_file: " << FLAGS_startup_program_file + << std::endl; + std::cerr << "main_program_file: " << FLAGS_main_program_file << std::endl; + std::cerr << "loss_name: " << FLAGS_loss_name << std::endl; + std::cerr << "save_dir: " << FLAGS_save_dir << std::endl; + std::cerr << "epoch_num: " << FLAGS_epoch_num << std::endl; + + std::string filelist = std::string(FLAGS_filelist); + std::vector file_vec; + std::ifstream fin(filelist); + if (fin) { + std::string filename; + while (fin >> filename) { + file_vec.push_back(filename); + } + } + PADDLE_ENFORCE_GE(file_vec.size(), 1, "At least one file to train"); + paddle::framework::InitDevices(false); + const auto cpu_place = paddle::platform::CPUPlace(); + paddle::framework::Executor executor(cpu_place); + paddle::framework::Scope scope; + auto startup_program = + paddle::train::LoadProgramDesc(std::string(FLAGS_startup_program_file)); + auto main_program = + paddle::train::LoadProgramDesc(std::string(FLAGS_main_program_file)); + + executor.Run(*startup_program, &scope, 0); + + std::string data_feed_desc_str; + paddle::train::ReadBinaryFile(std::string(FLAGS_data_proto_desc), + &data_feed_desc_str); + VLOG(3) << "load data feed desc done."; + std::unique_ptr dataset_ptr; + dataset_ptr = + paddle::framework::DatasetFactory::CreateDataset("MultiSlotDataset"); + VLOG(3) << "initialize dataset ptr done"; + + // find all params + std::vector param_names; + const paddle::framework::BlockDesc& global_block = main_program->Block(0); + for (auto* var : global_block.AllVars()) { + if (paddle::train::IsPersistable(var)) { + VLOG(3) << "persistable variable's name: " << var->Name(); + param_names.push_back(var->Name()); + } + } + + int epoch_num = FLAGS_epoch_num; + std::string loss_name = FLAGS_loss_name; + auto loss_var = scope.Var(loss_name); + + LOG(INFO) << "Start training..."; + + for (int epoch = 0; epoch < epoch_num; ++epoch) { + VLOG(3) << "Epoch:" << epoch; + // get reader + dataset_ptr->SetFileList(file_vec); + VLOG(3) << "set file list done"; + dataset_ptr->SetThreadNum(1); + VLOG(3) << "set thread num done"; + dataset_ptr->SetDataFeedDesc(data_feed_desc_str); + VLOG(3) << "set data feed desc done"; + dataset_ptr->CreateReaders(); + const std::vector readers = + dataset_ptr->GetReaders(); + PADDLE_ENFORCE_EQ(readers.size(), 1, + "readers num should be equal to thread num"); + const std::vector& input_feed_names = + readers[0]->GetUseSlotAlias(); + for (auto name : input_feed_names) { + readers[0]->AddFeedVar(scope.Var(name), name); + } + VLOG(3) << "get reader done"; + readers[0]->Start(); + VLOG(3) << "start a reader"; + VLOG(3) << "readers size: " << readers.size(); + + int step = 0; + std::vector loss_vec; + + while (readers[0]->Next() > 0) { + executor.Run(*main_program, &scope, 0, false, true); + loss_vec.push_back( + loss_var->Get().data()[0]); + } + float average_loss = + accumulate(loss_vec.begin(), loss_vec.end(), 0.0) / loss_vec.size(); + + LOG(INFO) << "epoch: " << epoch << "; average loss: " << average_loss; + dataset_ptr->DestroyReaders(); + + // save model + std::string save_dir_root = FLAGS_save_dir; + std::string save_dir = + save_dir_root + "/epoch" + std::to_string(epoch) + ".model"; + paddle::framework::save_model(main_program, &scope, param_names, save_dir, + false); + } +} diff --git a/paddle/fluid/train/imdb_demo/generate_program.py b/paddle/fluid/train/imdb_demo/generate_program.py new file mode 100644 index 0000000000000000000000000000000000000000..a12282d94ddf9ed3e0824c9af709bd1f5b82556f --- /dev/null +++ b/paddle/fluid/train/imdb_demo/generate_program.py @@ -0,0 +1,72 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import sys +import paddle +import logging +import paddle.fluid as fluid + +logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger("fluid") +logger.setLevel(logging.INFO) + + +def load_vocab(filename): + vocab = {} + with open(filename) as f: + wid = 0 + for line in f: + vocab[line.strip()] = wid + wid += 1 + vocab[""] = len(vocab) + return vocab + + +if __name__ == "__main__": + vocab = load_vocab('imdb.vocab') + dict_dim = len(vocab) + model_name = sys.argv[1] + data = fluid.layers.data( + name="words", shape=[1], dtype="int64", lod_level=1) + label = fluid.layers.data(name="label", shape=[1], dtype="int64") + + dataset = fluid.DatasetFactory().create_dataset() + dataset.set_batch_size(128) + dataset.set_pipe_command("python imdb_reader.py") + + dataset.set_use_var([data, label]) + desc = dataset.proto_desc + + with open("data.proto", "w") as f: + f.write(dataset.desc()) + + from nets import * + if model_name == 'cnn': + logger.info("Generate program description of CNN net") + avg_cost, acc, prediction = cnn_net(data, label, dict_dim) + elif model_name == 'bow': + logger.info("Generate program description of BOW net") + avg_cost, acc, prediction = bow_net(data, label, dict_dim) + else: + logger.error("no such model: " + model_name) + exit(0) + # optimizer = fluid.optimizer.SGD(learning_rate=0.01) + optimizer = fluid.optimizer.Adagrad(learning_rate=0.01) + optimizer.minimize(avg_cost) + + with open(model_name + "_main_program", "wb") as f: + f.write(fluid.default_main_program().desc.serialize_to_string()) + + with open(model_name + "_startup_program", "wb") as f: + f.write(fluid.default_startup_program().desc.serialize_to_string()) diff --git a/paddle/fluid/train/imdb_demo/imdb_reader.py b/paddle/fluid/train/imdb_demo/imdb_reader.py new file mode 100644 index 0000000000000000000000000000000000000000..f197c95ec32171fb075bb9deeacd6fc6ae3b16e8 --- /dev/null +++ b/paddle/fluid/train/imdb_demo/imdb_reader.py @@ -0,0 +1,75 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import os +import paddle +import re +import paddle.fluid.incubate.data_generator as dg + + +class IMDBDataset(dg.MultiSlotDataGenerator): + def load_resource(self, dictfile): + self._vocab = {} + wid = 0 + with open(dictfile) as f: + for line in f: + self._vocab[line.strip()] = wid + wid += 1 + self._unk_id = len(self._vocab) + self._pattern = re.compile(r'(;|,|\.|\?|!|\s|\(|\))') + self.return_value = ("words", [1, 2, 3, 4, 5, 6]), ("label", [0]) + + def get_words_and_label(self, line): + send = '|'.join(line.split('|')[:-1]).lower().replace("
", + " ").strip() + label = [int(line.split('|')[-1])] + + words = [x for x in self._pattern.split(send) if x and x != " "] + feas = [ + self._vocab[x] if x in self._vocab else self._unk_id for x in words + ] + return feas, label + + def infer_reader(self, infer_filelist, batch, buf_size): + def local_iter(): + for fname in infer_filelist: + with open(fname, "r") as fin: + for line in fin: + feas, label = self.get_words_and_label(line) + yield feas, label + + import paddle + batch_iter = paddle.batch( + paddle.reader.shuffle( + local_iter, buf_size=buf_size), + batch_size=batch) + return batch_iter + + def generate_sample(self, line): + def memory_iter(): + for i in range(1000): + yield self.return_value + + def data_iter(): + feas, label = self.get_words_and_label(line) + yield ("words", feas), ("label", label) + + return data_iter + + +if __name__ == "__main__": + imdb = IMDBDataset() + imdb.load_resource("imdb.vocab") + imdb.run_from_stdin() diff --git a/paddle/fluid/train/imdb_demo/include/save_model.h b/paddle/fluid/train/imdb_demo/include/save_model.h new file mode 100644 index 0000000000000000000000000000000000000000..452052866855d294676a0792e06df7a4b6ecd76f --- /dev/null +++ b/paddle/fluid/train/imdb_demo/include/save_model.h @@ -0,0 +1,41 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "gflags/gflags.h" +#include "paddle/fluid/framework/feed_fetch_method.h" +#include "paddle/fluid/framework/feed_fetch_type.h" +#include "paddle/fluid/framework/lod_rank_table.h" +#include "paddle/fluid/framework/lod_tensor_array.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/prune.h" +#include "paddle/fluid/framework/reader.h" +#include "paddle/fluid/platform/place.h" + +namespace paddle { +namespace framework { +void save_model(const std::unique_ptr& main_program, Scope* scope, + const std::vector& param_names, + const std::string& model_name, bool save_combine); +} +} diff --git a/paddle/fluid/train/imdb_demo/nets.py b/paddle/fluid/train/imdb_demo/nets.py new file mode 100644 index 0000000000000000000000000000000000000000..a25e67e3b5d56d1e672915cfade1a24ff6546eeb --- /dev/null +++ b/paddle/fluid/train/imdb_demo/nets.py @@ -0,0 +1,140 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import time +import numpy as np + +import paddle +import paddle.fluid as fluid + + +def bow_net(data, + label, + dict_dim, + emb_dim=128, + hid_dim=128, + hid_dim2=96, + class_dim=2): + """ + bow net + """ + emb = fluid.layers.embedding( + input=data, size=[dict_dim, emb_dim], is_sparse=True) + bow = fluid.layers.sequence_pool(input=emb, pool_type='sum') + bow_tanh = fluid.layers.tanh(bow) + fc_1 = fluid.layers.fc(input=bow_tanh, size=hid_dim, act="tanh") + fc_2 = fluid.layers.fc(input=fc_1, size=hid_dim2, act="tanh") + prediction = fluid.layers.fc(input=[fc_2], size=class_dim, act="softmax") + cost = fluid.layers.cross_entropy(input=prediction, label=label) + avg_cost = fluid.layers.mean(x=cost) + acc = fluid.layers.accuracy(input=prediction, label=label) + + return avg_cost, acc, prediction + + +def cnn_net(data, + label, + dict_dim, + emb_dim=128, + hid_dim=128, + hid_dim2=96, + class_dim=2, + win_size=3): + """ + conv net + """ + emb = fluid.layers.embedding( + input=data, size=[dict_dim, emb_dim], is_sparse=True) + conv_3 = fluid.nets.sequence_conv_pool( + input=emb, + num_filters=hid_dim, + filter_size=win_size, + act="tanh", + pool_type="max") + + fc_1 = fluid.layers.fc(input=[conv_3], size=hid_dim2) + + prediction = fluid.layers.fc(input=[fc_1], size=class_dim, act="softmax") + cost = fluid.layers.cross_entropy(input=prediction, label=label) + avg_cost = fluid.layers.mean(x=cost) + acc = fluid.layers.accuracy(input=prediction, label=label) + + return avg_cost, acc, prediction + + +def lstm_net(data, + label, + dict_dim, + emb_dim=128, + hid_dim=128, + hid_dim2=96, + class_dim=2, + emb_lr=30.0): + """ + lstm net + """ + emb = fluid.layers.embedding( + input=data, + size=[dict_dim, emb_dim], + param_attr=fluid.ParamAttr(learning_rate=emb_lr), + is_sparse=True) + + fc0 = fluid.layers.fc(input=emb, size=hid_dim * 4) + + lstm_h, c = fluid.layers.dynamic_lstm( + input=fc0, size=hid_dim * 4, is_reverse=False) + + lstm_max = fluid.layers.sequence_pool(input=lstm_h, pool_type='max') + lstm_max_tanh = fluid.layers.tanh(lstm_max) + + fc1 = fluid.layers.fc(input=lstm_max_tanh, size=hid_dim2, act='tanh') + + prediction = fluid.layers.fc(input=fc1, size=class_dim, act='softmax') + + cost = fluid.layers.cross_entropy(input=prediction, label=label) + avg_cost = fluid.layers.mean(x=cost) + acc = fluid.layers.accuracy(input=prediction, label=label) + + return avg_cost, acc, prediction + + +def gru_net(data, + label, + dict_dim, + emb_dim=128, + hid_dim=128, + hid_dim2=96, + class_dim=2, + emb_lr=400.0): + """ + gru net + """ + emb = fluid.layers.embedding( + input=data, + size=[dict_dim, emb_dim], + param_attr=fluid.ParamAttr(learning_rate=emb_lr)) + + fc0 = fluid.layers.fc(input=emb, size=hid_dim * 3) + gru_h = fluid.layers.dynamic_gru(input=fc0, size=hid_dim, is_reverse=False) + gru_max = fluid.layers.sequence_pool(input=gru_h, pool_type='max') + gru_max_tanh = fluid.layers.tanh(gru_max) + fc1 = fluid.layers.fc(input=gru_max_tanh, size=hid_dim2, act='tanh') + prediction = fluid.layers.fc(input=fc1, size=class_dim, act='softmax') + + cost = fluid.layers.cross_entropy(input=prediction, label=label) + avg_cost = fluid.layers.mean(x=cost) + acc = fluid.layers.accuracy(input=prediction, label=label) + + return avg_cost, acc, prediction diff --git a/paddle/fluid/train/imdb_demo/run.sh b/paddle/fluid/train/imdb_demo/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..f71b4bac602a9e6d5c7bea03f3c56043b13547d3 --- /dev/null +++ b/paddle/fluid/train/imdb_demo/run.sh @@ -0,0 +1,3 @@ + +set -exu +build/demo_trainer --flagfile="train.cfg" diff --git a/paddle/fluid/train/imdb_demo/save_model.cc b/paddle/fluid/train/imdb_demo/save_model.cc new file mode 100644 index 0000000000000000000000000000000000000000..49da550dbb7f52912406663df6cf11e21e193bd9 --- /dev/null +++ b/paddle/fluid/train/imdb_demo/save_model.cc @@ -0,0 +1,77 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "include/save_model.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "gflags/gflags.h" +#include "paddle/fluid/framework/feed_fetch_method.h" +#include "paddle/fluid/framework/feed_fetch_type.h" +#include "paddle/fluid/framework/lod_rank_table.h" +#include "paddle/fluid/framework/lod_tensor_array.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/framework/prune.h" +#include "paddle/fluid/framework/reader.h" +#include "paddle/fluid/platform/place.h" + +using std::unique_ptr; + +namespace paddle { +namespace framework { +void save_model(const unique_ptr& main_program, Scope* scope, + const std::vector& param_names, + const std::string& model_name, bool save_combine) { + auto place = platform::CPUPlace(); + const BlockDesc& global_block = main_program->Block(0); + std::vector paralist; + for (auto* var : global_block.AllVars()) { + bool is_model_param = false; + for (auto param_name : param_names) { + if (var->Name() == param_name) { + is_model_param = true; + break; + } + } + + if (!is_model_param) continue; + + if (!save_combine) { + VLOG(3) << "model var name: %s" << var->Name().c_str(); + + paddle::framework::AttributeMap attrs; + attrs.insert({"file_path", model_name + "/" + var->Name()}); + auto save_op = paddle::framework::OpRegistry::CreateOp( + "save", {{"X", {var->Name()}}}, {}, attrs); + + save_op->Run(*scope, place); + } else { + paralist.push_back(var->Name()); + } + } + if (save_combine) { + std::sort(paralist.begin(), paralist.end()); + paddle::framework::AttributeMap attrs; + attrs.insert({"file_path", model_name}); + auto save_op = paddle::framework::OpRegistry::CreateOp( + "save_combine", {{"X", paralist}}, {}, attrs); + save_op->Run(*scope, place); + } +} +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/train/imdb_demo/train.cfg b/paddle/fluid/train/imdb_demo/train.cfg new file mode 100644 index 0000000000000000000000000000000000000000..1821498890be8c17ff749bee5a9a0be3f2138810 --- /dev/null +++ b/paddle/fluid/train/imdb_demo/train.cfg @@ -0,0 +1,7 @@ +--filelist=train_filelist.txt +--data_proto_desc=data.proto +--loss_name=mean_0.tmp_0 +--startup_program_file=bow_startup_program +--main_program_file=bow_main_program +--save_dir=bow_model +--epoch_num=30 diff --git a/paddle/fluid/train/imdb_demo/train_filelist.txt b/paddle/fluid/train/imdb_demo/train_filelist.txt new file mode 100644 index 0000000000000000000000000000000000000000..dcf088af4176196a503097b7d4e16960bbe5ae10 --- /dev/null +++ b/paddle/fluid/train/imdb_demo/train_filelist.txt @@ -0,0 +1,12 @@ +train_data/part-0 +train_data/part-1 +train_data/part-10 +train_data/part-11 +train_data/part-2 +train_data/part-3 +train_data/part-4 +train_data/part-5 +train_data/part-6 +train_data/part-7 +train_data/part-8 +train_data/part-9 diff --git a/python/paddle/fluid/incubate/data_generator/__init__.py b/python/paddle/fluid/incubate/data_generator/__init__.py index c5d298f951d8a5a73073935d1ef52c357ff9011d..77c3fc6bf2d4fb75709ba9667860b14b2334f5a1 100644 --- a/python/paddle/fluid/incubate/data_generator/__init__.py +++ b/python/paddle/fluid/incubate/data_generator/__init__.py @@ -15,7 +15,7 @@ import os import sys -__all__ = ['MultiSlotDataGenerator'] +__all__ = ['MultiSlotDataGenerator', 'MultiSlotStringDataGenerator'] class DataGenerator(object):