提交 ffd5f44e 编写于 作者: S seiriosPlus

add UT for fuse

上级 d41d716f
......@@ -15,6 +15,8 @@
from __future__ import print_function
import unittest
import numpy as np
import paddle.fluid as fluid
import paddle.fluid.core as core
......@@ -27,8 +29,11 @@ class TestLookupTableFuseOp(unittest.TestCase):
self.check_with_place(place)
def check_with_place(self, place):
scope = core.Scope()
program = fluid.Program()
scope = fluid.global_scope()
init_program = fluid.Program()
lr = scope.var("LearningRate")
lr.get_tensor().set([0.01], place)
ids = [i for i in range(100)]
out = scope.var("output")
......@@ -41,13 +46,13 @@ class TestLookupTableFuseOp(unittest.TestCase):
"embedding_2.block0:Param:8:0:embedding_2@GRAD.block0:embedding_2.block0,kSparseIDs@embedding_2.block0:uniform_random&0&-0.5&0.5:none"
)
program.global_block().append_op(
init_program.global_block().append_op(
type="lookup_sparse_table_init",
inputs=None,
outputs=None,
attrs={"large_scale_metas": metas})
program.global_block().append_op(
init_program.global_block().append_op(
type="lookup_sparse_table_read",
inputs={"Ids": ids},
outputs={"Out": out},
......@@ -57,7 +62,7 @@ class TestLookupTableFuseOp(unittest.TestCase):
"value_names": ["Param", "Moment1", "Moment2"],
})
program.global_block().append_op(
init_program.global_block().append_op(
type="lookup_sparse_table_read",
inputs={"Ids": ids},
outputs={"Out": out},
......@@ -68,7 +73,43 @@ class TestLookupTableFuseOp(unittest.TestCase):
})
executor = fluid.Executor(fluid.CPUPlace())
executor.run(program)
executor.run(init_program)
training_program = fluid.Program()
rows = [0, 1, 2, 3, 4, 5, 6]
row_numel = 7
w_selected_rows = scope.var('W').get_selected_rows()
w_selected_rows.set_height(len(rows))
w_selected_rows.set_rows(rows)
w_array = np.ones((len(rows), row_numel)).astype("float32")
for i in range(len(rows)):
w_array[i] *= i
w_tensor = w_selected_rows.get_tensor()
w_tensor.set(w_array, place)
training_program.global_block().append_op(
type="lookup_sparse_table_fuse_adam",
inputs={"Grad": ids,
"LearningRate": lr},
outputs={"Out": out},
attrs={
"is_entry": False,
"tablename": "embedding_1.block0",
"value_names": ["Param", "Moment1", "Moment2"],
})
training_program.global_block().append_op(
type="lookup_sparse_table_fuse_sgd",
inputs={"Grad": ids,
"LearningRate": lr},
outputs={"Out": out},
attrs={
"is_entry": False,
"tablename": "embedding_2.block0",
"value_names": ["Param"],
})
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册