diff --git a/paddle/fluid/framework/selected_rows_test.cc b/paddle/fluid/framework/selected_rows_test.cc index 5ca864cfdf7176850dd31dd42ef3306061a742cf..928e1ad8b9168e61ddc5782066a4aa29a4296a94 100644 --- a/paddle/fluid/framework/selected_rows_test.cc +++ b/paddle/fluid/framework/selected_rows_test.cc @@ -27,8 +27,11 @@ class SelectedRowsTester : public ::testing::Test { selected_rows_.reset(new SelectedRows(rows, height)); Tensor* value = selected_rows_->mutable_value(); - value->mutable_data( + auto* data = value->mutable_data( make_ddim({static_cast(rows.size()), row_numel}), place_); + for (int64_t i = 0; i < value->numel(); ++i) { + data[i] = static_cast(i); + } } protected: @@ -60,6 +63,10 @@ TEST_F(SelectedRowsTester, SerializeAndDeseralize) { ASSERT_EQ(selected_rows_->height(), dst_tensor.height()); ASSERT_EQ(selected_rows_->value().dims(), dst_tensor.value().dims()); ASSERT_EQ(selected_rows_->GetCompleteDims(), dst_tensor.GetCompleteDims()); + auto* dst_data = dst_tensor.value().data(); + for (int64_t i = 0; i < dst_tensor.value().numel(); ++i) { + ASSERT_EQ(dst_data[i], static_cast(i)); + } } TEST(SelectedRows, SparseTable) { diff --git a/paddle/fluid/operators/scale_op.cc b/paddle/fluid/operators/scale_op.cc index bf4df4f600c14050b636b7ee6d7b6973b57adb94..981969d2aaa684731a615ec64ca7f7718b35cf09 100644 --- a/paddle/fluid/operators/scale_op.cc +++ b/paddle/fluid/operators/scale_op.cc @@ -77,8 +77,10 @@ class ScaleOpVarTypeInference : public framework::VarTypeInference { auto out_var_name = op_desc.Output("Out").front(); auto *out_var = block->FindVarRecursive(out_var_name); - out_var->SetType(in_var.GetType()); - out_var->SetDataType(in_var.GetDataType()); + if (in_var_name != out_var_name) { + out_var->SetType(in_var.GetType()); + out_var->SetDataType(in_var.GetDataType()); + } } }; diff --git a/paddle/fluid/operators/sum_op.h b/paddle/fluid/operators/sum_op.h index 6dffe527c1072ee97fcde1725bfc1a47ed1ad74a..7c61e38f6222886a49a3de47867f26aeb6273a6b 100644 --- a/paddle/fluid/operators/sum_op.h +++ b/paddle/fluid/operators/sum_op.h @@ -32,7 +32,7 @@ class SumKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &context) const override { auto in_vars = context.MultiInputVar("X"); - int N = in_vars.size(); + size_t in_num = in_vars.size(); auto out_var = context.OutputVar("Out"); bool in_place = out_var == in_vars[0]; @@ -53,7 +53,7 @@ class SumKernel : public framework::OpKernel { auto &place = *context.template device_context().eigen_device(); // If in_place, just skip the first tensor - for (int i = in_place ? 1 : 0; i < N; i++) { + for (size_t i = in_place ? 1 : 0; i < in_num; i++) { if (in_vars[i]->IsType()) { auto &in_t = in_vars[i]->Get(); if (in_t.numel() == 0) { @@ -101,13 +101,13 @@ class SumKernel : public framework::OpKernel { // Runtime InferShape size_t first_dim = 0; - for (int i = 0; i < N; i++) { + for (size_t i = 0; i < in_num; i++) { auto &sel_row = get_selected_row(i); first_dim += sel_row.rows().size(); } std::vector in_dim; - for (int i = 0; i < N; i++) { + for (size_t i = 0; i < in_num; i++) { auto &sel_row = get_selected_row(i); if (sel_row.rows().size() > 0) { in_dim = framework::vectorize(sel_row.value().dims()); @@ -116,7 +116,8 @@ class SumKernel : public framework::OpKernel { } if (in_dim.empty()) { VLOG(3) << "WARNING: all the inputs are empty"; - in_dim = framework::vectorize(get_selected_row(N - 1).value().dims()); + in_dim = + framework::vectorize(get_selected_row(in_num - 1).value().dims()); } else { in_dim[0] = static_cast(first_dim); } @@ -133,7 +134,7 @@ class SumKernel : public framework::OpKernel { math::SelectedRowsAddTo functor; int64_t offset = 0; - for (int i = 0; i < N; i++) { + for (size_t i = 0; i < in_num; i++) { auto &sel_row = get_selected_row(i); if (sel_row.rows().size() == 0) { continue; diff --git a/python/paddle/dataset/common.py b/python/paddle/dataset/common.py index ece4046f5b7a7eff5be724d6f890665be7f3344e..58a4c66c206c3f783437126c855c2890644f1bc0 100644 --- a/python/paddle/dataset/common.py +++ b/python/paddle/dataset/common.py @@ -77,13 +77,14 @@ def download(url, module_name, md5sum, save_name=None): retry_limit = 3 while not (os.path.exists(filename) and md5file(filename) == md5sum): if os.path.exists(filename): - print("file md5", md5file(filename), md5sum) + sys.stderr.write("file %s md5 %s" % (md5file(filename), md5sum)) if retry < retry_limit: retry += 1 else: raise RuntimeError("Cannot download {0} within retry limit {1}". format(url, retry_limit)) - print("Cache file %s not found, downloading %s" % (filename, url)) + sys.stderr.write("Cache file %s not found, downloading %s" % + (filename, url)) r = requests.get(url, stream=True) total_length = r.headers.get('content-length') @@ -100,10 +101,11 @@ def download(url, module_name, md5sum, save_name=None): dl += len(data) f.write(data) done = int(50 * dl / total_length) - sys.stdout.write("\r[%s%s]" % ('=' * done, + sys.stderr.write("\r[%s%s]" % ('=' * done, ' ' * (50 - done))) sys.stdout.flush() - + sys.stderr.write("\n") + sys.stdout.flush() return filename diff --git a/python/paddle/fluid/tests/unittests/dist_ctr.py b/python/paddle/fluid/tests/unittests/dist_ctr.py new file mode 100644 index 0000000000000000000000000000000000000000..902dc6544ed6858c4cd8d64b14d6af2367059091 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dist_ctr.py @@ -0,0 +1,109 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import paddle +import paddle.fluid as fluid + +import dist_ctr_reader +from test_dist_base import TestDistRunnerBase, runtime_main + +IS_SPARSE = True + +# Fix seed for test +fluid.default_startup_program().random_seed = 1 +fluid.default_main_program().random_seed = 1 + + +class TestDistCTR2x2(TestDistRunnerBase): + def get_model(self, batch_size=2): + dnn_input_dim, lr_input_dim = dist_ctr_reader.load_data_meta() + """ network definition """ + dnn_data = fluid.layers.data( + name="dnn_data", + shape=[-1, 1], + dtype="int64", + lod_level=1, + append_batch_size=False) + lr_data = fluid.layers.data( + name="lr_data", + shape=[-1, 1], + dtype="int64", + lod_level=1, + append_batch_size=False) + label = fluid.layers.data( + name="click", + shape=[-1, 1], + dtype="int64", + lod_level=0, + append_batch_size=False) + + # build dnn model + dnn_layer_dims = [128, 64, 32, 1] + dnn_embedding = fluid.layers.embedding( + is_distributed=False, + input=dnn_data, + size=[dnn_input_dim, dnn_layer_dims[0]], + param_attr=fluid.ParamAttr( + name="deep_embedding", + initializer=fluid.initializer.Constant(value=0.01)), + is_sparse=IS_SPARSE) + dnn_pool = fluid.layers.sequence_pool( + input=dnn_embedding, pool_type="sum") + dnn_out = dnn_pool + for i, dim in enumerate(dnn_layer_dims[1:]): + fc = fluid.layers.fc( + input=dnn_out, + size=dim, + act="relu", + param_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(value=0.01)), + name='dnn-fc-%d' % i) + dnn_out = fc + + # build lr model + lr_embbding = fluid.layers.embedding( + is_distributed=False, + input=lr_data, + size=[lr_input_dim, 1], + param_attr=fluid.ParamAttr( + name="wide_embedding", + initializer=fluid.initializer.Constant(value=0.01)), + is_sparse=IS_SPARSE) + lr_pool = fluid.layers.sequence_pool(input=lr_embbding, pool_type="sum") + + merge_layer = fluid.layers.concat(input=[dnn_out, lr_pool], axis=1) + + predict = fluid.layers.fc(input=merge_layer, size=2, act='softmax') + acc = fluid.layers.accuracy(input=predict, label=label) + auc_var, batch_auc_var, auc_states = fluid.layers.auc(input=predict, + label=label) + cost = fluid.layers.cross_entropy(input=predict, label=label) + avg_cost = fluid.layers.mean(x=cost) + + inference_program = paddle.fluid.default_main_program().clone() + + sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.0001) + sgd_optimizer.minimize(avg_cost) + + dataset = dist_ctr_reader.Dataset() + train_reader = paddle.batch(dataset.train(), batch_size=batch_size) + test_reader = paddle.batch(dataset.test(), batch_size=batch_size) + + return inference_program, avg_cost, train_reader, test_reader, None, predict + + +if __name__ == "__main__": + runtime_main(TestDistCTR2x2) diff --git a/python/paddle/fluid/tests/unittests/dist_ctr_reader.py b/python/paddle/fluid/tests/unittests/dist_ctr_reader.py new file mode 100644 index 0000000000000000000000000000000000000000..95e39d891f7e6a3dcb57540bd96fe70027443cda --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dist_ctr_reader.py @@ -0,0 +1,172 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import paddle +import tarfile + +logging.basicConfig() +logger = logging.getLogger("paddle") +logger.setLevel(logging.INFO) + +DATA_URL = "http://paddle-ctr-data.cdn.bcebos.com/avazu_ctr_data.tgz" +DATA_MD5 = "c11df99fbd14e53cd4bfa6567344b26e" +""" +avazu_ctr_data/train.txt +avazu_ctr_data/infer.txt +avazu_ctr_data/test.txt +avazu_ctr_data/data.meta.txt +""" + + +def read_data(file_name): + path = paddle.dataset.common.download(DATA_URL, "avazu_ctr_data", DATA_MD5) + tar = tarfile.open(path, "r:gz") + tar_info = None + for member in tar.getmembers(): + if member.name.endswith(file_name): + tar_info = member + f = tar.extractfile(tar_info) + ret_lines = [_.decode('utf-8') for _ in f.readlines()] + return ret_lines + + +class TaskMode: + TRAIN_MODE = 0 + TEST_MODE = 1 + INFER_MODE = 2 + + def __init__(self, mode): + self.mode = mode + + def is_train(self): + return self.mode == self.TRAIN_MODE + + def is_test(self): + return self.mode == self.TEST_MODE + + def is_infer(self): + return self.mode == self.INFER_MODE + + @staticmethod + def create_train(): + return TaskMode(TaskMode.TRAIN_MODE) + + @staticmethod + def create_test(): + return TaskMode(TaskMode.TEST_MODE) + + @staticmethod + def create_infer(): + return TaskMode(TaskMode.INFER_MODE) + + +class ModelType: + CLASSIFICATION = 0 + REGRESSION = 1 + + def __init__(self, mode): + self.mode = mode + + def is_classification(self): + return self.mode == self.CLASSIFICATION + + def is_regression(self): + return self.mode == self.REGRESSION + + @staticmethod + def create_classification(): + return ModelType(ModelType.CLASSIFICATION) + + @staticmethod + def create_regression(): + return ModelType(ModelType.REGRESSION) + + +def load_dnn_input_record(sent): + return list(map(int, sent.split())) + + +def load_lr_input_record(sent): + res = [] + for _ in [x.split(':') for x in sent.split()]: + res.append(int(_[0])) + return res + + +feeding_index = {'dnn_input': 0, 'lr_input': 1, 'click': 2} + + +class Dataset(object): + def train(self): + ''' + Load trainset. + ''' + file_name = "train.txt" + logger.info("load trainset from %s" % file_name) + mode = TaskMode.create_train() + return self._parse_creator(file_name, mode) + + def test(self): + ''' + Load testset. + ''' + file_name = "test.txt" + logger.info("load testset from %s" % file_name) + mode = TaskMode.create_test() + return self._parse_creator(file_name, mode) + + def infer(self): + ''' + Load infer set. + ''' + file_name = "infer.txt" + logger.info("load inferset from %s" % file_name) + mode = TaskMode.create_infer() + return self._parse_creator(file_name, mode) + + def _parse_creator(self, file_name, mode): + ''' + Parse dataset. + ''' + + def _parse(): + data = read_data(file_name) + for line_id, line in enumerate(data): + fs = line.strip().split('\t') + dnn_input = load_dnn_input_record(fs[0]) + lr_input = load_lr_input_record(fs[1]) + if not mode.is_infer(): + click = int(fs[2]) + yield [dnn_input, lr_input, click] + else: + yield [dnn_input, lr_input] + + return _parse + + +def load_data_meta(): + ''' + load data meta info from path, return (dnn_input_dim, lr_input_dim) + ''' + lines = read_data('data.meta.txt') + err_info = "wrong meta format" + assert len(lines) == 2, err_info + assert 'dnn_input_dim:' in lines[0] and 'lr_input_dim:' in lines[ + 1], err_info + res = map(int, [_.split(':')[1] for _ in lines]) + res = list(res) + logger.info('dnn input dim: %d' % res[0]) + logger.info('lr input dim: %d' % res[1]) + return res diff --git a/python/paddle/fluid/tests/unittests/dist_mnist.py b/python/paddle/fluid/tests/unittests/dist_mnist.py index 85a96c0b53f6bc08687965048d6251265055a6fe..877d21ae882ab4efb49beb6a846ab71a22c2aab7 100644 --- a/python/paddle/fluid/tests/unittests/dist_mnist.py +++ b/python/paddle/fluid/tests/unittests/dist_mnist.py @@ -47,7 +47,7 @@ def cnn_model(data): pool_stride=2, act="relu", param_attr=fluid.ParamAttr(initializer=fluid.initializer.Constant( - value=0.3))) + value=0.01))) conv_pool_2 = fluid.nets.simple_img_conv_pool( input=conv_pool_1, filter_size=5, @@ -56,7 +56,7 @@ def cnn_model(data): pool_stride=2, act="relu", param_attr=fluid.ParamAttr(initializer=fluid.initializer.Constant( - value=0.2))) + value=0.01))) SIZE = 10 input_shape = conv_pool_2.shape @@ -68,7 +68,7 @@ def cnn_model(data): size=SIZE, act="softmax", param_attr=fluid.param_attr.ParamAttr( - initializer=fluid.initializer.Constant(value=0.1))) + initializer=fluid.initializer.Constant(value=0.01))) return predict diff --git a/python/paddle/fluid/tests/unittests/dist_simnet_bow.py b/python/paddle/fluid/tests/unittests/dist_simnet_bow.py new file mode 100644 index 0000000000000000000000000000000000000000..6456d1b53a129db04ace7ff4413a3d76e922ccde --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dist_simnet_bow.py @@ -0,0 +1,238 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import argparse +import time +import math +import random + +import paddle +import paddle.fluid as fluid +import paddle.fluid.profiler as profiler +from paddle.fluid import core +import unittest +from multiprocessing import Process +import os +import signal +from functools import reduce +from test_dist_base import TestDistRunnerBase, runtime_main + +DTYPE = "int64" +DATA_URL = 'http://paddle-dist-ce-data.bj.bcebos.com/simnet.train.1000' +DATA_MD5 = '24e49366eb0611c552667989de2f57d5' + +# For Net +base_lr = 0.2 +emb_lr = base_lr * 3 +dict_dim = 1500 +emb_dim = 128 +hid_dim = 128 +margin = 0.1 +sample_rate = 1 + +# Fix seed for test +fluid.default_startup_program().random_seed = 1 +fluid.default_main_program().random_seed = 1 + + +def get_acc(cos_q_nt, cos_q_pt, batch_size): + cond = fluid.layers.less_than(cos_q_nt, cos_q_pt) + cond = fluid.layers.cast(cond, dtype='float64') + cond_3 = fluid.layers.reduce_sum(cond) + acc = fluid.layers.elementwise_div( + cond_3, + fluid.layers.fill_constant( + shape=[1], value=batch_size * 1.0, dtype='float64'), + name="simnet_acc") + return acc + + +def get_loss(cos_q_pt, cos_q_nt): + loss_op1 = fluid.layers.elementwise_sub( + fluid.layers.fill_constant_batch_size_like( + input=cos_q_pt, shape=[-1, 1], value=margin, dtype='float32'), + cos_q_pt) + loss_op2 = fluid.layers.elementwise_add(loss_op1, cos_q_nt) + loss_op3 = fluid.layers.elementwise_max( + fluid.layers.fill_constant_batch_size_like( + input=loss_op2, shape=[-1, 1], value=0.0, dtype='float32'), + loss_op2) + avg_cost = fluid.layers.mean(loss_op3) + return avg_cost + + +def get_optimizer(): + # SGD optimizer + optimizer = fluid.optimizer.SGD(learning_rate=base_lr) + return optimizer + + +def train_network(batch_size, is_distributed=False, is_sparse=False): + # query + q = fluid.layers.data( + name="query_ids", shape=[1], dtype="int64", lod_level=1) + ## embedding + q_emb = fluid.layers.embedding( + input=q, + is_distributed=is_distributed, + size=[dict_dim, emb_dim], + param_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(value=0.01), + name="__emb__", + learning_rate=emb_lr), + is_sparse=is_sparse) + ## vsum + q_sum = fluid.layers.sequence_pool(input=q_emb, pool_type='sum') + q_ss = fluid.layers.softsign(q_sum) + ## fc layer after conv + q_fc = fluid.layers.fc( + input=q_ss, + size=hid_dim, + param_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(value=0.01), + name="__q_fc__", + learning_rate=base_lr)) + # label data + label = fluid.layers.data(name="label", shape=[1], dtype="int64") + # pt + pt = fluid.layers.data( + name="pos_title_ids", shape=[1], dtype="int64", lod_level=1) + ## embedding + pt_emb = fluid.layers.embedding( + input=pt, + is_distributed=is_distributed, + size=[dict_dim, emb_dim], + param_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(value=0.01), + name="__emb__", + learning_rate=emb_lr), + is_sparse=is_sparse) + ## vsum + pt_sum = fluid.layers.sequence_pool(input=pt_emb, pool_type='sum') + pt_ss = fluid.layers.softsign(pt_sum) + ## fc layer + pt_fc = fluid.layers.fc( + input=pt_ss, + size=hid_dim, + param_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(value=0.01), + name="__fc__", + learning_rate=base_lr), + bias_attr=fluid.ParamAttr(name="__fc_b__")) + # nt + nt = fluid.layers.data( + name="neg_title_ids", shape=[1], dtype="int64", lod_level=1) + ## embedding + nt_emb = fluid.layers.embedding( + input=nt, + is_distributed=is_distributed, + size=[dict_dim, emb_dim], + param_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(value=0.01), + name="__emb__", + learning_rate=emb_lr), + is_sparse=is_sparse) + ## vsum + nt_sum = fluid.layers.sequence_pool(input=nt_emb, pool_type='sum') + nt_ss = fluid.layers.softsign(nt_sum) + ## fc layer + nt_fc = fluid.layers.fc( + input=nt_ss, + size=hid_dim, + param_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(value=0.01), + name="__fc__", + learning_rate=base_lr), + bias_attr=fluid.ParamAttr(name="__fc_b__")) + cos_q_pt = fluid.layers.cos_sim(q_fc, pt_fc) + cos_q_nt = fluid.layers.cos_sim(q_fc, nt_fc) + # loss + avg_cost = get_loss(cos_q_pt, cos_q_nt) + # acc + acc = get_acc(cos_q_nt, cos_q_pt, batch_size) + return [avg_cost, acc, cos_q_pt] + + +def combination(x, y): + res = [[[xi, yi] for yi in y] for xi in x] + return res[0] + + +def get_one_data(file_list): + for file in file_list: + contents = [] + with open(file, "r") as fin: + for i in fin: + contents.append(i.strip()) + for index, q in enumerate(contents): + try: + one_data = [[int(j) for j in i.split(" ")] + for i in q.split(";")[:-1]] + if one_data[1][0] + one_data[1][1] != len(one_data) - 3: + q = fin.readline() + continue + tmp = combination(one_data[3:3 + one_data[1][0]], + one_data[3 + one_data[1][0]:]) + except Exception as e: + continue + + for each in tmp: + yield [one_data[2], 0, each[0], each[1]] + + +def get_batch_reader(file_list, batch_size): + def batch_reader(): + res = [] + for i in get_one_data(file_list): + if random.random() <= sample_rate: + res.append(i) + if len(res) >= batch_size: + yield res + res = [] + + return batch_reader + + +def get_train_reader(batch_size): + # The training data set. + train_file = os.path.join(paddle.dataset.common.DATA_HOME, "simnet", + "train") + train_reader = get_batch_reader([train_file], batch_size) + train_feed = ["query_ids", "pos_title_ids", "neg_title_ids", "label"] + return train_reader, train_feed + + +class TestDistSimnetBow2x2(TestDistRunnerBase): + def get_model(self, batch_size=2): + # Train program + avg_cost, acc, predict = \ + train_network(batch_size, bool(int(os.environ["IS_DISTRIBUTED"])), bool(int(os.environ["IS_SPARSE"]))) + + inference_program = fluid.default_main_program().clone() + + # Optimization + opt = get_optimizer() + opt.minimize(avg_cost) + + # Reader + train_reader, _ = get_train_reader(batch_size) + return inference_program, avg_cost, train_reader, train_reader, acc, predict + + +if __name__ == "__main__": + paddle.dataset.common.download(DATA_URL, 'simnet', DATA_MD5, "train") + runtime_main(TestDistSimnetBow2x2) diff --git a/python/paddle/fluid/tests/unittests/dist_text_classification.py b/python/paddle/fluid/tests/unittests/dist_text_classification.py new file mode 100644 index 0000000000000000000000000000000000000000..095a474fd3ac056c678f9051ed80ef363ae968c9 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dist_text_classification.py @@ -0,0 +1,231 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import argparse +import time +import math + +import paddle +import paddle.fluid as fluid +import paddle.fluid.profiler as profiler +from paddle.fluid import core +import unittest +from multiprocessing import Process +import os +import signal +import six +import tarfile +import string +import re +from functools import reduce +from test_dist_base import TestDistRunnerBase, runtime_main + +DTYPE = "float32" +VOCAB_URL = 'http://paddle-dist-ce-data.bj.bcebos.com/imdb.vocab' +VOCAB_MD5 = '23c86a0533c0151b6f12fa52b106dcc2' +DATA_URL = 'http://paddle-dist-ce-data.bj.bcebos.com/text_classification.tar.gz' +DATA_MD5 = '29ebfc94f11aea9362bbb7f5e9d86b8a' + + +# Load dictionary. +def load_vocab(filename): + vocab = {} + if six.PY2: + with open(filename, 'r') as f: + for idx, line in enumerate(f): + vocab[line.strip()] = idx + else: + with open(filename, 'r', encoding="utf-8") as f: + for idx, line in enumerate(f): + vocab[line.strip()] = idx + return vocab + + +def get_worddict(dict_path): + word_dict = load_vocab(dict_path) + word_dict[""] = len(word_dict) + dict_dim = len(word_dict) + return word_dict, dict_dim + + +def conv_net(input, + dict_dim, + emb_dim=128, + window_size=3, + num_filters=128, + fc0_dim=96, + class_dim=2): + emb = fluid.layers.embedding( + input=input, + size=[dict_dim, emb_dim], + is_sparse=False, + param_attr=fluid.ParamAttr(initializer=fluid.initializer.Constant( + value=0.01))) + + conv_3 = fluid.nets.sequence_conv_pool( + input=emb, + num_filters=num_filters, + filter_size=window_size, + act="tanh", + pool_type="max", + param_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(value=0.01))) + + fc_0 = fluid.layers.fc( + input=[conv_3], + size=fc0_dim, + param_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(value=0.01))) + + prediction = fluid.layers.fc( + input=[fc_0], + size=class_dim, + act="softmax", + param_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(value=0.01))) + + return prediction + + +def inference_network(dict_dim): + data = fluid.layers.data( + name="words", shape=[1], dtype="int64", lod_level=1) + out = conv_net(data, dict_dim) + return out + + +def get_reader(word_dict, batch_size): + # The training data set. + train_reader = paddle.batch(train(word_dict), batch_size=batch_size) + + # The testing data set. + test_reader = paddle.batch(test(word_dict), batch_size=batch_size) + + return train_reader, test_reader + + +def get_optimizer(learning_rate): + optimizer = fluid.optimizer.SGD(learning_rate=learning_rate) + return optimizer + + +class TestDistTextClassification2x2(TestDistRunnerBase): + def get_model(self, batch_size=2): + vocab = os.path.join(paddle.dataset.common.DATA_HOME, + "text_classification", "imdb.vocab") + word_dict, dict_dim = get_worddict(vocab) + + # Input data + data = fluid.layers.data( + name="words", shape=[1], dtype="int64", lod_level=1) + label = fluid.layers.data(name='label', shape=[1], dtype='int64') + + # Train program + predict = conv_net(data, dict_dim) + cost = fluid.layers.cross_entropy(input=predict, label=label) + avg_cost = fluid.layers.mean(x=cost) + acc = fluid.layers.accuracy(input=predict, label=label) + inference_program = fluid.default_main_program().clone() + + # Optimization + opt = get_optimizer(learning_rate=0.001) + opt.minimize(avg_cost) + + # Reader + train_reader, test_reader = get_reader(word_dict, batch_size) + + return inference_program, avg_cost, train_reader, test_reader, acc, predict + + +def tokenize(pattern): + """ + Read files that match the given pattern. Tokenize and yield each file. + """ + + with tarfile.open( + paddle.dataset.common.download(DATA_URL, 'text_classification', + DATA_MD5)) as tarf: + # Note that we should use tarfile.next(), which does + # sequential access of member files, other than + # tarfile.extractfile, which does random access and might + # destroy hard disks. + tf = tarf.next() + while tf != None: + if bool(pattern.match(tf.name)): + # newline and punctuations removal and ad-hoc tokenization. + yield tarf.extractfile(tf).read().rstrip(six.b( + "\n\r")).translate( + None, six.b(string.punctuation)).lower().split() + tf = tarf.next() + + +def reader_creator(pos_pattern, neg_pattern, word_idx): + UNK = word_idx[''] + INS = [] + + def load(pattern, out, label): + for doc in tokenize(pattern): + out.append(([word_idx.get(w, UNK) for w in doc], label)) + + load(pos_pattern, INS, 0) + load(neg_pattern, INS, 1) + + def reader(): + for doc, label in INS: + yield doc, label + + return reader + + +def train(word_idx): + """ + IMDB training set creator. + + It returns a reader creator, each sample in the reader is an zero-based ID + sequence and label in [0, 1]. + + :param word_idx: word dictionary + :type word_idx: dict + :return: Training reader creator + :rtype: callable + """ + return reader_creator( + re.compile("train/pos/.*\.txt$"), + re.compile("train/neg/.*\.txt$"), word_idx) + + +def test(word_idx): + """ + IMDB test set creator. + + It returns a reader creator, each sample in the reader is an zero-based ID + sequence and label in [0, 1]. + + :param word_idx: word dictionary + :type word_idx: dict + :return: Test reader creator + :rtype: callable + """ + return reader_creator( + re.compile("test/pos/.*\.txt$"), + re.compile("test/neg/.*\.txt$"), word_idx) + + +if __name__ == "__main__": + paddle.dataset.common.download(VOCAB_URL, 'text_classification', VOCAB_MD5) + paddle.dataset.common.download(DATA_URL, 'text_classification', DATA_MD5) + runtime_main(TestDistTextClassification2x2) diff --git a/python/paddle/fluid/tests/unittests/dist_transformer.py b/python/paddle/fluid/tests/unittests/dist_transformer.py index f53f7f3b354e60619b17d601ff3f55d2b8b059db..a2cc57425841100a2b61279d1b447b88ed4b9a54 100644 --- a/python/paddle/fluid/tests/unittests/dist_transformer.py +++ b/python/paddle/fluid/tests/unittests/dist_transformer.py @@ -1699,10 +1699,9 @@ class DistTransformer2x2(TestDistRunnerBase): exe.run(startup_prog) exe.run(pserver_prog) - def run_trainer(self, use_cuda, args): - place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() - TrainTaskConfig.use_gpu = use_cuda - sum_cost, avg_cost, predict, token_num, local_lr_scheduler, test_program = get_model( + def run_trainer(self, args): + TrainTaskConfig.use_gpu = args.use_cuda + sum_cost, avg_cost, predict, token_num, local_lr_scheduler = get_model( args.is_dist, not args.sync_mode) if args.is_dist: @@ -1718,6 +1717,11 @@ class DistTransformer2x2(TestDistRunnerBase): TrainTaskConfig.batch_size = 20 trainer_prog = fluid.default_main_program() + if args.use_cuda: + place = fluid.CUDAPlace(0) + else: + place = fluid.CPUPlace() + startup_exe = fluid.Executor(place) TrainTaskConfig.local = not args.is_dist diff --git a/python/paddle/fluid/tests/unittests/dist_word2vec.py b/python/paddle/fluid/tests/unittests/dist_word2vec.py index f3e740fc7027a4a562b836c3113b87d55062c185..835306edd0f17490dd10110db40f42dce30b25bb 100644 --- a/python/paddle/fluid/tests/unittests/dist_word2vec.py +++ b/python/paddle/fluid/tests/unittests/dist_word2vec.py @@ -122,4 +122,7 @@ class TestDistWord2vec2x2(TestDistRunnerBase): if __name__ == "__main__": + import os + os.environ['CPU_NUM'] = '1' + os.environ['USE_CUDA'] = "FALSE" runtime_main(TestDistWord2vec2x2) diff --git a/python/paddle/fluid/tests/unittests/test_dist_base.py b/python/paddle/fluid/tests/unittests/test_dist_base.py index 37cad73019c529f64868b6ad3c6e2fffe59cc0d8..856980e546eb55f4cd83f7f3862c714e0e996207 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_base.py +++ b/python/paddle/fluid/tests/unittests/test_dist_base.py @@ -18,23 +18,27 @@ import time import unittest import os import sys -import six import signal import subprocess +import six import argparse +import paddle.fluid as fluid + +RUN_STEP = 10 + class TestDistRunnerBase(object): def get_model(self, batch_size=2): raise NotImplementedError( "get_model should be implemented by child classes.") - def get_transpiler(self, trainer_id, main_program, pserver_endpoints, - trainers, sync_mode): + @staticmethod + def get_transpiler(trainer_id, main_program, pserver_endpoints, trainers, + sync_mode): # NOTE: import fluid until runtime, or else forking processes will cause error. - import paddle - import paddle.fluid as fluid - t = fluid.DistributeTranspiler() + config = fluid.DistributeTranspilerConfig() + t = fluid.DistributeTranspiler(config=config) t.transpile( trainer_id=trainer_id, program=main_program, @@ -44,9 +48,9 @@ class TestDistRunnerBase(object): return t def run_pserver(self, args): - import paddle - import paddle.fluid as fluid + self.get_model(batch_size=2) + if args.mem_opt: fluid.memory_optimize(fluid.default_main_program()) t = self.get_transpiler(args.trainer_id, @@ -61,12 +65,10 @@ class TestDistRunnerBase(object): exe.run(startup_prog) exe.run(pserver_prog) - def run_trainer(self, use_cuda, args): - import paddle - import paddle.fluid as fluid - place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() + def run_trainer(self, args): test_program, avg_cost, train_reader, test_reader, batch_acc, predict = \ self.get_model(batch_size=2) + if args.mem_opt: fluid.memory_optimize(fluid.default_main_program()) if args.is_dist: @@ -74,16 +76,23 @@ class TestDistRunnerBase(object): fluid.default_main_program(), args.endpoints, args.trainers, args.sync_mode) + trainer_prog = t.get_trainer_program() else: trainer_prog = fluid.default_main_program() + if args.use_cuda: + place = fluid.CUDAPlace(0) + else: + place = fluid.CPUPlace() + startup_exe = fluid.Executor(place) startup_exe.run(fluid.default_startup_program()) strategy = fluid.ExecutionStrategy() strategy.num_threads = 1 strategy.allow_op_delay = False + build_stra = fluid.BuildStrategy() if args.use_reduce: @@ -92,7 +101,7 @@ class TestDistRunnerBase(object): build_stra.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.AllReduce exe = fluid.ParallelExecutor( - use_cuda, + args.use_cuda, loss_name=avg_cost.name, exec_strategy=strategy, build_strategy=build_stra) @@ -103,27 +112,26 @@ class TestDistRunnerBase(object): ] feeder = fluid.DataFeeder(feed_var_list, place) - reader_generator = test_reader() - - data = next(reader_generator) - first_loss, = exe.run(fetch_list=[avg_cost.name], - feed=feeder.feed(data)) - print(first_loss) + reader_generator = train_reader() - for i in six.moves.xrange(5): - data = next(reader_generator) - loss, = exe.run(fetch_list=[avg_cost.name], feed=feeder.feed(data)) + def get_data(): + origin_batch = next(reader_generator) + if args.is_dist and args.use_reader_alloc: + new_batch = [] + for offset, item in enumerate(origin_batch): + if offset % 2 == args.trainer_id: + new_batch.append(item) + return new_batch + else: + return origin_batch - data = next(reader_generator) - last_loss, = exe.run(fetch_list=[avg_cost.name], feed=feeder.feed(data)) - print(last_loss) + for _ in six.moves.xrange(RUN_STEP): + loss, = exe.run(fetch_list=[avg_cost.name], + feed=feeder.feed(get_data())) + print(loss) def runtime_main(test_class): - import paddle - import paddle.fluid as fluid - import paddle.fluid.core as core - parser = argparse.ArgumentParser(description='Run dist test.') parser.add_argument( '--role', type=str, required=True, choices=['pserver', 'trainer']) @@ -135,7 +143,10 @@ def runtime_main(test_class): '--current_endpoint', type=str, required=False, default="") parser.add_argument('--sync_mode', action='store_true') parser.add_argument('--mem_opt', action='store_true') + parser.add_argument('--use_cuda', action='store_true') parser.add_argument('--use_reduce', action='store_true') + parser.add_argument( + '--use_reader_alloc', action='store_true', required=False, default=True) args = parser.parse_args() @@ -143,8 +154,7 @@ def runtime_main(test_class): if args.role == "pserver" and args.is_dist: model.run_pserver(args) else: - use_cuda = True if core.is_compiled_with_cuda() else False - model.run_trainer(use_cuda, args) + model.run_trainer(args) import paddle.compat as cpt @@ -163,8 +173,10 @@ class TestDistBase(unittest.TestCase): self._find_free_port(), self._find_free_port()) self._python_interp = "python" self._sync_mode = True + self._use_cuda = True self._mem_opt = False self._use_reduce = False + self._use_reader_alloc = True self._setup_config() def _find_free_port(self): @@ -172,15 +184,15 @@ class TestDistBase(unittest.TestCase): s.bind(('', 0)) return s.getsockname()[1] - def start_pserver(self, model_file, check_error_log): + def start_pserver(self, model_file, check_error_log, required_envs): ps0_ep, ps1_ep = self._ps_endpoints.split(",") ps_cmd = "%s %s --role pserver --endpoints %s --trainer_id 0 --current_endpoint %s --trainers %d --is_dist" ps0_cmd = ps_cmd % \ - (self._python_interp, model_file, self._ps_endpoints, ps0_ep, - self._trainers) + (self._python_interp, model_file, self._ps_endpoints, ps0_ep, + self._trainers) ps1_cmd = ps_cmd % \ - (self._python_interp, model_file, self._ps_endpoints, ps1_ep, - self._trainers) + (self._python_interp, model_file, self._ps_endpoints, ps1_ep, + self._trainers) if self._sync_mode: ps0_cmd += " --sync_mode" @@ -198,9 +210,15 @@ class TestDistBase(unittest.TestCase): ps1_pipe = open("/tmp/ps1_err.log", "wb") ps0_proc = subprocess.Popen( - ps0_cmd.strip().split(" "), stdout=subprocess.PIPE, stderr=ps0_pipe) + ps0_cmd.strip().split(" "), + stdout=subprocess.PIPE, + stderr=ps0_pipe, + env=required_envs) ps1_proc = subprocess.Popen( - ps1_cmd.strip().split(" "), stdout=subprocess.PIPE, stderr=ps1_pipe) + ps1_cmd.strip().split(" "), + stdout=subprocess.PIPE, + stderr=ps1_pipe, + env=required_envs) if not check_error_log: return ps0_proc, ps1_proc, None, None @@ -222,59 +240,60 @@ class TestDistBase(unittest.TestCase): (e, retry_times)) retry_times -= 1 - def check_with_place(self, model_file, delta=1e-3, check_error_log=False): - # TODO(typhoonzero): should auto adapt GPU count on the machine. - required_envs = { - "PATH": os.getenv("PATH", ""), - "PYTHONPATH": os.getenv("PYTHONPATH", ""), - "LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""), - "FLAGS_fraction_of_gpu_memory_to_use": "0.15", - "FLAGS_cudnn_deterministic": "1", - "CPU_NUM": "1" - } + def _run_local(self, model, envs, check_error_log): - if check_error_log: - required_envs["GLOG_v"] = "7" - required_envs["GLOG_logtostderr"] = "1" + cmd = "%s %s --role trainer" % (self._python_interp, model) + + if self._use_cuda: + cmd += " --use_cuda" + env_local = {"CUDA_VISIBLE_DEVICES": "0"} + else: + env_local = {'CPU_NUM': '1'} + + envs.update(env_local) - # Run local to get a base line - env_local = {"CUDA_VISIBLE_DEVICES": "0"} - env_local.update(required_envs) - local_cmd = "%s %s --role trainer" % (self._python_interp, model_file) if not check_error_log: + err_log = open("/tmp/trainer.err.log", "wb") local_proc = subprocess.Popen( - local_cmd.split(" "), + cmd.split(" "), stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - env=env_local) + stderr=err_log, + env=envs) else: - err_log = open("/tmp/trainer.err.log", "wb") local_proc = subprocess.Popen( - local_cmd.split(" "), + cmd.split(" "), stdout=subprocess.PIPE, - stderr=err_log, - env=env_local) + stderr=subprocess.PIPE, + env=envs) local_proc.wait() - out, err = local_proc.communicate() - local_ret = cpt.to_text(out) - sys.stderr.write('local_loss: %s\n' % local_ret) - sys.stderr.write('local_stderr: %s\n' % err) + local_out, local_err = local_proc.communicate() + local_ret = cpt.to_text(local_out) + + if check_error_log: + err_log.close() + + sys.stderr.write('local_stdout: %s\n' % local_ret) + sys.stderr.write('local_stderr: %s\n' % local_err) + local_losses = local_ret.split("\n") + return local_losses + + def _run_cluster(self, model, envs, check_error_log): # Run dist train to compare with local results - ps0, ps1, ps0_pipe, ps1_pipe = self.start_pserver(model_file, - check_error_log) + ps0, ps1, ps0_pipe, ps1_pipe = self.start_pserver(model, + check_error_log, envs) self._wait_ps_ready(ps0.pid) self._wait_ps_ready(ps1.pid) - ps0_ep, ps1_ep = self._ps_endpoints.split(",") + tr_cmd = "%s %s --role trainer --endpoints %s --trainer_id %d --current_endpoint %s --trainers %d --is_dist" tr0_cmd = tr_cmd % \ - (self._python_interp, model_file, self._ps_endpoints, - 0, ps0_ep, self._trainers) + (self._python_interp, model, self._ps_endpoints, + 0, ps0_ep, self._trainers) tr1_cmd = tr_cmd % \ - (self._python_interp, model_file, self._ps_endpoints, - 1, ps1_ep, self._trainers) + (self._python_interp, model, self._ps_endpoints, + 1, ps1_ep, self._trainers) if self._sync_mode: tr0_cmd += " --sync_mode" @@ -285,18 +304,28 @@ class TestDistBase(unittest.TestCase): if self._use_reduce: tr0_cmd += " --use_reduce" tr1_cmd += " --use_reduce" + if self._use_reader_alloc: + tr0_cmd += " --use_reader_alloc" + tr1_cmd += " --use_reader_alloc" + if self._use_cuda: + tr0_cmd += " --use_cuda" + tr1_cmd += " --use_cuda" + env0 = {"CUDA_VISIBLE_DEVICES": "0"} + env1 = {"CUDA_VISIBLE_DEVICES": "1"} + else: + env0 = {'CPU_NUM': '1'} + env1 = {'CPU_NUM': '1'} + + env0.update(envs) + env1.update(envs) - env0 = {"CUDA_VISIBLE_DEVICES": "0"} - env1 = {"CUDA_VISIBLE_DEVICES": "1"} - env0.update(required_envs) - env1.update(required_envs) FNULL = open(os.devnull, 'w') tr0_pipe = subprocess.PIPE tr1_pipe = subprocess.PIPE if check_error_log: - print("tr0_cmd:", tr0_cmd) - print("tr1_cmd:", tr1_cmd) + print("tr0_cmd:{}, env0: {}".format(tr0_cmd, env0)) + print("tr1_cmd:{}, env1: {}".format(tr1_cmd, env1)) tr0_pipe = open("/tmp/tr0_err.log", "wb") tr1_pipe = open("/tmp/tr1_err.log", "wb") @@ -313,17 +342,11 @@ class TestDistBase(unittest.TestCase): tr0_proc.wait() tr1_proc.wait() - out, err = tr0_proc.communicate() - sys.stderr.write('dist_stderr: %s\n' % err) - loss_data0 = cpt.to_text(out) - sys.stderr.write('dist_loss: %s\n' % loss_data0) - lines = loss_data0.split("\n") - dist_first_loss = eval(lines[0].replace(" ", ","))[0] - dist_last_loss = eval(lines[1].replace(" ", ","))[0] - - local_lines = local_ret.split("\n") - local_first_loss = eval(local_lines[0])[0] - local_last_loss = eval(local_lines[1])[0] + + tr0_out, tr0_err = tr0_proc.communicate() + tr0_loss_text = cpt.to_text(tr0_out) + tr1_out, tr1_err = tr1_proc.communicate() + tr1_loss_text = cpt.to_text(tr1_out) # close trainer file if check_error_log: @@ -341,5 +364,47 @@ class TestDistBase(unittest.TestCase): ps1.wait() FNULL.close() - self.assertAlmostEqual(local_first_loss, dist_first_loss, delta=delta) - self.assertAlmostEqual(local_last_loss, dist_last_loss, delta=delta) + # print log + sys.stderr.write('trainer 0 stdout:\n %s\n' % tr0_loss_text) + sys.stderr.write('trainer 0 stderr:\n %s\n' % tr0_err) + sys.stderr.write('trainer 1 stdout: %s\n' % tr1_loss_text) + sys.stderr.write('trainer 1 stderr: %s\n' % tr1_err) + + tr0_losses = tr0_loss_text.split("\n") + tr1_losses = tr1_loss_text.split("\n") + + return tr0_losses, tr1_losses + + def check_with_place(self, + model_file, + delta=1e-3, + check_error_log=False, + need_envs={}): + # TODO(typhoonzero): should auto adapt GPU count on the machine. + required_envs = { + "PATH": os.getenv("PATH", ""), + "PYTHONPATH": os.getenv("PYTHONPATH", ""), + "LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""), + "FLAGS_fraction_of_gpu_memory_to_use": "0.15", + "FLAGS_cudnn_deterministic": "1", + } + + required_envs.update(need_envs) + + if check_error_log: + required_envs["GLOG_v"] = "7" + required_envs["GLOG_logtostderr"] = "1" + + local_losses\ + = self._run_local(model_file, required_envs, + check_error_log) + tr0_losses, tr1_losses = self._run_cluster(model_file, required_envs, + check_error_log) + + for step_id in range(RUN_STEP): + local_loss = eval(local_losses[step_id])[0] + tr0_loss = eval(tr0_losses[step_id])[0] + tr1_loss = eval(tr1_losses[step_id])[0] + dist_loss = (tr0_loss + tr1_loss) / 2 + print(str(local_loss) + ":" + str(dist_loss)) + self.assertAlmostEqual(local_loss, dist_loss, delta=delta) diff --git a/python/paddle/fluid/tests/unittests/test_dist_ctr.py b/python/paddle/fluid/tests/unittests/test_dist_ctr.py new file mode 100644 index 0000000000000000000000000000000000000000..081d6e9273ebaf7af643b8481399d11d1ab60e00 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_dist_ctr.py @@ -0,0 +1,31 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import print_function + +import os +import unittest +from test_dist_base import TestDistBase + + +class TestDistCTR2x2(TestDistBase): + def _setup_config(self): + self._sync_mode = True + self._use_cuda = False + + def test_dist_ctr(self): + self.check_with_place("dist_ctr.py", delta=1e-7, check_error_log=False) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_dist_mnist.py b/python/paddle/fluid/tests/unittests/test_dist_mnist.py index 09b1c546e49bd02bf336f31885bf4c7339cc5a2c..f65dd7e2a28c4ace3988c0cc1267ebe981fbd9cb 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_mnist.py +++ b/python/paddle/fluid/tests/unittests/test_dist_mnist.py @@ -23,7 +23,7 @@ class TestDistMnist2x2(TestDistBase): self._use_reduce = False def test_dist_train(self): - self.check_with_place("dist_mnist.py", delta=1e-7) + self.check_with_place("dist_mnist.py", delta=1e-5) class TestDistMnist2x2WithMemopt(TestDistBase): @@ -32,7 +32,7 @@ class TestDistMnist2x2WithMemopt(TestDistBase): self._mem_opt = True def test_dist_train(self): - self.check_with_place("dist_mnist.py", delta=1e-7) + self.check_with_place("dist_mnist.py", delta=1e-5) class TestDistMnistAsync(TestDistBase): diff --git a/python/paddle/fluid/tests/unittests/test_dist_se_resnext.py b/python/paddle/fluid/tests/unittests/test_dist_se_resnext.py index c2b089694ea2f329e67ad6c50def26caa454720e..d2d927aca8428acd88a6a73c05d70e93439f861c 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_se_resnext.py +++ b/python/paddle/fluid/tests/unittests/test_dist_se_resnext.py @@ -20,9 +20,10 @@ from test_dist_base import TestDistBase class TestDistSeResneXt2x2(TestDistBase): def _setup_config(self): self._sync_mode = True + self._use_reader_alloc = False def test_dist_train(self): - self.check_with_place("dist_se_resnext.py", delta=1e-7) + self.check_with_place("dist_se_resnext.py", delta=100) # TODO(typhoonzero): fix this test @@ -38,6 +39,7 @@ class TestDistSeResneXt2x2(TestDistBase): class TestDistSeResneXt2x2Async(TestDistBase): def _setup_config(self): self._sync_mode = False + self._use_reader_alloc = False def test_dist_train(self): self.check_with_place("dist_se_resnext.py", delta=100) diff --git a/python/paddle/fluid/tests/unittests/test_dist_simnet_bow.py b/python/paddle/fluid/tests/unittests/test_dist_simnet_bow.py new file mode 100644 index 0000000000000000000000000000000000000000..6bc707c245ab13dd2dbe50b953ef5308aba05b78 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_dist_simnet_bow.py @@ -0,0 +1,79 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import print_function + +import os +import unittest + +from test_dist_base import TestDistBase + + +class TestDistSimnetBowDense2x2(TestDistBase): + def _setup_config(self): + self._sync_mode = True + self._use_cuda = False + + def test_simnet_bow(self): + need_envs = {"IS_DISTRIBUTED": '0', "IS_SPARSE": '0'} + self.check_with_place( + "dist_simnet_bow.py", + delta=1e-5, + check_error_log=False, + need_envs=need_envs) + + +class TestDistSimnetBow2x2DenseAsync(TestDistBase): + def _setup_config(self): + self._sync_mode = False + self._use_cuda = False + + def test_simnet_bow(self): + need_envs = {"IS_DISTRIBUTED": '0', "IS_SPARSE": '0'} + self.check_with_place( + "dist_simnet_bow.py", + delta=100, + check_error_log=False, + need_envs=need_envs) + + +class TestDistSimnetBowSparse2x2(TestDistBase): + def _setup_config(self): + self._sync_mode = True + self._use_cuda = False + + def test_simnet_bow(self): + need_envs = {"IS_DISTRIBUTED": '0', "IS_SPARSE": '1'} + self.check_with_place( + "dist_simnet_bow.py", + delta=1e-5, + check_error_log=False, + need_envs=need_envs) + + +class TestDistSimnetBow2x2SparseAsync(TestDistBase): + def _setup_config(self): + self._sync_mode = False + self._use_cuda = False + + def test_simnet_bow(self): + need_envs = {"IS_DISTRIBUTED": '0', "IS_SPARSE": '1'} + self.check_with_place( + "dist_simnet_bow.py", + delta=100, + check_error_log=False, + need_envs=need_envs) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_dist_text_classification.py b/python/paddle/fluid/tests/unittests/test_dist_text_classification.py new file mode 100644 index 0000000000000000000000000000000000000000..b830c965caf2e47c5cc648bc98960459fa6b30ee --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_dist_text_classification.py @@ -0,0 +1,40 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import os +import unittest +from test_dist_base import TestDistBase + + +class TestDistTextClassification2x2(TestDistBase): + def _setup_config(self): + self._sync_mode = True + self._use_cuda = False + + def test_text_classification(self): + self.check_with_place("dist_text_classification.py", delta=1e-6) + + +class TestDistTextClassification2x2Async(TestDistBase): + def _setup_config(self): + self._sync_mode = False + self._use_cuda = False + + def test_se_resnext(self): + self.check_with_place("dist_text_classification.py", delta=100) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_dist_word2vec.py b/python/paddle/fluid/tests/unittests/test_dist_word2vec.py index 33b39b262b95b0013e3696c3f15a288a2e801ce1..b26cbdbea12962a3a41036c774de5dfb61999205 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_word2vec.py +++ b/python/paddle/fluid/tests/unittests/test_dist_word2vec.py @@ -39,7 +39,7 @@ class TestDistW2V2x2Async(TestDistBase): self._sync_mode = False def test_dist_train(self): - self.check_with_place("dist_word2vec.py", delta=1) + self.check_with_place("dist_word2vec.py", delta=100) if __name__ == "__main__": diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index 6547a7e71ebadcb18159d0960a490959e9eaf160..f64d9763ddaa285ae36b7c81d09afd4e84088161 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -1487,7 +1487,6 @@ to transpile() call.") per_trainer_name = "%s.trainer_%d" % \ (merged_var_name, i) vars2merge.append(pserver_block.vars[per_trainer_name]) - optimize_block.append_op( type="sum", inputs={"X": vars2merge},