提交 3f7a7eac 编写于 作者: T tangwei12

Batch AUC (#13567)

* add distributed auc

* add attr "is distributed" and config it

* add distributed auc

* add batch auc and code format

* code format

* auc optimize

* metric_op optimize

* code clean

* bug fix and code clean

* bug fix and code clean

* code optimize

* code optimize

* api spec update

* Comments optimized

* add mutex

* Revert: add mutex

* remove distribute metric

* remove distribute metric

* spec modifyed

* add annotation, test=develop

* keep API compatibility
test=develop
上级 644bad1d
...@@ -286,7 +286,7 @@ paddle.fluid.layers.iou_similarity ArgSpec(args=[], varargs='args', keywords='kw ...@@ -286,7 +286,7 @@ paddle.fluid.layers.iou_similarity ArgSpec(args=[], varargs='args', keywords='kw
paddle.fluid.layers.box_coder ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None) paddle.fluid.layers.box_coder ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None)
paddle.fluid.layers.polygon_box_transform ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None) paddle.fluid.layers.polygon_box_transform ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None)
paddle.fluid.layers.accuracy ArgSpec(args=['input', 'label', 'k', 'correct', 'total'], varargs=None, keywords=None, defaults=(1, None, None)) paddle.fluid.layers.accuracy ArgSpec(args=['input', 'label', 'k', 'correct', 'total'], varargs=None, keywords=None, defaults=(1, None, None))
paddle.fluid.layers.auc ArgSpec(args=['input', 'label', 'curve', 'num_thresholds', 'topk'], varargs=None, keywords=None, defaults=('ROC', 4095, 1)) paddle.fluid.layers.auc ArgSpec(args=['input', 'label', 'curve', 'num_thresholds', 'topk', 'slide_steps'], varargs=None, keywords=None, defaults=('ROC', 4095, 1, 1))
paddle.fluid.layers.exponential_decay ArgSpec(args=['learning_rate', 'decay_steps', 'decay_rate', 'staircase'], varargs=None, keywords=None, defaults=(False,)) paddle.fluid.layers.exponential_decay ArgSpec(args=['learning_rate', 'decay_steps', 'decay_rate', 'staircase'], varargs=None, keywords=None, defaults=(False,))
paddle.fluid.layers.natural_exp_decay ArgSpec(args=['learning_rate', 'decay_steps', 'decay_rate', 'staircase'], varargs=None, keywords=None, defaults=(False,)) paddle.fluid.layers.natural_exp_decay ArgSpec(args=['learning_rate', 'decay_steps', 'decay_rate', 'staircase'], varargs=None, keywords=None, defaults=(False,))
paddle.fluid.layers.inverse_time_decay ArgSpec(args=['learning_rate', 'decay_steps', 'decay_rate', 'staircase'], varargs=None, keywords=None, defaults=(False,)) paddle.fluid.layers.inverse_time_decay ArgSpec(args=['learning_rate', 'decay_steps', 'decay_rate', 'staircase'], varargs=None, keywords=None, defaults=(False,))
......
...@@ -36,11 +36,16 @@ class AucOp : public framework::OperatorWithKernel { ...@@ -36,11 +36,16 @@ class AucOp : public framework::OperatorWithKernel {
"Out and Label should have same height."); "Out and Label should have same height.");
int num_pred_buckets = ctx->Attrs().Get<int>("num_thresholds") + 1; int num_pred_buckets = ctx->Attrs().Get<int>("num_thresholds") + 1;
int slide_steps = ctx->Attrs().Get<int>("slide_steps");
PADDLE_ENFORCE_GE(num_pred_buckets, 1, "num_thresholds must larger than 1");
PADDLE_ENFORCE_GE(slide_steps, 0, "slide_steps must be natural number");
ctx->SetOutputDim("AUC", {1}); ctx->SetOutputDim("AUC", {1});
ctx->SetOutputDim("BatchAUC", {1});
ctx->SetOutputDim("StatPosOut", {num_pred_buckets}); slide_steps = slide_steps == 0 ? 1 : slide_steps;
ctx->SetOutputDim("StatNegOut", {num_pred_buckets}); ctx->SetOutputDim("StatPosOut", {slide_steps, num_pred_buckets});
ctx->SetOutputDim("StatNegOut", {slide_steps, num_pred_buckets});
} }
protected: protected:
...@@ -62,6 +67,7 @@ class AucOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -62,6 +67,7 @@ class AucOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("Label", AddInput("Label",
"A 2D int tensor indicating the label of the training data. " "A 2D int tensor indicating the label of the training data. "
"shape: [batch_size, 1]"); "shape: [batch_size, 1]");
// TODO(typhoonzero): support weight input // TODO(typhoonzero): support weight input
AddInput("StatPos", "Statistic value when label = 1"); AddInput("StatPos", "Statistic value when label = 1");
AddInput("StatNeg", "Statistic value when label = 0"); AddInput("StatNeg", "Statistic value when label = 0");
...@@ -69,18 +75,19 @@ class AucOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -69,18 +75,19 @@ class AucOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput("AUC", AddOutput("AUC",
"A scalar representing the " "A scalar representing the "
"current area-under-the-curve."); "current area-under-the-curve.");
AddOutput("BatchAUC", "The AUC for current batch");
AddOutput("StatPosOut", "Statistic value when label = 1"); AddOutput("StatPosOut", "Statistic value when label = 1");
AddOutput("StatNegOut", "Statistic value when label = 0"); AddOutput("StatNegOut", "Statistic value when label = 0");
AddAttr<std::string>("curve", "Curve type, can be 'ROC' or 'PR'.") AddAttr<std::string>("curve", "Curve type, can be 'ROC' or 'PR'.")
.SetDefault("ROC"); .SetDefault("ROC");
AddAttr<int>("num_thresholds", AddAttr<int>(
"The number of thresholds to use when discretizing the" "num_thresholds",
" roc curve.") "The number of thresholds to use when discretizing the roc curve.")
.SetDefault((2 << 12) - 1); .SetDefault((2 << 12) - 1);
AddAttr<int>("slide_steps", "Use slide steps to calc batch auc.")
.SetDefault(1);
AddComment(R"DOC( AddComment(R"DOC(
Area Under The Curve (AUC) Operator. Area Under The Curve (AUC) Operator.
......
...@@ -32,7 +32,9 @@ class AucKernel : public framework::OpKernel<T> { ...@@ -32,7 +32,9 @@ class AucKernel : public framework::OpKernel<T> {
std::string curve = ctx.Attr<std::string>("curve"); std::string curve = ctx.Attr<std::string>("curve");
int num_thresholds = ctx.Attr<int>("num_thresholds"); int num_thresholds = ctx.Attr<int>("num_thresholds");
// buckets contain numbers from 0 to num_thresholds
int num_pred_buckets = num_thresholds + 1; int num_pred_buckets = num_thresholds + 1;
int slide_steps = ctx.Attr<int>("slide_steps");
// Only use output var for now, make sure it's persistable and // Only use output var for now, make sure it's persistable and
// not cleaned up for each batch. // not cleaned up for each batch.
...@@ -40,16 +42,19 @@ class AucKernel : public framework::OpKernel<T> { ...@@ -40,16 +42,19 @@ class AucKernel : public framework::OpKernel<T> {
auto *stat_pos = ctx.Output<Tensor>("StatPosOut"); auto *stat_pos = ctx.Output<Tensor>("StatPosOut");
auto *stat_neg = ctx.Output<Tensor>("StatNegOut"); auto *stat_neg = ctx.Output<Tensor>("StatNegOut");
auto *stat_pos_data = stat_pos->mutable_data<int64_t>(ctx.GetPlace()); auto *origin_stat_pos = stat_pos->mutable_data<int64_t>(ctx.GetPlace());
auto *stat_neg_data = stat_neg->mutable_data<int64_t>(ctx.GetPlace()); auto *origin_stat_neg = stat_neg->mutable_data<int64_t>(ctx.GetPlace());
calcAuc(ctx, label, predict, stat_pos_data, stat_neg_data, num_thresholds,
auc);
auto *batch_auc = ctx.Output<Tensor>("BatchAUC"); std::vector<int64_t> stat_pos_data(num_pred_buckets, 0);
std::vector<int64_t> stat_pos_batch(num_pred_buckets, 0); std::vector<int64_t> stat_neg_data(num_pred_buckets, 0);
std::vector<int64_t> stat_neg_batch(num_pred_buckets, 0);
calcAuc(ctx, label, predict, stat_pos_batch.data(), stat_neg_batch.data(), auto stat_pos_calc = stat_pos_data.data();
num_thresholds, batch_auc); auto stat_neg_calc = stat_neg_data.data();
statAuc(label, predict, num_pred_buckets, num_thresholds, slide_steps,
origin_stat_pos, origin_stat_neg, &stat_pos_calc, &stat_neg_calc);
calcAuc(ctx, stat_pos_calc, stat_neg_calc, num_thresholds, auc);
} }
private: private:
...@@ -58,29 +63,76 @@ class AucKernel : public framework::OpKernel<T> { ...@@ -58,29 +63,76 @@ class AucKernel : public framework::OpKernel<T> {
return (X1 > X2 ? (X1 - X2) : (X2 - X1)) * (Y1 + Y2) / 2.0; return (X1 > X2 ? (X1 - X2) : (X2 - X1)) * (Y1 + Y2) / 2.0;
} }
inline static void calcAuc(const framework::ExecutionContext &ctx, inline static void statAuc(const framework::Tensor *label,
const framework::Tensor *label,
const framework::Tensor *predict, const framework::Tensor *predict,
int64_t *stat_pos, int64_t *stat_neg, const int num_pred_buckets,
int num_thresholds, const int num_thresholds, const int slide_steps,
framework::Tensor *auc_tensor) { int64_t *origin_stat_pos, int64_t *origin_stat_neg,
int64_t **stat_pos, int64_t **stat_neg) {
size_t batch_size = predict->dims()[0]; size_t batch_size = predict->dims()[0];
size_t inference_width = predict->dims()[1]; size_t inference_width = predict->dims()[1];
const T *inference_data = predict->data<T>(); const T *inference_data = predict->data<T>();
const auto *label_data = label->data<int64_t>(); const auto *label_data = label->data<int64_t>();
auto *auc = auc_tensor->mutable_data<double>(ctx.GetPlace());
for (size_t i = 0; i < batch_size; i++) { for (size_t i = 0; i < batch_size; i++) {
uint32_t binIdx = static_cast<uint32_t>( uint32_t binIdx = static_cast<uint32_t>(
inference_data[i * inference_width + 1] * num_thresholds); inference_data[i * inference_width + 1] * num_thresholds);
if (label_data[i]) { if (label_data[i]) {
stat_pos[binIdx] += 1.0; (*stat_pos)[binIdx] += 1.0;
} else { } else {
stat_neg[binIdx] += 1.0; (*stat_neg)[binIdx] += 1.0;
} }
} }
int bucket_length = num_pred_buckets * sizeof(int64_t);
// will stat auc unlimited.
if (slide_steps == 0) {
for (int slide = 0; slide < num_pred_buckets; ++slide) {
origin_stat_pos[slide] += (*stat_pos)[slide];
origin_stat_neg[slide] += (*stat_neg)[slide];
}
*stat_pos = origin_stat_pos;
*stat_neg = origin_stat_neg;
} else {
for (int slide = 1; slide < slide_steps; ++slide) {
int dst_idx = (slide - 1) * num_pred_buckets;
int src_inx = slide * num_pred_buckets;
std::memcpy(origin_stat_pos + dst_idx, origin_stat_pos + src_inx,
bucket_length);
std::memcpy(origin_stat_neg + dst_idx, origin_stat_neg + src_inx,
bucket_length);
}
std::memcpy(origin_stat_pos + (slide_steps - 1) * num_pred_buckets,
*stat_pos, bucket_length);
std::memcpy(origin_stat_neg + (slide_steps - 1) * num_pred_buckets,
*stat_neg, bucket_length);
std::memset(*stat_pos, 0, bucket_length);
std::memset(*stat_neg, 0, bucket_length);
for (int slide = 0; slide < num_pred_buckets; ++slide) {
int stat_pos_steps = 0;
int stat_neg_steps = 0;
for (int step = 0; step < slide_steps; ++step) {
stat_pos_steps += origin_stat_pos[slide + step * num_pred_buckets];
stat_neg_steps += origin_stat_neg[slide + step * num_pred_buckets];
}
(*stat_pos)[slide] += stat_pos_steps;
(*stat_neg)[slide] += stat_neg_steps;
}
}
}
inline static void calcAuc(const framework::ExecutionContext &ctx,
int64_t *stat_pos, int64_t *stat_neg,
int num_thresholds,
framework::Tensor *auc_tensor) {
auto *auc = auc_tensor->mutable_data<double>(ctx.GetPlace());
*auc = 0.0f; *auc = 0.0f;
double totPos = 0.0; double totPos = 0.0;
...@@ -96,7 +148,6 @@ class AucKernel : public framework::OpKernel<T> { ...@@ -96,7 +148,6 @@ class AucKernel : public framework::OpKernel<T> {
totPos += stat_pos[idx]; totPos += stat_pos[idx];
totNeg += stat_neg[idx]; totNeg += stat_neg[idx];
*auc += trapezoidArea(totNeg, totNegPrev, totPos, totPosPrev); *auc += trapezoidArea(totNeg, totNegPrev, totPos, totPosPrev);
--idx; --idx;
} }
......
...@@ -78,7 +78,12 @@ def accuracy(input, label, k=1, correct=None, total=None): ...@@ -78,7 +78,12 @@ def accuracy(input, label, k=1, correct=None, total=None):
return acc_out return acc_out
def auc(input, label, curve='ROC', num_thresholds=2**12 - 1, topk=1): def auc(input,
label,
curve='ROC',
num_thresholds=2**12 - 1,
topk=1,
slide_steps=1):
""" """
**Area Under the Curve (AUC) Layer** **Area Under the Curve (AUC) Layer**
...@@ -105,6 +110,8 @@ def auc(input, label, curve='ROC', num_thresholds=2**12 - 1, topk=1): ...@@ -105,6 +110,8 @@ def auc(input, label, curve='ROC', num_thresholds=2**12 - 1, topk=1):
num_thresholds(int): The number of thresholds to use when discretizing num_thresholds(int): The number of thresholds to use when discretizing
the roc curve. Default 200. the roc curve. Default 200.
topk(int): only topk number of prediction output will be used for auc. topk(int): only topk number of prediction output will be used for auc.
slide_steps: when calc batch auc, we can not only use step currently but the previous steps can be used. slide_steps=1 means use the current step, slide_steps=3 means use current step and the previous second steps, slide_steps=0 use all of the steps.
Returns: Returns:
Variable: A scalar representing the current AUC. Variable: A scalar representing the current AUC.
...@@ -120,16 +127,48 @@ def auc(input, label, curve='ROC', num_thresholds=2**12 - 1, topk=1): ...@@ -120,16 +127,48 @@ def auc(input, label, curve='ROC', num_thresholds=2**12 - 1, topk=1):
auc_out = helper.create_tmp_variable(dtype="float64") auc_out = helper.create_tmp_variable(dtype="float64")
batch_auc_out = helper.create_tmp_variable(dtype="float64") batch_auc_out = helper.create_tmp_variable(dtype="float64")
# make tp, tn, fp, fn persistable, so that can accumulate all batches. # make tp, tn, fp, fn persistable, so that can accumulate all batches.
# for batch auc
batch_stat_pos = helper.create_global_variable(
persistable=True,
dtype='int64',
shape=[slide_steps, num_thresholds + 1])
batch_stat_neg = helper.create_global_variable(
persistable=True,
dtype='int64',
shape=[slide_steps, num_thresholds + 1])
# for global auc
stat_pos = helper.create_global_variable( stat_pos = helper.create_global_variable(
persistable=True, dtype='int64', shape=[num_thresholds + 1]) persistable=True, dtype='int64', shape=[1, num_thresholds + 1])
stat_neg = helper.create_global_variable( stat_neg = helper.create_global_variable(
persistable=True, dtype='int64', shape=[num_thresholds + 1]) persistable=True, dtype='int64', shape=[1, num_thresholds + 1])
for var in [stat_pos, stat_neg]: for var in [batch_stat_pos, batch_stat_neg, stat_pos, stat_neg]:
helper.set_variable_initializer( helper.set_variable_initializer(
var, Constant( var, Constant(
value=0.0, force_cpu=True)) value=0.0, force_cpu=True))
# Batch AUC
helper.append_op(
type="auc",
inputs={
"Predict": [input],
"Label": [label],
"StatPos": [batch_stat_pos],
"StatNeg": [batch_stat_neg]
},
attrs={
"curve": curve,
"num_thresholds": num_thresholds,
"slide_steps": slide_steps
},
outputs={
"AUC": [batch_auc_out],
"StatPosOut": [batch_stat_pos],
"StatNegOut": [batch_stat_neg]
})
# Global AUC
helper.append_op( helper.append_op(
type="auc", type="auc",
inputs={ inputs={
...@@ -138,12 +177,16 @@ def auc(input, label, curve='ROC', num_thresholds=2**12 - 1, topk=1): ...@@ -138,12 +177,16 @@ def auc(input, label, curve='ROC', num_thresholds=2**12 - 1, topk=1):
"StatPos": [stat_pos], "StatPos": [stat_pos],
"StatNeg": [stat_neg] "StatNeg": [stat_neg]
}, },
attrs={"curve": curve, attrs={
"num_thresholds": num_thresholds}, "curve": curve,
"num_thresholds": num_thresholds,
"slide_steps": 0
},
outputs={ outputs={
"AUC": [auc_out], "AUC": [auc_out],
"BatchAUC": [batch_auc_out],
"StatPosOut": [stat_pos], "StatPosOut": [stat_pos],
"StatNegOut": [stat_neg] "StatNegOut": [stat_neg]
}) })
return auc_out, batch_auc_out, [stat_pos, stat_neg] return auc_out, batch_auc_out, [
batch_stat_pos, batch_stat_neg, stat_pos, stat_neg
]
...@@ -36,7 +36,11 @@ class TestAucOp(OpTest): ...@@ -36,7 +36,11 @@ class TestAucOp(OpTest):
"StatPos": stat_pos, "StatPos": stat_pos,
"StatNeg": stat_neg "StatNeg": stat_neg
} }
self.attrs = {'curve': 'ROC', 'num_thresholds': num_thresholds} self.attrs = {
'curve': 'ROC',
'num_thresholds': num_thresholds,
"slide_steps": 1
}
python_auc = metrics.Auc(name="auc", python_auc = metrics.Auc(name="auc",
curve='ROC', curve='ROC',
...@@ -45,7 +49,6 @@ class TestAucOp(OpTest): ...@@ -45,7 +49,6 @@ class TestAucOp(OpTest):
self.outputs = { self.outputs = {
'AUC': np.array(python_auc.eval()), 'AUC': np.array(python_auc.eval()),
'BatchAUC': np.array(python_auc.eval()),
'StatPosOut': np.array(python_auc._stat_pos), 'StatPosOut': np.array(python_auc._stat_pos),
'StatNegOut': np.array(python_auc._stat_neg) 'StatNegOut': np.array(python_auc._stat_neg)
} }
......
...@@ -39,8 +39,8 @@ import six ...@@ -39,8 +39,8 @@ import six
from .ps_dispatcher import RoundRobin, HashName, PSDispatcher from .ps_dispatcher import RoundRobin, HashName, PSDispatcher
from .. import core, framework from .. import core, framework
from ..framework import Program, default_main_program, \ from ..framework import Program, default_main_program, \
default_startup_program, Block, \ default_startup_program, Block, \
Parameter, grad_var_name Parameter, grad_var_name
from .details import * from .details import *
from functools import reduce from functools import reduce
...@@ -178,7 +178,7 @@ class DistributeTranspiler(object): ...@@ -178,7 +178,7 @@ class DistributeTranspiler(object):
pserver_program) pserver_program)
elif role == "TRAINER": elif role == "TRAINER":
trainer_program = t.get_trainer_program() trainer_program = t.get_trainer_program()
# for nccl2 mode # for nccl2 mode
config = fluid.DistributeTranspilerConfig() config = fluid.DistributeTranspilerConfig()
config.mode = "nccl2" config.mode = "nccl2"
...@@ -534,7 +534,7 @@ class DistributeTranspiler(object): ...@@ -534,7 +534,7 @@ class DistributeTranspiler(object):
}) })
for varname, splited_var in six.iteritems(self.param_var_mapping): for varname, splited_var in six.iteritems(self.param_var_mapping):
#add concat ops to merge splited parameters received from parameter servers. # add concat ops to merge splited parameters received from parameter servers.
if len(splited_var) <= 1: if len(splited_var) <= 1:
continue continue
# NOTE: if enable memory optimization, origin vars maybe removed. # NOTE: if enable memory optimization, origin vars maybe removed.
...@@ -734,19 +734,14 @@ in a single call.") ...@@ -734,19 +734,14 @@ in a single call.")
table_opt_block = self._create_table_optimize_block( table_opt_block = self._create_table_optimize_block(
pserver_index, pserver_program, pre_block_idx, grad_to_block_id) pserver_index, pserver_program, pre_block_idx, grad_to_block_id)
optimize_blocks.append(table_opt_block) optimize_blocks.append(table_opt_block)
prefetch_var_name_to_block_id = self._create_prefetch_block( lookup_table_var_name_to_block_id = self._create_prefetch_block(
pserver_index, pserver_program, table_opt_block) pserver_index, pserver_program, table_opt_block)
checkpoint_block_id = self._create_checkpoint_save_block( checkpoint_block_id = self._create_checkpoint_save_block(
pserver_program, table_opt_block.idx) pserver_program, table_opt_block.idx)
pserver_program._distributed_lookup_table = self.table_name pserver_program._distributed_lookup_table = self.table_name
prefetch_var_name_to_block_id.extend(
# NOTE: if has_distributed_lookup_table is False, then prefetch_block will lookup_table_var_name_to_block_id)
# not be executed, so it's safe to use optimize_block to hold the place
if self.has_distributed_lookup_table:
assert len(prefetch_var_name_to_block_id) > 0
else:
assert len(prefetch_var_name_to_block_id) == 0
attrs = { attrs = {
"optimize_blocks": optimize_blocks, "optimize_blocks": optimize_blocks,
...@@ -755,11 +750,14 @@ in a single call.") ...@@ -755,11 +750,14 @@ in a single call.")
"sync_mode": self.sync_mode, "sync_mode": self.sync_mode,
"grad_to_block_id": grad_to_block_id, "grad_to_block_id": grad_to_block_id,
} }
if len(prefetch_var_name_to_block_id) > 0:
attrs['prefetch_var_name_to_block_id'] \ if self.has_distributed_lookup_table:
= prefetch_var_name_to_block_id
attrs['checkpint_block_id'] = checkpoint_block_id attrs['checkpint_block_id'] = checkpoint_block_id
if len(prefetch_var_name_to_block_id) > 0:
attrs[
'prefetch_var_name_to_block_id'] = prefetch_var_name_to_block_id
# step5 append the listen_and_serv op # step5 append the listen_and_serv op
pserver_program.global_block().append_op( pserver_program.global_block().append_op(
type="listen_and_serv", type="listen_and_serv",
...@@ -1013,7 +1011,7 @@ to transpile() call.") ...@@ -1013,7 +1011,7 @@ to transpile() call.")
for g, p in zip(grad_blocks, param_blocks): for g, p in zip(grad_blocks, param_blocks):
g_name, g_bid, _ = g.split(":") g_name, g_bid, _ = g.split(":")
p_name, p_bid, _ = p.split(":") p_name, p_bid, _ = p.split(":")
self.grad_param_mapping[self.grad_var_mapping[g_name][int(g_bid)]] = \ self.grad_param_mapping[self.grad_var_mapping[g_name][int(g_bid)]] = \
self.param_var_mapping[p_name][int(p_bid)] self.param_var_mapping[p_name][int(p_bid)]
# create mapping of endpoint -> split var to create pserver side program # create mapping of endpoint -> split var to create pserver side program
...@@ -1320,7 +1318,7 @@ to transpile() call.") ...@@ -1320,7 +1318,7 @@ to transpile() call.")
if len(splited) == 1: if len(splited) == 1:
if self.sync_mode and add_trainer_suffix: if self.sync_mode and add_trainer_suffix:
new_var_name = "%s.trainer_%d" % \ new_var_name = "%s.trainer_%d" % \
(orig_var.name, self.trainer_id) (orig_var.name, self.trainer_id)
program.global_block()._rename_var(varname, new_var_name) program.global_block()._rename_var(varname, new_var_name)
var_mapping[varname] = \ var_mapping[varname] = \
[program.global_block().var(new_var_name)] [program.global_block().var(new_var_name)]
...@@ -1343,10 +1341,10 @@ to transpile() call.") ...@@ -1343,10 +1341,10 @@ to transpile() call.")
new_var_name = "" new_var_name = ""
if self.sync_mode and add_trainer_suffix: if self.sync_mode and add_trainer_suffix:
new_var_name = "%s.block%d.trainer_%d" % \ new_var_name = "%s.block%d.trainer_%d" % \
(varname, i, self.trainer_id) (varname, i, self.trainer_id)
else: else:
new_var_name = "%s.block%d" % \ new_var_name = "%s.block%d" % \
(varname, i) (varname, i)
var = program.global_block().create_var( var = program.global_block().create_var(
name=new_var_name, name=new_var_name,
persistable=False, persistable=False,
...@@ -1484,7 +1482,7 @@ to transpile() call.") ...@@ -1484,7 +1482,7 @@ to transpile() call.")
vars2merge = [] vars2merge = []
for i in range(self.trainer_num): for i in range(self.trainer_num):
per_trainer_name = "%s.trainer_%d" % \ per_trainer_name = "%s.trainer_%d" % \
(merged_var_name, i) (merged_var_name, i)
vars2merge.append(pserver_block.vars[per_trainer_name]) vars2merge.append(pserver_block.vars[per_trainer_name])
optimize_block.append_op( optimize_block.append_op(
...@@ -1645,7 +1643,7 @@ to transpile() call.") ...@@ -1645,7 +1643,7 @@ to transpile() call.")
# one op's output is another op's input, we say # one op's output is another op's input, we say
# the two operator is connected. # the two operator is connected.
if set(op1.desc.output_arg_names()) & set(op2.desc.input_arg_names()) or \ if set(op1.desc.output_arg_names()) & set(op2.desc.input_arg_names()) or \
set(op1.desc.input_arg_names()) & set(op2.desc.output_arg_names()): set(op1.desc.input_arg_names()) & set(op2.desc.output_arg_names()):
return True return True
return False return False
...@@ -1662,7 +1660,7 @@ to transpile() call.") ...@@ -1662,7 +1660,7 @@ to transpile() call.")
def _is_optimizer_op(self, op): def _is_optimizer_op(self, op):
if "Param" in op.input_names and \ if "Param" in op.input_names and \
"LearningRate" in op.input_names: "LearningRate" in op.input_names:
return True return True
return False return False
...@@ -1737,7 +1735,7 @@ to transpile() call.") ...@@ -1737,7 +1735,7 @@ to transpile() call.")
# NOTE: we need to skip all optimize ops, since it is connected # NOTE: we need to skip all optimize ops, since it is connected
# with forward/backward ops and lr ops, we only need the lr ops. # with forward/backward ops and lr ops, we only need the lr ops.
if op1 != op2 and self._is_op_connected(op1, op2) and \ if op1 != op2 and self._is_op_connected(op1, op2) and \
not self._is_optimizer_op(op1) and not self._is_optimizer_op(op2): not self._is_optimizer_op(op1) and not self._is_optimizer_op(op2):
ufind.union(op1, op2) ufind.union(op1, op2)
# find all ops which is related with lr var # find all ops which is related with lr var
for op1 in block.ops: for op1 in block.ops:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册