未验证 提交 09e30eee 编写于 作者: T tangwei12 提交者: GitHub

Merge pull request #12696 from seiriosPlus/lookuptable_size_on_pserver_fix

lookup table size fix
......@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import unittest
import paddle.fluid as fluid
from paddle.fluid.transpiler.distribute_transpiler import delete_ops
......@@ -362,12 +364,13 @@ class TestL2DecayWithPiecewise(TranspilerTest):
class TestDistLookupTableBase(TranspilerTest):
def network_with_table(self, is_sparse, is_distributed):
self.table_size = 1000
self.emb_size = 64
def emb_pool(ids):
table_size = 1000
emb_size = 64
emb = fluid.layers.embedding(
input=ids,
size=[table_size, emb_size],
size=[self.table_size, self.emb_size],
dtype='float32',
param_attr='shared_w', # share parameter
is_sparse=is_sparse,
......@@ -536,6 +539,22 @@ class TestAsyncDistLookupTable(TestDistLookupTableBase):
self.assertEqual([op.type for op in trainer.blocks[0].ops], ops)
class TestDistLookupTableSliceSize(TestDistLookupTableBase):
def net_conf(self):
self.network_with_table(is_sparse=True, is_distributed=True)
def transpiler_test_impl(self):
config = fluid.DistributeTranspilerConfig()
pserver1, startup1 = self.get_pserver(self.pserver1_ep, config)
self.assertTrue(self.transpiler.has_distributed_lookup_table)
lookup_table_var = pserver1.global_block().vars[
self.transpiler.table_name]
row_size = lookup_table_var.shape[0]
calc_row_size = int(math.ceil(self.table_size / self.pservers))
self.assertEqual(row_size, calc_row_size)
class TestRMSPropOptimizer(TranspilerTest):
def net_conf(self):
x = fluid.layers.data(name='x', shape=[1000], dtype='float32')
......
......@@ -885,9 +885,15 @@ class DistributeTranspiler(object):
# create table param and grad var in pserver program
origin_param_var = self.origin_program.global_block().vars[
self.table_name]
zero_dim = int(
math.ceil(origin_param_var.shape[0] / len(self.pserver_endpoints)))
table_shape = list(origin_param_var.shape)
table_shape[0] = zero_dim
param_var = pserver_program.global_block().create_var(
name=origin_param_var.name,
shape=origin_param_var.shape,
shape=table_shape,
dtype=origin_param_var.dtype,
type=core.VarDesc.VarType.SELECTED_ROWS,
persistable=True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册