diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 130558b09165dbe31f2062c2aebbab0a84d56b24..aec9123ed9074b044c4ef3bf1354440983c83428 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -269,7 +269,7 @@ paddle.fluid.layers.iou_similarity ArgSpec(args=[], varargs='args', keywords='kw paddle.fluid.layers.box_coder ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None) paddle.fluid.layers.polygon_box_transform ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None) paddle.fluid.layers.accuracy ArgSpec(args=['input', 'label', 'k', 'correct', 'total'], varargs=None, keywords=None, defaults=(1, None, None)) -paddle.fluid.layers.auc ArgSpec(args=['input', 'label', 'curve', 'num_thresholds', 'topk'], varargs=None, keywords=None, defaults=('ROC', 4095, 1)) +paddle.fluid.layers.auc ArgSpec(args=['input', 'label', 'curve', 'num_thresholds', 'topk', 'slide_steps'], varargs=None, keywords=None, defaults=('ROC', 4095, 1, 1)) paddle.fluid.layers.exponential_decay ArgSpec(args=['learning_rate', 'decay_steps', 'decay_rate', 'staircase'], varargs=None, keywords=None, defaults=(False,)) paddle.fluid.layers.natural_exp_decay ArgSpec(args=['learning_rate', 'decay_steps', 'decay_rate', 'staircase'], varargs=None, keywords=None, defaults=(False,)) paddle.fluid.layers.inverse_time_decay ArgSpec(args=['learning_rate', 'decay_steps', 'decay_rate', 'staircase'], varargs=None, keywords=None, defaults=(False,)) diff --git a/paddle/fluid/operators/auc_op.cc b/paddle/fluid/operators/auc_op.cc index dfaa7456f917c1308984b361afed752f96ea6f59..0784920064a879963cd9725cd9acf4cec7b874ce 100644 --- a/paddle/fluid/operators/auc_op.cc +++ b/paddle/fluid/operators/auc_op.cc @@ -36,11 +36,16 @@ class AucOp : public framework::OperatorWithKernel { "Out and Label should have same height."); int num_pred_buckets = ctx->Attrs().Get("num_thresholds") + 1; + int slide_steps = ctx->Attrs().Get("slide_steps"); + + PADDLE_ENFORCE_GE(num_pred_buckets, 1, "num_thresholds must larger than 1"); + PADDLE_ENFORCE_GE(slide_steps, 0, "slide_steps must be natural number"); ctx->SetOutputDim("AUC", {1}); - ctx->SetOutputDim("BatchAUC", {1}); - ctx->SetOutputDim("StatPosOut", {num_pred_buckets}); - ctx->SetOutputDim("StatNegOut", {num_pred_buckets}); + + slide_steps = slide_steps == 0 ? 1 : slide_steps; + ctx->SetOutputDim("StatPosOut", {slide_steps, num_pred_buckets}); + ctx->SetOutputDim("StatNegOut", {slide_steps, num_pred_buckets}); } protected: @@ -62,6 +67,7 @@ class AucOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("Label", "A 2D int tensor indicating the label of the training data. " "shape: [batch_size, 1]"); + // TODO(typhoonzero): support weight input AddInput("StatPos", "Statistic value when label = 1"); AddInput("StatNeg", "Statistic value when label = 0"); @@ -69,18 +75,19 @@ class AucOpMaker : public framework::OpProtoAndCheckerMaker { AddOutput("AUC", "A scalar representing the " "current area-under-the-curve."); - AddOutput("BatchAUC", "The AUC for current batch"); + AddOutput("StatPosOut", "Statistic value when label = 1"); AddOutput("StatNegOut", "Statistic value when label = 0"); AddAttr("curve", "Curve type, can be 'ROC' or 'PR'.") .SetDefault("ROC"); - AddAttr("num_thresholds", - "The number of thresholds to use when discretizing the" - " roc curve.") + AddAttr( + "num_thresholds", + "The number of thresholds to use when discretizing the roc curve.") .SetDefault((2 << 12) - 1); - + AddAttr("slide_steps", "Use slide steps to calc batch auc.") + .SetDefault(1); AddComment(R"DOC( Area Under The Curve (AUC) Operator. diff --git a/paddle/fluid/operators/auc_op.h b/paddle/fluid/operators/auc_op.h index fb0517d70635e090f8c5b59ff9d8420fc34c747b..fb370842d1942c3b3eebecb1fe5e8ffb845cb34b 100644 --- a/paddle/fluid/operators/auc_op.h +++ b/paddle/fluid/operators/auc_op.h @@ -32,7 +32,9 @@ class AucKernel : public framework::OpKernel { std::string curve = ctx.Attr("curve"); int num_thresholds = ctx.Attr("num_thresholds"); + // buckets contain numbers from 0 to num_thresholds int num_pred_buckets = num_thresholds + 1; + int slide_steps = ctx.Attr("slide_steps"); // Only use output var for now, make sure it's persistable and // not cleaned up for each batch. @@ -40,16 +42,19 @@ class AucKernel : public framework::OpKernel { auto *stat_pos = ctx.Output("StatPosOut"); auto *stat_neg = ctx.Output("StatNegOut"); - auto *stat_pos_data = stat_pos->mutable_data(ctx.GetPlace()); - auto *stat_neg_data = stat_neg->mutable_data(ctx.GetPlace()); - calcAuc(ctx, label, predict, stat_pos_data, stat_neg_data, num_thresholds, - auc); + auto *origin_stat_pos = stat_pos->mutable_data(ctx.GetPlace()); + auto *origin_stat_neg = stat_neg->mutable_data(ctx.GetPlace()); - auto *batch_auc = ctx.Output("BatchAUC"); - std::vector stat_pos_batch(num_pred_buckets, 0); - std::vector stat_neg_batch(num_pred_buckets, 0); - calcAuc(ctx, label, predict, stat_pos_batch.data(), stat_neg_batch.data(), - num_thresholds, batch_auc); + std::vector stat_pos_data(num_pred_buckets, 0); + std::vector stat_neg_data(num_pred_buckets, 0); + + auto stat_pos_calc = stat_pos_data.data(); + auto stat_neg_calc = stat_neg_data.data(); + + statAuc(label, predict, num_pred_buckets, num_thresholds, slide_steps, + origin_stat_pos, origin_stat_neg, &stat_pos_calc, &stat_neg_calc); + + calcAuc(ctx, stat_pos_calc, stat_neg_calc, num_thresholds, auc); } private: @@ -58,29 +63,76 @@ class AucKernel : public framework::OpKernel { return (X1 > X2 ? (X1 - X2) : (X2 - X1)) * (Y1 + Y2) / 2.0; } - inline static void calcAuc(const framework::ExecutionContext &ctx, - const framework::Tensor *label, + inline static void statAuc(const framework::Tensor *label, const framework::Tensor *predict, - int64_t *stat_pos, int64_t *stat_neg, - int num_thresholds, - framework::Tensor *auc_tensor) { + const int num_pred_buckets, + const int num_thresholds, const int slide_steps, + int64_t *origin_stat_pos, int64_t *origin_stat_neg, + int64_t **stat_pos, int64_t **stat_neg) { size_t batch_size = predict->dims()[0]; size_t inference_width = predict->dims()[1]; const T *inference_data = predict->data(); const auto *label_data = label->data(); - auto *auc = auc_tensor->mutable_data(ctx.GetPlace()); - for (size_t i = 0; i < batch_size; i++) { uint32_t binIdx = static_cast( inference_data[i * inference_width + 1] * num_thresholds); if (label_data[i]) { - stat_pos[binIdx] += 1.0; + (*stat_pos)[binIdx] += 1.0; } else { - stat_neg[binIdx] += 1.0; + (*stat_neg)[binIdx] += 1.0; } } + int bucket_length = num_pred_buckets * sizeof(int64_t); + + // will stat auc unlimited. + if (slide_steps == 0) { + for (int slide = 0; slide < num_pred_buckets; ++slide) { + origin_stat_pos[slide] += (*stat_pos)[slide]; + origin_stat_neg[slide] += (*stat_neg)[slide]; + } + + *stat_pos = origin_stat_pos; + *stat_neg = origin_stat_neg; + + } else { + for (int slide = 1; slide < slide_steps; ++slide) { + int dst_idx = (slide - 1) * num_pred_buckets; + int src_inx = slide * num_pred_buckets; + std::memcpy(origin_stat_pos + dst_idx, origin_stat_pos + src_inx, + bucket_length); + std::memcpy(origin_stat_neg + dst_idx, origin_stat_neg + src_inx, + bucket_length); + } + + std::memcpy(origin_stat_pos + (slide_steps - 1) * num_pred_buckets, + *stat_pos, bucket_length); + std::memcpy(origin_stat_neg + (slide_steps - 1) * num_pred_buckets, + *stat_neg, bucket_length); + + std::memset(*stat_pos, 0, bucket_length); + std::memset(*stat_neg, 0, bucket_length); + + for (int slide = 0; slide < num_pred_buckets; ++slide) { + int stat_pos_steps = 0; + int stat_neg_steps = 0; + for (int step = 0; step < slide_steps; ++step) { + stat_pos_steps += origin_stat_pos[slide + step * num_pred_buckets]; + stat_neg_steps += origin_stat_neg[slide + step * num_pred_buckets]; + } + (*stat_pos)[slide] += stat_pos_steps; + (*stat_neg)[slide] += stat_neg_steps; + } + } + } + + inline static void calcAuc(const framework::ExecutionContext &ctx, + int64_t *stat_pos, int64_t *stat_neg, + int num_thresholds, + framework::Tensor *auc_tensor) { + auto *auc = auc_tensor->mutable_data(ctx.GetPlace()); + *auc = 0.0f; double totPos = 0.0; @@ -96,7 +148,6 @@ class AucKernel : public framework::OpKernel { totPos += stat_pos[idx]; totNeg += stat_neg[idx]; *auc += trapezoidArea(totNeg, totNegPrev, totPos, totPosPrev); - --idx; } diff --git a/python/paddle/fluid/layers/metric_op.py b/python/paddle/fluid/layers/metric_op.py index b1598bfec210474ae1e17f9f88e8b57aa80b8452..a3064b565d096f7feda18379c66ffc8bf2f4a55c 100644 --- a/python/paddle/fluid/layers/metric_op.py +++ b/python/paddle/fluid/layers/metric_op.py @@ -78,7 +78,12 @@ def accuracy(input, label, k=1, correct=None, total=None): return acc_out -def auc(input, label, curve='ROC', num_thresholds=2**12 - 1, topk=1): +def auc(input, + label, + curve='ROC', + num_thresholds=2**12 - 1, + topk=1, + slide_steps=1): """ **Area Under the Curve (AUC) Layer** @@ -105,6 +110,8 @@ def auc(input, label, curve='ROC', num_thresholds=2**12 - 1, topk=1): num_thresholds(int): The number of thresholds to use when discretizing the roc curve. Default 200. topk(int): only topk number of prediction output will be used for auc. + slide_steps: when calc batch auc, we can not only use step currently but the previous steps can be used. slide_steps=1 means use the current step, slide_steps=3 means use current step and the previous second steps, slide_steps=0 use all of the steps. + Returns: Variable: A scalar representing the current AUC. @@ -120,16 +127,48 @@ def auc(input, label, curve='ROC', num_thresholds=2**12 - 1, topk=1): auc_out = helper.create_tmp_variable(dtype="float64") batch_auc_out = helper.create_tmp_variable(dtype="float64") # make tp, tn, fp, fn persistable, so that can accumulate all batches. + + # for batch auc + batch_stat_pos = helper.create_global_variable( + persistable=True, + dtype='int64', + shape=[slide_steps, num_thresholds + 1]) + batch_stat_neg = helper.create_global_variable( + persistable=True, + dtype='int64', + shape=[slide_steps, num_thresholds + 1]) + + # for global auc stat_pos = helper.create_global_variable( - persistable=True, dtype='int64', shape=[num_thresholds + 1]) + persistable=True, dtype='int64', shape=[1, num_thresholds + 1]) stat_neg = helper.create_global_variable( - persistable=True, dtype='int64', shape=[num_thresholds + 1]) + persistable=True, dtype='int64', shape=[1, num_thresholds + 1]) - for var in [stat_pos, stat_neg]: + for var in [batch_stat_pos, batch_stat_neg, stat_pos, stat_neg]: helper.set_variable_initializer( var, Constant( value=0.0, force_cpu=True)) + # Batch AUC + helper.append_op( + type="auc", + inputs={ + "Predict": [input], + "Label": [label], + "StatPos": [batch_stat_pos], + "StatNeg": [batch_stat_neg] + }, + attrs={ + "curve": curve, + "num_thresholds": num_thresholds, + "slide_steps": slide_steps + }, + outputs={ + "AUC": [batch_auc_out], + "StatPosOut": [batch_stat_pos], + "StatNegOut": [batch_stat_neg] + }) + # Global AUC helper.append_op( type="auc", inputs={ @@ -138,12 +177,16 @@ def auc(input, label, curve='ROC', num_thresholds=2**12 - 1, topk=1): "StatPos": [stat_pos], "StatNeg": [stat_neg] }, - attrs={"curve": curve, - "num_thresholds": num_thresholds}, + attrs={ + "curve": curve, + "num_thresholds": num_thresholds, + "slide_steps": 0 + }, outputs={ "AUC": [auc_out], - "BatchAUC": [batch_auc_out], "StatPosOut": [stat_pos], "StatNegOut": [stat_neg] }) - return auc_out, batch_auc_out, [stat_pos, stat_neg] + return auc_out, batch_auc_out, [ + batch_stat_pos, batch_stat_neg, stat_pos, stat_neg + ] diff --git a/python/paddle/fluid/tests/unittests/test_auc_op.py b/python/paddle/fluid/tests/unittests/test_auc_op.py index 1de4a9d016a177944253d12094722d3a05614be2..810e8a1a8547a92de923877695178e780981edeb 100644 --- a/python/paddle/fluid/tests/unittests/test_auc_op.py +++ b/python/paddle/fluid/tests/unittests/test_auc_op.py @@ -36,7 +36,11 @@ class TestAucOp(OpTest): "StatPos": stat_pos, "StatNeg": stat_neg } - self.attrs = {'curve': 'ROC', 'num_thresholds': num_thresholds} + self.attrs = { + 'curve': 'ROC', + 'num_thresholds': num_thresholds, + "slide_steps": 1 + } python_auc = metrics.Auc(name="auc", curve='ROC', @@ -45,7 +49,6 @@ class TestAucOp(OpTest): self.outputs = { 'AUC': np.array(python_auc.eval()), - 'BatchAUC': np.array(python_auc.eval()), 'StatPosOut': np.array(python_auc._stat_pos), 'StatNegOut': np.array(python_auc._stat_neg) } diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index d9cc709f7428abcf415fe9a27b0d48a157ed0b41..3ddc1f3addbd722df498008432cbda9be986c1a4 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -39,8 +39,8 @@ import six from .ps_dispatcher import RoundRobin, HashName, PSDispatcher from .. import core, framework from ..framework import Program, default_main_program, \ - default_startup_program, Block, \ - Parameter, grad_var_name + default_startup_program, Block, \ + Parameter, grad_var_name from .details import * from functools import reduce @@ -178,7 +178,7 @@ class DistributeTranspiler(object): pserver_program) elif role == "TRAINER": trainer_program = t.get_trainer_program() - + # for nccl2 mode config = fluid.DistributeTranspilerConfig() config.mode = "nccl2" @@ -534,7 +534,7 @@ class DistributeTranspiler(object): }) for varname, splited_var in six.iteritems(self.param_var_mapping): - #add concat ops to merge splited parameters received from parameter servers. + # add concat ops to merge splited parameters received from parameter servers. if len(splited_var) <= 1: continue # NOTE: if enable memory optimization, origin vars maybe removed. @@ -734,19 +734,14 @@ in a single call.") table_opt_block = self._create_table_optimize_block( pserver_index, pserver_program, pre_block_idx, grad_to_block_id) optimize_blocks.append(table_opt_block) - prefetch_var_name_to_block_id = self._create_prefetch_block( + lookup_table_var_name_to_block_id = self._create_prefetch_block( pserver_index, pserver_program, table_opt_block) checkpoint_block_id = self._create_checkpoint_save_block( pserver_program, table_opt_block.idx) pserver_program._distributed_lookup_table = self.table_name - - # NOTE: if has_distributed_lookup_table is False, then prefetch_block will - # not be executed, so it's safe to use optimize_block to hold the place - if self.has_distributed_lookup_table: - assert len(prefetch_var_name_to_block_id) > 0 - else: - assert len(prefetch_var_name_to_block_id) == 0 + prefetch_var_name_to_block_id.extend( + lookup_table_var_name_to_block_id) attrs = { "optimize_blocks": optimize_blocks, @@ -755,11 +750,14 @@ in a single call.") "sync_mode": self.sync_mode, "grad_to_block_id": grad_to_block_id, } - if len(prefetch_var_name_to_block_id) > 0: - attrs['prefetch_var_name_to_block_id'] \ - = prefetch_var_name_to_block_id + + if self.has_distributed_lookup_table: attrs['checkpint_block_id'] = checkpoint_block_id + if len(prefetch_var_name_to_block_id) > 0: + attrs[ + 'prefetch_var_name_to_block_id'] = prefetch_var_name_to_block_id + # step5 append the listen_and_serv op pserver_program.global_block().append_op( type="listen_and_serv", @@ -1013,7 +1011,7 @@ to transpile() call.") for g, p in zip(grad_blocks, param_blocks): g_name, g_bid, _ = g.split(":") p_name, p_bid, _ = p.split(":") - self.grad_param_mapping[self.grad_var_mapping[g_name][int(g_bid)]] = \ + self.grad_param_mapping[self.grad_var_mapping[g_name][int(g_bid)]] = \ self.param_var_mapping[p_name][int(p_bid)] # create mapping of endpoint -> split var to create pserver side program @@ -1320,7 +1318,7 @@ to transpile() call.") if len(splited) == 1: if self.sync_mode and add_trainer_suffix: new_var_name = "%s.trainer_%d" % \ - (orig_var.name, self.trainer_id) + (orig_var.name, self.trainer_id) program.global_block()._rename_var(varname, new_var_name) var_mapping[varname] = \ [program.global_block().var(new_var_name)] @@ -1343,10 +1341,10 @@ to transpile() call.") new_var_name = "" if self.sync_mode and add_trainer_suffix: new_var_name = "%s.block%d.trainer_%d" % \ - (varname, i, self.trainer_id) + (varname, i, self.trainer_id) else: new_var_name = "%s.block%d" % \ - (varname, i) + (varname, i) var = program.global_block().create_var( name=new_var_name, persistable=False, @@ -1484,7 +1482,7 @@ to transpile() call.") vars2merge = [] for i in range(self.trainer_num): per_trainer_name = "%s.trainer_%d" % \ - (merged_var_name, i) + (merged_var_name, i) vars2merge.append(pserver_block.vars[per_trainer_name]) optimize_block.append_op( @@ -1645,7 +1643,7 @@ to transpile() call.") # one op's output is another op's input, we say # the two operator is connected. if set(op1.desc.output_arg_names()) & set(op2.desc.input_arg_names()) or \ - set(op1.desc.input_arg_names()) & set(op2.desc.output_arg_names()): + set(op1.desc.input_arg_names()) & set(op2.desc.output_arg_names()): return True return False @@ -1662,7 +1660,7 @@ to transpile() call.") def _is_optimizer_op(self, op): if "Param" in op.input_names and \ - "LearningRate" in op.input_names: + "LearningRate" in op.input_names: return True return False @@ -1737,7 +1735,7 @@ to transpile() call.") # NOTE: we need to skip all optimize ops, since it is connected # with forward/backward ops and lr ops, we only need the lr ops. if op1 != op2 and self._is_op_connected(op1, op2) and \ - not self._is_optimizer_op(op1) and not self._is_optimizer_op(op2): + not self._is_optimizer_op(op1) and not self._is_optimizer_op(op2): ufind.union(op1, op2) # find all ops which is related with lr var for op1 in block.ops: