未验证 提交 8177ece5 编写于 作者: T tangwei12 提交者: GitHub

fix entry (#31079) (#31182)

* fix entry

* fix distributed lookup table fuse case

* fix entry bug at first time

* move entry from paddle.fluid -> paddle.distributed

* fix ut with paddle.enable_static()
Co-authored-by: Nmalin10 <malin10@baidu.com>
Co-authored-by: Nmalin10 <malin10@baidu.com>
上级 fe00d32a
...@@ -140,8 +140,9 @@ message CommonAccessorParameter { ...@@ -140,8 +140,9 @@ message CommonAccessorParameter {
repeated string params = 4; repeated string params = 4;
repeated uint32 dims = 5; repeated uint32 dims = 5;
repeated string initializers = 6; repeated string initializers = 6;
optional int32 trainer_num = 7; optional string entry = 7;
optional bool sync = 8; optional int32 trainer_num = 8;
optional bool sync = 9;
} }
message TableAccessorSaveParameter { message TableAccessorSaveParameter {
......
...@@ -238,12 +238,13 @@ int32_t CommonSparseTable::initialize() { ...@@ -238,12 +238,13 @@ int32_t CommonSparseTable::initialize() {
int32_t CommonSparseTable::initialize_recorder() { return 0; } int32_t CommonSparseTable::initialize_recorder() { return 0; }
int32_t CommonSparseTable::initialize_value() { int32_t CommonSparseTable::initialize_value() {
auto common = _config.common();
shard_values_.reserve(task_pool_size_); shard_values_.reserve(task_pool_size_);
for (int x = 0; x < task_pool_size_; ++x) { for (int x = 0; x < task_pool_size_; ++x) {
auto shard = auto shard = std::make_shared<ValueBlock>(
std::make_shared<ValueBlock>(value_names_, value_dims_, value_offsets_, value_names_, value_dims_, value_offsets_, value_idx_,
value_idx_, initializer_attrs_, "none"); initializer_attrs_, common.entry());
shard_values_.emplace_back(shard); shard_values_.emplace_back(shard);
} }
......
...@@ -71,7 +71,7 @@ inline bool count_entry(std::shared_ptr<VALUE> value, int threshold) { ...@@ -71,7 +71,7 @@ inline bool count_entry(std::shared_ptr<VALUE> value, int threshold) {
} }
inline bool probility_entry(std::shared_ptr<VALUE> value, float threshold) { inline bool probility_entry(std::shared_ptr<VALUE> value, float threshold) {
UniformInitializer uniform = UniformInitializer({"0", "0", "1"}); UniformInitializer uniform = UniformInitializer({"uniform", "0", "0", "1"});
return uniform.GetValue() >= threshold; return uniform.GetValue() >= threshold;
} }
...@@ -93,20 +93,20 @@ class ValueBlock { ...@@ -93,20 +93,20 @@ class ValueBlock {
// for Entry // for Entry
{ {
auto slices = string::split_string<std::string>(entry_attr, "&"); auto slices = string::split_string<std::string>(entry_attr, ":");
if (slices[0] == "none") { if (slices[0] == "none") {
entry_func_ = std::bind(&count_entry, std::placeholders::_1, 0); entry_func_ = std::bind(&count_entry, std::placeholders::_1, 0);
} else if (slices[0] == "count_filter") { } else if (slices[0] == "count_filter_entry") {
int threshold = std::stoi(slices[1]); int threshold = std::stoi(slices[1]);
entry_func_ = std::bind(&count_entry, std::placeholders::_1, threshold); entry_func_ = std::bind(&count_entry, std::placeholders::_1, threshold);
} else if (slices[0] == "probability") { } else if (slices[0] == "probability_entry") {
float threshold = std::stof(slices[1]); float threshold = std::stof(slices[1]);
entry_func_ = entry_func_ =
std::bind(&probility_entry, std::placeholders::_1, threshold); std::bind(&probility_entry, std::placeholders::_1, threshold);
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"Not supported Entry Type : %s, Only support [count_filter, " "Not supported Entry Type : %s, Only support [CountFilterEntry, "
"probability]", "ProbabilityEntry]",
slices[0])); slices[0]));
} }
} }
...@@ -179,10 +179,12 @@ class ValueBlock { ...@@ -179,10 +179,12 @@ class ValueBlock {
initializers_[x]->GetValue(value->data_.data() + value_offsets_[x], initializers_[x]->GetValue(value->data_.data() + value_offsets_[x],
value_dims_[x]); value_dims_[x]);
} }
value->need_save_ = true;
} }
} else {
value->need_save_ = true;
} }
value->need_save_ = true;
return; return;
} }
......
...@@ -78,6 +78,7 @@ void GetDownpourSparseTableProto( ...@@ -78,6 +78,7 @@ void GetDownpourSparseTableProto(
common_proto->set_table_name("MergedDense"); common_proto->set_table_name("MergedDense");
common_proto->set_trainer_num(1); common_proto->set_trainer_num(1);
common_proto->set_sync(false); common_proto->set_sync(false);
common_proto->set_entry("none");
common_proto->add_params("Param"); common_proto->add_params("Param");
common_proto->add_dims(10); common_proto->add_dims(10);
common_proto->add_initializers("uniform_random&0&-1.0&1.0"); common_proto->add_initializers("uniform_random&0&-1.0&1.0");
......
...@@ -25,6 +25,9 @@ from paddle.distributed.fleet.dataset import * ...@@ -25,6 +25,9 @@ from paddle.distributed.fleet.dataset import *
from . import collective from . import collective
from .collective import * from .collective import *
from .entry_attr import ProbabilityEntry
from .entry_attr import CountFilterEntry
# start multiprocess apis # start multiprocess apis
__all__ = ["spawn"] __all__ = ["spawn"]
...@@ -38,5 +41,17 @@ __all__ += [ ...@@ -38,5 +41,17 @@ __all__ += [
"QueueDataset", "QueueDataset",
] ]
# dataset reader
__all__ += [
"InMemoryDataset",
"QueueDataset",
]
# entry for embedding
__all__ += [
"ProbabilityEntry",
"CountFilterEntry",
]
# collective apis # collective apis
__all__ += collective.__all__ __all__ += collective.__all__
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
__all__ = ['ProbabilityEntry', 'CountFilterEntry']
class EntryAttr(object):
"""
Entry Config for paddle.static.nn.sparse_embedding with Parameter Server.
Examples:
.. code-block:: python
import paddle
sparse_feature_dim = 1024
embedding_size = 64
entry = paddle.distributed.ProbabilityEntry(0.1)
input = paddle.static.data(name='ins', shape=[1], dtype='int64')
emb = paddle.static.nn.sparse_embedding((
input=input,
size=[sparse_feature_dim, embedding_size],
is_test=False,
entry=entry,
param_attr=paddle.ParamAttr(name="SparseFeatFactors",
initializer=paddle.nn.initializer.Uniform()))
"""
def __init__(self):
self._name = None
def _to_attr(self):
"""
Returns the attributes of this parameter.
Returns:
Parameter attributes(map): The attributes of this parameter.
"""
raise NotImplementedError("EntryAttr is base class")
class ProbabilityEntry(EntryAttr):
"""
Examples:
.. code-block:: python
import paddle
sparse_feature_dim = 1024
embedding_size = 64
entry = paddle.distributed.ProbabilityEntry(0.1)
input = paddle.static.data(name='ins', shape=[1], dtype='int64')
emb = paddle.static.nn.sparse_embedding((
input=input,
size=[sparse_feature_dim, embedding_size],
is_test=False,
entry=entry,
param_attr=paddle.ParamAttr(name="SparseFeatFactors",
initializer=paddle.nn.initializer.Uniform()))
"""
def __init__(self, probability):
super(EntryAttr, self).__init__()
if not isinstance(probability, float):
raise ValueError("probability must be a float in (0,1)")
if probability <= 0 or probability >= 1:
raise ValueError("probability must be a float in (0,1)")
self._name = "probability_entry"
self._probability = probability
def _to_attr(self):
return ":".join([self._name, str(self._probability)])
class CountFilterEntry(EntryAttr):
"""
Examples:
.. code-block:: python
import paddle
sparse_feature_dim = 1024
embedding_size = 64
entry = paddle.distributed.CountFilterEntry(10)
input = paddle.static.data(name='ins', shape=[1], dtype='int64')
emb = paddle.static.nn.sparse_embedding((
input=input,
size=[sparse_feature_dim, embedding_size],
is_test=False,
entry=entry,
param_attr=paddle.ParamAttr(name="SparseFeatFactors",
initializer=paddle.nn.initializer.Uniform()))
"""
def __init__(self, count_filter):
super(EntryAttr, self).__init__()
if not isinstance(count_filter, int):
raise ValueError(
"count_filter must be a valid integer greater than 0")
if count_filter < 0:
raise ValueError(
"count_filter must be a valid integer greater or equal than 0")
self._name = "count_filter_entry"
self._count_filter = count_filter
def _to_attr(self):
return ":".join([self._name, str(self._count_filter)])
...@@ -58,6 +58,7 @@ class CommonAccessor: ...@@ -58,6 +58,7 @@ class CommonAccessor:
def __init__(self): def __init__(self):
self.accessor_class = "" self.accessor_class = ""
self.table_name = None self.table_name = None
self.entry = None
self.attrs = [] self.attrs = []
self.params = [] self.params = []
self.dims = [] self.dims = []
...@@ -93,6 +94,24 @@ class CommonAccessor: ...@@ -93,6 +94,24 @@ class CommonAccessor:
self.opt_input_map = opt_input_map self.opt_input_map = opt_input_map
self.opt_init_map = opt_init_map self.opt_init_map = opt_init_map
def parse_entry(self, varname, o_main_program):
from paddle.fluid.incubate.fleet.parameter_server.ir.public import is_distributed_sparse_op
from paddle.fluid.incubate.fleet.parameter_server.ir.public import is_sparse_op
for op in o_main_program.global_block().ops:
if not is_distributed_sparse_op(op) and not is_sparse_op(op):
continue
param_name = op.input("W")[0]
if param_name == varname and op.type == "lookup_table":
self.entry = op.attr('entry')
break
if param_name == varname and op.type == "lookup_table_v2":
self.entry = "none"
break
def get_shard(self, total_dim, shard_num, pserver_id): def get_shard(self, total_dim, shard_num, pserver_id):
# remainder = total_dim % shard_num # remainder = total_dim % shard_num
blocksize = int(total_dim / shard_num + 1) blocksize = int(total_dim / shard_num + 1)
...@@ -188,6 +207,8 @@ class CommonAccessor: ...@@ -188,6 +207,8 @@ class CommonAccessor:
if self.table_name: if self.table_name:
attrs += "table_name: \"{}\" ".format(self.table_name) attrs += "table_name: \"{}\" ".format(self.table_name)
if self.entry:
attrs += "entry: \"{}\" ".format(self.entry)
attrs += "trainer_num: {} ".format(self.trainer_num) attrs += "trainer_num: {} ".format(self.trainer_num)
attrs += "sync: {} ".format(self.sync) attrs += "sync: {} ".format(self.sync)
...@@ -655,36 +676,31 @@ class TheOnePSRuntime(RuntimeBase): ...@@ -655,36 +676,31 @@ class TheOnePSRuntime(RuntimeBase):
use_origin_program=True, use_origin_program=True,
split_dense_table=self.role_maker. split_dense_table=self.role_maker.
_is_heter_parameter_server_mode) _is_heter_parameter_server_mode)
tables = [] tables = []
for idx, (name, ctx) in enumerate(send_ctx.items()): for idx, (name, ctx) in enumerate(send_ctx.items()):
if ctx.is_tensor_table() or len(ctx.origin_varnames()) < 1:
continue
table = Table() table = Table()
table.id = ctx.table_id() table.id = ctx.table_id()
common = CommonAccessor()
if ctx.is_tensor_table():
continue
if ctx.is_sparse(): if ctx.is_sparse():
if len(ctx.origin_varnames()) < 1:
continue
table.type = "PS_SPARSE_TABLE" table.type = "PS_SPARSE_TABLE"
table.shard_num = 256
if self.compiled_strategy.is_geo_mode(): if self.compiled_strategy.is_geo_mode():
table.table_class = "SparseGeoTable" table.table_class = "SparseGeoTable"
else: else:
table.table_class = "CommonSparseTable" table.table_class = "CommonSparseTable"
table.shard_num = 256
else:
if len(ctx.origin_varnames()) < 1:
continue
table.type = "PS_DENSE_TABLE"
table.table_class = "CommonDenseTable"
table.shard_num = 256
common = CommonAccessor()
if ctx.is_sparse():
common.table_name = self.compiled_strategy.grad_name_to_param_name[ common.table_name = self.compiled_strategy.grad_name_to_param_name[
ctx.origin_varnames()[0]] ctx.origin_varnames()[0]]
else: else:
table.type = "PS_DENSE_TABLE"
table.table_class = "CommonDenseTable"
table.shard_num = 256
common.table_name = "MergedDense" common.table_name = "MergedDense"
common.parse_by_optimizer(ctx.origin_varnames()[0], common.parse_by_optimizer(ctx.origin_varnames()[0],
...@@ -693,6 +709,10 @@ class TheOnePSRuntime(RuntimeBase): ...@@ -693,6 +709,10 @@ class TheOnePSRuntime(RuntimeBase):
else ctx.sections()[0], else ctx.sections()[0],
self.compiled_strategy) self.compiled_strategy)
if ctx.is_sparse():
common.parse_entry(common.table_name,
self.origin_main_program)
if is_sync: if is_sync:
common.sync = "true" common.sync = "true"
else: else:
......
...@@ -46,7 +46,6 @@ from paddle.fluid.data_feeder import check_variable_and_dtype, check_type, check ...@@ -46,7 +46,6 @@ from paddle.fluid.data_feeder import check_variable_and_dtype, check_type, check
from paddle.fluid import core from paddle.fluid import core
from paddle.fluid.param_attr import ParamAttr from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.entry_attr import ProbabilityEntry, CountFilterEntry
from paddle.fluid.framework import Variable, convert_np_dtype_to_dtype_ from paddle.fluid.framework import Variable, convert_np_dtype_to_dtype_
from paddle.fluid.layers import slice, reshape from paddle.fluid.layers import slice, reshape
...@@ -993,11 +992,13 @@ def sparse_embedding(input, ...@@ -993,11 +992,13 @@ def sparse_embedding(input,
entry_str = "none" entry_str = "none"
if entry is not None: if entry is not None:
if not isinstance(entry, ProbabilityEntry) and not isinstance( if entry.__class__.__name__ not in [
entry, CountFilterEntry): "ProbabilityEntry", "CountFilterEntry"
]:
raise ValueError( raise ValueError(
"entry must be instance in [ProbabilityEntry, CountFilterEntry]") "entry must be instance in [paddle.distributed.ProbabilityEntry, paddle.distributed.CountFilterEntry]"
entry_str = entry.to_attr() )
entry_str = entry._to_attr()
helper.append_op( helper.append_op(
type='lookup_table', type='lookup_table',
......
...@@ -28,7 +28,7 @@ class EntryAttr(object): ...@@ -28,7 +28,7 @@ class EntryAttr(object):
def __init__(self): def __init__(self):
self._name = None self._name = None
def to_attr(self): def _to_attr(self):
""" """
Returns the attributes of this parameter. Returns the attributes of this parameter.
...@@ -51,7 +51,7 @@ class ProbabilityEntry(EntryAttr): ...@@ -51,7 +51,7 @@ class ProbabilityEntry(EntryAttr):
self._name = "probability_entry" self._name = "probability_entry"
self._probability = probability self._probability = probability
def to_attr(self): def _to_attr(self):
return ":".join([self._name, str(self._probability)]) return ":".join([self._name, str(self._probability)])
...@@ -70,5 +70,5 @@ class CountFilterEntry(EntryAttr): ...@@ -70,5 +70,5 @@ class CountFilterEntry(EntryAttr):
self._name = "count_filter_entry" self._name = "count_filter_entry"
self._count_filter = count_filter self._count_filter = count_filter
def to_attr(self): def _to_attr(self):
return ":".join([self._name, str(self._count_filter)]) return ":".join([self._name, str(self._count_filter)])
...@@ -172,8 +172,21 @@ def distributed_ops_pass(program, config): ...@@ -172,8 +172,21 @@ def distributed_ops_pass(program, config):
"lookup_table_version": op_type "lookup_table_version": op_type
}) })
else: else:
raise ValueError( for i in range(len(inputs_idxs)):
"something wrong with Fleet, submit a issue is recommended") distributed_idx = op_idxs[i] + 1
program.global_block()._insert_op(
index=distributed_idx,
type="distributed_lookup_table",
inputs={"Ids": [inputs[i]],
'W': w},
outputs={"Outputs": [outputs[i]]},
attrs={
"is_distributed": is_distributed,
"padding_idx": padding_idx,
"table_id": table_id,
"lookup_table_version": op_type
})
pull_sparse_ops = _get_pull_sparse_ops(program) pull_sparse_ops = _get_pull_sparse_ops(program)
_pull_sparse_fuse(program, pull_sparse_ops) _pull_sparse_fuse(program, pull_sparse_ops)
......
...@@ -14,21 +14,24 @@ ...@@ -14,21 +14,24 @@
from __future__ import print_function from __future__ import print_function
import paddle
paddle.enable_static()
import unittest import unittest
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.entry_attr import ProbabilityEntry, CountFilterEntry from paddle.distributed import ProbabilityEntry, CountFilterEntry
class EntryAttrChecks(unittest.TestCase): class EntryAttrChecks(unittest.TestCase):
def base(self): def base(self):
with self.assertRaises(NotImplementedError): with self.assertRaises(NotImplementedError):
import paddle.fluid.entry_attr as entry from paddle.distributed.entry_attr import EntryAttr
base = entry.EntryAttr() base = EntryAttr()
base.to_attr() base._to_attr()
def probability_entry(self): def probability_entry(self):
prob = ProbabilityEntry(0.5) prob = ProbabilityEntry(0.5)
ss = prob.to_attr() ss = prob._to_attr()
self.assertEqual("probability_entry:0.5", ss) self.assertEqual("probability_entry:0.5", ss)
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
...@@ -39,7 +42,7 @@ class EntryAttrChecks(unittest.TestCase): ...@@ -39,7 +42,7 @@ class EntryAttrChecks(unittest.TestCase):
def countfilter_entry(self): def countfilter_entry(self):
counter = CountFilterEntry(20) counter = CountFilterEntry(20)
ss = counter.to_attr() ss = counter._to_attr()
self.assertEqual("count_filter_entry:20", ss) self.assertEqual("count_filter_entry:20", ss)
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
...@@ -61,7 +64,7 @@ class EntryAttrChecks(unittest.TestCase): ...@@ -61,7 +64,7 @@ class EntryAttrChecks(unittest.TestCase):
lod_level=1, lod_level=1,
append_batch_size=False) append_batch_size=False)
prob = ProbabilityEntry(0.5) prob = ProbabilityEntry(0.5)
emb = fluid.contrib.layers.sparse_embedding( emb = paddle.static.nn.sparse_embedding(
input=input, input=input,
size=[100, 10], size=[100, 10],
is_test=False, is_test=False,
......
...@@ -14,6 +14,9 @@ ...@@ -14,6 +14,9 @@
from __future__ import print_function from __future__ import print_function
import paddle
paddle.enable_static()
import unittest import unittest
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.framework import default_main_program from paddle.fluid.framework import default_main_program
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册