diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 5a6dbbb38fa48bcb43110df1806e4de4bfefa075..405dbb0f66ef783a8a55525987706b273790261e 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -269,7 +269,7 @@ paddle.fluid.layers.iou_similarity ArgSpec(args=['x', 'y', 'name'], varargs=None paddle.fluid.layers.box_coder ArgSpec(args=['prior_box', 'prior_box_var', 'target_box', 'code_type', 'box_normalized', 'name'], varargs=None, keywords=None, defaults=('encode_center_size', True, None)) paddle.fluid.layers.polygon_box_transform ArgSpec(args=['input', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.accuracy ArgSpec(args=['input', 'label', 'k', 'correct', 'total'], varargs=None, keywords=None, defaults=(1, None, None)) -paddle.fluid.layers.auc ArgSpec(args=['input', 'label', 'curve', 'num_thresholds', 'topk'], varargs=None, keywords=None, defaults=('ROC', 4095, 1)) +paddle.fluid.layers.auc ArgSpec(args=['input', 'label', 'curve', 'num_thresholds', 'topk', 'slide_steps'], varargs=None, keywords=None, defaults=('ROC', 4095, 1, 1)) paddle.fluid.layers.exponential_decay ArgSpec(args=['learning_rate', 'decay_steps', 'decay_rate', 'staircase'], varargs=None, keywords=None, defaults=(False,)) paddle.fluid.layers.natural_exp_decay ArgSpec(args=['learning_rate', 'decay_steps', 'decay_rate', 'staircase'], varargs=None, keywords=None, defaults=(False,)) paddle.fluid.layers.inverse_time_decay ArgSpec(args=['learning_rate', 'decay_steps', 'decay_rate', 'staircase'], varargs=None, keywords=None, defaults=(False,)) diff --git a/paddle/fluid/framework/selected_rows_test.cc b/paddle/fluid/framework/selected_rows_test.cc index 5ca864cfdf7176850dd31dd42ef3306061a742cf..928e1ad8b9168e61ddc5782066a4aa29a4296a94 100644 --- a/paddle/fluid/framework/selected_rows_test.cc +++ b/paddle/fluid/framework/selected_rows_test.cc @@ -27,8 +27,11 @@ class SelectedRowsTester : public ::testing::Test { selected_rows_.reset(new SelectedRows(rows, height)); Tensor* value = selected_rows_->mutable_value(); - value->mutable_data( + auto* data = value->mutable_data( make_ddim({static_cast(rows.size()), row_numel}), place_); + for (int64_t i = 0; i < value->numel(); ++i) { + data[i] = static_cast(i); + } } protected: @@ -60,6 +63,10 @@ TEST_F(SelectedRowsTester, SerializeAndDeseralize) { ASSERT_EQ(selected_rows_->height(), dst_tensor.height()); ASSERT_EQ(selected_rows_->value().dims(), dst_tensor.value().dims()); ASSERT_EQ(selected_rows_->GetCompleteDims(), dst_tensor.GetCompleteDims()); + auto* dst_data = dst_tensor.value().data(); + for (int64_t i = 0; i < dst_tensor.value().numel(); ++i) { + ASSERT_EQ(dst_data[i], static_cast(i)); + } } TEST(SelectedRows, SparseTable) { diff --git a/paddle/fluid/operators/auc_op.cc b/paddle/fluid/operators/auc_op.cc index dfaa7456f917c1308984b361afed752f96ea6f59..0784920064a879963cd9725cd9acf4cec7b874ce 100644 --- a/paddle/fluid/operators/auc_op.cc +++ b/paddle/fluid/operators/auc_op.cc @@ -36,11 +36,16 @@ class AucOp : public framework::OperatorWithKernel { "Out and Label should have same height."); int num_pred_buckets = ctx->Attrs().Get("num_thresholds") + 1; + int slide_steps = ctx->Attrs().Get("slide_steps"); + + PADDLE_ENFORCE_GE(num_pred_buckets, 1, "num_thresholds must larger than 1"); + PADDLE_ENFORCE_GE(slide_steps, 0, "slide_steps must be natural number"); ctx->SetOutputDim("AUC", {1}); - ctx->SetOutputDim("BatchAUC", {1}); - ctx->SetOutputDim("StatPosOut", {num_pred_buckets}); - ctx->SetOutputDim("StatNegOut", {num_pred_buckets}); + + slide_steps = slide_steps == 0 ? 1 : slide_steps; + ctx->SetOutputDim("StatPosOut", {slide_steps, num_pred_buckets}); + ctx->SetOutputDim("StatNegOut", {slide_steps, num_pred_buckets}); } protected: @@ -62,6 +67,7 @@ class AucOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("Label", "A 2D int tensor indicating the label of the training data. " "shape: [batch_size, 1]"); + // TODO(typhoonzero): support weight input AddInput("StatPos", "Statistic value when label = 1"); AddInput("StatNeg", "Statistic value when label = 0"); @@ -69,18 +75,19 @@ class AucOpMaker : public framework::OpProtoAndCheckerMaker { AddOutput("AUC", "A scalar representing the " "current area-under-the-curve."); - AddOutput("BatchAUC", "The AUC for current batch"); + AddOutput("StatPosOut", "Statistic value when label = 1"); AddOutput("StatNegOut", "Statistic value when label = 0"); AddAttr("curve", "Curve type, can be 'ROC' or 'PR'.") .SetDefault("ROC"); - AddAttr("num_thresholds", - "The number of thresholds to use when discretizing the" - " roc curve.") + AddAttr( + "num_thresholds", + "The number of thresholds to use when discretizing the roc curve.") .SetDefault((2 << 12) - 1); - + AddAttr("slide_steps", "Use slide steps to calc batch auc.") + .SetDefault(1); AddComment(R"DOC( Area Under The Curve (AUC) Operator. diff --git a/paddle/fluid/operators/auc_op.h b/paddle/fluid/operators/auc_op.h index fb0517d70635e090f8c5b59ff9d8420fc34c747b..fb370842d1942c3b3eebecb1fe5e8ffb845cb34b 100644 --- a/paddle/fluid/operators/auc_op.h +++ b/paddle/fluid/operators/auc_op.h @@ -32,7 +32,9 @@ class AucKernel : public framework::OpKernel { std::string curve = ctx.Attr("curve"); int num_thresholds = ctx.Attr("num_thresholds"); + // buckets contain numbers from 0 to num_thresholds int num_pred_buckets = num_thresholds + 1; + int slide_steps = ctx.Attr("slide_steps"); // Only use output var for now, make sure it's persistable and // not cleaned up for each batch. @@ -40,16 +42,19 @@ class AucKernel : public framework::OpKernel { auto *stat_pos = ctx.Output("StatPosOut"); auto *stat_neg = ctx.Output("StatNegOut"); - auto *stat_pos_data = stat_pos->mutable_data(ctx.GetPlace()); - auto *stat_neg_data = stat_neg->mutable_data(ctx.GetPlace()); - calcAuc(ctx, label, predict, stat_pos_data, stat_neg_data, num_thresholds, - auc); + auto *origin_stat_pos = stat_pos->mutable_data(ctx.GetPlace()); + auto *origin_stat_neg = stat_neg->mutable_data(ctx.GetPlace()); - auto *batch_auc = ctx.Output("BatchAUC"); - std::vector stat_pos_batch(num_pred_buckets, 0); - std::vector stat_neg_batch(num_pred_buckets, 0); - calcAuc(ctx, label, predict, stat_pos_batch.data(), stat_neg_batch.data(), - num_thresholds, batch_auc); + std::vector stat_pos_data(num_pred_buckets, 0); + std::vector stat_neg_data(num_pred_buckets, 0); + + auto stat_pos_calc = stat_pos_data.data(); + auto stat_neg_calc = stat_neg_data.data(); + + statAuc(label, predict, num_pred_buckets, num_thresholds, slide_steps, + origin_stat_pos, origin_stat_neg, &stat_pos_calc, &stat_neg_calc); + + calcAuc(ctx, stat_pos_calc, stat_neg_calc, num_thresholds, auc); } private: @@ -58,29 +63,76 @@ class AucKernel : public framework::OpKernel { return (X1 > X2 ? (X1 - X2) : (X2 - X1)) * (Y1 + Y2) / 2.0; } - inline static void calcAuc(const framework::ExecutionContext &ctx, - const framework::Tensor *label, + inline static void statAuc(const framework::Tensor *label, const framework::Tensor *predict, - int64_t *stat_pos, int64_t *stat_neg, - int num_thresholds, - framework::Tensor *auc_tensor) { + const int num_pred_buckets, + const int num_thresholds, const int slide_steps, + int64_t *origin_stat_pos, int64_t *origin_stat_neg, + int64_t **stat_pos, int64_t **stat_neg) { size_t batch_size = predict->dims()[0]; size_t inference_width = predict->dims()[1]; const T *inference_data = predict->data(); const auto *label_data = label->data(); - auto *auc = auc_tensor->mutable_data(ctx.GetPlace()); - for (size_t i = 0; i < batch_size; i++) { uint32_t binIdx = static_cast( inference_data[i * inference_width + 1] * num_thresholds); if (label_data[i]) { - stat_pos[binIdx] += 1.0; + (*stat_pos)[binIdx] += 1.0; } else { - stat_neg[binIdx] += 1.0; + (*stat_neg)[binIdx] += 1.0; } } + int bucket_length = num_pred_buckets * sizeof(int64_t); + + // will stat auc unlimited. + if (slide_steps == 0) { + for (int slide = 0; slide < num_pred_buckets; ++slide) { + origin_stat_pos[slide] += (*stat_pos)[slide]; + origin_stat_neg[slide] += (*stat_neg)[slide]; + } + + *stat_pos = origin_stat_pos; + *stat_neg = origin_stat_neg; + + } else { + for (int slide = 1; slide < slide_steps; ++slide) { + int dst_idx = (slide - 1) * num_pred_buckets; + int src_inx = slide * num_pred_buckets; + std::memcpy(origin_stat_pos + dst_idx, origin_stat_pos + src_inx, + bucket_length); + std::memcpy(origin_stat_neg + dst_idx, origin_stat_neg + src_inx, + bucket_length); + } + + std::memcpy(origin_stat_pos + (slide_steps - 1) * num_pred_buckets, + *stat_pos, bucket_length); + std::memcpy(origin_stat_neg + (slide_steps - 1) * num_pred_buckets, + *stat_neg, bucket_length); + + std::memset(*stat_pos, 0, bucket_length); + std::memset(*stat_neg, 0, bucket_length); + + for (int slide = 0; slide < num_pred_buckets; ++slide) { + int stat_pos_steps = 0; + int stat_neg_steps = 0; + for (int step = 0; step < slide_steps; ++step) { + stat_pos_steps += origin_stat_pos[slide + step * num_pred_buckets]; + stat_neg_steps += origin_stat_neg[slide + step * num_pred_buckets]; + } + (*stat_pos)[slide] += stat_pos_steps; + (*stat_neg)[slide] += stat_neg_steps; + } + } + } + + inline static void calcAuc(const framework::ExecutionContext &ctx, + int64_t *stat_pos, int64_t *stat_neg, + int num_thresholds, + framework::Tensor *auc_tensor) { + auto *auc = auc_tensor->mutable_data(ctx.GetPlace()); + *auc = 0.0f; double totPos = 0.0; @@ -96,7 +148,6 @@ class AucKernel : public framework::OpKernel { totPos += stat_pos[idx]; totNeg += stat_neg[idx]; *auc += trapezoidArea(totNeg, totNegPrev, totPos, totPosPrev); - --idx; } diff --git a/paddle/fluid/operators/scale_op.cc b/paddle/fluid/operators/scale_op.cc index bf4df4f600c14050b636b7ee6d7b6973b57adb94..981969d2aaa684731a615ec64ca7f7718b35cf09 100644 --- a/paddle/fluid/operators/scale_op.cc +++ b/paddle/fluid/operators/scale_op.cc @@ -77,8 +77,10 @@ class ScaleOpVarTypeInference : public framework::VarTypeInference { auto out_var_name = op_desc.Output("Out").front(); auto *out_var = block->FindVarRecursive(out_var_name); - out_var->SetType(in_var.GetType()); - out_var->SetDataType(in_var.GetDataType()); + if (in_var_name != out_var_name) { + out_var->SetType(in_var.GetType()); + out_var->SetDataType(in_var.GetDataType()); + } } }; diff --git a/paddle/fluid/operators/sum_op.h b/paddle/fluid/operators/sum_op.h index 6dffe527c1072ee97fcde1725bfc1a47ed1ad74a..7c61e38f6222886a49a3de47867f26aeb6273a6b 100644 --- a/paddle/fluid/operators/sum_op.h +++ b/paddle/fluid/operators/sum_op.h @@ -32,7 +32,7 @@ class SumKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &context) const override { auto in_vars = context.MultiInputVar("X"); - int N = in_vars.size(); + size_t in_num = in_vars.size(); auto out_var = context.OutputVar("Out"); bool in_place = out_var == in_vars[0]; @@ -53,7 +53,7 @@ class SumKernel : public framework::OpKernel { auto &place = *context.template device_context().eigen_device(); // If in_place, just skip the first tensor - for (int i = in_place ? 1 : 0; i < N; i++) { + for (size_t i = in_place ? 1 : 0; i < in_num; i++) { if (in_vars[i]->IsType()) { auto &in_t = in_vars[i]->Get(); if (in_t.numel() == 0) { @@ -101,13 +101,13 @@ class SumKernel : public framework::OpKernel { // Runtime InferShape size_t first_dim = 0; - for (int i = 0; i < N; i++) { + for (size_t i = 0; i < in_num; i++) { auto &sel_row = get_selected_row(i); first_dim += sel_row.rows().size(); } std::vector in_dim; - for (int i = 0; i < N; i++) { + for (size_t i = 0; i < in_num; i++) { auto &sel_row = get_selected_row(i); if (sel_row.rows().size() > 0) { in_dim = framework::vectorize(sel_row.value().dims()); @@ -116,7 +116,8 @@ class SumKernel : public framework::OpKernel { } if (in_dim.empty()) { VLOG(3) << "WARNING: all the inputs are empty"; - in_dim = framework::vectorize(get_selected_row(N - 1).value().dims()); + in_dim = + framework::vectorize(get_selected_row(in_num - 1).value().dims()); } else { in_dim[0] = static_cast(first_dim); } @@ -133,7 +134,7 @@ class SumKernel : public framework::OpKernel { math::SelectedRowsAddTo functor; int64_t offset = 0; - for (int i = 0; i < N; i++) { + for (size_t i = 0; i < in_num; i++) { auto &sel_row = get_selected_row(i); if (sel_row.rows().size() == 0) { continue; diff --git a/python/paddle/dataset/common.py b/python/paddle/dataset/common.py index ece4046f5b7a7eff5be724d6f890665be7f3344e..58a4c66c206c3f783437126c855c2890644f1bc0 100644 --- a/python/paddle/dataset/common.py +++ b/python/paddle/dataset/common.py @@ -77,13 +77,14 @@ def download(url, module_name, md5sum, save_name=None): retry_limit = 3 while not (os.path.exists(filename) and md5file(filename) == md5sum): if os.path.exists(filename): - print("file md5", md5file(filename), md5sum) + sys.stderr.write("file %s md5 %s" % (md5file(filename), md5sum)) if retry < retry_limit: retry += 1 else: raise RuntimeError("Cannot download {0} within retry limit {1}". format(url, retry_limit)) - print("Cache file %s not found, downloading %s" % (filename, url)) + sys.stderr.write("Cache file %s not found, downloading %s" % + (filename, url)) r = requests.get(url, stream=True) total_length = r.headers.get('content-length') @@ -100,10 +101,11 @@ def download(url, module_name, md5sum, save_name=None): dl += len(data) f.write(data) done = int(50 * dl / total_length) - sys.stdout.write("\r[%s%s]" % ('=' * done, + sys.stderr.write("\r[%s%s]" % ('=' * done, ' ' * (50 - done))) sys.stdout.flush() - + sys.stderr.write("\n") + sys.stdout.flush() return filename diff --git a/python/paddle/fluid/layers/metric_op.py b/python/paddle/fluid/layers/metric_op.py index b1598bfec210474ae1e17f9f88e8b57aa80b8452..a3064b565d096f7feda18379c66ffc8bf2f4a55c 100644 --- a/python/paddle/fluid/layers/metric_op.py +++ b/python/paddle/fluid/layers/metric_op.py @@ -78,7 +78,12 @@ def accuracy(input, label, k=1, correct=None, total=None): return acc_out -def auc(input, label, curve='ROC', num_thresholds=2**12 - 1, topk=1): +def auc(input, + label, + curve='ROC', + num_thresholds=2**12 - 1, + topk=1, + slide_steps=1): """ **Area Under the Curve (AUC) Layer** @@ -105,6 +110,8 @@ def auc(input, label, curve='ROC', num_thresholds=2**12 - 1, topk=1): num_thresholds(int): The number of thresholds to use when discretizing the roc curve. Default 200. topk(int): only topk number of prediction output will be used for auc. + slide_steps: when calc batch auc, we can not only use step currently but the previous steps can be used. slide_steps=1 means use the current step, slide_steps=3 means use current step and the previous second steps, slide_steps=0 use all of the steps. + Returns: Variable: A scalar representing the current AUC. @@ -120,16 +127,48 @@ def auc(input, label, curve='ROC', num_thresholds=2**12 - 1, topk=1): auc_out = helper.create_tmp_variable(dtype="float64") batch_auc_out = helper.create_tmp_variable(dtype="float64") # make tp, tn, fp, fn persistable, so that can accumulate all batches. + + # for batch auc + batch_stat_pos = helper.create_global_variable( + persistable=True, + dtype='int64', + shape=[slide_steps, num_thresholds + 1]) + batch_stat_neg = helper.create_global_variable( + persistable=True, + dtype='int64', + shape=[slide_steps, num_thresholds + 1]) + + # for global auc stat_pos = helper.create_global_variable( - persistable=True, dtype='int64', shape=[num_thresholds + 1]) + persistable=True, dtype='int64', shape=[1, num_thresholds + 1]) stat_neg = helper.create_global_variable( - persistable=True, dtype='int64', shape=[num_thresholds + 1]) + persistable=True, dtype='int64', shape=[1, num_thresholds + 1]) - for var in [stat_pos, stat_neg]: + for var in [batch_stat_pos, batch_stat_neg, stat_pos, stat_neg]: helper.set_variable_initializer( var, Constant( value=0.0, force_cpu=True)) + # Batch AUC + helper.append_op( + type="auc", + inputs={ + "Predict": [input], + "Label": [label], + "StatPos": [batch_stat_pos], + "StatNeg": [batch_stat_neg] + }, + attrs={ + "curve": curve, + "num_thresholds": num_thresholds, + "slide_steps": slide_steps + }, + outputs={ + "AUC": [batch_auc_out], + "StatPosOut": [batch_stat_pos], + "StatNegOut": [batch_stat_neg] + }) + # Global AUC helper.append_op( type="auc", inputs={ @@ -138,12 +177,16 @@ def auc(input, label, curve='ROC', num_thresholds=2**12 - 1, topk=1): "StatPos": [stat_pos], "StatNeg": [stat_neg] }, - attrs={"curve": curve, - "num_thresholds": num_thresholds}, + attrs={ + "curve": curve, + "num_thresholds": num_thresholds, + "slide_steps": 0 + }, outputs={ "AUC": [auc_out], - "BatchAUC": [batch_auc_out], "StatPosOut": [stat_pos], "StatNegOut": [stat_neg] }) - return auc_out, batch_auc_out, [stat_pos, stat_neg] + return auc_out, batch_auc_out, [ + batch_stat_pos, batch_stat_neg, stat_pos, stat_neg + ] diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index d02c890209e65bdceb5da23ba5b9c7c0356174b8..2ce792728e24712c69e82617b867d1d165e6f097 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -17,6 +17,9 @@ if(NOT WITH_DISTRIBUTE) list(REMOVE_ITEM TEST_OPS test_listen_and_serv_op) LIST(REMOVE_ITEM TEST_OPS test_dist_mnist) LIST(REMOVE_ITEM TEST_OPS test_dist_word2vec) + LIST(REMOVE_ITEM TEST_OPS test_dist_ctr) + LIST(REMOVE_ITEM TEST_OPS test_dist_simnet_bow) + LIST(REMOVE_ITEM TEST_OPS test_dist_text_classification) endif(NOT WITH_DISTRIBUTE) list(REMOVE_ITEM TEST_OPS test_seq_concat_op) # FIXME(helin): https://github.com/PaddlePaddle/Paddle/issues/8290 diff --git a/python/paddle/fluid/tests/unittests/dist_ctr.py b/python/paddle/fluid/tests/unittests/dist_ctr.py new file mode 100644 index 0000000000000000000000000000000000000000..902dc6544ed6858c4cd8d64b14d6af2367059091 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dist_ctr.py @@ -0,0 +1,109 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import paddle +import paddle.fluid as fluid + +import dist_ctr_reader +from test_dist_base import TestDistRunnerBase, runtime_main + +IS_SPARSE = True + +# Fix seed for test +fluid.default_startup_program().random_seed = 1 +fluid.default_main_program().random_seed = 1 + + +class TestDistCTR2x2(TestDistRunnerBase): + def get_model(self, batch_size=2): + dnn_input_dim, lr_input_dim = dist_ctr_reader.load_data_meta() + """ network definition """ + dnn_data = fluid.layers.data( + name="dnn_data", + shape=[-1, 1], + dtype="int64", + lod_level=1, + append_batch_size=False) + lr_data = fluid.layers.data( + name="lr_data", + shape=[-1, 1], + dtype="int64", + lod_level=1, + append_batch_size=False) + label = fluid.layers.data( + name="click", + shape=[-1, 1], + dtype="int64", + lod_level=0, + append_batch_size=False) + + # build dnn model + dnn_layer_dims = [128, 64, 32, 1] + dnn_embedding = fluid.layers.embedding( + is_distributed=False, + input=dnn_data, + size=[dnn_input_dim, dnn_layer_dims[0]], + param_attr=fluid.ParamAttr( + name="deep_embedding", + initializer=fluid.initializer.Constant(value=0.01)), + is_sparse=IS_SPARSE) + dnn_pool = fluid.layers.sequence_pool( + input=dnn_embedding, pool_type="sum") + dnn_out = dnn_pool + for i, dim in enumerate(dnn_layer_dims[1:]): + fc = fluid.layers.fc( + input=dnn_out, + size=dim, + act="relu", + param_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(value=0.01)), + name='dnn-fc-%d' % i) + dnn_out = fc + + # build lr model + lr_embbding = fluid.layers.embedding( + is_distributed=False, + input=lr_data, + size=[lr_input_dim, 1], + param_attr=fluid.ParamAttr( + name="wide_embedding", + initializer=fluid.initializer.Constant(value=0.01)), + is_sparse=IS_SPARSE) + lr_pool = fluid.layers.sequence_pool(input=lr_embbding, pool_type="sum") + + merge_layer = fluid.layers.concat(input=[dnn_out, lr_pool], axis=1) + + predict = fluid.layers.fc(input=merge_layer, size=2, act='softmax') + acc = fluid.layers.accuracy(input=predict, label=label) + auc_var, batch_auc_var, auc_states = fluid.layers.auc(input=predict, + label=label) + cost = fluid.layers.cross_entropy(input=predict, label=label) + avg_cost = fluid.layers.mean(x=cost) + + inference_program = paddle.fluid.default_main_program().clone() + + sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.0001) + sgd_optimizer.minimize(avg_cost) + + dataset = dist_ctr_reader.Dataset() + train_reader = paddle.batch(dataset.train(), batch_size=batch_size) + test_reader = paddle.batch(dataset.test(), batch_size=batch_size) + + return inference_program, avg_cost, train_reader, test_reader, None, predict + + +if __name__ == "__main__": + runtime_main(TestDistCTR2x2) diff --git a/python/paddle/fluid/tests/unittests/dist_ctr_reader.py b/python/paddle/fluid/tests/unittests/dist_ctr_reader.py new file mode 100644 index 0000000000000000000000000000000000000000..95e39d891f7e6a3dcb57540bd96fe70027443cda --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dist_ctr_reader.py @@ -0,0 +1,172 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import paddle +import tarfile + +logging.basicConfig() +logger = logging.getLogger("paddle") +logger.setLevel(logging.INFO) + +DATA_URL = "http://paddle-ctr-data.cdn.bcebos.com/avazu_ctr_data.tgz" +DATA_MD5 = "c11df99fbd14e53cd4bfa6567344b26e" +""" +avazu_ctr_data/train.txt +avazu_ctr_data/infer.txt +avazu_ctr_data/test.txt +avazu_ctr_data/data.meta.txt +""" + + +def read_data(file_name): + path = paddle.dataset.common.download(DATA_URL, "avazu_ctr_data", DATA_MD5) + tar = tarfile.open(path, "r:gz") + tar_info = None + for member in tar.getmembers(): + if member.name.endswith(file_name): + tar_info = member + f = tar.extractfile(tar_info) + ret_lines = [_.decode('utf-8') for _ in f.readlines()] + return ret_lines + + +class TaskMode: + TRAIN_MODE = 0 + TEST_MODE = 1 + INFER_MODE = 2 + + def __init__(self, mode): + self.mode = mode + + def is_train(self): + return self.mode == self.TRAIN_MODE + + def is_test(self): + return self.mode == self.TEST_MODE + + def is_infer(self): + return self.mode == self.INFER_MODE + + @staticmethod + def create_train(): + return TaskMode(TaskMode.TRAIN_MODE) + + @staticmethod + def create_test(): + return TaskMode(TaskMode.TEST_MODE) + + @staticmethod + def create_infer(): + return TaskMode(TaskMode.INFER_MODE) + + +class ModelType: + CLASSIFICATION = 0 + REGRESSION = 1 + + def __init__(self, mode): + self.mode = mode + + def is_classification(self): + return self.mode == self.CLASSIFICATION + + def is_regression(self): + return self.mode == self.REGRESSION + + @staticmethod + def create_classification(): + return ModelType(ModelType.CLASSIFICATION) + + @staticmethod + def create_regression(): + return ModelType(ModelType.REGRESSION) + + +def load_dnn_input_record(sent): + return list(map(int, sent.split())) + + +def load_lr_input_record(sent): + res = [] + for _ in [x.split(':') for x in sent.split()]: + res.append(int(_[0])) + return res + + +feeding_index = {'dnn_input': 0, 'lr_input': 1, 'click': 2} + + +class Dataset(object): + def train(self): + ''' + Load trainset. + ''' + file_name = "train.txt" + logger.info("load trainset from %s" % file_name) + mode = TaskMode.create_train() + return self._parse_creator(file_name, mode) + + def test(self): + ''' + Load testset. + ''' + file_name = "test.txt" + logger.info("load testset from %s" % file_name) + mode = TaskMode.create_test() + return self._parse_creator(file_name, mode) + + def infer(self): + ''' + Load infer set. + ''' + file_name = "infer.txt" + logger.info("load inferset from %s" % file_name) + mode = TaskMode.create_infer() + return self._parse_creator(file_name, mode) + + def _parse_creator(self, file_name, mode): + ''' + Parse dataset. + ''' + + def _parse(): + data = read_data(file_name) + for line_id, line in enumerate(data): + fs = line.strip().split('\t') + dnn_input = load_dnn_input_record(fs[0]) + lr_input = load_lr_input_record(fs[1]) + if not mode.is_infer(): + click = int(fs[2]) + yield [dnn_input, lr_input, click] + else: + yield [dnn_input, lr_input] + + return _parse + + +def load_data_meta(): + ''' + load data meta info from path, return (dnn_input_dim, lr_input_dim) + ''' + lines = read_data('data.meta.txt') + err_info = "wrong meta format" + assert len(lines) == 2, err_info + assert 'dnn_input_dim:' in lines[0] and 'lr_input_dim:' in lines[ + 1], err_info + res = map(int, [_.split(':')[1] for _ in lines]) + res = list(res) + logger.info('dnn input dim: %d' % res[0]) + logger.info('lr input dim: %d' % res[1]) + return res diff --git a/python/paddle/fluid/tests/unittests/dist_mnist.py b/python/paddle/fluid/tests/unittests/dist_mnist.py index 85a96c0b53f6bc08687965048d6251265055a6fe..877d21ae882ab4efb49beb6a846ab71a22c2aab7 100644 --- a/python/paddle/fluid/tests/unittests/dist_mnist.py +++ b/python/paddle/fluid/tests/unittests/dist_mnist.py @@ -47,7 +47,7 @@ def cnn_model(data): pool_stride=2, act="relu", param_attr=fluid.ParamAttr(initializer=fluid.initializer.Constant( - value=0.3))) + value=0.01))) conv_pool_2 = fluid.nets.simple_img_conv_pool( input=conv_pool_1, filter_size=5, @@ -56,7 +56,7 @@ def cnn_model(data): pool_stride=2, act="relu", param_attr=fluid.ParamAttr(initializer=fluid.initializer.Constant( - value=0.2))) + value=0.01))) SIZE = 10 input_shape = conv_pool_2.shape @@ -68,7 +68,7 @@ def cnn_model(data): size=SIZE, act="softmax", param_attr=fluid.param_attr.ParamAttr( - initializer=fluid.initializer.Constant(value=0.1))) + initializer=fluid.initializer.Constant(value=0.01))) return predict diff --git a/python/paddle/fluid/tests/unittests/dist_se_resnext.py b/python/paddle/fluid/tests/unittests/dist_se_resnext.py index a4ffe7d40c40501ebd43fec0b664159227ea34bd..5da370570680e9f10a22ad882e3346e6381dfe63 100644 --- a/python/paddle/fluid/tests/unittests/dist_se_resnext.py +++ b/python/paddle/fluid/tests/unittests/dist_se_resnext.py @@ -247,7 +247,7 @@ class DistSeResneXt2x2(TestDistRunnerBase): # Reader train_reader = paddle.batch( - paddle.dataset.flowers.train(), batch_size=batch_size) + paddle.dataset.flowers.test(use_xmap=False), batch_size=batch_size) test_reader = paddle.batch( paddle.dataset.flowers.test(use_xmap=False), batch_size=batch_size) diff --git a/python/paddle/fluid/tests/unittests/dist_simnet_bow.py b/python/paddle/fluid/tests/unittests/dist_simnet_bow.py new file mode 100644 index 0000000000000000000000000000000000000000..6456d1b53a129db04ace7ff4413a3d76e922ccde --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dist_simnet_bow.py @@ -0,0 +1,238 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import argparse +import time +import math +import random + +import paddle +import paddle.fluid as fluid +import paddle.fluid.profiler as profiler +from paddle.fluid import core +import unittest +from multiprocessing import Process +import os +import signal +from functools import reduce +from test_dist_base import TestDistRunnerBase, runtime_main + +DTYPE = "int64" +DATA_URL = 'http://paddle-dist-ce-data.bj.bcebos.com/simnet.train.1000' +DATA_MD5 = '24e49366eb0611c552667989de2f57d5' + +# For Net +base_lr = 0.2 +emb_lr = base_lr * 3 +dict_dim = 1500 +emb_dim = 128 +hid_dim = 128 +margin = 0.1 +sample_rate = 1 + +# Fix seed for test +fluid.default_startup_program().random_seed = 1 +fluid.default_main_program().random_seed = 1 + + +def get_acc(cos_q_nt, cos_q_pt, batch_size): + cond = fluid.layers.less_than(cos_q_nt, cos_q_pt) + cond = fluid.layers.cast(cond, dtype='float64') + cond_3 = fluid.layers.reduce_sum(cond) + acc = fluid.layers.elementwise_div( + cond_3, + fluid.layers.fill_constant( + shape=[1], value=batch_size * 1.0, dtype='float64'), + name="simnet_acc") + return acc + + +def get_loss(cos_q_pt, cos_q_nt): + loss_op1 = fluid.layers.elementwise_sub( + fluid.layers.fill_constant_batch_size_like( + input=cos_q_pt, shape=[-1, 1], value=margin, dtype='float32'), + cos_q_pt) + loss_op2 = fluid.layers.elementwise_add(loss_op1, cos_q_nt) + loss_op3 = fluid.layers.elementwise_max( + fluid.layers.fill_constant_batch_size_like( + input=loss_op2, shape=[-1, 1], value=0.0, dtype='float32'), + loss_op2) + avg_cost = fluid.layers.mean(loss_op3) + return avg_cost + + +def get_optimizer(): + # SGD optimizer + optimizer = fluid.optimizer.SGD(learning_rate=base_lr) + return optimizer + + +def train_network(batch_size, is_distributed=False, is_sparse=False): + # query + q = fluid.layers.data( + name="query_ids", shape=[1], dtype="int64", lod_level=1) + ## embedding + q_emb = fluid.layers.embedding( + input=q, + is_distributed=is_distributed, + size=[dict_dim, emb_dim], + param_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(value=0.01), + name="__emb__", + learning_rate=emb_lr), + is_sparse=is_sparse) + ## vsum + q_sum = fluid.layers.sequence_pool(input=q_emb, pool_type='sum') + q_ss = fluid.layers.softsign(q_sum) + ## fc layer after conv + q_fc = fluid.layers.fc( + input=q_ss, + size=hid_dim, + param_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(value=0.01), + name="__q_fc__", + learning_rate=base_lr)) + # label data + label = fluid.layers.data(name="label", shape=[1], dtype="int64") + # pt + pt = fluid.layers.data( + name="pos_title_ids", shape=[1], dtype="int64", lod_level=1) + ## embedding + pt_emb = fluid.layers.embedding( + input=pt, + is_distributed=is_distributed, + size=[dict_dim, emb_dim], + param_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(value=0.01), + name="__emb__", + learning_rate=emb_lr), + is_sparse=is_sparse) + ## vsum + pt_sum = fluid.layers.sequence_pool(input=pt_emb, pool_type='sum') + pt_ss = fluid.layers.softsign(pt_sum) + ## fc layer + pt_fc = fluid.layers.fc( + input=pt_ss, + size=hid_dim, + param_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(value=0.01), + name="__fc__", + learning_rate=base_lr), + bias_attr=fluid.ParamAttr(name="__fc_b__")) + # nt + nt = fluid.layers.data( + name="neg_title_ids", shape=[1], dtype="int64", lod_level=1) + ## embedding + nt_emb = fluid.layers.embedding( + input=nt, + is_distributed=is_distributed, + size=[dict_dim, emb_dim], + param_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(value=0.01), + name="__emb__", + learning_rate=emb_lr), + is_sparse=is_sparse) + ## vsum + nt_sum = fluid.layers.sequence_pool(input=nt_emb, pool_type='sum') + nt_ss = fluid.layers.softsign(nt_sum) + ## fc layer + nt_fc = fluid.layers.fc( + input=nt_ss, + size=hid_dim, + param_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(value=0.01), + name="__fc__", + learning_rate=base_lr), + bias_attr=fluid.ParamAttr(name="__fc_b__")) + cos_q_pt = fluid.layers.cos_sim(q_fc, pt_fc) + cos_q_nt = fluid.layers.cos_sim(q_fc, nt_fc) + # loss + avg_cost = get_loss(cos_q_pt, cos_q_nt) + # acc + acc = get_acc(cos_q_nt, cos_q_pt, batch_size) + return [avg_cost, acc, cos_q_pt] + + +def combination(x, y): + res = [[[xi, yi] for yi in y] for xi in x] + return res[0] + + +def get_one_data(file_list): + for file in file_list: + contents = [] + with open(file, "r") as fin: + for i in fin: + contents.append(i.strip()) + for index, q in enumerate(contents): + try: + one_data = [[int(j) for j in i.split(" ")] + for i in q.split(";")[:-1]] + if one_data[1][0] + one_data[1][1] != len(one_data) - 3: + q = fin.readline() + continue + tmp = combination(one_data[3:3 + one_data[1][0]], + one_data[3 + one_data[1][0]:]) + except Exception as e: + continue + + for each in tmp: + yield [one_data[2], 0, each[0], each[1]] + + +def get_batch_reader(file_list, batch_size): + def batch_reader(): + res = [] + for i in get_one_data(file_list): + if random.random() <= sample_rate: + res.append(i) + if len(res) >= batch_size: + yield res + res = [] + + return batch_reader + + +def get_train_reader(batch_size): + # The training data set. + train_file = os.path.join(paddle.dataset.common.DATA_HOME, "simnet", + "train") + train_reader = get_batch_reader([train_file], batch_size) + train_feed = ["query_ids", "pos_title_ids", "neg_title_ids", "label"] + return train_reader, train_feed + + +class TestDistSimnetBow2x2(TestDistRunnerBase): + def get_model(self, batch_size=2): + # Train program + avg_cost, acc, predict = \ + train_network(batch_size, bool(int(os.environ["IS_DISTRIBUTED"])), bool(int(os.environ["IS_SPARSE"]))) + + inference_program = fluid.default_main_program().clone() + + # Optimization + opt = get_optimizer() + opt.minimize(avg_cost) + + # Reader + train_reader, _ = get_train_reader(batch_size) + return inference_program, avg_cost, train_reader, train_reader, acc, predict + + +if __name__ == "__main__": + paddle.dataset.common.download(DATA_URL, 'simnet', DATA_MD5, "train") + runtime_main(TestDistSimnetBow2x2) diff --git a/python/paddle/fluid/tests/unittests/dist_text_classification.py b/python/paddle/fluid/tests/unittests/dist_text_classification.py new file mode 100644 index 0000000000000000000000000000000000000000..095a474fd3ac056c678f9051ed80ef363ae968c9 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dist_text_classification.py @@ -0,0 +1,231 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import argparse +import time +import math + +import paddle +import paddle.fluid as fluid +import paddle.fluid.profiler as profiler +from paddle.fluid import core +import unittest +from multiprocessing import Process +import os +import signal +import six +import tarfile +import string +import re +from functools import reduce +from test_dist_base import TestDistRunnerBase, runtime_main + +DTYPE = "float32" +VOCAB_URL = 'http://paddle-dist-ce-data.bj.bcebos.com/imdb.vocab' +VOCAB_MD5 = '23c86a0533c0151b6f12fa52b106dcc2' +DATA_URL = 'http://paddle-dist-ce-data.bj.bcebos.com/text_classification.tar.gz' +DATA_MD5 = '29ebfc94f11aea9362bbb7f5e9d86b8a' + + +# Load dictionary. +def load_vocab(filename): + vocab = {} + if six.PY2: + with open(filename, 'r') as f: + for idx, line in enumerate(f): + vocab[line.strip()] = idx + else: + with open(filename, 'r', encoding="utf-8") as f: + for idx, line in enumerate(f): + vocab[line.strip()] = idx + return vocab + + +def get_worddict(dict_path): + word_dict = load_vocab(dict_path) + word_dict[""] = len(word_dict) + dict_dim = len(word_dict) + return word_dict, dict_dim + + +def conv_net(input, + dict_dim, + emb_dim=128, + window_size=3, + num_filters=128, + fc0_dim=96, + class_dim=2): + emb = fluid.layers.embedding( + input=input, + size=[dict_dim, emb_dim], + is_sparse=False, + param_attr=fluid.ParamAttr(initializer=fluid.initializer.Constant( + value=0.01))) + + conv_3 = fluid.nets.sequence_conv_pool( + input=emb, + num_filters=num_filters, + filter_size=window_size, + act="tanh", + pool_type="max", + param_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(value=0.01))) + + fc_0 = fluid.layers.fc( + input=[conv_3], + size=fc0_dim, + param_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(value=0.01))) + + prediction = fluid.layers.fc( + input=[fc_0], + size=class_dim, + act="softmax", + param_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(value=0.01))) + + return prediction + + +def inference_network(dict_dim): + data = fluid.layers.data( + name="words", shape=[1], dtype="int64", lod_level=1) + out = conv_net(data, dict_dim) + return out + + +def get_reader(word_dict, batch_size): + # The training data set. + train_reader = paddle.batch(train(word_dict), batch_size=batch_size) + + # The testing data set. + test_reader = paddle.batch(test(word_dict), batch_size=batch_size) + + return train_reader, test_reader + + +def get_optimizer(learning_rate): + optimizer = fluid.optimizer.SGD(learning_rate=learning_rate) + return optimizer + + +class TestDistTextClassification2x2(TestDistRunnerBase): + def get_model(self, batch_size=2): + vocab = os.path.join(paddle.dataset.common.DATA_HOME, + "text_classification", "imdb.vocab") + word_dict, dict_dim = get_worddict(vocab) + + # Input data + data = fluid.layers.data( + name="words", shape=[1], dtype="int64", lod_level=1) + label = fluid.layers.data(name='label', shape=[1], dtype='int64') + + # Train program + predict = conv_net(data, dict_dim) + cost = fluid.layers.cross_entropy(input=predict, label=label) + avg_cost = fluid.layers.mean(x=cost) + acc = fluid.layers.accuracy(input=predict, label=label) + inference_program = fluid.default_main_program().clone() + + # Optimization + opt = get_optimizer(learning_rate=0.001) + opt.minimize(avg_cost) + + # Reader + train_reader, test_reader = get_reader(word_dict, batch_size) + + return inference_program, avg_cost, train_reader, test_reader, acc, predict + + +def tokenize(pattern): + """ + Read files that match the given pattern. Tokenize and yield each file. + """ + + with tarfile.open( + paddle.dataset.common.download(DATA_URL, 'text_classification', + DATA_MD5)) as tarf: + # Note that we should use tarfile.next(), which does + # sequential access of member files, other than + # tarfile.extractfile, which does random access and might + # destroy hard disks. + tf = tarf.next() + while tf != None: + if bool(pattern.match(tf.name)): + # newline and punctuations removal and ad-hoc tokenization. + yield tarf.extractfile(tf).read().rstrip(six.b( + "\n\r")).translate( + None, six.b(string.punctuation)).lower().split() + tf = tarf.next() + + +def reader_creator(pos_pattern, neg_pattern, word_idx): + UNK = word_idx[''] + INS = [] + + def load(pattern, out, label): + for doc in tokenize(pattern): + out.append(([word_idx.get(w, UNK) for w in doc], label)) + + load(pos_pattern, INS, 0) + load(neg_pattern, INS, 1) + + def reader(): + for doc, label in INS: + yield doc, label + + return reader + + +def train(word_idx): + """ + IMDB training set creator. + + It returns a reader creator, each sample in the reader is an zero-based ID + sequence and label in [0, 1]. + + :param word_idx: word dictionary + :type word_idx: dict + :return: Training reader creator + :rtype: callable + """ + return reader_creator( + re.compile("train/pos/.*\.txt$"), + re.compile("train/neg/.*\.txt$"), word_idx) + + +def test(word_idx): + """ + IMDB test set creator. + + It returns a reader creator, each sample in the reader is an zero-based ID + sequence and label in [0, 1]. + + :param word_idx: word dictionary + :type word_idx: dict + :return: Test reader creator + :rtype: callable + """ + return reader_creator( + re.compile("test/pos/.*\.txt$"), + re.compile("test/neg/.*\.txt$"), word_idx) + + +if __name__ == "__main__": + paddle.dataset.common.download(VOCAB_URL, 'text_classification', VOCAB_MD5) + paddle.dataset.common.download(DATA_URL, 'text_classification', DATA_MD5) + runtime_main(TestDistTextClassification2x2) diff --git a/python/paddle/fluid/tests/unittests/dist_transformer.py b/python/paddle/fluid/tests/unittests/dist_transformer.py index f53f7f3b354e60619b17d601ff3f55d2b8b059db..a2cc57425841100a2b61279d1b447b88ed4b9a54 100644 --- a/python/paddle/fluid/tests/unittests/dist_transformer.py +++ b/python/paddle/fluid/tests/unittests/dist_transformer.py @@ -1699,10 +1699,9 @@ class DistTransformer2x2(TestDistRunnerBase): exe.run(startup_prog) exe.run(pserver_prog) - def run_trainer(self, use_cuda, args): - place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() - TrainTaskConfig.use_gpu = use_cuda - sum_cost, avg_cost, predict, token_num, local_lr_scheduler, test_program = get_model( + def run_trainer(self, args): + TrainTaskConfig.use_gpu = args.use_cuda + sum_cost, avg_cost, predict, token_num, local_lr_scheduler = get_model( args.is_dist, not args.sync_mode) if args.is_dist: @@ -1718,6 +1717,11 @@ class DistTransformer2x2(TestDistRunnerBase): TrainTaskConfig.batch_size = 20 trainer_prog = fluid.default_main_program() + if args.use_cuda: + place = fluid.CUDAPlace(0) + else: + place = fluid.CPUPlace() + startup_exe = fluid.Executor(place) TrainTaskConfig.local = not args.is_dist diff --git a/python/paddle/fluid/tests/unittests/dist_word2vec.py b/python/paddle/fluid/tests/unittests/dist_word2vec.py index f3e740fc7027a4a562b836c3113b87d55062c185..835306edd0f17490dd10110db40f42dce30b25bb 100644 --- a/python/paddle/fluid/tests/unittests/dist_word2vec.py +++ b/python/paddle/fluid/tests/unittests/dist_word2vec.py @@ -122,4 +122,7 @@ class TestDistWord2vec2x2(TestDistRunnerBase): if __name__ == "__main__": + import os + os.environ['CPU_NUM'] = '1' + os.environ['USE_CUDA'] = "FALSE" runtime_main(TestDistWord2vec2x2) diff --git a/python/paddle/fluid/tests/unittests/test_auc_op.py b/python/paddle/fluid/tests/unittests/test_auc_op.py index 1de4a9d016a177944253d12094722d3a05614be2..810e8a1a8547a92de923877695178e780981edeb 100644 --- a/python/paddle/fluid/tests/unittests/test_auc_op.py +++ b/python/paddle/fluid/tests/unittests/test_auc_op.py @@ -36,7 +36,11 @@ class TestAucOp(OpTest): "StatPos": stat_pos, "StatNeg": stat_neg } - self.attrs = {'curve': 'ROC', 'num_thresholds': num_thresholds} + self.attrs = { + 'curve': 'ROC', + 'num_thresholds': num_thresholds, + "slide_steps": 1 + } python_auc = metrics.Auc(name="auc", curve='ROC', @@ -45,7 +49,6 @@ class TestAucOp(OpTest): self.outputs = { 'AUC': np.array(python_auc.eval()), - 'BatchAUC': np.array(python_auc.eval()), 'StatPosOut': np.array(python_auc._stat_pos), 'StatNegOut': np.array(python_auc._stat_neg) } diff --git a/python/paddle/fluid/tests/unittests/test_dist_base.py b/python/paddle/fluid/tests/unittests/test_dist_base.py index 37cad73019c529f64868b6ad3c6e2fffe59cc0d8..6c52497e7fa682987f4bf25015363051417efb6a 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_base.py +++ b/python/paddle/fluid/tests/unittests/test_dist_base.py @@ -18,23 +18,27 @@ import time import unittest import os import sys -import six import signal import subprocess +import six import argparse +import paddle.fluid as fluid + +RUN_STEP = 10 + class TestDistRunnerBase(object): def get_model(self, batch_size=2): raise NotImplementedError( "get_model should be implemented by child classes.") - def get_transpiler(self, trainer_id, main_program, pserver_endpoints, - trainers, sync_mode): + @staticmethod + def get_transpiler(trainer_id, main_program, pserver_endpoints, trainers, + sync_mode): # NOTE: import fluid until runtime, or else forking processes will cause error. - import paddle - import paddle.fluid as fluid - t = fluid.DistributeTranspiler() + config = fluid.DistributeTranspilerConfig() + t = fluid.DistributeTranspiler(config=config) t.transpile( trainer_id=trainer_id, program=main_program, @@ -44,9 +48,9 @@ class TestDistRunnerBase(object): return t def run_pserver(self, args): - import paddle - import paddle.fluid as fluid + self.get_model(batch_size=2) + if args.mem_opt: fluid.memory_optimize(fluid.default_main_program()) t = self.get_transpiler(args.trainer_id, @@ -61,12 +65,10 @@ class TestDistRunnerBase(object): exe.run(startup_prog) exe.run(pserver_prog) - def run_trainer(self, use_cuda, args): - import paddle - import paddle.fluid as fluid - place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() + def run_trainer(self, args): test_program, avg_cost, train_reader, test_reader, batch_acc, predict = \ self.get_model(batch_size=2) + if args.mem_opt: fluid.memory_optimize(fluid.default_main_program()) if args.is_dist: @@ -74,16 +76,23 @@ class TestDistRunnerBase(object): fluid.default_main_program(), args.endpoints, args.trainers, args.sync_mode) + trainer_prog = t.get_trainer_program() else: trainer_prog = fluid.default_main_program() + if args.use_cuda: + place = fluid.CUDAPlace(0) + else: + place = fluid.CPUPlace() + startup_exe = fluid.Executor(place) startup_exe.run(fluid.default_startup_program()) strategy = fluid.ExecutionStrategy() strategy.num_threads = 1 strategy.allow_op_delay = False + build_stra = fluid.BuildStrategy() if args.use_reduce: @@ -92,7 +101,7 @@ class TestDistRunnerBase(object): build_stra.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.AllReduce exe = fluid.ParallelExecutor( - use_cuda, + args.use_cuda, loss_name=avg_cost.name, exec_strategy=strategy, build_strategy=build_stra) @@ -103,27 +112,26 @@ class TestDistRunnerBase(object): ] feeder = fluid.DataFeeder(feed_var_list, place) - reader_generator = test_reader() + reader_generator = train_reader() - data = next(reader_generator) - first_loss, = exe.run(fetch_list=[avg_cost.name], - feed=feeder.feed(data)) - print(first_loss) + def get_data(): + origin_batch = next(reader_generator) + if args.is_dist and args.use_reader_alloc: + new_batch = [] + for offset, item in enumerate(origin_batch): + if offset % 2 == args.trainer_id: + new_batch.append(item) + return new_batch + else: + return origin_batch - for i in six.moves.xrange(5): - data = next(reader_generator) - loss, = exe.run(fetch_list=[avg_cost.name], feed=feeder.feed(data)) - - data = next(reader_generator) - last_loss, = exe.run(fetch_list=[avg_cost.name], feed=feeder.feed(data)) - print(last_loss) + for _ in six.moves.xrange(RUN_STEP): + loss, = exe.run(fetch_list=[avg_cost.name], + feed=feeder.feed(get_data())) + print(loss) def runtime_main(test_class): - import paddle - import paddle.fluid as fluid - import paddle.fluid.core as core - parser = argparse.ArgumentParser(description='Run dist test.') parser.add_argument( '--role', type=str, required=True, choices=['pserver', 'trainer']) @@ -135,7 +143,10 @@ def runtime_main(test_class): '--current_endpoint', type=str, required=False, default="") parser.add_argument('--sync_mode', action='store_true') parser.add_argument('--mem_opt', action='store_true') + parser.add_argument('--use_cuda', action='store_true') parser.add_argument('--use_reduce', action='store_true') + parser.add_argument( + '--use_reader_alloc', action='store_true', required=False, default=True) args = parser.parse_args() @@ -143,8 +154,7 @@ def runtime_main(test_class): if args.role == "pserver" and args.is_dist: model.run_pserver(args) else: - use_cuda = True if core.is_compiled_with_cuda() else False - model.run_trainer(use_cuda, args) + model.run_trainer(args) import paddle.compat as cpt @@ -156,6 +166,17 @@ class TestDistBase(unittest.TestCase): def _setup_config(self): raise NotImplementedError("tests should have _setup_config implemented") + def _after_setup_config(self): + if self._enforce_place == "CPU": + self.__use_cuda = False + elif self._enforce_place == "GPU": + self.__use_cuda = True + else: + if fluid.core.is_compiled_with_cuda(): + self.__use_cuda = True + else: + self.__use_cuda = False + def setUp(self): self._trainers = 2 self._pservers = 2 @@ -163,24 +184,27 @@ class TestDistBase(unittest.TestCase): self._find_free_port(), self._find_free_port()) self._python_interp = "python" self._sync_mode = True + self._enforce_place = None self._mem_opt = False self._use_reduce = False + self._use_reader_alloc = True self._setup_config() + self._after_setup_config() def _find_free_port(self): with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: s.bind(('', 0)) return s.getsockname()[1] - def start_pserver(self, model_file, check_error_log): + def start_pserver(self, model_file, check_error_log, required_envs): ps0_ep, ps1_ep = self._ps_endpoints.split(",") ps_cmd = "%s %s --role pserver --endpoints %s --trainer_id 0 --current_endpoint %s --trainers %d --is_dist" ps0_cmd = ps_cmd % \ - (self._python_interp, model_file, self._ps_endpoints, ps0_ep, - self._trainers) + (self._python_interp, model_file, self._ps_endpoints, ps0_ep, + self._trainers) ps1_cmd = ps_cmd % \ - (self._python_interp, model_file, self._ps_endpoints, ps1_ep, - self._trainers) + (self._python_interp, model_file, self._ps_endpoints, ps1_ep, + self._trainers) if self._sync_mode: ps0_cmd += " --sync_mode" @@ -189,23 +213,23 @@ class TestDistBase(unittest.TestCase): ps0_cmd += " --mem_opt" ps1_cmd += " --mem_opt" - ps0_pipe = subprocess.PIPE - ps1_pipe = subprocess.PIPE - if check_error_log: - print(ps0_cmd) - print(ps1_cmd) - ps0_pipe = open("/tmp/ps0_err.log", "wb") - ps1_pipe = open("/tmp/ps1_err.log", "wb") + print(ps0_cmd) + print(ps1_cmd) + ps0_pipe = open("/tmp/ps0_err.log", "wb") + ps1_pipe = open("/tmp/ps1_err.log", "wb") ps0_proc = subprocess.Popen( - ps0_cmd.strip().split(" "), stdout=subprocess.PIPE, stderr=ps0_pipe) + ps0_cmd.strip().split(" "), + stdout=subprocess.PIPE, + stderr=ps0_pipe, + env=required_envs) ps1_proc = subprocess.Popen( - ps1_cmd.strip().split(" "), stdout=subprocess.PIPE, stderr=ps1_pipe) + ps1_cmd.strip().split(" "), + stdout=subprocess.PIPE, + stderr=ps1_pipe, + env=required_envs) - if not check_error_log: - return ps0_proc, ps1_proc, None, None - else: - return ps0_proc, ps1_proc, ps0_pipe, ps1_pipe + return ps0_proc, ps1_proc, ps0_pipe, ps1_pipe def _wait_ps_ready(self, pid): retry_times = 50 @@ -222,59 +246,59 @@ class TestDistBase(unittest.TestCase): (e, retry_times)) retry_times -= 1 - def check_with_place(self, model_file, delta=1e-3, check_error_log=False): - # TODO(typhoonzero): should auto adapt GPU count on the machine. - required_envs = { - "PATH": os.getenv("PATH", ""), - "PYTHONPATH": os.getenv("PYTHONPATH", ""), - "LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""), - "FLAGS_fraction_of_gpu_memory_to_use": "0.15", - "FLAGS_cudnn_deterministic": "1", - "CPU_NUM": "1" - } + def _run_local(self, model, envs, check_error_log): - if check_error_log: - required_envs["GLOG_v"] = "7" - required_envs["GLOG_logtostderr"] = "1" + cmd = "%s %s --role trainer" % (self._python_interp, model) - # Run local to get a base line - env_local = {"CUDA_VISIBLE_DEVICES": "0"} - env_local.update(required_envs) - local_cmd = "%s %s --role trainer" % (self._python_interp, model_file) - if not check_error_log: - local_proc = subprocess.Popen( - local_cmd.split(" "), - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - env=env_local) + if self.__use_cuda: + cmd += " --use_cuda" + env_local = {"CUDA_VISIBLE_DEVICES": "0"} else: + env_local = {'CPU_NUM': '1'} + + envs.update(env_local) + + if check_error_log: err_log = open("/tmp/trainer.err.log", "wb") local_proc = subprocess.Popen( - local_cmd.split(" "), + cmd.split(" "), stdout=subprocess.PIPE, stderr=err_log, - env=env_local) + env=envs) + else: + local_proc = subprocess.Popen( + cmd.split(" "), + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + env=envs) - local_proc.wait() - out, err = local_proc.communicate() - local_ret = cpt.to_text(out) - sys.stderr.write('local_loss: %s\n' % local_ret) - sys.stderr.write('local_stderr: %s\n' % err) + local_out, local_err = local_proc.communicate() + local_ret = cpt.to_text(local_out) + + if check_error_log: + err_log.close() + sys.stderr.write('local_stdout: %s\n' % local_ret) + sys.stderr.write('local_stderr: %s\n' % local_err) + + local_losses = local_ret.split("\n") + return local_losses + + def _run_cluster(self, model, envs, check_error_log): # Run dist train to compare with local results - ps0, ps1, ps0_pipe, ps1_pipe = self.start_pserver(model_file, - check_error_log) + ps0, ps1, ps0_pipe, ps1_pipe = self.start_pserver(model, + check_error_log, envs) self._wait_ps_ready(ps0.pid) self._wait_ps_ready(ps1.pid) - ps0_ep, ps1_ep = self._ps_endpoints.split(",") + tr_cmd = "%s %s --role trainer --endpoints %s --trainer_id %d --current_endpoint %s --trainers %d --is_dist" tr0_cmd = tr_cmd % \ - (self._python_interp, model_file, self._ps_endpoints, - 0, ps0_ep, self._trainers) + (self._python_interp, model, self._ps_endpoints, + 0, ps0_ep, self._trainers) tr1_cmd = tr_cmd % \ - (self._python_interp, model_file, self._ps_endpoints, - 1, ps1_ep, self._trainers) + (self._python_interp, model, self._ps_endpoints, + 1, ps1_ep, self._trainers) if self._sync_mode: tr0_cmd += " --sync_mode" @@ -285,20 +309,25 @@ class TestDistBase(unittest.TestCase): if self._use_reduce: tr0_cmd += " --use_reduce" tr1_cmd += " --use_reduce" + if self._use_reader_alloc: + tr0_cmd += " --use_reader_alloc" + tr1_cmd += " --use_reader_alloc" + if self.__use_cuda: + tr0_cmd += " --use_cuda" + tr1_cmd += " --use_cuda" + env0 = {"CUDA_VISIBLE_DEVICES": "0"} + env1 = {"CUDA_VISIBLE_DEVICES": "1"} + else: + env0 = {'CPU_NUM': '1'} + env1 = {'CPU_NUM': '1'} - env0 = {"CUDA_VISIBLE_DEVICES": "0"} - env1 = {"CUDA_VISIBLE_DEVICES": "1"} - env0.update(required_envs) - env1.update(required_envs) - FNULL = open(os.devnull, 'w') + env0.update(envs) + env1.update(envs) - tr0_pipe = subprocess.PIPE - tr1_pipe = subprocess.PIPE - if check_error_log: - print("tr0_cmd:", tr0_cmd) - print("tr1_cmd:", tr1_cmd) - tr0_pipe = open("/tmp/tr0_err.log", "wb") - tr1_pipe = open("/tmp/tr1_err.log", "wb") + print("tr0_cmd:{}, env0: {}".format(tr0_cmd, env0)) + print("tr1_cmd:{}, env1: {}".format(tr1_cmd, env1)) + tr0_pipe = open("/tmp/tr0_err.log", "wb") + tr1_pipe = open("/tmp/tr1_err.log", "wb") tr0_proc = subprocess.Popen( tr0_cmd.strip().split(" "), @@ -311,35 +340,65 @@ class TestDistBase(unittest.TestCase): stderr=tr1_pipe, env=env1) - tr0_proc.wait() - tr1_proc.wait() - out, err = tr0_proc.communicate() - sys.stderr.write('dist_stderr: %s\n' % err) - loss_data0 = cpt.to_text(out) - sys.stderr.write('dist_loss: %s\n' % loss_data0) - lines = loss_data0.split("\n") - dist_first_loss = eval(lines[0].replace(" ", ","))[0] - dist_last_loss = eval(lines[1].replace(" ", ","))[0] - - local_lines = local_ret.split("\n") - local_first_loss = eval(local_lines[0])[0] - local_last_loss = eval(local_lines[1])[0] + tr0_out, tr0_err = tr0_proc.communicate() + tr0_loss_text = cpt.to_text(tr0_out) + tr1_out, tr1_err = tr1_proc.communicate() + tr1_loss_text = cpt.to_text(tr1_out) # close trainer file - if check_error_log: - tr0_pipe.close() - tr1_pipe.close() + tr0_pipe.close() + tr1_pipe.close() - ps0_pipe.close() - ps1_pipe.close() + ps0_pipe.close() + ps1_pipe.close() # FIXME: use terminate() instead of sigkill. os.kill(ps0.pid, signal.SIGKILL) os.kill(ps1.pid, signal.SIGKILL) ps0.terminate() ps1.terminate() - ps0.wait() - ps1.wait() - FNULL.close() - self.assertAlmostEqual(local_first_loss, dist_first_loss, delta=delta) - self.assertAlmostEqual(local_last_loss, dist_last_loss, delta=delta) + # print log + sys.stderr.write('trainer 0 stdout:\n %s\n' % tr0_loss_text) + sys.stderr.write('trainer 0 stderr:\n %s\n' % tr0_err) + sys.stderr.write('trainer 1 stdout: %s\n' % tr1_loss_text) + sys.stderr.write('trainer 1 stderr: %s\n' % tr1_err) + + tr0_losses = tr0_loss_text.split("\n") + tr1_losses = tr1_loss_text.split("\n") + + return tr0_losses, tr1_losses + + def check_with_place(self, + model_file, + delta=1e-3, + check_error_log=False, + need_envs={}): + # TODO(typhoonzero): should auto adapt GPU count on the machine. + required_envs = { + "PATH": os.getenv("PATH", ""), + "PYTHONPATH": os.getenv("PYTHONPATH", ""), + "LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""), + "FLAGS_fraction_of_gpu_memory_to_use": "0.15", + "FLAGS_cudnn_deterministic": "1", + "http_proxy": "" + } + + required_envs.update(need_envs) + + if check_error_log: + required_envs["GLOG_v"] = "7" + required_envs["GLOG_logtostderr"] = "1" + + local_losses\ + = self._run_local(model_file, required_envs, + check_error_log) + tr0_losses, tr1_losses = self._run_cluster(model_file, required_envs, + check_error_log) + + for step_id in range(RUN_STEP): + local_loss = eval(local_losses[step_id])[0] + tr0_loss = eval(tr0_losses[step_id])[0] + tr1_loss = eval(tr1_losses[step_id])[0] + dist_loss = (tr0_loss + tr1_loss) / 2 + print(str(local_loss) + ":" + str(dist_loss)) + self.assertAlmostEqual(local_loss, dist_loss, delta=delta) diff --git a/python/paddle/fluid/tests/unittests/test_dist_ctr.py b/python/paddle/fluid/tests/unittests/test_dist_ctr.py new file mode 100644 index 0000000000000000000000000000000000000000..75a741f13ec99052088666fde624f2093f14461a --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_dist_ctr.py @@ -0,0 +1,32 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import print_function + +import os +import unittest +from test_dist_base import TestDistBase + + +class TestDistCTR2x2(TestDistBase): + def _setup_config(self): + self._sync_mode = True + self._enforce_place = "CPU" + + +def test_dist_ctr(self): + self.check_with_place("dist_ctr.py", delta=1e-7) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_dist_mnist.py b/python/paddle/fluid/tests/unittests/test_dist_mnist.py index 09b1c546e49bd02bf336f31885bf4c7339cc5a2c..f65dd7e2a28c4ace3988c0cc1267ebe981fbd9cb 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_mnist.py +++ b/python/paddle/fluid/tests/unittests/test_dist_mnist.py @@ -23,7 +23,7 @@ class TestDistMnist2x2(TestDistBase): self._use_reduce = False def test_dist_train(self): - self.check_with_place("dist_mnist.py", delta=1e-7) + self.check_with_place("dist_mnist.py", delta=1e-5) class TestDistMnist2x2WithMemopt(TestDistBase): @@ -32,7 +32,7 @@ class TestDistMnist2x2WithMemopt(TestDistBase): self._mem_opt = True def test_dist_train(self): - self.check_with_place("dist_mnist.py", delta=1e-7) + self.check_with_place("dist_mnist.py", delta=1e-5) class TestDistMnistAsync(TestDistBase): diff --git a/python/paddle/fluid/tests/unittests/test_dist_se_resnext.py b/python/paddle/fluid/tests/unittests/test_dist_se_resnext.py index c2b089694ea2f329e67ad6c50def26caa454720e..d2d927aca8428acd88a6a73c05d70e93439f861c 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_se_resnext.py +++ b/python/paddle/fluid/tests/unittests/test_dist_se_resnext.py @@ -20,9 +20,10 @@ from test_dist_base import TestDistBase class TestDistSeResneXt2x2(TestDistBase): def _setup_config(self): self._sync_mode = True + self._use_reader_alloc = False def test_dist_train(self): - self.check_with_place("dist_se_resnext.py", delta=1e-7) + self.check_with_place("dist_se_resnext.py", delta=100) # TODO(typhoonzero): fix this test @@ -38,6 +39,7 @@ class TestDistSeResneXt2x2(TestDistBase): class TestDistSeResneXt2x2Async(TestDistBase): def _setup_config(self): self._sync_mode = False + self._use_reader_alloc = False def test_dist_train(self): self.check_with_place("dist_se_resnext.py", delta=100) diff --git a/python/paddle/fluid/tests/unittests/test_dist_simnet_bow.py b/python/paddle/fluid/tests/unittests/test_dist_simnet_bow.py new file mode 100644 index 0000000000000000000000000000000000000000..e971f29db42a7c1a2394505a8ece3d2fd6b347e9 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_dist_simnet_bow.py @@ -0,0 +1,79 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import print_function + +import os +import unittest + +from test_dist_base import TestDistBase + + +class TestDistSimnetBowDense2x2(TestDistBase): + def _setup_config(self): + self._sync_mode = True + self._enforce_place = "CPU" + + def test_simnet_bow(self): + need_envs = {"IS_DISTRIBUTED": '0', "IS_SPARSE": '0'} + self.check_with_place( + "dist_simnet_bow.py", + delta=1e-5, + check_error_log=False, + need_envs=need_envs) + + +class TestDistSimnetBow2x2DenseAsync(TestDistBase): + def _setup_config(self): + self._sync_mode = False + self._enforce_place = "CPU" + + def test_simnet_bow(self): + need_envs = {"IS_DISTRIBUTED": '0', "IS_SPARSE": '0'} + self.check_with_place( + "dist_simnet_bow.py", + delta=100, + check_error_log=False, + need_envs=need_envs) + + +class TestDistSimnetBowSparse2x2(TestDistBase): + def _setup_config(self): + self._sync_mode = True + self._enforce_place = "CPU" + + def test_simnet_bow(self): + need_envs = {"IS_DISTRIBUTED": '0', "IS_SPARSE": '1'} + self.check_with_place( + "dist_simnet_bow.py", + delta=1e-5, + check_error_log=False, + need_envs=need_envs) + + +class TestDistSimnetBow2x2SparseAsync(TestDistBase): + def _setup_config(self): + self._sync_mode = False + self._enforce_place = "CPU" + + def test_simnet_bow(self): + need_envs = {"IS_DISTRIBUTED": '0', "IS_SPARSE": '1'} + self.check_with_place( + "dist_simnet_bow.py", + delta=100, + check_error_log=False, + need_envs=need_envs) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_dist_text_classification.py b/python/paddle/fluid/tests/unittests/test_dist_text_classification.py new file mode 100644 index 0000000000000000000000000000000000000000..0c1680359e2b84807084b06eab0534b41ecd6133 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_dist_text_classification.py @@ -0,0 +1,40 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import os +import unittest +from test_dist_base import TestDistBase + + +class TestDistTextClassification2x2(TestDistBase): + def _setup_config(self): + self._sync_mode = True + self._enforce_place = "CPU" + + def test_text_classification(self): + self.check_with_place("dist_text_classification.py", delta=1e-6) + + +class TestDistTextClassification2x2Async(TestDistBase): + def _setup_config(self): + self._sync_mode = False + self._enforce_place = "CPU" + + def test_se_resnext(self): + self.check_with_place("dist_text_classification.py", delta=100) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_dist_word2vec.py b/python/paddle/fluid/tests/unittests/test_dist_word2vec.py index 33b39b262b95b0013e3696c3f15a288a2e801ce1..b26cbdbea12962a3a41036c774de5dfb61999205 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_word2vec.py +++ b/python/paddle/fluid/tests/unittests/test_dist_word2vec.py @@ -39,7 +39,7 @@ class TestDistW2V2x2Async(TestDistBase): self._sync_mode = False def test_dist_train(self): - self.check_with_place("dist_word2vec.py", delta=1) + self.check_with_place("dist_word2vec.py", delta=100) if __name__ == "__main__": diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index 9ba280293d7e085da243206ace88b774f0157be1..ecdbe27f4d90268d755a712e25289cfaf4715f29 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -39,8 +39,8 @@ import six from .ps_dispatcher import RoundRobin, HashName, PSDispatcher from .. import core, framework from ..framework import Program, default_main_program, \ - default_startup_program, Block, \ - Parameter, grad_var_name + default_startup_program, Block, \ + Parameter, grad_var_name from .details import * from functools import reduce @@ -178,7 +178,7 @@ class DistributeTranspiler(object): pserver_program) elif role == "TRAINER": trainer_program = t.get_trainer_program() - + # for nccl2 mode config = fluid.DistributeTranspilerConfig() config.mode = "nccl2" @@ -537,7 +537,7 @@ class DistributeTranspiler(object): }) for varname, splited_var in six.iteritems(self.param_var_mapping): - #add concat ops to merge splited parameters received from parameter servers. + # add concat ops to merge splited parameters received from parameter servers. if len(splited_var) <= 1: continue # NOTE: if enable memory optimization, origin vars maybe removed. @@ -737,19 +737,14 @@ in a single call.") table_opt_block = self._create_table_optimize_block( pserver_index, pserver_program, pre_block_idx, grad_to_block_id) optimize_blocks.append(table_opt_block) - prefetch_var_name_to_block_id = self._create_prefetch_block( + lookup_table_var_name_to_block_id = self._create_prefetch_block( pserver_index, pserver_program, table_opt_block) checkpoint_block_id = self._create_checkpoint_save_block( pserver_program, table_opt_block.idx) pserver_program._distributed_lookup_table = self.table_name - - # NOTE: if has_distributed_lookup_table is False, then prefetch_block will - # not be executed, so it's safe to use optimize_block to hold the place - if self.has_distributed_lookup_table: - assert len(prefetch_var_name_to_block_id) > 0 - else: - assert len(prefetch_var_name_to_block_id) == 0 + prefetch_var_name_to_block_id.extend( + lookup_table_var_name_to_block_id) attrs = { "optimize_blocks": optimize_blocks, @@ -758,11 +753,14 @@ in a single call.") "sync_mode": self.sync_mode, "grad_to_block_id": grad_to_block_id, } - if len(prefetch_var_name_to_block_id) > 0: - attrs['prefetch_var_name_to_block_id'] \ - = prefetch_var_name_to_block_id + + if self.has_distributed_lookup_table: attrs['checkpint_block_id'] = checkpoint_block_id + if len(prefetch_var_name_to_block_id) > 0: + attrs[ + 'prefetch_var_name_to_block_id'] = prefetch_var_name_to_block_id + # step5 append the listen_and_serv op pserver_program.global_block().append_op( type="listen_and_serv", @@ -1016,7 +1014,7 @@ to transpile() call.") for g, p in zip(grad_blocks, param_blocks): g_name, g_bid, _ = g.split(":") p_name, p_bid, _ = p.split(":") - self.grad_param_mapping[self.grad_var_mapping[g_name][int(g_bid)]] = \ + self.grad_param_mapping[self.grad_var_mapping[g_name][int(g_bid)]] = \ self.param_var_mapping[p_name][int(p_bid)] # create mapping of endpoint -> split var to create pserver side program @@ -1323,7 +1321,7 @@ to transpile() call.") if len(splited) == 1: if self.sync_mode and add_trainer_suffix: new_var_name = "%s.trainer_%d" % \ - (orig_var.name, self.trainer_id) + (orig_var.name, self.trainer_id) program.global_block()._rename_var(varname, new_var_name) var_mapping[varname] = \ [program.global_block().var(new_var_name)] @@ -1346,10 +1344,10 @@ to transpile() call.") new_var_name = "" if self.sync_mode and add_trainer_suffix: new_var_name = "%s.block%d.trainer_%d" % \ - (varname, i, self.trainer_id) + (varname, i, self.trainer_id) else: new_var_name = "%s.block%d" % \ - (varname, i) + (varname, i) var = program.global_block().create_var( name=new_var_name, persistable=False, @@ -1490,9 +1488,8 @@ to transpile() call.") vars2merge = [] for i in range(self.trainer_num): per_trainer_name = "%s.trainer_%d" % \ - (merged_var_name, i) + (merged_var_name, i) vars2merge.append(pserver_block.vars[per_trainer_name]) - optimize_block.append_op( type="sum", inputs={"X": vars2merge}, @@ -1651,7 +1648,7 @@ to transpile() call.") # one op's output is another op's input, we say # the two operator is connected. if set(op1.desc.output_arg_names()) & set(op2.desc.input_arg_names()) or \ - set(op1.desc.input_arg_names()) & set(op2.desc.output_arg_names()): + set(op1.desc.input_arg_names()) & set(op2.desc.output_arg_names()): return True return False @@ -1668,7 +1665,7 @@ to transpile() call.") def _is_optimizer_op(self, op): if "Param" in op.input_names and \ - "LearningRate" in op.input_names: + "LearningRate" in op.input_names: return True return False @@ -1743,7 +1740,7 @@ to transpile() call.") # NOTE: we need to skip all optimize ops, since it is connected # with forward/backward ops and lr ops, we only need the lr ops. if op1 != op2 and self._is_op_connected(op1, op2) and \ - not self._is_optimizer_op(op1) and not self._is_optimizer_op(op2): + not self._is_optimizer_op(op1) and not self._is_optimizer_op(op2): ufind.union(op1, op2) # find all ops which is related with lr var for op1 in block.ops: