提交 f702ab74 编写于 作者: J JiabinYang

add dist transpiler test

上级 50fce879
......@@ -875,5 +875,53 @@ class TestRemoteNce(TestDistLookupTableBase):
pass
# test for remote prefetch
class TestRemoteHsigmoid(TestDistLookupTableBase):
def network_with_table(self, is_sparse, is_distributed):
num_total_classes = 10
input = fluid.layers.data(name="input", shape=[10], dtype="float32")
label = fluid.layers.data(name="label", shape=[1], dtype="int64")
path_table = fluid.layers.data(
name='path_table', shape=[10], dtype='int64')
path_code = fluid.layers.data(
name='path_code', shape=[10], dtype='int64')
w_param = fluid.default_main_program().global_block().create_parameter(
shape=[num_total_classes, 10],
dtype='float32',
name='hs_w',
initializer=fluid.initializer.ConstantInitializer())
b_param = fluid.default_main_program().global_block().create_parameter(
shape=[num_total_classes, 1],
dtype='float32',
name='hs_b',
initializer=fluid.initializer.ConstantInitializer())
cost = fluid.layers.hsigmoid(
input=input,
label=label,
num_classes=non_leaf_num,
path_table=path_table,
path_code=path_code,
is_custom=True,
is_sparse=is_sparse)
avg_cost = fluid.layers.mean(cost)
# optimizer
optimizer = fluid.optimizer.SGD(learning_rate=0.003)
optimizer.minimize(avg_cost)
def net_conf(self):
import os
os.environ['PADDLE_ENABLE_REMOTE_PREFETCH'] = "1"
self.network_with_table(is_sparse=True, is_distributed=False)
def transpiler_test_impl(self):
trainer, _ = self.get_trainer()
for op in trainer.blocks[0].ops:
if op.type == "recv":
pass
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册