未验证 提交 de195518 编写于 作者: X Xin Pan 提交者: GitHub

Merge pull request #13621 from seiriosPlus/release/1.0.0

CherryPick BatchAUC/Distributed UT from develop
...@@ -269,7 +269,7 @@ paddle.fluid.layers.iou_similarity ArgSpec(args=['x', 'y', 'name'], varargs=None ...@@ -269,7 +269,7 @@ paddle.fluid.layers.iou_similarity ArgSpec(args=['x', 'y', 'name'], varargs=None
paddle.fluid.layers.box_coder ArgSpec(args=['prior_box', 'prior_box_var', 'target_box', 'code_type', 'box_normalized', 'name'], varargs=None, keywords=None, defaults=('encode_center_size', True, None)) paddle.fluid.layers.box_coder ArgSpec(args=['prior_box', 'prior_box_var', 'target_box', 'code_type', 'box_normalized', 'name'], varargs=None, keywords=None, defaults=('encode_center_size', True, None))
paddle.fluid.layers.polygon_box_transform ArgSpec(args=['input', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.polygon_box_transform ArgSpec(args=['input', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.accuracy ArgSpec(args=['input', 'label', 'k', 'correct', 'total'], varargs=None, keywords=None, defaults=(1, None, None)) paddle.fluid.layers.accuracy ArgSpec(args=['input', 'label', 'k', 'correct', 'total'], varargs=None, keywords=None, defaults=(1, None, None))
paddle.fluid.layers.auc ArgSpec(args=['input', 'label', 'curve', 'num_thresholds', 'topk'], varargs=None, keywords=None, defaults=('ROC', 4095, 1)) paddle.fluid.layers.auc ArgSpec(args=['input', 'label', 'curve', 'num_thresholds', 'topk', 'slide_steps'], varargs=None, keywords=None, defaults=('ROC', 4095, 1, 1))
paddle.fluid.layers.exponential_decay ArgSpec(args=['learning_rate', 'decay_steps', 'decay_rate', 'staircase'], varargs=None, keywords=None, defaults=(False,)) paddle.fluid.layers.exponential_decay ArgSpec(args=['learning_rate', 'decay_steps', 'decay_rate', 'staircase'], varargs=None, keywords=None, defaults=(False,))
paddle.fluid.layers.natural_exp_decay ArgSpec(args=['learning_rate', 'decay_steps', 'decay_rate', 'staircase'], varargs=None, keywords=None, defaults=(False,)) paddle.fluid.layers.natural_exp_decay ArgSpec(args=['learning_rate', 'decay_steps', 'decay_rate', 'staircase'], varargs=None, keywords=None, defaults=(False,))
paddle.fluid.layers.inverse_time_decay ArgSpec(args=['learning_rate', 'decay_steps', 'decay_rate', 'staircase'], varargs=None, keywords=None, defaults=(False,)) paddle.fluid.layers.inverse_time_decay ArgSpec(args=['learning_rate', 'decay_steps', 'decay_rate', 'staircase'], varargs=None, keywords=None, defaults=(False,))
......
...@@ -27,8 +27,11 @@ class SelectedRowsTester : public ::testing::Test { ...@@ -27,8 +27,11 @@ class SelectedRowsTester : public ::testing::Test {
selected_rows_.reset(new SelectedRows(rows, height)); selected_rows_.reset(new SelectedRows(rows, height));
Tensor* value = selected_rows_->mutable_value(); Tensor* value = selected_rows_->mutable_value();
value->mutable_data<float>( auto* data = value->mutable_data<float>(
make_ddim({static_cast<int64_t>(rows.size()), row_numel}), place_); make_ddim({static_cast<int64_t>(rows.size()), row_numel}), place_);
for (int64_t i = 0; i < value->numel(); ++i) {
data[i] = static_cast<float>(i);
}
} }
protected: protected:
...@@ -60,6 +63,10 @@ TEST_F(SelectedRowsTester, SerializeAndDeseralize) { ...@@ -60,6 +63,10 @@ TEST_F(SelectedRowsTester, SerializeAndDeseralize) {
ASSERT_EQ(selected_rows_->height(), dst_tensor.height()); ASSERT_EQ(selected_rows_->height(), dst_tensor.height());
ASSERT_EQ(selected_rows_->value().dims(), dst_tensor.value().dims()); ASSERT_EQ(selected_rows_->value().dims(), dst_tensor.value().dims());
ASSERT_EQ(selected_rows_->GetCompleteDims(), dst_tensor.GetCompleteDims()); ASSERT_EQ(selected_rows_->GetCompleteDims(), dst_tensor.GetCompleteDims());
auto* dst_data = dst_tensor.value().data<float>();
for (int64_t i = 0; i < dst_tensor.value().numel(); ++i) {
ASSERT_EQ(dst_data[i], static_cast<float>(i));
}
} }
TEST(SelectedRows, SparseTable) { TEST(SelectedRows, SparseTable) {
......
...@@ -36,11 +36,16 @@ class AucOp : public framework::OperatorWithKernel { ...@@ -36,11 +36,16 @@ class AucOp : public framework::OperatorWithKernel {
"Out and Label should have same height."); "Out and Label should have same height.");
int num_pred_buckets = ctx->Attrs().Get<int>("num_thresholds") + 1; int num_pred_buckets = ctx->Attrs().Get<int>("num_thresholds") + 1;
int slide_steps = ctx->Attrs().Get<int>("slide_steps");
PADDLE_ENFORCE_GE(num_pred_buckets, 1, "num_thresholds must larger than 1");
PADDLE_ENFORCE_GE(slide_steps, 0, "slide_steps must be natural number");
ctx->SetOutputDim("AUC", {1}); ctx->SetOutputDim("AUC", {1});
ctx->SetOutputDim("BatchAUC", {1});
ctx->SetOutputDim("StatPosOut", {num_pred_buckets}); slide_steps = slide_steps == 0 ? 1 : slide_steps;
ctx->SetOutputDim("StatNegOut", {num_pred_buckets}); ctx->SetOutputDim("StatPosOut", {slide_steps, num_pred_buckets});
ctx->SetOutputDim("StatNegOut", {slide_steps, num_pred_buckets});
} }
protected: protected:
...@@ -62,6 +67,7 @@ class AucOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -62,6 +67,7 @@ class AucOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("Label", AddInput("Label",
"A 2D int tensor indicating the label of the training data. " "A 2D int tensor indicating the label of the training data. "
"shape: [batch_size, 1]"); "shape: [batch_size, 1]");
// TODO(typhoonzero): support weight input // TODO(typhoonzero): support weight input
AddInput("StatPos", "Statistic value when label = 1"); AddInput("StatPos", "Statistic value when label = 1");
AddInput("StatNeg", "Statistic value when label = 0"); AddInput("StatNeg", "Statistic value when label = 0");
...@@ -69,18 +75,19 @@ class AucOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -69,18 +75,19 @@ class AucOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput("AUC", AddOutput("AUC",
"A scalar representing the " "A scalar representing the "
"current area-under-the-curve."); "current area-under-the-curve.");
AddOutput("BatchAUC", "The AUC for current batch");
AddOutput("StatPosOut", "Statistic value when label = 1"); AddOutput("StatPosOut", "Statistic value when label = 1");
AddOutput("StatNegOut", "Statistic value when label = 0"); AddOutput("StatNegOut", "Statistic value when label = 0");
AddAttr<std::string>("curve", "Curve type, can be 'ROC' or 'PR'.") AddAttr<std::string>("curve", "Curve type, can be 'ROC' or 'PR'.")
.SetDefault("ROC"); .SetDefault("ROC");
AddAttr<int>("num_thresholds", AddAttr<int>(
"The number of thresholds to use when discretizing the" "num_thresholds",
" roc curve.") "The number of thresholds to use when discretizing the roc curve.")
.SetDefault((2 << 12) - 1); .SetDefault((2 << 12) - 1);
AddAttr<int>("slide_steps", "Use slide steps to calc batch auc.")
.SetDefault(1);
AddComment(R"DOC( AddComment(R"DOC(
Area Under The Curve (AUC) Operator. Area Under The Curve (AUC) Operator.
......
...@@ -32,7 +32,9 @@ class AucKernel : public framework::OpKernel<T> { ...@@ -32,7 +32,9 @@ class AucKernel : public framework::OpKernel<T> {
std::string curve = ctx.Attr<std::string>("curve"); std::string curve = ctx.Attr<std::string>("curve");
int num_thresholds = ctx.Attr<int>("num_thresholds"); int num_thresholds = ctx.Attr<int>("num_thresholds");
// buckets contain numbers from 0 to num_thresholds
int num_pred_buckets = num_thresholds + 1; int num_pred_buckets = num_thresholds + 1;
int slide_steps = ctx.Attr<int>("slide_steps");
// Only use output var for now, make sure it's persistable and // Only use output var for now, make sure it's persistable and
// not cleaned up for each batch. // not cleaned up for each batch.
...@@ -40,16 +42,19 @@ class AucKernel : public framework::OpKernel<T> { ...@@ -40,16 +42,19 @@ class AucKernel : public framework::OpKernel<T> {
auto *stat_pos = ctx.Output<Tensor>("StatPosOut"); auto *stat_pos = ctx.Output<Tensor>("StatPosOut");
auto *stat_neg = ctx.Output<Tensor>("StatNegOut"); auto *stat_neg = ctx.Output<Tensor>("StatNegOut");
auto *stat_pos_data = stat_pos->mutable_data<int64_t>(ctx.GetPlace()); auto *origin_stat_pos = stat_pos->mutable_data<int64_t>(ctx.GetPlace());
auto *stat_neg_data = stat_neg->mutable_data<int64_t>(ctx.GetPlace()); auto *origin_stat_neg = stat_neg->mutable_data<int64_t>(ctx.GetPlace());
calcAuc(ctx, label, predict, stat_pos_data, stat_neg_data, num_thresholds,
auc);
auto *batch_auc = ctx.Output<Tensor>("BatchAUC"); std::vector<int64_t> stat_pos_data(num_pred_buckets, 0);
std::vector<int64_t> stat_pos_batch(num_pred_buckets, 0); std::vector<int64_t> stat_neg_data(num_pred_buckets, 0);
std::vector<int64_t> stat_neg_batch(num_pred_buckets, 0);
calcAuc(ctx, label, predict, stat_pos_batch.data(), stat_neg_batch.data(), auto stat_pos_calc = stat_pos_data.data();
num_thresholds, batch_auc); auto stat_neg_calc = stat_neg_data.data();
statAuc(label, predict, num_pred_buckets, num_thresholds, slide_steps,
origin_stat_pos, origin_stat_neg, &stat_pos_calc, &stat_neg_calc);
calcAuc(ctx, stat_pos_calc, stat_neg_calc, num_thresholds, auc);
} }
private: private:
...@@ -58,29 +63,76 @@ class AucKernel : public framework::OpKernel<T> { ...@@ -58,29 +63,76 @@ class AucKernel : public framework::OpKernel<T> {
return (X1 > X2 ? (X1 - X2) : (X2 - X1)) * (Y1 + Y2) / 2.0; return (X1 > X2 ? (X1 - X2) : (X2 - X1)) * (Y1 + Y2) / 2.0;
} }
inline static void calcAuc(const framework::ExecutionContext &ctx, inline static void statAuc(const framework::Tensor *label,
const framework::Tensor *label,
const framework::Tensor *predict, const framework::Tensor *predict,
int64_t *stat_pos, int64_t *stat_neg, const int num_pred_buckets,
int num_thresholds, const int num_thresholds, const int slide_steps,
framework::Tensor *auc_tensor) { int64_t *origin_stat_pos, int64_t *origin_stat_neg,
int64_t **stat_pos, int64_t **stat_neg) {
size_t batch_size = predict->dims()[0]; size_t batch_size = predict->dims()[0];
size_t inference_width = predict->dims()[1]; size_t inference_width = predict->dims()[1];
const T *inference_data = predict->data<T>(); const T *inference_data = predict->data<T>();
const auto *label_data = label->data<int64_t>(); const auto *label_data = label->data<int64_t>();
auto *auc = auc_tensor->mutable_data<double>(ctx.GetPlace());
for (size_t i = 0; i < batch_size; i++) { for (size_t i = 0; i < batch_size; i++) {
uint32_t binIdx = static_cast<uint32_t>( uint32_t binIdx = static_cast<uint32_t>(
inference_data[i * inference_width + 1] * num_thresholds); inference_data[i * inference_width + 1] * num_thresholds);
if (label_data[i]) { if (label_data[i]) {
stat_pos[binIdx] += 1.0; (*stat_pos)[binIdx] += 1.0;
} else {
(*stat_neg)[binIdx] += 1.0;
}
}
int bucket_length = num_pred_buckets * sizeof(int64_t);
// will stat auc unlimited.
if (slide_steps == 0) {
for (int slide = 0; slide < num_pred_buckets; ++slide) {
origin_stat_pos[slide] += (*stat_pos)[slide];
origin_stat_neg[slide] += (*stat_neg)[slide];
}
*stat_pos = origin_stat_pos;
*stat_neg = origin_stat_neg;
} else { } else {
stat_neg[binIdx] += 1.0; for (int slide = 1; slide < slide_steps; ++slide) {
int dst_idx = (slide - 1) * num_pred_buckets;
int src_inx = slide * num_pred_buckets;
std::memcpy(origin_stat_pos + dst_idx, origin_stat_pos + src_inx,
bucket_length);
std::memcpy(origin_stat_neg + dst_idx, origin_stat_neg + src_inx,
bucket_length);
}
std::memcpy(origin_stat_pos + (slide_steps - 1) * num_pred_buckets,
*stat_pos, bucket_length);
std::memcpy(origin_stat_neg + (slide_steps - 1) * num_pred_buckets,
*stat_neg, bucket_length);
std::memset(*stat_pos, 0, bucket_length);
std::memset(*stat_neg, 0, bucket_length);
for (int slide = 0; slide < num_pred_buckets; ++slide) {
int stat_pos_steps = 0;
int stat_neg_steps = 0;
for (int step = 0; step < slide_steps; ++step) {
stat_pos_steps += origin_stat_pos[slide + step * num_pred_buckets];
stat_neg_steps += origin_stat_neg[slide + step * num_pred_buckets];
}
(*stat_pos)[slide] += stat_pos_steps;
(*stat_neg)[slide] += stat_neg_steps;
}
} }
} }
inline static void calcAuc(const framework::ExecutionContext &ctx,
int64_t *stat_pos, int64_t *stat_neg,
int num_thresholds,
framework::Tensor *auc_tensor) {
auto *auc = auc_tensor->mutable_data<double>(ctx.GetPlace());
*auc = 0.0f; *auc = 0.0f;
double totPos = 0.0; double totPos = 0.0;
...@@ -96,7 +148,6 @@ class AucKernel : public framework::OpKernel<T> { ...@@ -96,7 +148,6 @@ class AucKernel : public framework::OpKernel<T> {
totPos += stat_pos[idx]; totPos += stat_pos[idx];
totNeg += stat_neg[idx]; totNeg += stat_neg[idx];
*auc += trapezoidArea(totNeg, totNegPrev, totPos, totPosPrev); *auc += trapezoidArea(totNeg, totNegPrev, totPos, totPosPrev);
--idx; --idx;
} }
......
...@@ -77,9 +77,11 @@ class ScaleOpVarTypeInference : public framework::VarTypeInference { ...@@ -77,9 +77,11 @@ class ScaleOpVarTypeInference : public framework::VarTypeInference {
auto out_var_name = op_desc.Output("Out").front(); auto out_var_name = op_desc.Output("Out").front();
auto *out_var = block->FindVarRecursive(out_var_name); auto *out_var = block->FindVarRecursive(out_var_name);
if (in_var_name != out_var_name) {
out_var->SetType(in_var.GetType()); out_var->SetType(in_var.GetType());
out_var->SetDataType(in_var.GetDataType()); out_var->SetDataType(in_var.GetDataType());
} }
}
}; };
class ScaleGradMaker : public framework::SingleGradOpDescMaker { class ScaleGradMaker : public framework::SingleGradOpDescMaker {
......
...@@ -32,7 +32,7 @@ class SumKernel : public framework::OpKernel<T> { ...@@ -32,7 +32,7 @@ class SumKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &context) const override { void Compute(const framework::ExecutionContext &context) const override {
auto in_vars = context.MultiInputVar("X"); auto in_vars = context.MultiInputVar("X");
int N = in_vars.size(); size_t in_num = in_vars.size();
auto out_var = context.OutputVar("Out"); auto out_var = context.OutputVar("Out");
bool in_place = out_var == in_vars[0]; bool in_place = out_var == in_vars[0];
...@@ -53,7 +53,7 @@ class SumKernel : public framework::OpKernel<T> { ...@@ -53,7 +53,7 @@ class SumKernel : public framework::OpKernel<T> {
auto &place = auto &place =
*context.template device_context<DeviceContext>().eigen_device(); *context.template device_context<DeviceContext>().eigen_device();
// If in_place, just skip the first tensor // If in_place, just skip the first tensor
for (int i = in_place ? 1 : 0; i < N; i++) { for (size_t i = in_place ? 1 : 0; i < in_num; i++) {
if (in_vars[i]->IsType<framework::LoDTensor>()) { if (in_vars[i]->IsType<framework::LoDTensor>()) {
auto &in_t = in_vars[i]->Get<framework::LoDTensor>(); auto &in_t = in_vars[i]->Get<framework::LoDTensor>();
if (in_t.numel() == 0) { if (in_t.numel() == 0) {
...@@ -101,13 +101,13 @@ class SumKernel : public framework::OpKernel<T> { ...@@ -101,13 +101,13 @@ class SumKernel : public framework::OpKernel<T> {
// Runtime InferShape // Runtime InferShape
size_t first_dim = 0; size_t first_dim = 0;
for (int i = 0; i < N; i++) { for (size_t i = 0; i < in_num; i++) {
auto &sel_row = get_selected_row(i); auto &sel_row = get_selected_row(i);
first_dim += sel_row.rows().size(); first_dim += sel_row.rows().size();
} }
std::vector<int64_t> in_dim; std::vector<int64_t> in_dim;
for (int i = 0; i < N; i++) { for (size_t i = 0; i < in_num; i++) {
auto &sel_row = get_selected_row(i); auto &sel_row = get_selected_row(i);
if (sel_row.rows().size() > 0) { if (sel_row.rows().size() > 0) {
in_dim = framework::vectorize(sel_row.value().dims()); in_dim = framework::vectorize(sel_row.value().dims());
...@@ -116,7 +116,8 @@ class SumKernel : public framework::OpKernel<T> { ...@@ -116,7 +116,8 @@ class SumKernel : public framework::OpKernel<T> {
} }
if (in_dim.empty()) { if (in_dim.empty()) {
VLOG(3) << "WARNING: all the inputs are empty"; VLOG(3) << "WARNING: all the inputs are empty";
in_dim = framework::vectorize(get_selected_row(N - 1).value().dims()); in_dim =
framework::vectorize(get_selected_row(in_num - 1).value().dims());
} else { } else {
in_dim[0] = static_cast<int64_t>(first_dim); in_dim[0] = static_cast<int64_t>(first_dim);
} }
...@@ -133,7 +134,7 @@ class SumKernel : public framework::OpKernel<T> { ...@@ -133,7 +134,7 @@ class SumKernel : public framework::OpKernel<T> {
math::SelectedRowsAddTo<DeviceContext, T> functor; math::SelectedRowsAddTo<DeviceContext, T> functor;
int64_t offset = 0; int64_t offset = 0;
for (int i = 0; i < N; i++) { for (size_t i = 0; i < in_num; i++) {
auto &sel_row = get_selected_row(i); auto &sel_row = get_selected_row(i);
if (sel_row.rows().size() == 0) { if (sel_row.rows().size() == 0) {
continue; continue;
......
...@@ -77,13 +77,14 @@ def download(url, module_name, md5sum, save_name=None): ...@@ -77,13 +77,14 @@ def download(url, module_name, md5sum, save_name=None):
retry_limit = 3 retry_limit = 3
while not (os.path.exists(filename) and md5file(filename) == md5sum): while not (os.path.exists(filename) and md5file(filename) == md5sum):
if os.path.exists(filename): if os.path.exists(filename):
print("file md5", md5file(filename), md5sum) sys.stderr.write("file %s md5 %s" % (md5file(filename), md5sum))
if retry < retry_limit: if retry < retry_limit:
retry += 1 retry += 1
else: else:
raise RuntimeError("Cannot download {0} within retry limit {1}". raise RuntimeError("Cannot download {0} within retry limit {1}".
format(url, retry_limit)) format(url, retry_limit))
print("Cache file %s not found, downloading %s" % (filename, url)) sys.stderr.write("Cache file %s not found, downloading %s" %
(filename, url))
r = requests.get(url, stream=True) r = requests.get(url, stream=True)
total_length = r.headers.get('content-length') total_length = r.headers.get('content-length')
...@@ -100,10 +101,11 @@ def download(url, module_name, md5sum, save_name=None): ...@@ -100,10 +101,11 @@ def download(url, module_name, md5sum, save_name=None):
dl += len(data) dl += len(data)
f.write(data) f.write(data)
done = int(50 * dl / total_length) done = int(50 * dl / total_length)
sys.stdout.write("\r[%s%s]" % ('=' * done, sys.stderr.write("\r[%s%s]" % ('=' * done,
' ' * (50 - done))) ' ' * (50 - done)))
sys.stdout.flush() sys.stdout.flush()
sys.stderr.write("\n")
sys.stdout.flush()
return filename return filename
......
...@@ -78,7 +78,12 @@ def accuracy(input, label, k=1, correct=None, total=None): ...@@ -78,7 +78,12 @@ def accuracy(input, label, k=1, correct=None, total=None):
return acc_out return acc_out
def auc(input, label, curve='ROC', num_thresholds=2**12 - 1, topk=1): def auc(input,
label,
curve='ROC',
num_thresholds=2**12 - 1,
topk=1,
slide_steps=1):
""" """
**Area Under the Curve (AUC) Layer** **Area Under the Curve (AUC) Layer**
...@@ -105,6 +110,8 @@ def auc(input, label, curve='ROC', num_thresholds=2**12 - 1, topk=1): ...@@ -105,6 +110,8 @@ def auc(input, label, curve='ROC', num_thresholds=2**12 - 1, topk=1):
num_thresholds(int): The number of thresholds to use when discretizing num_thresholds(int): The number of thresholds to use when discretizing
the roc curve. Default 200. the roc curve. Default 200.
topk(int): only topk number of prediction output will be used for auc. topk(int): only topk number of prediction output will be used for auc.
slide_steps: when calc batch auc, we can not only use step currently but the previous steps can be used. slide_steps=1 means use the current step, slide_steps=3 means use current step and the previous second steps, slide_steps=0 use all of the steps.
Returns: Returns:
Variable: A scalar representing the current AUC. Variable: A scalar representing the current AUC.
...@@ -120,16 +127,48 @@ def auc(input, label, curve='ROC', num_thresholds=2**12 - 1, topk=1): ...@@ -120,16 +127,48 @@ def auc(input, label, curve='ROC', num_thresholds=2**12 - 1, topk=1):
auc_out = helper.create_tmp_variable(dtype="float64") auc_out = helper.create_tmp_variable(dtype="float64")
batch_auc_out = helper.create_tmp_variable(dtype="float64") batch_auc_out = helper.create_tmp_variable(dtype="float64")
# make tp, tn, fp, fn persistable, so that can accumulate all batches. # make tp, tn, fp, fn persistable, so that can accumulate all batches.
# for batch auc
batch_stat_pos = helper.create_global_variable(
persistable=True,
dtype='int64',
shape=[slide_steps, num_thresholds + 1])
batch_stat_neg = helper.create_global_variable(
persistable=True,
dtype='int64',
shape=[slide_steps, num_thresholds + 1])
# for global auc
stat_pos = helper.create_global_variable( stat_pos = helper.create_global_variable(
persistable=True, dtype='int64', shape=[num_thresholds + 1]) persistable=True, dtype='int64', shape=[1, num_thresholds + 1])
stat_neg = helper.create_global_variable( stat_neg = helper.create_global_variable(
persistable=True, dtype='int64', shape=[num_thresholds + 1]) persistable=True, dtype='int64', shape=[1, num_thresholds + 1])
for var in [stat_pos, stat_neg]: for var in [batch_stat_pos, batch_stat_neg, stat_pos, stat_neg]:
helper.set_variable_initializer( helper.set_variable_initializer(
var, Constant( var, Constant(
value=0.0, force_cpu=True)) value=0.0, force_cpu=True))
# Batch AUC
helper.append_op(
type="auc",
inputs={
"Predict": [input],
"Label": [label],
"StatPos": [batch_stat_pos],
"StatNeg": [batch_stat_neg]
},
attrs={
"curve": curve,
"num_thresholds": num_thresholds,
"slide_steps": slide_steps
},
outputs={
"AUC": [batch_auc_out],
"StatPosOut": [batch_stat_pos],
"StatNegOut": [batch_stat_neg]
})
# Global AUC
helper.append_op( helper.append_op(
type="auc", type="auc",
inputs={ inputs={
...@@ -138,12 +177,16 @@ def auc(input, label, curve='ROC', num_thresholds=2**12 - 1, topk=1): ...@@ -138,12 +177,16 @@ def auc(input, label, curve='ROC', num_thresholds=2**12 - 1, topk=1):
"StatPos": [stat_pos], "StatPos": [stat_pos],
"StatNeg": [stat_neg] "StatNeg": [stat_neg]
}, },
attrs={"curve": curve, attrs={
"num_thresholds": num_thresholds}, "curve": curve,
"num_thresholds": num_thresholds,
"slide_steps": 0
},
outputs={ outputs={
"AUC": [auc_out], "AUC": [auc_out],
"BatchAUC": [batch_auc_out],
"StatPosOut": [stat_pos], "StatPosOut": [stat_pos],
"StatNegOut": [stat_neg] "StatNegOut": [stat_neg]
}) })
return auc_out, batch_auc_out, [stat_pos, stat_neg] return auc_out, batch_auc_out, [
batch_stat_pos, batch_stat_neg, stat_pos, stat_neg
]
...@@ -17,6 +17,9 @@ if(NOT WITH_DISTRIBUTE) ...@@ -17,6 +17,9 @@ if(NOT WITH_DISTRIBUTE)
list(REMOVE_ITEM TEST_OPS test_listen_and_serv_op) list(REMOVE_ITEM TEST_OPS test_listen_and_serv_op)
LIST(REMOVE_ITEM TEST_OPS test_dist_mnist) LIST(REMOVE_ITEM TEST_OPS test_dist_mnist)
LIST(REMOVE_ITEM TEST_OPS test_dist_word2vec) LIST(REMOVE_ITEM TEST_OPS test_dist_word2vec)
LIST(REMOVE_ITEM TEST_OPS test_dist_ctr)
LIST(REMOVE_ITEM TEST_OPS test_dist_simnet_bow)
LIST(REMOVE_ITEM TEST_OPS test_dist_text_classification)
endif(NOT WITH_DISTRIBUTE) endif(NOT WITH_DISTRIBUTE)
list(REMOVE_ITEM TEST_OPS test_seq_concat_op) # FIXME(helin): https://github.com/PaddlePaddle/Paddle/issues/8290 list(REMOVE_ITEM TEST_OPS test_seq_concat_op) # FIXME(helin): https://github.com/PaddlePaddle/Paddle/issues/8290
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import paddle
import paddle.fluid as fluid
import dist_ctr_reader
from test_dist_base import TestDistRunnerBase, runtime_main
IS_SPARSE = True
# Fix seed for test
fluid.default_startup_program().random_seed = 1
fluid.default_main_program().random_seed = 1
class TestDistCTR2x2(TestDistRunnerBase):
def get_model(self, batch_size=2):
dnn_input_dim, lr_input_dim = dist_ctr_reader.load_data_meta()
""" network definition """
dnn_data = fluid.layers.data(
name="dnn_data",
shape=[-1, 1],
dtype="int64",
lod_level=1,
append_batch_size=False)
lr_data = fluid.layers.data(
name="lr_data",
shape=[-1, 1],
dtype="int64",
lod_level=1,
append_batch_size=False)
label = fluid.layers.data(
name="click",
shape=[-1, 1],
dtype="int64",
lod_level=0,
append_batch_size=False)
# build dnn model
dnn_layer_dims = [128, 64, 32, 1]
dnn_embedding = fluid.layers.embedding(
is_distributed=False,
input=dnn_data,
size=[dnn_input_dim, dnn_layer_dims[0]],
param_attr=fluid.ParamAttr(
name="deep_embedding",
initializer=fluid.initializer.Constant(value=0.01)),
is_sparse=IS_SPARSE)
dnn_pool = fluid.layers.sequence_pool(
input=dnn_embedding, pool_type="sum")
dnn_out = dnn_pool
for i, dim in enumerate(dnn_layer_dims[1:]):
fc = fluid.layers.fc(
input=dnn_out,
size=dim,
act="relu",
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=0.01)),
name='dnn-fc-%d' % i)
dnn_out = fc
# build lr model
lr_embbding = fluid.layers.embedding(
is_distributed=False,
input=lr_data,
size=[lr_input_dim, 1],
param_attr=fluid.ParamAttr(
name="wide_embedding",
initializer=fluid.initializer.Constant(value=0.01)),
is_sparse=IS_SPARSE)
lr_pool = fluid.layers.sequence_pool(input=lr_embbding, pool_type="sum")
merge_layer = fluid.layers.concat(input=[dnn_out, lr_pool], axis=1)
predict = fluid.layers.fc(input=merge_layer, size=2, act='softmax')
acc = fluid.layers.accuracy(input=predict, label=label)
auc_var, batch_auc_var, auc_states = fluid.layers.auc(input=predict,
label=label)
cost = fluid.layers.cross_entropy(input=predict, label=label)
avg_cost = fluid.layers.mean(x=cost)
inference_program = paddle.fluid.default_main_program().clone()
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.0001)
sgd_optimizer.minimize(avg_cost)
dataset = dist_ctr_reader.Dataset()
train_reader = paddle.batch(dataset.train(), batch_size=batch_size)
test_reader = paddle.batch(dataset.test(), batch_size=batch_size)
return inference_program, avg_cost, train_reader, test_reader, None, predict
if __name__ == "__main__":
runtime_main(TestDistCTR2x2)
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import paddle
import tarfile
logging.basicConfig()
logger = logging.getLogger("paddle")
logger.setLevel(logging.INFO)
DATA_URL = "http://paddle-ctr-data.cdn.bcebos.com/avazu_ctr_data.tgz"
DATA_MD5 = "c11df99fbd14e53cd4bfa6567344b26e"
"""
avazu_ctr_data/train.txt
avazu_ctr_data/infer.txt
avazu_ctr_data/test.txt
avazu_ctr_data/data.meta.txt
"""
def read_data(file_name):
path = paddle.dataset.common.download(DATA_URL, "avazu_ctr_data", DATA_MD5)
tar = tarfile.open(path, "r:gz")
tar_info = None
for member in tar.getmembers():
if member.name.endswith(file_name):
tar_info = member
f = tar.extractfile(tar_info)
ret_lines = [_.decode('utf-8') for _ in f.readlines()]
return ret_lines
class TaskMode:
TRAIN_MODE = 0
TEST_MODE = 1
INFER_MODE = 2
def __init__(self, mode):
self.mode = mode
def is_train(self):
return self.mode == self.TRAIN_MODE
def is_test(self):
return self.mode == self.TEST_MODE
def is_infer(self):
return self.mode == self.INFER_MODE
@staticmethod
def create_train():
return TaskMode(TaskMode.TRAIN_MODE)
@staticmethod
def create_test():
return TaskMode(TaskMode.TEST_MODE)
@staticmethod
def create_infer():
return TaskMode(TaskMode.INFER_MODE)
class ModelType:
CLASSIFICATION = 0
REGRESSION = 1
def __init__(self, mode):
self.mode = mode
def is_classification(self):
return self.mode == self.CLASSIFICATION
def is_regression(self):
return self.mode == self.REGRESSION
@staticmethod
def create_classification():
return ModelType(ModelType.CLASSIFICATION)
@staticmethod
def create_regression():
return ModelType(ModelType.REGRESSION)
def load_dnn_input_record(sent):
return list(map(int, sent.split()))
def load_lr_input_record(sent):
res = []
for _ in [x.split(':') for x in sent.split()]:
res.append(int(_[0]))
return res
feeding_index = {'dnn_input': 0, 'lr_input': 1, 'click': 2}
class Dataset(object):
def train(self):
'''
Load trainset.
'''
file_name = "train.txt"
logger.info("load trainset from %s" % file_name)
mode = TaskMode.create_train()
return self._parse_creator(file_name, mode)
def test(self):
'''
Load testset.
'''
file_name = "test.txt"
logger.info("load testset from %s" % file_name)
mode = TaskMode.create_test()
return self._parse_creator(file_name, mode)
def infer(self):
'''
Load infer set.
'''
file_name = "infer.txt"
logger.info("load inferset from %s" % file_name)
mode = TaskMode.create_infer()
return self._parse_creator(file_name, mode)
def _parse_creator(self, file_name, mode):
'''
Parse dataset.
'''
def _parse():
data = read_data(file_name)
for line_id, line in enumerate(data):
fs = line.strip().split('\t')
dnn_input = load_dnn_input_record(fs[0])
lr_input = load_lr_input_record(fs[1])
if not mode.is_infer():
click = int(fs[2])
yield [dnn_input, lr_input, click]
else:
yield [dnn_input, lr_input]
return _parse
def load_data_meta():
'''
load data meta info from path, return (dnn_input_dim, lr_input_dim)
'''
lines = read_data('data.meta.txt')
err_info = "wrong meta format"
assert len(lines) == 2, err_info
assert 'dnn_input_dim:' in lines[0] and 'lr_input_dim:' in lines[
1], err_info
res = map(int, [_.split(':')[1] for _ in lines])
res = list(res)
logger.info('dnn input dim: %d' % res[0])
logger.info('lr input dim: %d' % res[1])
return res
...@@ -47,7 +47,7 @@ def cnn_model(data): ...@@ -47,7 +47,7 @@ def cnn_model(data):
pool_stride=2, pool_stride=2,
act="relu", act="relu",
param_attr=fluid.ParamAttr(initializer=fluid.initializer.Constant( param_attr=fluid.ParamAttr(initializer=fluid.initializer.Constant(
value=0.3))) value=0.01)))
conv_pool_2 = fluid.nets.simple_img_conv_pool( conv_pool_2 = fluid.nets.simple_img_conv_pool(
input=conv_pool_1, input=conv_pool_1,
filter_size=5, filter_size=5,
...@@ -56,7 +56,7 @@ def cnn_model(data): ...@@ -56,7 +56,7 @@ def cnn_model(data):
pool_stride=2, pool_stride=2,
act="relu", act="relu",
param_attr=fluid.ParamAttr(initializer=fluid.initializer.Constant( param_attr=fluid.ParamAttr(initializer=fluid.initializer.Constant(
value=0.2))) value=0.01)))
SIZE = 10 SIZE = 10
input_shape = conv_pool_2.shape input_shape = conv_pool_2.shape
...@@ -68,7 +68,7 @@ def cnn_model(data): ...@@ -68,7 +68,7 @@ def cnn_model(data):
size=SIZE, size=SIZE,
act="softmax", act="softmax",
param_attr=fluid.param_attr.ParamAttr( param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Constant(value=0.1))) initializer=fluid.initializer.Constant(value=0.01)))
return predict return predict
......
...@@ -247,7 +247,7 @@ class DistSeResneXt2x2(TestDistRunnerBase): ...@@ -247,7 +247,7 @@ class DistSeResneXt2x2(TestDistRunnerBase):
# Reader # Reader
train_reader = paddle.batch( train_reader = paddle.batch(
paddle.dataset.flowers.train(), batch_size=batch_size) paddle.dataset.flowers.test(use_xmap=False), batch_size=batch_size)
test_reader = paddle.batch( test_reader = paddle.batch(
paddle.dataset.flowers.test(use_xmap=False), batch_size=batch_size) paddle.dataset.flowers.test(use_xmap=False), batch_size=batch_size)
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import argparse
import time
import math
import random
import paddle
import paddle.fluid as fluid
import paddle.fluid.profiler as profiler
from paddle.fluid import core
import unittest
from multiprocessing import Process
import os
import signal
from functools import reduce
from test_dist_base import TestDistRunnerBase, runtime_main
DTYPE = "int64"
DATA_URL = 'http://paddle-dist-ce-data.bj.bcebos.com/simnet.train.1000'
DATA_MD5 = '24e49366eb0611c552667989de2f57d5'
# For Net
base_lr = 0.2
emb_lr = base_lr * 3
dict_dim = 1500
emb_dim = 128
hid_dim = 128
margin = 0.1
sample_rate = 1
# Fix seed for test
fluid.default_startup_program().random_seed = 1
fluid.default_main_program().random_seed = 1
def get_acc(cos_q_nt, cos_q_pt, batch_size):
cond = fluid.layers.less_than(cos_q_nt, cos_q_pt)
cond = fluid.layers.cast(cond, dtype='float64')
cond_3 = fluid.layers.reduce_sum(cond)
acc = fluid.layers.elementwise_div(
cond_3,
fluid.layers.fill_constant(
shape=[1], value=batch_size * 1.0, dtype='float64'),
name="simnet_acc")
return acc
def get_loss(cos_q_pt, cos_q_nt):
loss_op1 = fluid.layers.elementwise_sub(
fluid.layers.fill_constant_batch_size_like(
input=cos_q_pt, shape=[-1, 1], value=margin, dtype='float32'),
cos_q_pt)
loss_op2 = fluid.layers.elementwise_add(loss_op1, cos_q_nt)
loss_op3 = fluid.layers.elementwise_max(
fluid.layers.fill_constant_batch_size_like(
input=loss_op2, shape=[-1, 1], value=0.0, dtype='float32'),
loss_op2)
avg_cost = fluid.layers.mean(loss_op3)
return avg_cost
def get_optimizer():
# SGD optimizer
optimizer = fluid.optimizer.SGD(learning_rate=base_lr)
return optimizer
def train_network(batch_size, is_distributed=False, is_sparse=False):
# query
q = fluid.layers.data(
name="query_ids", shape=[1], dtype="int64", lod_level=1)
## embedding
q_emb = fluid.layers.embedding(
input=q,
is_distributed=is_distributed,
size=[dict_dim, emb_dim],
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=0.01),
name="__emb__",
learning_rate=emb_lr),
is_sparse=is_sparse)
## vsum
q_sum = fluid.layers.sequence_pool(input=q_emb, pool_type='sum')
q_ss = fluid.layers.softsign(q_sum)
## fc layer after conv
q_fc = fluid.layers.fc(
input=q_ss,
size=hid_dim,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=0.01),
name="__q_fc__",
learning_rate=base_lr))
# label data
label = fluid.layers.data(name="label", shape=[1], dtype="int64")
# pt
pt = fluid.layers.data(
name="pos_title_ids", shape=[1], dtype="int64", lod_level=1)
## embedding
pt_emb = fluid.layers.embedding(
input=pt,
is_distributed=is_distributed,
size=[dict_dim, emb_dim],
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=0.01),
name="__emb__",
learning_rate=emb_lr),
is_sparse=is_sparse)
## vsum
pt_sum = fluid.layers.sequence_pool(input=pt_emb, pool_type='sum')
pt_ss = fluid.layers.softsign(pt_sum)
## fc layer
pt_fc = fluid.layers.fc(
input=pt_ss,
size=hid_dim,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=0.01),
name="__fc__",
learning_rate=base_lr),
bias_attr=fluid.ParamAttr(name="__fc_b__"))
# nt
nt = fluid.layers.data(
name="neg_title_ids", shape=[1], dtype="int64", lod_level=1)
## embedding
nt_emb = fluid.layers.embedding(
input=nt,
is_distributed=is_distributed,
size=[dict_dim, emb_dim],
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=0.01),
name="__emb__",
learning_rate=emb_lr),
is_sparse=is_sparse)
## vsum
nt_sum = fluid.layers.sequence_pool(input=nt_emb, pool_type='sum')
nt_ss = fluid.layers.softsign(nt_sum)
## fc layer
nt_fc = fluid.layers.fc(
input=nt_ss,
size=hid_dim,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=0.01),
name="__fc__",
learning_rate=base_lr),
bias_attr=fluid.ParamAttr(name="__fc_b__"))
cos_q_pt = fluid.layers.cos_sim(q_fc, pt_fc)
cos_q_nt = fluid.layers.cos_sim(q_fc, nt_fc)
# loss
avg_cost = get_loss(cos_q_pt, cos_q_nt)
# acc
acc = get_acc(cos_q_nt, cos_q_pt, batch_size)
return [avg_cost, acc, cos_q_pt]
def combination(x, y):
res = [[[xi, yi] for yi in y] for xi in x]
return res[0]
def get_one_data(file_list):
for file in file_list:
contents = []
with open(file, "r") as fin:
for i in fin:
contents.append(i.strip())
for index, q in enumerate(contents):
try:
one_data = [[int(j) for j in i.split(" ")]
for i in q.split(";")[:-1]]
if one_data[1][0] + one_data[1][1] != len(one_data) - 3:
q = fin.readline()
continue
tmp = combination(one_data[3:3 + one_data[1][0]],
one_data[3 + one_data[1][0]:])
except Exception as e:
continue
for each in tmp:
yield [one_data[2], 0, each[0], each[1]]
def get_batch_reader(file_list, batch_size):
def batch_reader():
res = []
for i in get_one_data(file_list):
if random.random() <= sample_rate:
res.append(i)
if len(res) >= batch_size:
yield res
res = []
return batch_reader
def get_train_reader(batch_size):
# The training data set.
train_file = os.path.join(paddle.dataset.common.DATA_HOME, "simnet",
"train")
train_reader = get_batch_reader([train_file], batch_size)
train_feed = ["query_ids", "pos_title_ids", "neg_title_ids", "label"]
return train_reader, train_feed
class TestDistSimnetBow2x2(TestDistRunnerBase):
def get_model(self, batch_size=2):
# Train program
avg_cost, acc, predict = \
train_network(batch_size, bool(int(os.environ["IS_DISTRIBUTED"])), bool(int(os.environ["IS_SPARSE"])))
inference_program = fluid.default_main_program().clone()
# Optimization
opt = get_optimizer()
opt.minimize(avg_cost)
# Reader
train_reader, _ = get_train_reader(batch_size)
return inference_program, avg_cost, train_reader, train_reader, acc, predict
if __name__ == "__main__":
paddle.dataset.common.download(DATA_URL, 'simnet', DATA_MD5, "train")
runtime_main(TestDistSimnetBow2x2)
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import argparse
import time
import math
import paddle
import paddle.fluid as fluid
import paddle.fluid.profiler as profiler
from paddle.fluid import core
import unittest
from multiprocessing import Process
import os
import signal
import six
import tarfile
import string
import re
from functools import reduce
from test_dist_base import TestDistRunnerBase, runtime_main
DTYPE = "float32"
VOCAB_URL = 'http://paddle-dist-ce-data.bj.bcebos.com/imdb.vocab'
VOCAB_MD5 = '23c86a0533c0151b6f12fa52b106dcc2'
DATA_URL = 'http://paddle-dist-ce-data.bj.bcebos.com/text_classification.tar.gz'
DATA_MD5 = '29ebfc94f11aea9362bbb7f5e9d86b8a'
# Load dictionary.
def load_vocab(filename):
vocab = {}
if six.PY2:
with open(filename, 'r') as f:
for idx, line in enumerate(f):
vocab[line.strip()] = idx
else:
with open(filename, 'r', encoding="utf-8") as f:
for idx, line in enumerate(f):
vocab[line.strip()] = idx
return vocab
def get_worddict(dict_path):
word_dict = load_vocab(dict_path)
word_dict["<unk>"] = len(word_dict)
dict_dim = len(word_dict)
return word_dict, dict_dim
def conv_net(input,
dict_dim,
emb_dim=128,
window_size=3,
num_filters=128,
fc0_dim=96,
class_dim=2):
emb = fluid.layers.embedding(
input=input,
size=[dict_dim, emb_dim],
is_sparse=False,
param_attr=fluid.ParamAttr(initializer=fluid.initializer.Constant(
value=0.01)))
conv_3 = fluid.nets.sequence_conv_pool(
input=emb,
num_filters=num_filters,
filter_size=window_size,
act="tanh",
pool_type="max",
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=0.01)))
fc_0 = fluid.layers.fc(
input=[conv_3],
size=fc0_dim,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=0.01)))
prediction = fluid.layers.fc(
input=[fc_0],
size=class_dim,
act="softmax",
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=0.01)))
return prediction
def inference_network(dict_dim):
data = fluid.layers.data(
name="words", shape=[1], dtype="int64", lod_level=1)
out = conv_net(data, dict_dim)
return out
def get_reader(word_dict, batch_size):
# The training data set.
train_reader = paddle.batch(train(word_dict), batch_size=batch_size)
# The testing data set.
test_reader = paddle.batch(test(word_dict), batch_size=batch_size)
return train_reader, test_reader
def get_optimizer(learning_rate):
optimizer = fluid.optimizer.SGD(learning_rate=learning_rate)
return optimizer
class TestDistTextClassification2x2(TestDistRunnerBase):
def get_model(self, batch_size=2):
vocab = os.path.join(paddle.dataset.common.DATA_HOME,
"text_classification", "imdb.vocab")
word_dict, dict_dim = get_worddict(vocab)
# Input data
data = fluid.layers.data(
name="words", shape=[1], dtype="int64", lod_level=1)
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
# Train program
predict = conv_net(data, dict_dim)
cost = fluid.layers.cross_entropy(input=predict, label=label)
avg_cost = fluid.layers.mean(x=cost)
acc = fluid.layers.accuracy(input=predict, label=label)
inference_program = fluid.default_main_program().clone()
# Optimization
opt = get_optimizer(learning_rate=0.001)
opt.minimize(avg_cost)
# Reader
train_reader, test_reader = get_reader(word_dict, batch_size)
return inference_program, avg_cost, train_reader, test_reader, acc, predict
def tokenize(pattern):
"""
Read files that match the given pattern. Tokenize and yield each file.
"""
with tarfile.open(
paddle.dataset.common.download(DATA_URL, 'text_classification',
DATA_MD5)) as tarf:
# Note that we should use tarfile.next(), which does
# sequential access of member files, other than
# tarfile.extractfile, which does random access and might
# destroy hard disks.
tf = tarf.next()
while tf != None:
if bool(pattern.match(tf.name)):
# newline and punctuations removal and ad-hoc tokenization.
yield tarf.extractfile(tf).read().rstrip(six.b(
"\n\r")).translate(
None, six.b(string.punctuation)).lower().split()
tf = tarf.next()
def reader_creator(pos_pattern, neg_pattern, word_idx):
UNK = word_idx['<unk>']
INS = []
def load(pattern, out, label):
for doc in tokenize(pattern):
out.append(([word_idx.get(w, UNK) for w in doc], label))
load(pos_pattern, INS, 0)
load(neg_pattern, INS, 1)
def reader():
for doc, label in INS:
yield doc, label
return reader
def train(word_idx):
"""
IMDB training set creator.
It returns a reader creator, each sample in the reader is an zero-based ID
sequence and label in [0, 1].
:param word_idx: word dictionary
:type word_idx: dict
:return: Training reader creator
:rtype: callable
"""
return reader_creator(
re.compile("train/pos/.*\.txt$"),
re.compile("train/neg/.*\.txt$"), word_idx)
def test(word_idx):
"""
IMDB test set creator.
It returns a reader creator, each sample in the reader is an zero-based ID
sequence and label in [0, 1].
:param word_idx: word dictionary
:type word_idx: dict
:return: Test reader creator
:rtype: callable
"""
return reader_creator(
re.compile("test/pos/.*\.txt$"),
re.compile("test/neg/.*\.txt$"), word_idx)
if __name__ == "__main__":
paddle.dataset.common.download(VOCAB_URL, 'text_classification', VOCAB_MD5)
paddle.dataset.common.download(DATA_URL, 'text_classification', DATA_MD5)
runtime_main(TestDistTextClassification2x2)
...@@ -1699,10 +1699,9 @@ class DistTransformer2x2(TestDistRunnerBase): ...@@ -1699,10 +1699,9 @@ class DistTransformer2x2(TestDistRunnerBase):
exe.run(startup_prog) exe.run(startup_prog)
exe.run(pserver_prog) exe.run(pserver_prog)
def run_trainer(self, use_cuda, args): def run_trainer(self, args):
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() TrainTaskConfig.use_gpu = args.use_cuda
TrainTaskConfig.use_gpu = use_cuda sum_cost, avg_cost, predict, token_num, local_lr_scheduler = get_model(
sum_cost, avg_cost, predict, token_num, local_lr_scheduler, test_program = get_model(
args.is_dist, not args.sync_mode) args.is_dist, not args.sync_mode)
if args.is_dist: if args.is_dist:
...@@ -1718,6 +1717,11 @@ class DistTransformer2x2(TestDistRunnerBase): ...@@ -1718,6 +1717,11 @@ class DistTransformer2x2(TestDistRunnerBase):
TrainTaskConfig.batch_size = 20 TrainTaskConfig.batch_size = 20
trainer_prog = fluid.default_main_program() trainer_prog = fluid.default_main_program()
if args.use_cuda:
place = fluid.CUDAPlace(0)
else:
place = fluid.CPUPlace()
startup_exe = fluid.Executor(place) startup_exe = fluid.Executor(place)
TrainTaskConfig.local = not args.is_dist TrainTaskConfig.local = not args.is_dist
......
...@@ -122,4 +122,7 @@ class TestDistWord2vec2x2(TestDistRunnerBase): ...@@ -122,4 +122,7 @@ class TestDistWord2vec2x2(TestDistRunnerBase):
if __name__ == "__main__": if __name__ == "__main__":
import os
os.environ['CPU_NUM'] = '1'
os.environ['USE_CUDA'] = "FALSE"
runtime_main(TestDistWord2vec2x2) runtime_main(TestDistWord2vec2x2)
...@@ -36,7 +36,11 @@ class TestAucOp(OpTest): ...@@ -36,7 +36,11 @@ class TestAucOp(OpTest):
"StatPos": stat_pos, "StatPos": stat_pos,
"StatNeg": stat_neg "StatNeg": stat_neg
} }
self.attrs = {'curve': 'ROC', 'num_thresholds': num_thresholds} self.attrs = {
'curve': 'ROC',
'num_thresholds': num_thresholds,
"slide_steps": 1
}
python_auc = metrics.Auc(name="auc", python_auc = metrics.Auc(name="auc",
curve='ROC', curve='ROC',
...@@ -45,7 +49,6 @@ class TestAucOp(OpTest): ...@@ -45,7 +49,6 @@ class TestAucOp(OpTest):
self.outputs = { self.outputs = {
'AUC': np.array(python_auc.eval()), 'AUC': np.array(python_auc.eval()),
'BatchAUC': np.array(python_auc.eval()),
'StatPosOut': np.array(python_auc._stat_pos), 'StatPosOut': np.array(python_auc._stat_pos),
'StatNegOut': np.array(python_auc._stat_neg) 'StatNegOut': np.array(python_auc._stat_neg)
} }
......
...@@ -18,23 +18,27 @@ import time ...@@ -18,23 +18,27 @@ import time
import unittest import unittest
import os import os
import sys import sys
import six
import signal import signal
import subprocess import subprocess
import six
import argparse import argparse
import paddle.fluid as fluid
RUN_STEP = 10
class TestDistRunnerBase(object): class TestDistRunnerBase(object):
def get_model(self, batch_size=2): def get_model(self, batch_size=2):
raise NotImplementedError( raise NotImplementedError(
"get_model should be implemented by child classes.") "get_model should be implemented by child classes.")
def get_transpiler(self, trainer_id, main_program, pserver_endpoints, @staticmethod
trainers, sync_mode): def get_transpiler(trainer_id, main_program, pserver_endpoints, trainers,
sync_mode):
# NOTE: import fluid until runtime, or else forking processes will cause error. # NOTE: import fluid until runtime, or else forking processes will cause error.
import paddle config = fluid.DistributeTranspilerConfig()
import paddle.fluid as fluid t = fluid.DistributeTranspiler(config=config)
t = fluid.DistributeTranspiler()
t.transpile( t.transpile(
trainer_id=trainer_id, trainer_id=trainer_id,
program=main_program, program=main_program,
...@@ -44,9 +48,9 @@ class TestDistRunnerBase(object): ...@@ -44,9 +48,9 @@ class TestDistRunnerBase(object):
return t return t
def run_pserver(self, args): def run_pserver(self, args):
import paddle
import paddle.fluid as fluid
self.get_model(batch_size=2) self.get_model(batch_size=2)
if args.mem_opt: if args.mem_opt:
fluid.memory_optimize(fluid.default_main_program()) fluid.memory_optimize(fluid.default_main_program())
t = self.get_transpiler(args.trainer_id, t = self.get_transpiler(args.trainer_id,
...@@ -61,12 +65,10 @@ class TestDistRunnerBase(object): ...@@ -61,12 +65,10 @@ class TestDistRunnerBase(object):
exe.run(startup_prog) exe.run(startup_prog)
exe.run(pserver_prog) exe.run(pserver_prog)
def run_trainer(self, use_cuda, args): def run_trainer(self, args):
import paddle
import paddle.fluid as fluid
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
test_program, avg_cost, train_reader, test_reader, batch_acc, predict = \ test_program, avg_cost, train_reader, test_reader, batch_acc, predict = \
self.get_model(batch_size=2) self.get_model(batch_size=2)
if args.mem_opt: if args.mem_opt:
fluid.memory_optimize(fluid.default_main_program()) fluid.memory_optimize(fluid.default_main_program())
if args.is_dist: if args.is_dist:
...@@ -74,16 +76,23 @@ class TestDistRunnerBase(object): ...@@ -74,16 +76,23 @@ class TestDistRunnerBase(object):
fluid.default_main_program(), fluid.default_main_program(),
args.endpoints, args.trainers, args.endpoints, args.trainers,
args.sync_mode) args.sync_mode)
trainer_prog = t.get_trainer_program() trainer_prog = t.get_trainer_program()
else: else:
trainer_prog = fluid.default_main_program() trainer_prog = fluid.default_main_program()
if args.use_cuda:
place = fluid.CUDAPlace(0)
else:
place = fluid.CPUPlace()
startup_exe = fluid.Executor(place) startup_exe = fluid.Executor(place)
startup_exe.run(fluid.default_startup_program()) startup_exe.run(fluid.default_startup_program())
strategy = fluid.ExecutionStrategy() strategy = fluid.ExecutionStrategy()
strategy.num_threads = 1 strategy.num_threads = 1
strategy.allow_op_delay = False strategy.allow_op_delay = False
build_stra = fluid.BuildStrategy() build_stra = fluid.BuildStrategy()
if args.use_reduce: if args.use_reduce:
...@@ -92,7 +101,7 @@ class TestDistRunnerBase(object): ...@@ -92,7 +101,7 @@ class TestDistRunnerBase(object):
build_stra.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.AllReduce build_stra.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.AllReduce
exe = fluid.ParallelExecutor( exe = fluid.ParallelExecutor(
use_cuda, args.use_cuda,
loss_name=avg_cost.name, loss_name=avg_cost.name,
exec_strategy=strategy, exec_strategy=strategy,
build_strategy=build_stra) build_strategy=build_stra)
...@@ -103,27 +112,26 @@ class TestDistRunnerBase(object): ...@@ -103,27 +112,26 @@ class TestDistRunnerBase(object):
] ]
feeder = fluid.DataFeeder(feed_var_list, place) feeder = fluid.DataFeeder(feed_var_list, place)
reader_generator = test_reader() reader_generator = train_reader()
data = next(reader_generator) def get_data():
first_loss, = exe.run(fetch_list=[avg_cost.name], origin_batch = next(reader_generator)
feed=feeder.feed(data)) if args.is_dist and args.use_reader_alloc:
print(first_loss) new_batch = []
for offset, item in enumerate(origin_batch):
for i in six.moves.xrange(5): if offset % 2 == args.trainer_id:
data = next(reader_generator) new_batch.append(item)
loss, = exe.run(fetch_list=[avg_cost.name], feed=feeder.feed(data)) return new_batch
else:
return origin_batch
data = next(reader_generator) for _ in six.moves.xrange(RUN_STEP):
last_loss, = exe.run(fetch_list=[avg_cost.name], feed=feeder.feed(data)) loss, = exe.run(fetch_list=[avg_cost.name],
print(last_loss) feed=feeder.feed(get_data()))
print(loss)
def runtime_main(test_class): def runtime_main(test_class):
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
parser = argparse.ArgumentParser(description='Run dist test.') parser = argparse.ArgumentParser(description='Run dist test.')
parser.add_argument( parser.add_argument(
'--role', type=str, required=True, choices=['pserver', 'trainer']) '--role', type=str, required=True, choices=['pserver', 'trainer'])
...@@ -135,7 +143,10 @@ def runtime_main(test_class): ...@@ -135,7 +143,10 @@ def runtime_main(test_class):
'--current_endpoint', type=str, required=False, default="") '--current_endpoint', type=str, required=False, default="")
parser.add_argument('--sync_mode', action='store_true') parser.add_argument('--sync_mode', action='store_true')
parser.add_argument('--mem_opt', action='store_true') parser.add_argument('--mem_opt', action='store_true')
parser.add_argument('--use_cuda', action='store_true')
parser.add_argument('--use_reduce', action='store_true') parser.add_argument('--use_reduce', action='store_true')
parser.add_argument(
'--use_reader_alloc', action='store_true', required=False, default=True)
args = parser.parse_args() args = parser.parse_args()
...@@ -143,8 +154,7 @@ def runtime_main(test_class): ...@@ -143,8 +154,7 @@ def runtime_main(test_class):
if args.role == "pserver" and args.is_dist: if args.role == "pserver" and args.is_dist:
model.run_pserver(args) model.run_pserver(args)
else: else:
use_cuda = True if core.is_compiled_with_cuda() else False model.run_trainer(args)
model.run_trainer(use_cuda, args)
import paddle.compat as cpt import paddle.compat as cpt
...@@ -156,6 +166,17 @@ class TestDistBase(unittest.TestCase): ...@@ -156,6 +166,17 @@ class TestDistBase(unittest.TestCase):
def _setup_config(self): def _setup_config(self):
raise NotImplementedError("tests should have _setup_config implemented") raise NotImplementedError("tests should have _setup_config implemented")
def _after_setup_config(self):
if self._enforce_place == "CPU":
self.__use_cuda = False
elif self._enforce_place == "GPU":
self.__use_cuda = True
else:
if fluid.core.is_compiled_with_cuda():
self.__use_cuda = True
else:
self.__use_cuda = False
def setUp(self): def setUp(self):
self._trainers = 2 self._trainers = 2
self._pservers = 2 self._pservers = 2
...@@ -163,16 +184,19 @@ class TestDistBase(unittest.TestCase): ...@@ -163,16 +184,19 @@ class TestDistBase(unittest.TestCase):
self._find_free_port(), self._find_free_port()) self._find_free_port(), self._find_free_port())
self._python_interp = "python" self._python_interp = "python"
self._sync_mode = True self._sync_mode = True
self._enforce_place = None
self._mem_opt = False self._mem_opt = False
self._use_reduce = False self._use_reduce = False
self._use_reader_alloc = True
self._setup_config() self._setup_config()
self._after_setup_config()
def _find_free_port(self): def _find_free_port(self):
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
s.bind(('', 0)) s.bind(('', 0))
return s.getsockname()[1] return s.getsockname()[1]
def start_pserver(self, model_file, check_error_log): def start_pserver(self, model_file, check_error_log, required_envs):
ps0_ep, ps1_ep = self._ps_endpoints.split(",") ps0_ep, ps1_ep = self._ps_endpoints.split(",")
ps_cmd = "%s %s --role pserver --endpoints %s --trainer_id 0 --current_endpoint %s --trainers %d --is_dist" ps_cmd = "%s %s --role pserver --endpoints %s --trainer_id 0 --current_endpoint %s --trainers %d --is_dist"
ps0_cmd = ps_cmd % \ ps0_cmd = ps_cmd % \
...@@ -189,22 +213,22 @@ class TestDistBase(unittest.TestCase): ...@@ -189,22 +213,22 @@ class TestDistBase(unittest.TestCase):
ps0_cmd += " --mem_opt" ps0_cmd += " --mem_opt"
ps1_cmd += " --mem_opt" ps1_cmd += " --mem_opt"
ps0_pipe = subprocess.PIPE
ps1_pipe = subprocess.PIPE
if check_error_log:
print(ps0_cmd) print(ps0_cmd)
print(ps1_cmd) print(ps1_cmd)
ps0_pipe = open("/tmp/ps0_err.log", "wb") ps0_pipe = open("/tmp/ps0_err.log", "wb")
ps1_pipe = open("/tmp/ps1_err.log", "wb") ps1_pipe = open("/tmp/ps1_err.log", "wb")
ps0_proc = subprocess.Popen( ps0_proc = subprocess.Popen(
ps0_cmd.strip().split(" "), stdout=subprocess.PIPE, stderr=ps0_pipe) ps0_cmd.strip().split(" "),
stdout=subprocess.PIPE,
stderr=ps0_pipe,
env=required_envs)
ps1_proc = subprocess.Popen( ps1_proc = subprocess.Popen(
ps1_cmd.strip().split(" "), stdout=subprocess.PIPE, stderr=ps1_pipe) ps1_cmd.strip().split(" "),
stdout=subprocess.PIPE,
stderr=ps1_pipe,
env=required_envs)
if not check_error_log:
return ps0_proc, ps1_proc, None, None
else:
return ps0_proc, ps1_proc, ps0_pipe, ps1_pipe return ps0_proc, ps1_proc, ps0_pipe, ps1_pipe
def _wait_ps_ready(self, pid): def _wait_ps_ready(self, pid):
...@@ -222,58 +246,58 @@ class TestDistBase(unittest.TestCase): ...@@ -222,58 +246,58 @@ class TestDistBase(unittest.TestCase):
(e, retry_times)) (e, retry_times))
retry_times -= 1 retry_times -= 1
def check_with_place(self, model_file, delta=1e-3, check_error_log=False): def _run_local(self, model, envs, check_error_log):
# TODO(typhoonzero): should auto adapt GPU count on the machine.
required_envs = {
"PATH": os.getenv("PATH", ""),
"PYTHONPATH": os.getenv("PYTHONPATH", ""),
"LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""),
"FLAGS_fraction_of_gpu_memory_to_use": "0.15",
"FLAGS_cudnn_deterministic": "1",
"CPU_NUM": "1"
}
if check_error_log: cmd = "%s %s --role trainer" % (self._python_interp, model)
required_envs["GLOG_v"] = "7"
required_envs["GLOG_logtostderr"] = "1"
# Run local to get a base line if self.__use_cuda:
cmd += " --use_cuda"
env_local = {"CUDA_VISIBLE_DEVICES": "0"} env_local = {"CUDA_VISIBLE_DEVICES": "0"}
env_local.update(required_envs)
local_cmd = "%s %s --role trainer" % (self._python_interp, model_file)
if not check_error_log:
local_proc = subprocess.Popen(
local_cmd.split(" "),
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
env=env_local)
else: else:
env_local = {'CPU_NUM': '1'}
envs.update(env_local)
if check_error_log:
err_log = open("/tmp/trainer.err.log", "wb") err_log = open("/tmp/trainer.err.log", "wb")
local_proc = subprocess.Popen( local_proc = subprocess.Popen(
local_cmd.split(" "), cmd.split(" "),
stdout=subprocess.PIPE, stdout=subprocess.PIPE,
stderr=err_log, stderr=err_log,
env=env_local) env=envs)
else:
local_proc = subprocess.Popen(
cmd.split(" "),
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
env=envs)
local_out, local_err = local_proc.communicate()
local_ret = cpt.to_text(local_out)
if check_error_log:
err_log.close()
sys.stderr.write('local_stdout: %s\n' % local_ret)
sys.stderr.write('local_stderr: %s\n' % local_err)
local_proc.wait() local_losses = local_ret.split("\n")
out, err = local_proc.communicate() return local_losses
local_ret = cpt.to_text(out)
sys.stderr.write('local_loss: %s\n' % local_ret)
sys.stderr.write('local_stderr: %s\n' % err)
def _run_cluster(self, model, envs, check_error_log):
# Run dist train to compare with local results # Run dist train to compare with local results
ps0, ps1, ps0_pipe, ps1_pipe = self.start_pserver(model_file, ps0, ps1, ps0_pipe, ps1_pipe = self.start_pserver(model,
check_error_log) check_error_log, envs)
self._wait_ps_ready(ps0.pid) self._wait_ps_ready(ps0.pid)
self._wait_ps_ready(ps1.pid) self._wait_ps_ready(ps1.pid)
ps0_ep, ps1_ep = self._ps_endpoints.split(",") ps0_ep, ps1_ep = self._ps_endpoints.split(",")
tr_cmd = "%s %s --role trainer --endpoints %s --trainer_id %d --current_endpoint %s --trainers %d --is_dist" tr_cmd = "%s %s --role trainer --endpoints %s --trainer_id %d --current_endpoint %s --trainers %d --is_dist"
tr0_cmd = tr_cmd % \ tr0_cmd = tr_cmd % \
(self._python_interp, model_file, self._ps_endpoints, (self._python_interp, model, self._ps_endpoints,
0, ps0_ep, self._trainers) 0, ps0_ep, self._trainers)
tr1_cmd = tr_cmd % \ tr1_cmd = tr_cmd % \
(self._python_interp, model_file, self._ps_endpoints, (self._python_interp, model, self._ps_endpoints,
1, ps1_ep, self._trainers) 1, ps1_ep, self._trainers)
if self._sync_mode: if self._sync_mode:
...@@ -285,18 +309,23 @@ class TestDistBase(unittest.TestCase): ...@@ -285,18 +309,23 @@ class TestDistBase(unittest.TestCase):
if self._use_reduce: if self._use_reduce:
tr0_cmd += " --use_reduce" tr0_cmd += " --use_reduce"
tr1_cmd += " --use_reduce" tr1_cmd += " --use_reduce"
if self._use_reader_alloc:
tr0_cmd += " --use_reader_alloc"
tr1_cmd += " --use_reader_alloc"
if self.__use_cuda:
tr0_cmd += " --use_cuda"
tr1_cmd += " --use_cuda"
env0 = {"CUDA_VISIBLE_DEVICES": "0"} env0 = {"CUDA_VISIBLE_DEVICES": "0"}
env1 = {"CUDA_VISIBLE_DEVICES": "1"} env1 = {"CUDA_VISIBLE_DEVICES": "1"}
env0.update(required_envs) else:
env1.update(required_envs) env0 = {'CPU_NUM': '1'}
FNULL = open(os.devnull, 'w') env1 = {'CPU_NUM': '1'}
tr0_pipe = subprocess.PIPE env0.update(envs)
tr1_pipe = subprocess.PIPE env1.update(envs)
if check_error_log:
print("tr0_cmd:", tr0_cmd) print("tr0_cmd:{}, env0: {}".format(tr0_cmd, env0))
print("tr1_cmd:", tr1_cmd) print("tr1_cmd:{}, env1: {}".format(tr1_cmd, env1))
tr0_pipe = open("/tmp/tr0_err.log", "wb") tr0_pipe = open("/tmp/tr0_err.log", "wb")
tr1_pipe = open("/tmp/tr1_err.log", "wb") tr1_pipe = open("/tmp/tr1_err.log", "wb")
...@@ -311,22 +340,12 @@ class TestDistBase(unittest.TestCase): ...@@ -311,22 +340,12 @@ class TestDistBase(unittest.TestCase):
stderr=tr1_pipe, stderr=tr1_pipe,
env=env1) env=env1)
tr0_proc.wait() tr0_out, tr0_err = tr0_proc.communicate()
tr1_proc.wait() tr0_loss_text = cpt.to_text(tr0_out)
out, err = tr0_proc.communicate() tr1_out, tr1_err = tr1_proc.communicate()
sys.stderr.write('dist_stderr: %s\n' % err) tr1_loss_text = cpt.to_text(tr1_out)
loss_data0 = cpt.to_text(out)
sys.stderr.write('dist_loss: %s\n' % loss_data0)
lines = loss_data0.split("\n")
dist_first_loss = eval(lines[0].replace(" ", ","))[0]
dist_last_loss = eval(lines[1].replace(" ", ","))[0]
local_lines = local_ret.split("\n")
local_first_loss = eval(local_lines[0])[0]
local_last_loss = eval(local_lines[1])[0]
# close trainer file # close trainer file
if check_error_log:
tr0_pipe.close() tr0_pipe.close()
tr1_pipe.close() tr1_pipe.close()
...@@ -337,9 +356,49 @@ class TestDistBase(unittest.TestCase): ...@@ -337,9 +356,49 @@ class TestDistBase(unittest.TestCase):
os.kill(ps1.pid, signal.SIGKILL) os.kill(ps1.pid, signal.SIGKILL)
ps0.terminate() ps0.terminate()
ps1.terminate() ps1.terminate()
ps0.wait()
ps1.wait()
FNULL.close()
self.assertAlmostEqual(local_first_loss, dist_first_loss, delta=delta) # print log
self.assertAlmostEqual(local_last_loss, dist_last_loss, delta=delta) sys.stderr.write('trainer 0 stdout:\n %s\n' % tr0_loss_text)
sys.stderr.write('trainer 0 stderr:\n %s\n' % tr0_err)
sys.stderr.write('trainer 1 stdout: %s\n' % tr1_loss_text)
sys.stderr.write('trainer 1 stderr: %s\n' % tr1_err)
tr0_losses = tr0_loss_text.split("\n")
tr1_losses = tr1_loss_text.split("\n")
return tr0_losses, tr1_losses
def check_with_place(self,
model_file,
delta=1e-3,
check_error_log=False,
need_envs={}):
# TODO(typhoonzero): should auto adapt GPU count on the machine.
required_envs = {
"PATH": os.getenv("PATH", ""),
"PYTHONPATH": os.getenv("PYTHONPATH", ""),
"LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""),
"FLAGS_fraction_of_gpu_memory_to_use": "0.15",
"FLAGS_cudnn_deterministic": "1",
"http_proxy": ""
}
required_envs.update(need_envs)
if check_error_log:
required_envs["GLOG_v"] = "7"
required_envs["GLOG_logtostderr"] = "1"
local_losses\
= self._run_local(model_file, required_envs,
check_error_log)
tr0_losses, tr1_losses = self._run_cluster(model_file, required_envs,
check_error_log)
for step_id in range(RUN_STEP):
local_loss = eval(local_losses[step_id])[0]
tr0_loss = eval(tr0_losses[step_id])[0]
tr1_loss = eval(tr1_losses[step_id])[0]
dist_loss = (tr0_loss + tr1_loss) / 2
print(str(local_loss) + ":" + str(dist_loss))
self.assertAlmostEqual(local_loss, dist_loss, delta=delta)
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import os
import unittest
from test_dist_base import TestDistBase
class TestDistCTR2x2(TestDistBase):
def _setup_config(self):
self._sync_mode = True
self._enforce_place = "CPU"
def test_dist_ctr(self):
self.check_with_place("dist_ctr.py", delta=1e-7)
if __name__ == "__main__":
unittest.main()
...@@ -23,7 +23,7 @@ class TestDistMnist2x2(TestDistBase): ...@@ -23,7 +23,7 @@ class TestDistMnist2x2(TestDistBase):
self._use_reduce = False self._use_reduce = False
def test_dist_train(self): def test_dist_train(self):
self.check_with_place("dist_mnist.py", delta=1e-7) self.check_with_place("dist_mnist.py", delta=1e-5)
class TestDistMnist2x2WithMemopt(TestDistBase): class TestDistMnist2x2WithMemopt(TestDistBase):
...@@ -32,7 +32,7 @@ class TestDistMnist2x2WithMemopt(TestDistBase): ...@@ -32,7 +32,7 @@ class TestDistMnist2x2WithMemopt(TestDistBase):
self._mem_opt = True self._mem_opt = True
def test_dist_train(self): def test_dist_train(self):
self.check_with_place("dist_mnist.py", delta=1e-7) self.check_with_place("dist_mnist.py", delta=1e-5)
class TestDistMnistAsync(TestDistBase): class TestDistMnistAsync(TestDistBase):
......
...@@ -20,9 +20,10 @@ from test_dist_base import TestDistBase ...@@ -20,9 +20,10 @@ from test_dist_base import TestDistBase
class TestDistSeResneXt2x2(TestDistBase): class TestDistSeResneXt2x2(TestDistBase):
def _setup_config(self): def _setup_config(self):
self._sync_mode = True self._sync_mode = True
self._use_reader_alloc = False
def test_dist_train(self): def test_dist_train(self):
self.check_with_place("dist_se_resnext.py", delta=1e-7) self.check_with_place("dist_se_resnext.py", delta=100)
# TODO(typhoonzero): fix this test # TODO(typhoonzero): fix this test
...@@ -38,6 +39,7 @@ class TestDistSeResneXt2x2(TestDistBase): ...@@ -38,6 +39,7 @@ class TestDistSeResneXt2x2(TestDistBase):
class TestDistSeResneXt2x2Async(TestDistBase): class TestDistSeResneXt2x2Async(TestDistBase):
def _setup_config(self): def _setup_config(self):
self._sync_mode = False self._sync_mode = False
self._use_reader_alloc = False
def test_dist_train(self): def test_dist_train(self):
self.check_with_place("dist_se_resnext.py", delta=100) self.check_with_place("dist_se_resnext.py", delta=100)
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import os
import unittest
from test_dist_base import TestDistBase
class TestDistSimnetBowDense2x2(TestDistBase):
def _setup_config(self):
self._sync_mode = True
self._enforce_place = "CPU"
def test_simnet_bow(self):
need_envs = {"IS_DISTRIBUTED": '0', "IS_SPARSE": '0'}
self.check_with_place(
"dist_simnet_bow.py",
delta=1e-5,
check_error_log=False,
need_envs=need_envs)
class TestDistSimnetBow2x2DenseAsync(TestDistBase):
def _setup_config(self):
self._sync_mode = False
self._enforce_place = "CPU"
def test_simnet_bow(self):
need_envs = {"IS_DISTRIBUTED": '0', "IS_SPARSE": '0'}
self.check_with_place(
"dist_simnet_bow.py",
delta=100,
check_error_log=False,
need_envs=need_envs)
class TestDistSimnetBowSparse2x2(TestDistBase):
def _setup_config(self):
self._sync_mode = True
self._enforce_place = "CPU"
def test_simnet_bow(self):
need_envs = {"IS_DISTRIBUTED": '0', "IS_SPARSE": '1'}
self.check_with_place(
"dist_simnet_bow.py",
delta=1e-5,
check_error_log=False,
need_envs=need_envs)
class TestDistSimnetBow2x2SparseAsync(TestDistBase):
def _setup_config(self):
self._sync_mode = False
self._enforce_place = "CPU"
def test_simnet_bow(self):
need_envs = {"IS_DISTRIBUTED": '0', "IS_SPARSE": '1'}
self.check_with_place(
"dist_simnet_bow.py",
delta=100,
check_error_log=False,
need_envs=need_envs)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import os
import unittest
from test_dist_base import TestDistBase
class TestDistTextClassification2x2(TestDistBase):
def _setup_config(self):
self._sync_mode = True
self._enforce_place = "CPU"
def test_text_classification(self):
self.check_with_place("dist_text_classification.py", delta=1e-6)
class TestDistTextClassification2x2Async(TestDistBase):
def _setup_config(self):
self._sync_mode = False
self._enforce_place = "CPU"
def test_se_resnext(self):
self.check_with_place("dist_text_classification.py", delta=100)
if __name__ == "__main__":
unittest.main()
...@@ -39,7 +39,7 @@ class TestDistW2V2x2Async(TestDistBase): ...@@ -39,7 +39,7 @@ class TestDistW2V2x2Async(TestDistBase):
self._sync_mode = False self._sync_mode = False
def test_dist_train(self): def test_dist_train(self):
self.check_with_place("dist_word2vec.py", delta=1) self.check_with_place("dist_word2vec.py", delta=100)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -537,7 +537,7 @@ class DistributeTranspiler(object): ...@@ -537,7 +537,7 @@ class DistributeTranspiler(object):
}) })
for varname, splited_var in six.iteritems(self.param_var_mapping): for varname, splited_var in six.iteritems(self.param_var_mapping):
#add concat ops to merge splited parameters received from parameter servers. # add concat ops to merge splited parameters received from parameter servers.
if len(splited_var) <= 1: if len(splited_var) <= 1:
continue continue
# NOTE: if enable memory optimization, origin vars maybe removed. # NOTE: if enable memory optimization, origin vars maybe removed.
...@@ -737,19 +737,14 @@ in a single call.") ...@@ -737,19 +737,14 @@ in a single call.")
table_opt_block = self._create_table_optimize_block( table_opt_block = self._create_table_optimize_block(
pserver_index, pserver_program, pre_block_idx, grad_to_block_id) pserver_index, pserver_program, pre_block_idx, grad_to_block_id)
optimize_blocks.append(table_opt_block) optimize_blocks.append(table_opt_block)
prefetch_var_name_to_block_id = self._create_prefetch_block( lookup_table_var_name_to_block_id = self._create_prefetch_block(
pserver_index, pserver_program, table_opt_block) pserver_index, pserver_program, table_opt_block)
checkpoint_block_id = self._create_checkpoint_save_block( checkpoint_block_id = self._create_checkpoint_save_block(
pserver_program, table_opt_block.idx) pserver_program, table_opt_block.idx)
pserver_program._distributed_lookup_table = self.table_name pserver_program._distributed_lookup_table = self.table_name
prefetch_var_name_to_block_id.extend(
# NOTE: if has_distributed_lookup_table is False, then prefetch_block will lookup_table_var_name_to_block_id)
# not be executed, so it's safe to use optimize_block to hold the place
if self.has_distributed_lookup_table:
assert len(prefetch_var_name_to_block_id) > 0
else:
assert len(prefetch_var_name_to_block_id) == 0
attrs = { attrs = {
"optimize_blocks": optimize_blocks, "optimize_blocks": optimize_blocks,
...@@ -758,11 +753,14 @@ in a single call.") ...@@ -758,11 +753,14 @@ in a single call.")
"sync_mode": self.sync_mode, "sync_mode": self.sync_mode,
"grad_to_block_id": grad_to_block_id, "grad_to_block_id": grad_to_block_id,
} }
if len(prefetch_var_name_to_block_id) > 0:
attrs['prefetch_var_name_to_block_id'] \ if self.has_distributed_lookup_table:
= prefetch_var_name_to_block_id
attrs['checkpint_block_id'] = checkpoint_block_id attrs['checkpint_block_id'] = checkpoint_block_id
if len(prefetch_var_name_to_block_id) > 0:
attrs[
'prefetch_var_name_to_block_id'] = prefetch_var_name_to_block_id
# step5 append the listen_and_serv op # step5 append the listen_and_serv op
pserver_program.global_block().append_op( pserver_program.global_block().append_op(
type="listen_and_serv", type="listen_and_serv",
...@@ -1492,7 +1490,6 @@ to transpile() call.") ...@@ -1492,7 +1490,6 @@ to transpile() call.")
per_trainer_name = "%s.trainer_%d" % \ per_trainer_name = "%s.trainer_%d" % \
(merged_var_name, i) (merged_var_name, i)
vars2merge.append(pserver_block.vars[per_trainer_name]) vars2merge.append(pserver_block.vars[per_trainer_name])
optimize_block.append_op( optimize_block.append_op(
type="sum", type="sum",
inputs={"X": vars2merge}, inputs={"X": vars2merge},
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册