提交 7871d545 编写于 作者: W weixing02

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into dataset

......@@ -140,7 +140,9 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
if (timeout) {
if (exception_) {
throw * exception_;
auto exp = *exception_;
exception_.reset();
throw exp;
} else {
continue;
}
......
......@@ -74,7 +74,7 @@ ParallelExecutor::ParallelExecutor(
member_->own_local_scope = false;
PADDLE_ENFORCE_EQ(member_->places_.size(), local_scopes.size());
for (size_t i = 0; i < member_->places_.size(); ++i) {
member_->local_scopes_.emplace_back(local_scopes[i]);
member_->local_scopes_.emplace_back(&local_scopes[i]->NewScope());
}
}
......
......@@ -21,26 +21,16 @@ namespace reader {
class ThreadedReader : public framework::DecoratedReader {
public:
ThreadedReader(ReaderBase* reader, bool safe_mode)
: DecoratedReader(reader), safe_mode_(safe_mode) {}
explicit ThreadedReader(ReaderBase* reader) : DecoratedReader(reader) {}
void ReadNext(std::vector<framework::LoDTensor>* out) override {
std::lock_guard<std::mutex> lock(mutex_);
reader_->ReadNext(out);
}
void ReInit() override {
if (safe_mode_) {
PADDLE_THROW(
"ThreadedReader::ReInit() is disabled when 'safe_mode' is true.");
}
VLOG(5) << "ThreadedReader::ReInit() is invoked! It might be buggy in "
"multi-thread environment.";
reader_->ReInit();
}
void ReInit() override { reader_->ReInit(); }
private:
bool safe_mode_;
std::mutex mutex_;
};
......@@ -58,8 +48,7 @@ class CreateThreadedReaderOp : public framework::OperatorBase {
}
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
->Get<framework::ReaderHolder>();
bool safe_mode = Attr<bool>("safe_mode");
out->Reset(new ThreadedReader(underlying_reader.Get(), safe_mode));
out->Reset(new ThreadedReader(underlying_reader.Get()));
}
};
......@@ -67,10 +56,6 @@ class CreateThreadedReaderOpMaker : public DecoratedReaderMakerBase {
public:
CreateThreadedReaderOpMaker(OpProto* op_proto, OpAttrChecker* op_checker)
: DecoratedReaderMakerBase(op_proto, op_checker) {
AddAttr<bool>("safe_mode",
"When 'safe_mode' is true, 'ReInit()' is disabled to avoid "
"unexpected bugs in multi-thread environment.")
.SetDefault(true);
AddComment(R"DOC(
CreateThreadedReader Operator
......
......@@ -457,8 +457,8 @@ def __create_shared_decorated_reader__(op_type, reader, attrs):
return monkey_patch_reader_methods(main_prog_var)
def __create_unshared_decorated_reader__(op_type, reader, attrs):
new_reader_name = unique_name(op_type)
def __create_unshared_decorated_reader__(op_type, reader, attrs, name=None):
new_reader_name = name if name is not None else unique_name(op_type)
main_blk = default_main_program().current_block()
new_reader = main_blk.create_var(name=new_reader_name)
main_blk.append_op(
......@@ -481,12 +481,12 @@ def batch(reader, batch_size):
'create_batch_reader', reader, {'batch_size': int(batch_size)})
def double_buffer(reader, place=None):
def double_buffer(reader, place=None, name=None):
attrs = dict()
if place is not None:
attrs['place'] = str(place).upper()
return __create_unshared_decorated_reader__('create_double_buffer_reader',
reader, attrs)
return __create_unshared_decorated_reader__(
'create_double_buffer_reader', reader, attrs, name=name)
def multi_pass(reader, pass_num):
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import paddle.fluid as fluid
import paddle.v2 as paddle
def load_vocab(filename):
"""
load vocabulary
"""
vocab = {}
with open(filename) as f:
wid = 0
for line in f:
vocab[line.strip()] = wid
wid += 1
return vocab
# load word dict with paddle inner function
word_dict = load_vocab(sys.argv[1])
word_dict["<unk>"] = len(word_dict)
print "Dict dim = ", len(word_dict)
# input text data
data = fluid.layers.data(name="words", shape=[1], dtype="int64", lod_level=1)
# label data
label = fluid.layers.data(name="label", shape=[1], dtype="int64")
# like placeholder
feeder = fluid.DataFeeder(feed_list=[data, label], place=fluid.CPUPlace())
# train data set
BATCH_SIZE = 128
train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.imdb.train(word_dict), buf_size=10000),
batch_size=BATCH_SIZE)
test_reader = paddle.batch(
paddle.dataset.imdb.test(word_dict), batch_size=BATCH_SIZE)
fluid.recordio_writer.convert_reader_to_recordio_file(
"train.recordio", feeder=feeder, reader_creator=train_reader)
fluid.recordio_writer.convert_reader_to_recordio_file(
"test.recordio", feeder=feeder, reader_creator=test_reader)
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
import numpy
import sys
TRAIN_FILES = ['train.recordio']
TEST_FILES = ['test.recordio']
DICT_DIM = 89528
# embedding dim
emb_dim = 128
# hidden dim
hid_dim = 128
# hidden dim2
hid_dim2 = 96
# class num
class_dim = 2
def network_cfg(is_train, pass_num=100):
with fluid.unique_name.guard():
train_file_obj = fluid.layers.open_files(
filenames=TRAIN_FILES,
pass_num=pass_num,
shapes=[[-1, 1], [-1, 1]],
lod_levels=[1, 0],
dtypes=['int64', 'int64'],
thread_num=1)
test_file_obj = fluid.layers.open_files(
filenames=TEST_FILES,
pass_num=1,
shapes=[[-1, 1], [-1, 1]],
lod_levels=[1, 0],
dtypes=['int64', 'int64'],
thread_num=1)
if is_train:
file_obj = fluid.layers.shuffle(train_file_obj, buffer_size=1000)
else:
file_obj = test_file_obj
file_obj = fluid.layers.double_buffer(
file_obj,
name="train_double_buffer" if is_train else 'test_double_buffer')
data, label = fluid.layers.read_file(file_obj)
emb = fluid.layers.embedding(input=data, size=[DICT_DIM, emb_dim])
# sequence conv with window size = 3
win_size = 3
conv_3 = fluid.nets.sequence_conv_pool(
input=emb,
num_filters=hid_dim,
filter_size=win_size,
act="tanh",
pool_type="max")
# fc layer after conv
fc_1 = fluid.layers.fc(input=[conv_3], size=hid_dim2)
# probability of each class
prediction = fluid.layers.fc(input=[fc_1],
size=class_dim,
act="softmax")
# cross entropy loss
cost = fluid.layers.cross_entropy(input=prediction, label=label)
# mean loss
avg_cost = fluid.layers.mean(x=cost)
acc = fluid.layers.accuracy(input=prediction, label=label)
if is_train:
# SGD optimizer
sgd_optimizer = fluid.optimizer.Adagrad(learning_rate=0.01)
sgd_optimizer.minimize(avg_cost)
return {
'loss': avg_cost,
'log': [avg_cost, acc],
'file': train_file_obj if is_train else test_file_obj
}
def main():
train = fluid.Program()
startup = fluid.Program()
with fluid.program_guard(train, startup):
train_args = network_cfg(is_train=True)
test = fluid.Program()
with fluid.program_guard(test, fluid.Program()):
test_args = network_cfg(is_train=False)
# startup
place = fluid.CUDAPlace(0)
exe = fluid.Executor(place=place)
exe.run(startup)
train_exe = fluid.ParallelExecutor(
use_cuda=True, loss_name=train_args['loss'].name, main_program=train)
fetch_var_list = [var.name for var in train_args['log']]
for i in xrange(sys.maxint):
result = map(numpy.array,
train_exe.run(fetch_list=fetch_var_list
if i % 1000 == 0 else []))
if len(result) != 0:
print 'Train: ', result
if i % 1000 == 0:
test_exe = fluid.ParallelExecutor(
use_cuda=True, main_program=test, share_vars_from=train_exe)
loss = []
acc = []
try:
while True:
loss_np, acc_np = map(
numpy.array, test_exe.run(fetch_list=fetch_var_list))
loss.append(loss_np[0])
acc.append(acc_np[0])
except:
test_args['file'].reset()
print 'TEST: ', numpy.mean(loss), numpy.mean(acc)
if __name__ == '__main__':
main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册