未验证 提交 38f9b71b 编写于 作者: C Chengmo 提交者: GitHub

[cherry-pick] fix fluid.embedding (#25328)

* test=release/1.8, cherry fix fluid.embedding
上级 b69d0647
...@@ -209,16 +209,20 @@ void prefetchs(const std::vector<std::string>& id_var_names, ...@@ -209,16 +209,20 @@ void prefetchs(const std::vector<std::string>& id_var_names,
TableAndEndpoints tables; TableAndEndpoints tables;
for (auto& id_name : id_var_names) { for (auto& id_name : id_var_names) {
auto& id_tensor = scope.FindVar(id_name)->Get<framework::LoDTensor>(); auto* id_tensor =
auto* id_data = id_tensor.data<int64_t>(); scope.FindVar(id_name)->GetMutable<framework::LoDTensor>();
auto id_dims = id_tensor->dims();
id_tensor->Resize(framework::make_ddim(
{static_cast<int64_t>(id_dims[0] * id_dims[1]), 1}));
auto* id_data = id_tensor->data<int64_t>();
std::vector<int64_t> ids; std::vector<int64_t> ids;
for (int64_t i = 0; i < id_tensor.numel(); ++i) { for (int64_t i = 0; i < id_tensor->numel(); ++i) {
ids.push_back(id_data[i]); ids.push_back(id_data[i]);
ids_union.push_back(id_data[i]); ids_union.push_back(id_data[i]);
} }
ids_group.push_back(ids); ids_group.push_back(ids);
ids_lods.push_back(id_tensor.lod()); ids_lods.push_back(id_tensor->lod());
} }
std::unordered_set<int64_t> s(ids_union.begin(), ids_union.end()); std::unordered_set<int64_t> s(ids_union.begin(), ids_union.end());
......
...@@ -26,7 +26,7 @@ class DistributedLookupTableOp : public framework::OperatorWithKernel { ...@@ -26,7 +26,7 @@ class DistributedLookupTableOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInputs("Ids"), PADDLE_ENFORCE(ctx->HasInputs("Ids"),
"Input(Ids) of LookupTableOp should not be null."); "Input(Ids) of LookupTableOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("W"), PADDLE_ENFORCE(ctx->HasInput("W"),
...@@ -40,11 +40,9 @@ class DistributedLookupTableOp : public framework::OperatorWithKernel { ...@@ -40,11 +40,9 @@ class DistributedLookupTableOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(table_dims.size(), 2, PADDLE_ENFORCE_EQ(table_dims.size(), 2,
"Only 2 dimensions of the 'Embedding' is supported."); "Only 2 dimensions of the 'Embedding' is supported.");
for (auto &ids_dim : ids_dims) { for (auto& ids_dim : ids_dims) {
PADDLE_ENFORCE_EQ(ids_dim.size(), 2, PADDLE_ENFORCE_EQ(ids_dim.size(), 2,
"The dimension of the 'Ids' tensor must be 2."); "The dimension of the 'Ids' tensor must be 2.");
PADDLE_ENFORCE_EQ(ids_dim[1], 1,
"The last dimension of the 'Ids' tensor must be 1.");
} }
auto lookup_tables = auto lookup_tables =
...@@ -52,6 +50,8 @@ class DistributedLookupTableOp : public framework::OperatorWithKernel { ...@@ -52,6 +50,8 @@ class DistributedLookupTableOp : public framework::OperatorWithKernel {
auto height_sections = auto height_sections =
ctx->Attrs().Get<std::vector<int64_t>>("height_sections"); ctx->Attrs().Get<std::vector<int64_t>>("height_sections");
auto endpoints = ctx->Attrs().Get<std::vector<std::string>>("endpoints"); auto endpoints = ctx->Attrs().Get<std::vector<std::string>>("endpoints");
auto lookup_table_version =
ctx->Attrs().Get<std::string>("lookup_table_version");
PADDLE_ENFORCE(lookup_tables.size() == height_sections.size() && PADDLE_ENFORCE(lookup_tables.size() == height_sections.size() &&
lookup_tables.size() == endpoints.size() && lookup_tables.size() == endpoints.size() &&
...@@ -61,8 +61,15 @@ class DistributedLookupTableOp : public framework::OperatorWithKernel { ...@@ -61,8 +61,15 @@ class DistributedLookupTableOp : public framework::OperatorWithKernel {
auto outputs_dims = std::vector<framework::DDim>(); auto outputs_dims = std::vector<framework::DDim>();
for (auto &ids_dim : ids_dims) { for (auto& ids_dim : ids_dims) {
outputs_dims.push_back(framework::make_ddim({ids_dim[0], table_dims[1]})); if (lookup_table_version == "lookup_table") {
outputs_dims.push_back(
framework::make_ddim({ids_dim[0], table_dims[1]}));
} else if (lookup_table_version == "lookup_table_v2") {
outputs_dims.push_back(framework::make_ddim(
{static_cast<int64_t>(ids_dim[0]), static_cast<int64_t>(ids_dim[1]),
static_cast<int64_t>(table_dims[1])}));
}
} }
ctx->SetOutputsDim("Outputs", outputs_dims); ctx->SetOutputsDim("Outputs", outputs_dims);
...@@ -71,7 +78,7 @@ class DistributedLookupTableOp : public framework::OperatorWithKernel { ...@@ -71,7 +78,7 @@ class DistributedLookupTableOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::proto::VarType::Type(ctx.Attr<int>("dtype")), framework::proto::VarType::Type(ctx.Attr<int>("dtype")),
ctx.GetPlace()); ctx.GetPlace());
...@@ -81,7 +88,7 @@ class DistributedLookupTableOp : public framework::OperatorWithKernel { ...@@ -81,7 +88,7 @@ class DistributedLookupTableOp : public framework::OperatorWithKernel {
template <typename T> template <typename T>
class DistributedLookupTableKernel : public framework::OpKernel<T> { class DistributedLookupTableKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto ids_vars = context.MultiInputVar("Ids"); auto ids_vars = context.MultiInputVar("Ids");
auto emb_vars = context.MultiOutput<framework::Tensor>("Embeddings"); auto emb_vars = context.MultiOutput<framework::Tensor>("Embeddings");
...@@ -93,10 +100,30 @@ class DistributedLookupTableKernel : public framework::OpKernel<T> { ...@@ -93,10 +100,30 @@ class DistributedLookupTableKernel : public framework::OpKernel<T> {
auto height_sections = auto height_sections =
context.Attr<std::vector<int64_t>>("height_sections"); context.Attr<std::vector<int64_t>>("height_sections");
auto endpoints = context.Attr<std::vector<std::string>>("endpoints"); auto endpoints = context.Attr<std::vector<std::string>>("endpoints");
auto lookup_table_version =
context.Attr<std::string>("lookup_table_version");
operators::distributed::prefetchs( operators::distributed::prefetchs(
id_names, out_names, embedding_name, false, lookup_tables, endpoints, id_names, out_names, embedding_name, false, lookup_tables, endpoints,
height_sections, context, context.scope()); height_sections, context, context.scope());
if (lookup_table_version == "lookup_table_v2") {
auto& scope = context.scope();
auto emb_dim =
scope.FindVar(embedding_name)->Get<framework::LoDTensor>().dims()[1];
for (size_t i = 0; i < id_names.size(); ++i) {
auto* id_var = scope.FindVar(id_names[i]);
auto* out_var = scope.FindVar(out_names[i]);
auto* id_tensor = id_var->GetMutable<framework::LoDTensor>();
auto* out_tensor = out_var->GetMutable<framework::LoDTensor>();
auto id_dims = id_tensor->dims();
out_tensor->Resize(framework::make_ddim(
{static_cast<int64_t>(id_dims[0]), static_cast<int64_t>(id_dims[1]),
static_cast<int64_t>(emb_dim)}));
}
}
} }
}; };
...@@ -134,6 +161,12 @@ class DistributedLookupTableOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -134,6 +161,12 @@ class DistributedLookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<int>("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0); AddAttr<int>("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0);
AddAttr<std::string>(
"lookup_table_version",
"(string, default lookup_table) "
"To distinguish between different versions of embedding OP")
.SetDefault(std::string("lookup_table"));
AddAttr<int64_t>("padding_idx", AddAttr<int64_t>("padding_idx",
"(int64, default -1) " "(int64, default -1) "
"If the value is -1, it makes no effect to lookup. " "If the value is -1, it makes no effect to lookup. "
......
...@@ -92,8 +92,8 @@ def train_network(batch_size, ...@@ -92,8 +92,8 @@ def train_network(batch_size,
# query # query
q = fluid.layers.data( q = fluid.layers.data(
name="query_ids", shape=[1], dtype="int64", lod_level=1) name="query_ids", shape=[1], dtype="int64", lod_level=1)
## embedding # embedding
q_emb = fluid.layers.embedding( q_emb = fluid.embedding(
input=q, input=q,
is_distributed=is_distributed, is_distributed=is_distributed,
size=[dict_dim, emb_dim], size=[dict_dim, emb_dim],
...@@ -104,10 +104,11 @@ def train_network(batch_size, ...@@ -104,10 +104,11 @@ def train_network(batch_size,
initializer=fluid.initializer.Constant(value=0.01), initializer=fluid.initializer.Constant(value=0.01),
name="__emb__"), name="__emb__"),
is_sparse=is_sparse) is_sparse=is_sparse)
## vsum q_emb = fluid.layers.reshape(q_emb, [-1, emb_dim])
# vsum
q_sum = fluid.layers.sequence_pool(input=q_emb, pool_type='sum') q_sum = fluid.layers.sequence_pool(input=q_emb, pool_type='sum')
q_ss = fluid.layers.softsign(q_sum) q_ss = fluid.layers.softsign(q_sum)
## fc layer after conv # fc layer after conv
q_fc = fluid.layers.fc( q_fc = fluid.layers.fc(
input=q_ss, input=q_ss,
size=hid_dim, size=hid_dim,
...@@ -120,8 +121,8 @@ def train_network(batch_size, ...@@ -120,8 +121,8 @@ def train_network(batch_size,
# pt # pt
pt = fluid.layers.data( pt = fluid.layers.data(
name="pos_title_ids", shape=[1], dtype="int64", lod_level=1) name="pos_title_ids", shape=[1], dtype="int64", lod_level=1)
## embedding # embedding
pt_emb = fluid.layers.embedding( pt_emb = fluid.embedding(
input=pt, input=pt,
is_distributed=is_distributed, is_distributed=is_distributed,
size=[dict_dim, emb_dim], size=[dict_dim, emb_dim],
...@@ -132,10 +133,11 @@ def train_network(batch_size, ...@@ -132,10 +133,11 @@ def train_network(batch_size,
initializer=fluid.initializer.Constant(value=0.01), initializer=fluid.initializer.Constant(value=0.01),
name="__emb__"), name="__emb__"),
is_sparse=is_sparse) is_sparse=is_sparse)
## vsum pt_emb = fluid.layers.reshape(pt_emb, [-1, emb_dim])
# vsum
pt_sum = fluid.layers.sequence_pool(input=pt_emb, pool_type='sum') pt_sum = fluid.layers.sequence_pool(input=pt_emb, pool_type='sum')
pt_ss = fluid.layers.softsign(pt_sum) pt_ss = fluid.layers.softsign(pt_sum)
## fc layer # fc layer
pt_fc = fluid.layers.fc( pt_fc = fluid.layers.fc(
input=pt_ss, input=pt_ss,
size=hid_dim, size=hid_dim,
...@@ -147,8 +149,8 @@ def train_network(batch_size, ...@@ -147,8 +149,8 @@ def train_network(batch_size,
# nt # nt
nt = fluid.layers.data( nt = fluid.layers.data(
name="neg_title_ids", shape=[1], dtype="int64", lod_level=1) name="neg_title_ids", shape=[1], dtype="int64", lod_level=1)
## embedding # embedding
nt_emb = fluid.layers.embedding( nt_emb = fluid.embedding(
input=nt, input=nt,
is_distributed=is_distributed, is_distributed=is_distributed,
size=[dict_dim, emb_dim], size=[dict_dim, emb_dim],
...@@ -159,10 +161,11 @@ def train_network(batch_size, ...@@ -159,10 +161,11 @@ def train_network(batch_size,
initializer=fluid.initializer.Constant(value=0.01), initializer=fluid.initializer.Constant(value=0.01),
name="__emb__"), name="__emb__"),
is_sparse=is_sparse) is_sparse=is_sparse)
## vsum nt_emb = fluid.layers.reshape(nt_emb, [-1, emb_dim])
# vsum
nt_sum = fluid.layers.sequence_pool(input=nt_emb, pool_type='sum') nt_sum = fluid.layers.sequence_pool(input=nt_emb, pool_type='sum')
nt_ss = fluid.layers.softsign(nt_sum) nt_ss = fluid.layers.softsign(nt_sum)
## fc layer # fc layer
nt_fc = fluid.layers.fc( nt_fc = fluid.layers.fc(
input=nt_ss, input=nt_ss,
size=hid_dim, size=hid_dim,
......
...@@ -46,7 +46,7 @@ class TestDistSimnetBow2x2DenseAsync(TestDistBase): ...@@ -46,7 +46,7 @@ class TestDistSimnetBow2x2DenseAsync(TestDistBase):
self._sync_mode = False self._sync_mode = False
self._enforce_place = "CPU" self._enforce_place = "CPU"
#FIXME(typhoonzero): fix async tests later # FIXME(typhoonzero): fix async tests later
def notest_simnet_bow(self): def notest_simnet_bow(self):
need_envs = { need_envs = {
"IS_DISTRIBUTED": '0', "IS_DISTRIBUTED": '0',
...@@ -107,7 +107,7 @@ class TestDistSimnetBow2x2LookupTableSync(TestDistBase): ...@@ -107,7 +107,7 @@ class TestDistSimnetBow2x2LookupTableSync(TestDistBase):
def test_simnet_bow(self): def test_simnet_bow(self):
need_envs = { need_envs = {
"IS_DISTRIBUTED": '1', "IS_DISTRIBUTED": '0',
"IS_SPARSE": '1', "IS_SPARSE": '1',
'IS_SELF_CONTAINED_LR': '1' 'IS_SELF_CONTAINED_LR': '1'
} }
...@@ -126,7 +126,7 @@ class TestDistSimnetBow2x2LookupTableAsync(TestDistBase): ...@@ -126,7 +126,7 @@ class TestDistSimnetBow2x2LookupTableAsync(TestDistBase):
def test_simnet_bow(self): def test_simnet_bow(self):
need_envs = { need_envs = {
"IS_DISTRIBUTED": '1', "IS_DISTRIBUTED": '0',
"IS_SPARSE": '1', "IS_SPARSE": '1',
'IS_SELF_CONTAINED_LR': '1' 'IS_SELF_CONTAINED_LR': '1'
} }
...@@ -145,7 +145,7 @@ class TestDistSimnetBow2x2LookupTableNotContainLRSync(TestDistBase): ...@@ -145,7 +145,7 @@ class TestDistSimnetBow2x2LookupTableNotContainLRSync(TestDistBase):
def test_simnet_bow(self): def test_simnet_bow(self):
need_envs = { need_envs = {
"IS_DISTRIBUTED": '1', "IS_DISTRIBUTED": '0',
"IS_SPARSE": '1', "IS_SPARSE": '1',
'IS_SELF_CONTAINED_LR': '0' 'IS_SELF_CONTAINED_LR': '0'
} }
......
...@@ -50,8 +50,8 @@ from .details import delete_ops, find_op_by_output_arg ...@@ -50,8 +50,8 @@ from .details import delete_ops, find_op_by_output_arg
from ..distribute_lookup_table import find_distributed_lookup_table from ..distribute_lookup_table import find_distributed_lookup_table
from . import collective from . import collective
LOOKUP_TABLE_TYPE = "lookup_table" LOOKUP_TABLE_TYPE = ["lookup_table", "lookup_table_v2"]
LOOKUP_TABLE_GRAD_TYPE = "lookup_table_grad" LOOKUP_TABLE_GRAD_TYPE = ["lookup_table_grad", "lookup_table_v2_grad"]
OP_NAME_SCOPE = "op_namescope" OP_NAME_SCOPE = "op_namescope"
CLIP_OP_NAME_SCOPE = "@CLIP" CLIP_OP_NAME_SCOPE = "@CLIP"
OP_ROLE_VAR_ATTR_NAME = core.op_proto_and_checker_maker.kOpRoleVarAttrName() OP_ROLE_VAR_ATTR_NAME = core.op_proto_and_checker_maker.kOpRoleVarAttrName()
...@@ -199,10 +199,10 @@ class DistributeTranspilerConfig(object): ...@@ -199,10 +199,10 @@ class DistributeTranspilerConfig(object):
geo_sgd_need_push_nums = 100 geo_sgd_need_push_nums = 100
nccl_comm_num = 1 nccl_comm_num = 1
#The picture here illustrates the principle: # The picture here illustrates the principle:
#https://github.com/PaddlePaddle/Paddle/pull/17263#discussion_r285411396 # https://github.com/PaddlePaddle/Paddle/pull/17263#discussion_r285411396
use_hierarchical_allreduce = False use_hierarchical_allreduce = False
#Nccl ranks in a node when use hierarchical allreduce, it's set to gpu cards' number in most cases. # Nccl ranks in a node when use hierarchical allreduce, it's set to gpu cards' number in most cases.
hierarchical_allreduce_inter_nranks = 0 hierarchical_allreduce_inter_nranks = 0
# if mode is collective # if mode is collective
...@@ -445,7 +445,7 @@ class DistributeTranspiler(object): ...@@ -445,7 +445,7 @@ class DistributeTranspiler(object):
def _get_all_remote_sparse_update_op(self, main_program): def _get_all_remote_sparse_update_op(self, main_program):
sparse_update_ops = [] sparse_update_ops = []
sparse_update_op_types = ["lookup_table", "nce"] sparse_update_op_types = ["lookup_table", "nce", "lookup_table_v2"]
for op in main_program.global_block().ops: for op in main_program.global_block().ops:
if op.type in sparse_update_op_types and op.attr( if op.type in sparse_update_op_types and op.attr(
'remote_prefetch') is True: 'remote_prefetch') is True:
...@@ -475,7 +475,7 @@ class DistributeTranspiler(object): ...@@ -475,7 +475,7 @@ class DistributeTranspiler(object):
ops.append(op) ops.append(op)
used_ops.append(idx) used_ops.append(idx)
if op_type == "lookup_table": if op_type in LOOKUP_TABLE_TYPE:
all_ops = program.global_block().ops all_ops = program.global_block().ops
op_idxs = [all_ops.index(op) for op in ops] op_idxs = [all_ops.index(op) for op in ops]
inputs = [ inputs = [
...@@ -521,7 +521,8 @@ class DistributeTranspiler(object): ...@@ -521,7 +521,8 @@ class DistributeTranspiler(object):
"height_sections": height_sections, "height_sections": height_sections,
"endpoints": endpoints, "endpoints": endpoints,
"padding_idx": padding_idx, "padding_idx": padding_idx,
"trainer_id": self.trainer_id "trainer_id": self.trainer_id,
"lookup_table_version": op_type
}) })
else: else:
raise ValueError( raise ValueError(
...@@ -609,10 +610,12 @@ WIKI: https://github.com/PaddlePaddle/Fleet/blob/develop/markdown_doc/transpiler ...@@ -609,10 +610,12 @@ WIKI: https://github.com/PaddlePaddle/Fleet/blob/develop/markdown_doc/transpiler
) )
assert trainers_num > self.config.hierarchical_allreduce_inter_nranks, \ assert trainers_num > self.config.hierarchical_allreduce_inter_nranks, \
"trainers_num:{} < hierarchical_allreduce_inter_nranks:{}".format(trainers_num, self.config.hierarchical_allreduce_inter_nranks) "trainers_num:{} < hierarchical_allreduce_inter_nranks:{}".format(
trainers_num, self.config.hierarchical_allreduce_inter_nranks)
assert trainers_num % self.config.hierarchical_allreduce_inter_nranks == 0, \ assert trainers_num % self.config.hierarchical_allreduce_inter_nranks == 0, \
"trainers_num:{} mod hierarchical_allreduce_inter_nranks:{} != 0".format(trainers_num, self.config.hierarchical_allreduce_inter_nranks) "trainers_num:{} mod hierarchical_allreduce_inter_nranks:{} != 0".format(
trainers_num, self.config.hierarchical_allreduce_inter_nranks)
self.origin_program._hierarchical_allreduce_inter_nranks = \ self.origin_program._hierarchical_allreduce_inter_nranks = \
int(self.config.hierarchical_allreduce_inter_nranks) int(self.config.hierarchical_allreduce_inter_nranks)
...@@ -778,7 +781,7 @@ WIKI: https://github.com/PaddlePaddle/Fleet/blob/develop/markdown_doc/transpiler ...@@ -778,7 +781,7 @@ WIKI: https://github.com/PaddlePaddle/Fleet/blob/develop/markdown_doc/transpiler
decay_dummy_output = program.global_block().create_var( decay_dummy_output = program.global_block().create_var(
name=framework.generate_control_dev_var_name()) name=framework.generate_control_dev_var_name())
if self.config.runtime_split_send_recv: if self.config.runtime_split_send_recv:
## async mode, using communicator to merge and send # async mode, using communicator to merge and send
send_varnames = [self.counter_var.name] send_varnames = [self.counter_var.name]
else: else:
send_varnames = [] send_varnames = []
...@@ -1015,7 +1018,7 @@ WIKI: https://github.com/PaddlePaddle/Fleet/blob/develop/markdown_doc/transpiler ...@@ -1015,7 +1018,7 @@ WIKI: https://github.com/PaddlePaddle/Fleet/blob/develop/markdown_doc/transpiler
- Delete optimizer related op, because parameter updated on Pserver - Delete optimizer related op, because parameter updated on Pserver
- After the op which computed gradient of each parameter, add ``Send_op`` and ``Recv_op`` - After the op which computed gradient of each parameter, add ``Send_op`` and ``Recv_op``
Args: Args:
wait_port(bool): Whether to wait for the parameter server to be ready before returning to program, wait_port(bool): Whether to wait for the parameter server to be ready before returning to program,
default is True default is True
...@@ -1072,7 +1075,7 @@ WIKI: https://github.com/PaddlePaddle/Fleet/blob/develop/markdown_doc/transpiler ...@@ -1072,7 +1075,7 @@ WIKI: https://github.com/PaddlePaddle/Fleet/blob/develop/markdown_doc/transpiler
sparse_table_names = self._get_sparse_table_names() sparse_table_names = self._get_sparse_table_names()
# self._fake_init_sparsetable(sparse_table_names) # self._fake_init_sparsetable(sparse_table_names)
#self._delete_trainer_optimizer(is_startup=True) # self._delete_trainer_optimizer(is_startup=True)
for varname, splited_var in six.iteritems(self.param_var_mapping): for varname, splited_var in six.iteritems(self.param_var_mapping):
if varname in sparse_table_names: if varname in sparse_table_names:
...@@ -1466,8 +1469,8 @@ WIKI: https://github.com/PaddlePaddle/Fleet/blob/develop/markdown_doc/transpiler ...@@ -1466,8 +1469,8 @@ WIKI: https://github.com/PaddlePaddle/Fleet/blob/develop/markdown_doc/transpiler
Program: parameter server side startup program. Program: parameter server side startup program.
Examples: Examples:
.. code-block:: python .. code-block:: python
pserver_endpoints = "192.168.0.1:6174,192.168.0.2:6174" pserver_endpoints = "192.168.0.1:6174,192.168.0.2:6174"
trainer_endpoints = "192.168.0.1:6174,192.168.0.2:6174" trainer_endpoints = "192.168.0.1:6174,192.168.0.2:6174"
current_endpoint = "192.168.0.1:6174" current_endpoint = "192.168.0.1:6174"
...@@ -2661,7 +2664,7 @@ WIKI: https://github.com/PaddlePaddle/Fleet/blob/develop/markdown_doc/transpiler ...@@ -2661,7 +2664,7 @@ WIKI: https://github.com/PaddlePaddle/Fleet/blob/develop/markdown_doc/transpiler
for op in block.ops: for op in block.ops:
if self._is_opt_role_op(op): if self._is_opt_role_op(op):
# Todo(chengmo): Whether clip related op belongs to Optimize guard should be discussed # Todo(chengmo): Whether clip related op belongs to Optimize guard should be discussed
# delete clip op from opt_ops when run in Parameter Server mode # delete clip op from opt_ops when run in Parameter Server mode
if OP_NAME_SCOPE in op.all_attrs( if OP_NAME_SCOPE in op.all_attrs(
) and CLIP_OP_NAME_SCOPE in op.attr( ) and CLIP_OP_NAME_SCOPE in op.attr(
OP_NAME_SCOPE OP_NAME_SCOPE
...@@ -2692,7 +2695,7 @@ WIKI: https://github.com/PaddlePaddle/Fleet/blob/develop/markdown_doc/transpiler ...@@ -2692,7 +2695,7 @@ WIKI: https://github.com/PaddlePaddle/Fleet/blob/develop/markdown_doc/transpiler
return opt_ops, params_grads return opt_ops, params_grads
def _get_distribute_update_vars(self): def _get_distribute_update_vars(self):
#TODO(chengmo): find more powerful and simple way to deal with these special situation # TODO(chengmo): find more powerful and simple way to deal with these special situation
""" """
This Function is used for a special model, like PyramidDnn which has pyramid hash op. This Function is used for a special model, like PyramidDnn which has pyramid hash op.
Some Parameters don't use optimizing op to update its value, but updated in its BP process. Some Parameters don't use optimizing op to update its value, but updated in its BP process.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册