未验证 提交 4a54a464 编写于 作者: A Adam 提交者: GitHub

Add UT for SGD operator with large inputs (#23195)

上级 e3a078fb
...@@ -16,6 +16,7 @@ from __future__ import print_function ...@@ -16,6 +16,7 @@ from __future__ import print_function
import unittest import unittest
import numpy as np import numpy as np
import paddle.fluid as fluid
import paddle.fluid.core as core import paddle.fluid.core as core
from paddle.fluid.op import Operator from paddle.fluid.op import Operator
from op_test import OpTest from op_test import OpTest
...@@ -186,5 +187,26 @@ class TestSGDOpOptimizeSelectedRows(unittest.TestCase): ...@@ -186,5 +187,26 @@ class TestSGDOpOptimizeSelectedRows(unittest.TestCase):
self.check_with_place(place) self.check_with_place(place)
class TestSGDOpWithLargeInput(unittest.TestCase):
def runTest(self):
data = fluid.layers.fill_constant(shape=[1], value=128, dtype='int64')
label = fluid.layers.fill_constant(
shape=[1, 150], value=0.5, dtype='float32')
emb = fluid.embedding(input=data, size=(10000000, 150), dtype='float32')
out = fluid.layers.l2_normalize(x=emb, axis=-1)
cost = fluid.layers.square_error_cost(input=out, label=label)
avg_cost = fluid.layers.mean(cost)
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001)
sgd_optimizer.minimize(avg_cost)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
compiled_prog = fluid.compiler.CompiledProgram(
fluid.default_main_program())
result = exe.run(compiled_prog, fetch_list=[avg_cost])
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册