未验证 提交 e85fcaa7 编写于 作者: C Chengmo 提交者: GitHub

Fix fluid.embedding in Distributed Training (#25174)

* test=develop, fix_embedding
上级 c701588b
......@@ -209,16 +209,20 @@ void prefetchs(const std::vector<std::string>& id_var_names,
TableAndEndpoints tables;
for (auto& id_name : id_var_names) {
auto& id_tensor = scope.FindVar(id_name)->Get<framework::LoDTensor>();
auto* id_data = id_tensor.data<int64_t>();
auto* id_tensor =
scope.FindVar(id_name)->GetMutable<framework::LoDTensor>();
auto id_dims = id_tensor->dims();
id_tensor->Resize(framework::make_ddim(
{static_cast<int64_t>(id_dims[0] * id_dims[1]), 1}));
auto* id_data = id_tensor->data<int64_t>();
std::vector<int64_t> ids;
for (int64_t i = 0; i < id_tensor.numel(); ++i) {
for (int64_t i = 0; i < id_tensor->numel(); ++i) {
ids.push_back(id_data[i]);
ids_union.push_back(id_data[i]);
}
ids_group.push_back(ids);
ids_lods.push_back(id_tensor.lod());
ids_lods.push_back(id_tensor->lod());
}
std::unordered_set<int64_t> s(ids_union.begin(), ids_union.end());
......
......@@ -26,7 +26,7 @@ class DistributedLookupTableOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInputs("Ids"),
"Input(Ids) of LookupTableOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("W"),
......@@ -40,11 +40,9 @@ class DistributedLookupTableOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(table_dims.size(), 2,
"Only 2 dimensions of the 'Embedding' is supported.");
for (auto &ids_dim : ids_dims) {
for (auto& ids_dim : ids_dims) {
PADDLE_ENFORCE_EQ(ids_dim.size(), 2,
"The dimension of the 'Ids' tensor must be 2.");
PADDLE_ENFORCE_EQ(ids_dim[1], 1,
"The last dimension of the 'Ids' tensor must be 1.");
}
auto lookup_tables =
......@@ -52,6 +50,8 @@ class DistributedLookupTableOp : public framework::OperatorWithKernel {
auto height_sections =
ctx->Attrs().Get<std::vector<int64_t>>("height_sections");
auto endpoints = ctx->Attrs().Get<std::vector<std::string>>("endpoints");
auto lookup_table_version =
ctx->Attrs().Get<std::string>("lookup_table_version");
PADDLE_ENFORCE(lookup_tables.size() == height_sections.size() &&
lookup_tables.size() == endpoints.size() &&
......@@ -61,8 +61,15 @@ class DistributedLookupTableOp : public framework::OperatorWithKernel {
auto outputs_dims = std::vector<framework::DDim>();
for (auto &ids_dim : ids_dims) {
outputs_dims.push_back(framework::make_ddim({ids_dim[0], table_dims[1]}));
for (auto& ids_dim : ids_dims) {
if (lookup_table_version == "lookup_table") {
outputs_dims.push_back(
framework::make_ddim({ids_dim[0], table_dims[1]}));
} else if (lookup_table_version == "lookup_table_v2") {
outputs_dims.push_back(framework::make_ddim(
{static_cast<int64_t>(ids_dim[0]), static_cast<int64_t>(ids_dim[1]),
static_cast<int64_t>(table_dims[1])}));
}
}
ctx->SetOutputsDim("Outputs", outputs_dims);
......@@ -71,7 +78,7 @@ class DistributedLookupTableOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::proto::VarType::Type(ctx.Attr<int>("dtype")),
ctx.GetPlace());
......@@ -81,7 +88,7 @@ class DistributedLookupTableOp : public framework::OperatorWithKernel {
template <typename T>
class DistributedLookupTableKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
void Compute(const framework::ExecutionContext& context) const override {
auto ids_vars = context.MultiInputVar("Ids");
auto emb_vars = context.MultiOutput<framework::Tensor>("Embeddings");
......@@ -93,10 +100,30 @@ class DistributedLookupTableKernel : public framework::OpKernel<T> {
auto height_sections =
context.Attr<std::vector<int64_t>>("height_sections");
auto endpoints = context.Attr<std::vector<std::string>>("endpoints");
auto lookup_table_version =
context.Attr<std::string>("lookup_table_version");
operators::distributed::prefetchs(
id_names, out_names, embedding_name, false, lookup_tables, endpoints,
height_sections, context, context.scope());
if (lookup_table_version == "lookup_table_v2") {
auto& scope = context.scope();
auto emb_dim =
scope.FindVar(embedding_name)->Get<framework::LoDTensor>().dims()[1];
for (size_t i = 0; i < id_names.size(); ++i) {
auto* id_var = scope.FindVar(id_names[i]);
auto* out_var = scope.FindVar(out_names[i]);
auto* id_tensor = id_var->GetMutable<framework::LoDTensor>();
auto* out_tensor = out_var->GetMutable<framework::LoDTensor>();
auto id_dims = id_tensor->dims();
out_tensor->Resize(framework::make_ddim(
{static_cast<int64_t>(id_dims[0]), static_cast<int64_t>(id_dims[1]),
static_cast<int64_t>(emb_dim)}));
}
}
}
};
......@@ -134,6 +161,12 @@ class DistributedLookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<int>("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0);
AddAttr<std::string>(
"lookup_table_version",
"(string, default lookup_table) "
"To distinguish between different versions of embedding OP")
.SetDefault(std::string("lookup_table"));
AddAttr<int64_t>("padding_idx",
"(int64, default -1) "
"If the value is -1, it makes no effect to lookup. "
......
......@@ -92,8 +92,8 @@ def train_network(batch_size,
# query
q = fluid.layers.data(
name="query_ids", shape=[1], dtype="int64", lod_level=1)
## embedding
q_emb = fluid.layers.embedding(
# embedding
q_emb = fluid.embedding(
input=q,
is_distributed=is_distributed,
size=[dict_dim, emb_dim],
......@@ -104,10 +104,11 @@ def train_network(batch_size,
initializer=fluid.initializer.Constant(value=0.01),
name="__emb__"),
is_sparse=is_sparse)
## vsum
q_emb = fluid.layers.reshape(q_emb, [-1, emb_dim])
# vsum
q_sum = fluid.layers.sequence_pool(input=q_emb, pool_type='sum')
q_ss = fluid.layers.softsign(q_sum)
## fc layer after conv
# fc layer after conv
q_fc = fluid.layers.fc(
input=q_ss,
size=hid_dim,
......@@ -120,8 +121,8 @@ def train_network(batch_size,
# pt
pt = fluid.layers.data(
name="pos_title_ids", shape=[1], dtype="int64", lod_level=1)
## embedding
pt_emb = fluid.layers.embedding(
# embedding
pt_emb = fluid.embedding(
input=pt,
is_distributed=is_distributed,
size=[dict_dim, emb_dim],
......@@ -132,10 +133,11 @@ def train_network(batch_size,
initializer=fluid.initializer.Constant(value=0.01),
name="__emb__"),
is_sparse=is_sparse)
## vsum
pt_emb = fluid.layers.reshape(pt_emb, [-1, emb_dim])
# vsum
pt_sum = fluid.layers.sequence_pool(input=pt_emb, pool_type='sum')
pt_ss = fluid.layers.softsign(pt_sum)
## fc layer
# fc layer
pt_fc = fluid.layers.fc(
input=pt_ss,
size=hid_dim,
......@@ -147,8 +149,8 @@ def train_network(batch_size,
# nt
nt = fluid.layers.data(
name="neg_title_ids", shape=[1], dtype="int64", lod_level=1)
## embedding
nt_emb = fluid.layers.embedding(
# embedding
nt_emb = fluid.embedding(
input=nt,
is_distributed=is_distributed,
size=[dict_dim, emb_dim],
......@@ -159,10 +161,11 @@ def train_network(batch_size,
initializer=fluid.initializer.Constant(value=0.01),
name="__emb__"),
is_sparse=is_sparse)
## vsum
nt_emb = fluid.layers.reshape(nt_emb, [-1, emb_dim])
# vsum
nt_sum = fluid.layers.sequence_pool(input=nt_emb, pool_type='sum')
nt_ss = fluid.layers.softsign(nt_sum)
## fc layer
# fc layer
nt_fc = fluid.layers.fc(
input=nt_ss,
size=hid_dim,
......
......@@ -46,7 +46,7 @@ class TestDistSimnetBow2x2DenseAsync(TestDistBase):
self._sync_mode = False
self._enforce_place = "CPU"
#FIXME(typhoonzero): fix async tests later
# FIXME(typhoonzero): fix async tests later
def notest_simnet_bow(self):
need_envs = {
"IS_DISTRIBUTED": '0',
......@@ -107,7 +107,7 @@ class TestDistSimnetBow2x2LookupTableSync(TestDistBase):
def test_simnet_bow(self):
need_envs = {
"IS_DISTRIBUTED": '1',
"IS_DISTRIBUTED": '0',
"IS_SPARSE": '1',
'IS_SELF_CONTAINED_LR': '1'
}
......@@ -126,7 +126,7 @@ class TestDistSimnetBow2x2LookupTableAsync(TestDistBase):
def test_simnet_bow(self):
need_envs = {
"IS_DISTRIBUTED": '1',
"IS_DISTRIBUTED": '0',
"IS_SPARSE": '1',
'IS_SELF_CONTAINED_LR': '1'
}
......@@ -145,7 +145,7 @@ class TestDistSimnetBow2x2LookupTableNotContainLRSync(TestDistBase):
def test_simnet_bow(self):
need_envs = {
"IS_DISTRIBUTED": '1',
"IS_DISTRIBUTED": '0',
"IS_SPARSE": '1',
'IS_SELF_CONTAINED_LR': '0'
}
......
......@@ -50,8 +50,8 @@ from .details import delete_ops, find_op_by_output_arg
from ..distribute_lookup_table import find_distributed_lookup_table
from . import collective
LOOKUP_TABLE_TYPE = "lookup_table"
LOOKUP_TABLE_GRAD_TYPE = "lookup_table_grad"
LOOKUP_TABLE_TYPE = ["lookup_table", "lookup_table_v2"]
LOOKUP_TABLE_GRAD_TYPE = ["lookup_table_grad", "lookup_table_v2_grad"]
OP_NAME_SCOPE = "op_namescope"
CLIP_OP_NAME_SCOPE = "@CLIP"
OP_ROLE_VAR_ATTR_NAME = core.op_proto_and_checker_maker.kOpRoleVarAttrName()
......@@ -140,7 +140,7 @@ def slice_variable(var_list, slice_count, min_block_size):
class DistributeTranspilerConfig(object):
"""
:api_attr: Static Graph
:api_attr: Static Graph
A configuration class that provide support for transpiler distributed jobs.
Some important parameters are explained as follows:
......@@ -201,10 +201,10 @@ class DistributeTranspilerConfig(object):
geo_sgd_need_push_nums = 100
nccl_comm_num = 1
#The picture here illustrates the principle:
#https://github.com/PaddlePaddle/Paddle/pull/17263#discussion_r285411396
# The picture here illustrates the principle:
# https://github.com/PaddlePaddle/Paddle/pull/17263#discussion_r285411396
use_hierarchical_allreduce = False
#Nccl ranks in a node when use hierarchical allreduce, it's set to gpu cards' number in most cases.
# Nccl ranks in a node when use hierarchical allreduce, it's set to gpu cards' number in most cases.
hierarchical_allreduce_inter_nranks = 0
# if mode is collective
......@@ -255,7 +255,7 @@ class ServerRuntimeConfig(object):
class DistributeTranspiler(object):
"""
:api_attr: Static Graph
:api_attr: Static Graph
**DistributeTranspiler**
......@@ -449,7 +449,7 @@ class DistributeTranspiler(object):
def _get_all_remote_sparse_update_op(self, main_program):
sparse_update_ops = []
sparse_update_op_types = ["lookup_table", "nce"]
sparse_update_op_types = ["lookup_table", "nce", "lookup_table_v2"]
for op in main_program.global_block().ops:
if op.type in sparse_update_op_types and op.attr(
'remote_prefetch') is True:
......@@ -479,7 +479,7 @@ class DistributeTranspiler(object):
ops.append(op)
used_ops.append(idx)
if op_type == "lookup_table":
if op_type in LOOKUP_TABLE_TYPE:
all_ops = program.global_block().ops
op_idxs = [all_ops.index(op) for op in ops]
inputs = [
......@@ -525,7 +525,8 @@ class DistributeTranspiler(object):
"height_sections": height_sections,
"endpoints": endpoints,
"padding_idx": padding_idx,
"trainer_id": self.trainer_id
"trainer_id": self.trainer_id,
"lookup_table_version": op_type
})
else:
raise ValueError(
......@@ -613,10 +614,12 @@ WIKI: https://github.com/PaddlePaddle/Fleet/blob/develop/markdown_doc/transpiler
)
assert trainers_num > self.config.hierarchical_allreduce_inter_nranks, \
"trainers_num:{} < hierarchical_allreduce_inter_nranks:{}".format(trainers_num, self.config.hierarchical_allreduce_inter_nranks)
"trainers_num:{} < hierarchical_allreduce_inter_nranks:{}".format(
trainers_num, self.config.hierarchical_allreduce_inter_nranks)
assert trainers_num % self.config.hierarchical_allreduce_inter_nranks == 0, \
"trainers_num:{} mod hierarchical_allreduce_inter_nranks:{} != 0".format(trainers_num, self.config.hierarchical_allreduce_inter_nranks)
"trainers_num:{} mod hierarchical_allreduce_inter_nranks:{} != 0".format(
trainers_num, self.config.hierarchical_allreduce_inter_nranks)
self.origin_program._hierarchical_allreduce_inter_nranks = \
int(self.config.hierarchical_allreduce_inter_nranks)
......@@ -782,7 +785,7 @@ WIKI: https://github.com/PaddlePaddle/Fleet/blob/develop/markdown_doc/transpiler
decay_dummy_output = program.global_block().create_var(
name=framework.generate_control_dev_var_name())
if self.config.runtime_split_send_recv:
## async mode, using communicator to merge and send
# async mode, using communicator to merge and send
send_varnames = [self.counter_var.name]
else:
send_varnames = []
......@@ -1019,7 +1022,7 @@ WIKI: https://github.com/PaddlePaddle/Fleet/blob/develop/markdown_doc/transpiler
- Delete optimizer related op, because parameter updated on Pserver
- After the op which computed gradient of each parameter, add ``Send_op`` and ``Recv_op``
Args:
wait_port(bool): Whether to wait for the parameter server to be ready before returning to program,
default is True
......@@ -1076,7 +1079,7 @@ WIKI: https://github.com/PaddlePaddle/Fleet/blob/develop/markdown_doc/transpiler
sparse_table_names = self._get_sparse_table_names()
# self._fake_init_sparsetable(sparse_table_names)
#self._delete_trainer_optimizer(is_startup=True)
# self._delete_trainer_optimizer(is_startup=True)
for varname, splited_var in six.iteritems(self.param_var_mapping):
if varname in sparse_table_names:
......@@ -1470,8 +1473,8 @@ WIKI: https://github.com/PaddlePaddle/Fleet/blob/develop/markdown_doc/transpiler
Program: parameter server side startup program.
Examples:
.. code-block:: python
.. code-block:: python
pserver_endpoints = "192.168.0.1:6174,192.168.0.2:6174"
trainer_endpoints = "192.168.0.1:6174,192.168.0.2:6174"
current_endpoint = "192.168.0.1:6174"
......@@ -2665,7 +2668,7 @@ WIKI: https://github.com/PaddlePaddle/Fleet/blob/develop/markdown_doc/transpiler
for op in block.ops:
if self._is_opt_role_op(op):
# Todo(chengmo): Whether clip related op belongs to Optimize guard should be discussed
# delete clip op from opt_ops when run in Parameter Server mode
# delete clip op from opt_ops when run in Parameter Server mode
if OP_NAME_SCOPE in op.all_attrs(
) and CLIP_OP_NAME_SCOPE in op.attr(
OP_NAME_SCOPE
......@@ -2696,7 +2699,7 @@ WIKI: https://github.com/PaddlePaddle/Fleet/blob/develop/markdown_doc/transpiler
return opt_ops, params_grads
def _get_distribute_update_vars(self):
#TODO(chengmo): find more powerful and simple way to deal with these special situation
# TODO(chengmo): find more powerful and simple way to deal with these special situation
"""
This Function is used for a special model, like PyramidDnn which has pyramid hash op.
Some Parameters don't use optimizing op to update its value, but updated in its BP process.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册