提交 da61a5b6 编写于 作者: Q Qiao Longfei

Merge branch 'optimizer-prefetch' of https://github.com/seiriosPlus/Paddle into cpu-for-1.1-merge

......@@ -120,19 +120,25 @@ size_t GraphNum(const Graph &graph) {
std::deque<ir::Node *> q_nodes;
std::vector<std::unordered_set<ir::Node *>> graph_nodes;
std::unordered_set<ir::Node *> g_nodes;
// q_set used to record records in the queue.
std::unordered_set<ir::Node *> q_set;
size_t graph_count = 0;
auto traverse_nodes = [&visited_nodes,
&q_nodes](const std::vector<ir::Node *> &nodes) {
std::copy_if(
nodes.begin(), nodes.end(), std::back_inserter(q_nodes),
[&visited_nodes](Node *node) { return !visited_nodes.count(node); });
auto traverse_nodes = [&visited_nodes, &q_nodes,
&q_set](const std::vector<ir::Node *> &nodes) {
for (auto n : nodes) {
if (visited_nodes.count(n) == 0 && q_set.count(n) == 0) {
q_nodes.push_back(n);
q_set.insert(n);
}
}
};
while (visited_nodes.size() != nodes.size()) {
if (!q_nodes.empty()) {
auto cur_node = q_nodes.front();
q_nodes.pop_front();
q_set.erase(cur_node);
visited_nodes.insert(cur_node);
g_nodes.insert(cur_node);
traverse_nodes(cur_node->inputs);
......@@ -146,6 +152,7 @@ size_t GraphNum(const Graph &graph) {
for (auto &n : nodes) {
if (visited_nodes.count(n) == 0) {
q_nodes.push_back(n);
q_set.insert(n);
break;
}
}
......
......@@ -28,12 +28,12 @@ enum class OpRole {
kBackward = 0x0001,
kOptimize = 0x0002,
// RPC role is for send/recv releated op
kRPC = 0x0003,
kRPC = 0x0004,
// Dist role is for split_byref/split_selected_rows/concat
// used for distributed training.
kDist = 0x0004,
kDist = 0x0008,
// Tag all learning rate scheduler operators.
kLRSched = 0x0005,
kLRSched = 0x0016,
kLoss = 0x0100,
// The default value of op's role. This should be only used for unittests and
......
......@@ -156,6 +156,12 @@ ParallelExecutor::ParallelExecutor(
params, member_->local_scopes_, member_->use_cuda_);
#endif
// If the loss_var_name is given, the number of graph should be only one.
if (loss_var_name.size()) {
PADDLE_ENFORCE_EQ(ir::GraphNum(*graph), 1,
"The number of graph should be only one");
}
if (exec_strategy.type_ == ExecutionStrategy::kDefault) {
member_->executor_.reset(new details::ThreadedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, places, std::move(graph)));
......
......@@ -20,13 +20,16 @@ namespace operators {
class MergeIdsOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Ids", "(LoDTensor) the input ids with shape{batch_num, 1}");
AddInput(
"X",
"(LoDTensors) multi input tensor with shape{batch_num, N}, N is the "
AddInput("Ids", "(LoDTensor) the input ids with shape{batch_num, 1}")
.AsDuplicable();
AddInput("Rows", "(LoDTensor) the input ids with shape{row_size, 1}, ")
.AsDuplicable();
AddInput("X",
"(LoDTensors) multi input tensor with shape{Rows, N}, N is the "
"size of embedding table")
.AsDuplicable();
AddOutput("Out", "(LoDTensor) The merged outputs of the input tensors.");
AddOutput("Out", "(LoDTensor) The merged outputs of the input tensors.")
.AsDuplicable();
AddComment(R"DOC(
Merge multi LoDTensor's into one according to Ids's shard num.
......@@ -79,15 +82,19 @@ class MergeIdsOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Ids"), "MergeIdsOp must has input Ids.");
PADDLE_ENFORCE(ctx->HasInputs("X"), "MergeIdsOp must has input X.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), "MergeIdsOp must has output Out.");
PADDLE_ENFORCE(ctx->HasInputs("Ids"),
"MergeIdsOp must has multi input Ids.");
PADDLE_ENFORCE(ctx->HasInputs("Rows"),
"MergeIdsOp must has multi input Rows.");
PADDLE_ENFORCE(ctx->HasInputs("X"), "MergeIdsOp must has multi input X.");
PADDLE_ENFORCE(ctx->HasOutputs("Out"),
"MergeIdsOp must has multi output Out.");
auto ids_var_type = ctx->GetInputsVarType("Ids").front();
auto ids_dims = ctx->GetInputDim("Ids");
auto ids_dims = ctx->GetInputsDim("Ids");
if (ids_var_type == framework::proto::VarType::LOD_TENSOR) {
PADDLE_ENFORCE_EQ(ids_dims.size(), 2);
PADDLE_ENFORCE_EQ(ids_dims[1], 1);
PADDLE_ENFORCE_EQ(ids_dims[0].size(), 2);
PADDLE_ENFORCE_EQ(ids_dims[0][1], 1);
}
auto x_var_type = ctx->GetInputsVarType("X");
for (auto &var_type : x_var_type) {
......
......@@ -14,6 +14,8 @@ limitations under the License. */
#pragma once
#include <tuple>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor_util.h"
......@@ -30,59 +32,70 @@ class MergeIdsOpKernel : public framework::OpKernel<T> {
if (!platform::is_cpu_place(place)) {
PADDLE_THROW("MergeIds do not support GPU kernel");
}
VLOG(3) << "run in MergeIdsOpKernel";
const auto *ids_var = ctx.InputVar("Ids");
PADDLE_ENFORCE(ids_var->IsType<framework::LoDTensor>(),
"only support to merge Ids of LoDTensor");
const auto ids = ctx.MultiInput<framework::LoDTensor>("Ids");
const auto row_ids = ctx.MultiInput<framework::LoDTensor>("Rows");
const auto x_tensors = ctx.MultiInput<framework::LoDTensor>("X");
auto outs = ctx.MultiOutput<framework::LoDTensor>("Out");
const auto &ids_tensor = ids_var->Get<framework::LoDTensor>();
const auto &ids_dims = ids_tensor.dims();
const int64_t *ids = ids_tensor.data<int64_t>();
PADDLE_ENFORCE_EQ(row_ids.size(), x_tensors.size(),
"the number of Rows and X should be the same");
PADDLE_ENFORCE_EQ(ids.size(), outs.size(),
"the number of Ids and Out should be the same");
auto x_tensors = ctx.MultiInput<framework::LoDTensor>("X");
int row_ids_size = 0;
int row_size = 0;
int embedding_size = 0;
auto *out = ctx.Output<framework::LoDTensor>("Out");
for (int i = 0; i < x_tensors.size(); ++i) {
const auto *x_tensor = x_tensors[i];
const auto *row_id = row_ids[i];
int batch_size = 0;
int embedding_size = 0;
for (auto &input : x_tensors) {
if (framework::product(input->dims()) != 0) {
if (embedding_size == 0) {
embedding_size = input->dims()[1];
embedding_size = x_tensor->dims()[1];
}
PADDLE_ENFORCE_EQ(embedding_size, input->dims()[1],
PADDLE_ENFORCE_EQ(embedding_size, x_tensor->dims()[1],
"embedding size of all input should be the same");
batch_size += input->dims()[0];
}
row_size += x_tensor->dims()[0];
row_ids_size += row_id->dims()[0];
}
PADDLE_ENFORCE_EQ(
batch_size, ids_dims[0],
"the batch size of ids and merged embedding value should be the same");
row_size, row_ids_size,
"the merged X dim[0] and merged Rows dim[0] should be the same");
std::unordered_map<int64_t, std::tuple<int64_t, int64_t>>
selected_rows_idx_map;
for (int i = 0; i < x_tensors.size(); ++i) {
const auto *row_id = row_ids[i];
for (int j = 0; j < row_id->numel(); ++j) {
int64_t key = row_id->data<int64_t>()[j];
std::tuple<int64_t, int64_t> val = std::make_tuple(i, j);
selected_rows_idx_map.insert(std::make_pair(key, val));
}
}
PADDLE_ENFORCE_EQ(row_ids_size, selected_rows_idx_map.size(),
"the rows and tensor map size should be the same");
for (int i = 0; i < outs.size(); ++i) {
auto *out_ids = ids[i];
auto *out = outs[i];
const size_t shard_num = x_tensors.size();
out->set_lod(out_ids->lod());
if (shard_num == 1) {
VLOG(3) << "only one shard, we can copy the data directly";
TensorCopy(*x_tensors[0], place, out);
} else {
std::vector<int> in_indexs(shard_num, 0);
int nums = static_cast<int>(out_ids->dims()[0]);
auto *out_data = out->mutable_data<T>(
framework::make_ddim({batch_size, embedding_size}), place);
// copy data from ins[shard_num] to out.
for (int i = 0; i < ids_dims[0]; ++i) {
int64_t id = ids[i];
size_t shard_id = static_cast<size_t>(id) % shard_num;
int index = in_indexs[shard_id];
memcpy(out_data + embedding_size * i,
x_tensors[shard_id]->data<T>() + index * embedding_size,
framework::make_ddim({nums, embedding_size}), place);
for (int j = 0; j < nums; ++j) {
int id = out_ids->data<int64_t>()[j];
auto row_tuple = selected_rows_idx_map[id];
int64_t row_idx = std::get<1>(row_tuple);
const auto *x_tensor = x_tensors[std::get<0>(row_tuple)];
memcpy(out_data + embedding_size * j,
x_tensor->data<T>() + row_idx * embedding_size,
sizeof(T) * embedding_size);
in_indexs[shard_id] += 1;
}
for (size_t i = 0; i < shard_num; ++i) {
PADDLE_ENFORCE_EQ(in_indexs[i], x_tensors[i]->dims()[0],
"after merge, all data in x_tensor should be used");
}
}
}
......
......@@ -20,17 +20,24 @@ namespace operators {
class SplitIdsOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Ids", "(LoDTensor) the input ids with shape{batch_num, 1}");
AddOutput("Out", "(LoDTensor) The outputs of the input Ids.")
AddInput("Ids", "(LoDTensor) the input ids with shape{batch_num, 1}")
.AsDuplicable();
AddOutput("Out", "(LoDTensors) The outputs of the input Ids.")
.AsDuplicable();
AddComment(R"DOC(
Split a LoDTensor of Ids into multi LoDTensors, the number is pserver's number
Example:
Input:
X = [1,2,3,4,5,6]
X = [[1,2,3,4,5,6],[2,3]]
Out(3 output):
if compress is True:
out0 = [3, 3, 6]
out1 = [1, 4]
out2 = [2, 2, 5]
else:
out0 = [3, 6]
out1 = [1, 4]
out2 = [2, 5]
......@@ -43,16 +50,24 @@ class SplitIdsOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Ids"), "SplitIdsOp must has input Ids.");
PADDLE_ENFORCE(ctx->HasInputs("Ids"), "SplitIdsOp must has input Ids.");
PADDLE_ENFORCE(ctx->HasOutputs("Out"), "SplitIdsOp must has output Out.");
auto ids_var_type = ctx->GetInputsVarType("Ids").front();
auto ids_dims = ctx->GetInputDim("Ids");
auto ids_dims = ctx->GetInputsDim("Ids");
if (ids_var_type == framework::proto::VarType::LOD_TENSOR) {
PADDLE_ENFORCE_EQ(ids_dims.size(), 2);
PADDLE_ENFORCE_EQ(ids_dims[1], 1);
PADDLE_ENFORCE_EQ(ids_dims[0].size(), 2);
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::ToDataType(
ctx.MultiInput<framework::Tensor>("Ids").front()->type()),
ctx.GetPlace());
}
};
class SplitIdsOpInferVarType : public framework::VarTypeInference {
......@@ -66,12 +81,28 @@ class SplitIdsOpInferVarType : public framework::VarTypeInference {
}
};
class SplitIdsOpGradMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
auto grad = new framework::OpDesc();
grad->SetType("concat");
grad->SetInput("X", OutputGrad("Out"));
grad->SetOutput("Out", InputGrad("Ids"));
grad->SetAttr("axis", 0);
return std::unique_ptr<framework::OpDesc>(grad);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(split_ids, ops::SplitIdsOp, ops::SplitIdsOpMaker,
ops::SplitIdsOpInferVarType);
ops::SplitIdsOpGradMaker, ops::SplitIdsOpInferVarType);
REGISTER_OP_CPU_KERNEL(
split_ids, ops::SplitIdsOpKernel<paddle::platform::CPUPlace, int64_t>,
ops::SplitIdsOpKernel<paddle::platform::CPUPlace, float>);
......@@ -14,6 +14,8 @@ limitations under the License. */
#pragma once
#include <iterator>
#include <set>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
......@@ -31,19 +33,39 @@ class SplitIdsOpKernel : public framework::OpKernel<T> {
PADDLE_THROW("SplitIds do not support GPU kernel");
}
const auto *ids_var = ctx.InputVar("Ids");
const auto ids_vars = ctx.MultiInputVar("Ids");
PADDLE_ENFORCE_GT(ids_vars.size(), 0, "The number of Ids should > 0");
auto *ids_var = ids_vars[0];
if (ids_var->IsType<framework::LoDTensor>()) {
const auto &ids_dims = ctx.Input<framework::LoDTensor>("Ids")->dims();
const T *ids = ctx.Input<framework::LoDTensor>("Ids")->data<T>();
int batch_size = 0;
const auto ids_tensors = ctx.MultiInput<framework::LoDTensor>("Ids");
for (size_t i = 0; i < ids_tensors.size(); ++i) {
batch_size += ids_tensors[i]->dims()[0];
}
VLOG(4) << "Get Total BatchSize is: " << batch_size;
std::vector<T> all_ids(batch_size);
int offset = 0;
for (size_t i = 0; i < ids_tensors.size(); ++i) {
const auto *ids = ids_tensors[i];
std::memcpy(all_ids.data() + offset, ids->data<T>(),
ids->numel() * sizeof(T));
offset += ids->numel();
}
std::set<T> st(all_ids.begin(), all_ids.end());
all_ids.assign(st.begin(), st.end());
auto outs = ctx.MultiOutput<framework::LoDTensor>("Out");
const size_t shard_num = outs.size();
std::vector<std::vector<T>> out_ids;
out_ids.resize(outs.size());
// split id by their shard_num.
for (int i = 0; i < ids_dims[0]; ++i) {
T id = ids[i];
for (int i = 0; i < all_ids.size(); ++i) {
T id = all_ids[i];
size_t shard_id = static_cast<size_t>(id) % shard_num;
out_ids[shard_id].push_back(id);
}
......@@ -64,7 +86,7 @@ class SplitIdsOpKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_EQ(ids_dims[0],
static_cast<int64_t>(ids_selected_rows->rows().size()),
"");
const T *ids = ids_selected_rows->value().data<T>();
const T *ids_data = ids_selected_rows->value().data<T>();
const auto &ids_rows = ids_selected_rows->rows();
auto outs = ctx.MultiOutput<framework::SelectedRows>("Out");
const size_t shard_num = outs.size();
......@@ -87,7 +109,7 @@ class SplitIdsOpKernel : public framework::OpKernel<T> {
T *output = out->mutable_value()->mutable_data<T>(ddim, place);
for (int64_t i = 0; i < ddim[0]; ++i) {
memcpy(output + i * row_width,
ids + id_to_index[out->rows()[i]] * row_width,
ids_data + id_to_index[out->rows()[i]] * row_width,
row_width * sizeof(T));
}
}
......
......@@ -316,7 +316,7 @@ class DetectionMAP(Evaluator):
gt_label (Variable): The ground truth label index, which is a LoDTensor
with shape [N, 1].
gt_box (Variable): The ground truth bounding box (bbox), which is a
LoDTensor with shape [N, 6]. The layout is [xmin, ymin, xmax, ymax].
LoDTensor with shape [N, 4]. The layout is [xmin, ymin, xmax, ymax].
gt_difficult (Variable|None): Whether this ground truth is a difficult
bounding bbox, which can be a LoDTensor [N, 1] or not set. If None,
it means all the ground truth labels are not difficult bbox.
......
......@@ -13,8 +13,6 @@
# limitations under the License.
"""
Fluid Metrics
The metrics are accomplished via Python natively.
"""
from __future__ import print_function
......@@ -24,6 +22,12 @@ import copy
import warnings
import six
from .layer_helper import LayerHelper
from .initializer import Constant
from . import unique_name
from .framework import Program, Variable, program_guard
from . import layers
__all__ = [
'MetricBase',
'CompositeMetric',
......@@ -478,67 +482,6 @@ class EditDistance(MetricBase):
return avg_distance, avg_instance_error
class DetectionMAP(MetricBase):
"""
Calculate the detection mean average precision (mAP).
mAP is the metric to measure the accuracy of object detectors
like Faster R-CNN, SSD, etc.
It is the average of the maximum precisions at different recall values.
Please get more information from the following articles:
https://sanchom.wordpress.com/tag/average-precision/
https://arxiv.org/abs/1512.02325
The general steps are as follows:
1. calculate the true positive and false positive according to the input
of detection and labels.
2. calculate mAP value, support two versions: '11 point' and 'integral'.
Examples:
.. code-block:: python
pred = fluid.layers.fc(input=data, size=1000, act="tanh")
batch_map = layers.detection_map(
input,
label,
class_num,
background_label,
overlap_threshold=overlap_threshold,
evaluate_difficult=evaluate_difficult,
ap_version=ap_version)
metric = fluid.metrics.DetectionMAP()
for data in train_reader():
loss, preds, labels = exe.run(fetch_list=[cost, batch_map])
batch_size = data[0]
metric.update(value=batch_map, weight=batch_size)
numpy_map = metric.eval()
"""
def __init__(self, name=None):
super(DetectionMAP, self).__init__(name)
# the current map value
self.value = .0
self.weight = .0
def update(self, value, weight):
if not _is_number_or_matrix_(value):
raise ValueError(
"The 'value' must be a number(int, float) or a numpy ndarray.")
if not _is_number_(weight):
raise ValueError("The 'weight' must be a number(int, float).")
self.value += value
self.weight += weight
def eval(self):
if self.weight == 0:
raise ValueError(
"There is no data in DetectionMAP Metrics. "
"Please check layers.detection_map output has added to DetectionMAP."
)
return self.value / self.weight
class Auc(MetricBase):
"""
Auc metric adapts to the binary classification.
......@@ -616,3 +559,179 @@ class Auc(MetricBase):
idx -= 1
return auc / tot_pos / tot_neg if tot_pos > 0.0 and tot_neg > 0.0 else 0.0
class DetectionMAP(object):
"""
Calculate the detection mean average precision (mAP).
The general steps are as follows:
1. calculate the true positive and false positive according to the input
of detection and labels.
2. calculate mAP value, support two versions: '11 point' and 'integral'.
Please get more information from the following articles:
https://sanchom.wordpress.com/tag/average-precision/
https://arxiv.org/abs/1512.02325
Args:
input (Variable): The detection results, which is a LoDTensor with shape
[M, 6]. The layout is [label, confidence, xmin, ymin, xmax, ymax].
gt_label (Variable): The ground truth label index, which is a LoDTensor
with shape [N, 1].
gt_box (Variable): The ground truth bounding box (bbox), which is a
LoDTensor with shape [N, 4]. The layout is [xmin, ymin, xmax, ymax].
gt_difficult (Variable|None): Whether this ground truth is a difficult
bounding bbox, which can be a LoDTensor [N, 1] or not set. If None,
it means all the ground truth labels are not difficult bbox.
class_num (int): The class number.
background_label (int): The index of background label, the background
label will be ignored. If set to -1, then all categories will be
considered, 0 by defalut.
overlap_threshold (float): The threshold for deciding true/false
positive, 0.5 by defalut.
evaluate_difficult (bool): Whether to consider difficult ground truth
for evaluation, True by defalut. This argument does not work when
gt_difficult is None.
ap_version (string): The average precision calculation ways, it must be
'integral' or '11point'. Please check
https://sanchom.wordpress.com/tag/average-precision/ for details.
- 11point: the 11-point interpolated average precision.
- integral: the natural integral of the precision-recall curve.
Examples:
.. code-block:: python
exe = fluid.Executor(place)
map_evaluator = fluid.Evaluator.DetectionMAP(input,
gt_label, gt_box, gt_difficult)
cur_map, accum_map = map_evaluator.get_map_var()
fetch = [cost, cur_map, accum_map]
for epoch in PASS_NUM:
map_evaluator.reset(exe)
for data in batches:
loss, cur_map_v, accum_map_v = exe.run(fetch_list=fetch)
In the above example:
'cur_map_v' is the mAP of current mini-batch.
'accum_map_v' is the accumulative mAP of one pass.
"""
def __init__(self,
input,
gt_label,
gt_box,
gt_difficult=None,
class_num=None,
background_label=0,
overlap_threshold=0.5,
evaluate_difficult=True,
ap_version='integral'):
self.helper = LayerHelper('map_eval')
gt_label = layers.cast(x=gt_label, dtype=gt_box.dtype)
if gt_difficult:
gt_difficult = layers.cast(x=gt_difficult, dtype=gt_box.dtype)
label = layers.concat([gt_label, gt_difficult, gt_box], axis=1)
else:
label = layers.concat([gt_label, gt_box], axis=1)
# calculate mean average precision (mAP) of current mini-batch
map = layers.detection_map(
input,
label,
class_num,
background_label,
overlap_threshold=overlap_threshold,
evaluate_difficult=evaluate_difficult,
ap_version=ap_version)
states = []
states.append(
self._create_state(
dtype='int32', shape=None, suffix='accum_pos_count'))
states.append(
self._create_state(
dtype='float32', shape=None, suffix='accum_true_pos'))
states.append(
self._create_state(
dtype='float32', shape=None, suffix='accum_false_pos'))
var = self._create_state(dtype='int32', shape=[1], suffix='has_state')
self.helper.set_variable_initializer(
var, initializer=Constant(value=int(0)))
self.has_state = var
# calculate accumulative mAP
accum_map = layers.detection_map(
input,
label,
class_num,
background_label,
overlap_threshold=overlap_threshold,
evaluate_difficult=evaluate_difficult,
has_state=self.has_state,
input_states=states,
out_states=states,
ap_version=ap_version)
layers.fill_constant(
shape=self.has_state.shape,
value=1,
dtype=self.has_state.dtype,
out=self.has_state)
self.cur_map = map
self.accum_map = accum_map
def _create_state(self, suffix, dtype, shape):
"""
Create state variable.
Args:
suffix(str): the state suffix.
dtype(str|core.VarDesc.VarType): the state data type
shape(tuple|list): the shape of state
Returns: State variable
"""
state = self.helper.create_variable(
name="_".join([unique_name.generate(self.helper.name), suffix]),
persistable=True,
dtype=dtype,
shape=shape)
return state
def get_map_var(self):
"""
Returns: mAP variable of current mini-batch and
accumulative mAP variable cross mini-batches.
"""
return self.cur_map, self.accum_map
def reset(self, executor, reset_program=None):
"""
Reset metric states at the begin of each pass/user specified batch.
Args:
executor(Executor): a executor for executing
the reset_program.
reset_program(Program|None): a single Program for reset process.
If None, will create a Program.
"""
def _clone_var_(block, var):
assert isinstance(var, Variable)
return block.create_var(
name=var.name,
shape=var.shape,
dtype=var.dtype,
type=var.type,
lod_level=var.lod_level,
persistable=var.persistable)
if reset_program is None:
reset_program = Program()
with program_guard(main_program=reset_program):
var = _clone_var_(reset_program.current_block(), self.has_state)
layers.fill_constant(
shape=var.shape, value=0, dtype=var.dtype, out=var)
executor.run(reset_program)
......@@ -23,8 +23,7 @@ class TestDistCTR2x2(TestDistBase):
self._sync_mode = True
self._enforce_place = "CPU"
def test_dist_ctr(self):
def test_dist_ctr(self):
self.check_with_place("dist_ctr.py", delta=1e-7, check_error_log=False)
......
......@@ -40,8 +40,7 @@ class TestDistMnistAsync(TestDistBase):
self._sync_mode = False
self._use_reduce = False
# FIXME(typhoonzero): fix async mode test later
def no_test_dist_train(self):
def test_dist_train(self):
self.check_with_place("dist_mnist.py", delta=200)
......
......@@ -40,8 +40,7 @@ class TestDistSeResneXt2x2Async(TestDistBase):
self._sync_mode = False
self._use_reader_alloc = False
#FIXME(typhoonzero): fix async mode later
def no_test_dist_train(self):
def test_dist_train(self):
self.check_with_place("dist_se_resnext.py", delta=100)
......
......@@ -79,8 +79,7 @@ class TestDistSimnetBow2x2SparseAsync(TestDistBase):
self._sync_mode = False
self._enforce_place = "CPU"
#FIXME(typhoonzero): fix async tests later
def no_test_simnet_bow(self):
def test_simnet_bow(self):
need_envs = {
"IS_DISTRIBUTED": '0',
"IS_SPARSE": '1',
......
......@@ -22,15 +22,28 @@ from op_test import OpTest
class TestMergeIdsOp(OpTest):
def setUp(self):
self.op_type = "merge_ids"
ids = np.array([[0], [2], [2], [3], [5], [5], [6]]).astype('int64')
x0 = np.array([[0.1, 0.2], [0.2, 0.3], [0.3, 0.4]]).astype('float32')
x1 = np.array([]).astype('float32')
x2 = np.array([[0.4, 0.5], [0.4, 0.5], [0.5, 0.6],
[0.5, 0.6]]).astype('float32')
out = np.array([[0.1, 0.2], [0.4, 0.5], [0.4, 0.5], [0.2, 0.3],
[0.5, 0.6], [0.5, 0.6], [0.3, 0.4]]).astype('float32')
self.inputs = {'Ids': ids, "X": [('x0', x0), ('x1', x1), ('x2', x2)]}
self.outputs = {'Out': out}
ids1 = np.array([[0], [2], [5], [6]]).astype('int64')
ids2 = np.array([[0], [2], [2], [3]]).astype('int64')
rows1 = np.array([[0], [2]]).astype('int64')
rows2 = np.array([[3], [5]]).astype('int64')
rows3 = np.array([[6]]).astype('int64')
x0 = np.array([[0.1, 0.2], [0.2, 0.3]]).astype('float32')
x1 = np.array([[0.3, 0.4], [0.4, 0.5]]).astype('float32')
x2 = np.array([[0.5, 0.6]]).astype('float32')
out1 = np.array(
[[0.1, 0.2], [0.2, 0.3], [0.4, 0.5], [0.5, 0.6]]).astype('float32')
out2 = np.array(
[[0.1, 0.2], [0.2, 0.3], [0.2, 0.3], [0.3, 0.4]]).astype('float32')
self.inputs = {
'Ids': [('ids1', ids1), ('ids2', ids2)],
"Rows": [('rows1', rows1), ('rows2', rows2), ('rows3', rows3)],
"X": [('x0', x0), ('x1', x1), ('x2', x2)]
}
self.outputs = {'Out': [('out1', out1), ('out2', out2)]}
def test_check_output(self):
self.check_output()
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle.fluid as fluid
from paddle.fluid.framework import Program, program_guard
class TestMetricsDetectionMap(unittest.TestCase):
def test_detection_map(self):
program = fluid.Program()
with program_guard(program):
detect_res = fluid.layers.data(
name='detect_res',
shape=[10, 6],
append_batch_size=False,
dtype='float32')
label = fluid.layers.data(
name='label',
shape=[10, 1],
append_batch_size=False,
dtype='float32')
box = fluid.layers.data(
name='bbox',
shape=[10, 4],
append_batch_size=False,
dtype='float32')
map_eval = fluid.metrics.DetectionMAP(
detect_res, label, box, class_num=21)
cur_map, accm_map = map_eval.get_map_var()
self.assertIsNotNone(cur_map)
self.assertIsNotNone(accm_map)
print(str(program))
if __name__ == '__main__':
unittest.main()
......@@ -25,18 +25,21 @@ from paddle.fluid.op import Operator
class TestSplitIdsOp(OpTest):
def setUp(self):
self.op_type = "split_ids"
ids = np.array([[0], [2], [2], [3], [5], [5], [6]]).astype('int64')
ids1 = np.array([[0], [2], [2], [3], [5], [5], [6]]).astype('int64')
ids2 = np.array([[6], [2], [3], [3], [5], [2], [6]]).astype('int64')
ids3 = np.array([[2], [2], [2], [3], [5], [5], [6]]).astype('int64')
out0 = np.array([[0], [3], [6]]).astype('int64')
out1 = np.array([[]]).astype('int64')
out2 = np.array([[2], [2], [5], [5]]).astype('int64')
self.inputs = {'Ids': ids}
out2 = np.array([[2], [5]]).astype('int64')
self.inputs = {'Ids': [('ids1', ids1), ('ids2', ids2), ('ids3', ids3)]}
self.outputs = {'Out': [('out0', out0), ('out1', out1), ('out2', out2)]}
def test_check_output(self):
self.check_output()
class TestSpliteIds(unittest.TestCase):
class TestSplitSelectedRows(unittest.TestCase):
def get_places(self):
places = [core.CPUPlace()]
return places
......
......@@ -1034,15 +1034,11 @@ to transpile() call.")
def _replace_lookup_table_op_with_prefetch(self, program,
pserver_endpoints):
# 1. replace lookup_table_op with split_ids_op -> prefetch_op -> sum_op
# self.all_prefetch_input_vars =
# [[var0_prefetch_in_pserver0, var0_prefetch_in_pserver1]
# [var1_prefetch_in_pserver0, var1_prefetch_in_pserver1]]
self.all_in_ids_vars = []
self.all_prefetch_input_vars = []
# self.all_prefetch_input_vars =
# [[var0_prefetch_in_pserver0, var0_prefetch_in_pserver1]
# [var1_prefetch_in_pserver0, var1_prefetch_in_pserver1]]
self.all_prefetch_output_vars = []
self.all_out_emb_vars = []
lookup_table_op_index = -1
continue_search_lookup_table_op = True
while continue_search_lookup_table_op:
......@@ -1052,42 +1048,50 @@ to transpile() call.")
if op.type == LOOKUP_TABLE_TYPE:
continue_search_lookup_table_op = True
lookup_table_op_index = list(all_ops).index(op)
lookup_table_op_index = lookup_table_op_index if lookup_table_op_index != -1 else list(
all_ops).index(op)
ids_name = op.input("Ids")
out_name = op.output("Out")
ids_var = program.global_block().vars[ids_name[0]]
prefetch_input_vars = self._create_splited_vars(
source_var=ids_var,
block=program.global_block(),
tag="_prefetch_in_")
self.all_prefetch_input_vars.append(prefetch_input_vars)
self.all_in_ids_vars.append(ids_var)
out_var = program.global_block().vars[out_name[0]]
prefetch_output_vars = self._create_splited_vars(
source_var=out_var,
block=program.global_block(),
tag="_prefetch_out_")
self.all_prefetch_output_vars.append(prefetch_output_vars)
self.all_out_emb_vars.append(out_var)
# delete lookup_table_op
delete_ops(program.global_block(), [op])
# break for loop
break
for index in range(len(self.pserver_endpoints)):
in_var = program.global_block().create_var(
name=str("prefetch_compress_in_tmp_" + str(index)),
type=self.all_in_ids_vars[0].type,
shape=self.all_in_ids_vars[0].shape,
dtype=self.all_in_ids_vars[0].dtype)
self.all_prefetch_input_vars.append(in_var)
out_var = program.global_block().create_var(
name=str("prefetch_compress_out_tmp_" + str(index)),
type=self.all_out_emb_vars[0].type,
shape=self.all_out_emb_vars[0].shape,
dtype=self.all_out_emb_vars[0].dtype)
self.all_prefetch_output_vars.append(out_var)
# insert split_ids_op
program.global_block()._insert_op(
index=lookup_table_op_index,
type="split_ids",
inputs={
'Ids': [
program.global_block().vars[varname]
for varname in ids_name
]
},
outputs={"Out": prefetch_input_vars})
inputs={'Ids': self.all_in_ids_vars},
outputs={"Out": self.all_prefetch_input_vars})
# insert prefetch_op
program.global_block()._insert_op(
index=lookup_table_op_index + 1,
type="prefetch",
inputs={'X': prefetch_input_vars},
outputs={"Out": prefetch_output_vars},
inputs={'X': self.all_prefetch_input_vars},
outputs={"Out": self.all_prefetch_output_vars},
attrs={
"epmap": pserver_endpoints,
# FIXME(qiao) temporarily disable this config because prefetch
......@@ -1100,23 +1104,11 @@ to transpile() call.")
index=lookup_table_op_index + 2,
type="merge_ids",
inputs={
'Ids': [
program.global_block().vars[varname]
for varname in ids_name
],
'X': prefetch_output_vars
'Ids': self.all_in_ids_vars,
'Rows': self.all_prefetch_input_vars,
'X': self.all_prefetch_output_vars
},
outputs={
"Out": [
program.global_block().vars[varname]
for varname in out_name
]
})
# delete lookup_table_op
delete_ops(program.global_block(), [op])
# break for loop
break
outputs={"Out": self.all_out_emb_vars})
def _split_table_grad_and_add_send_vars(self, program, pserver_endpoints):
# 2. add split_ids_op and send_op to send gradient to pservers
......@@ -1160,15 +1152,14 @@ to transpile() call.")
# STEP: create prefetch block
table_var = pserver_program.global_block().vars[self.table_name]
prefetch_var_name_to_block_id = []
for index in range(len(self.all_prefetch_input_vars)):
prefetch_block = pserver_program._create_block(optimize_block.idx)
trainer_ids = self.all_prefetch_input_vars[index][pserver_index]
trainer_ids = self.all_prefetch_input_vars[pserver_index]
pserver_ids = pserver_program.global_block().create_var(
name=trainer_ids.name,
type=trainer_ids.type,
shape=trainer_ids.shape,
dtype=trainer_ids.dtype)
trainer_out = self.all_prefetch_output_vars[index][pserver_index]
trainer_out = self.all_prefetch_output_vars[pserver_index]
pserver_out = pserver_program.global_block().create_var(
name=trainer_out.name,
type=trainer_out.type,
......@@ -1364,16 +1355,6 @@ to transpile() call.")
program.global_block()._sync_with_cpp()
return var_mapping
def _create_splited_vars(self, source_var, block, tag):
return [
block.create_var(
name=str(source_var.name + tag + str(index)),
type=source_var.type,
shape=source_var.shape,
dtype=source_var.dtype)
for index in range(len(self.pserver_endpoints))
]
def _clone_var(self, block, var, persistable=True):
return block.create_var(
name=var.name,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册