未验证 提交 fc7e3e99 编写于 作者: W WangXi 提交者: GitHub

fix sgd unittest timeout (#33665)

上级 930ca3f4
...@@ -22,6 +22,8 @@ from paddle.fluid.op import Operator ...@@ -22,6 +22,8 @@ from paddle.fluid.op import Operator
from op_test import OpTest from op_test import OpTest
import paddle import paddle
paddle.enable_static()
class TestSGDOp(OpTest): class TestSGDOp(OpTest):
def setUp(self): def setUp(self):
...@@ -226,33 +228,47 @@ class TestSGDV2(unittest.TestCase): ...@@ -226,33 +228,47 @@ class TestSGDV2(unittest.TestCase):
def test_sgd(self): def test_sgd(self):
paddle.enable_static() paddle.enable_static()
place = fluid.CPUPlace()
main = fluid.Program()
with fluid.program_guard(main):
x = fluid.layers.data(name='x', shape=[13], dtype='float32')
y = fluid.layers.data(name='y', shape=[1], dtype='float32')
y_predict = fluid.layers.fc(input=x, size=1, act=None)
cost = fluid.layers.square_error_cost(input=y_predict, label=y)
avg_cost = fluid.layers.mean(cost)
rms_optimizer = paddle.optimizer.SGD(learning_rate=0.1)
rms_optimizer.minimize(avg_cost)
fetch_list = [avg_cost] def check_sgd_optimizer(optimizer_attr):
train_reader = paddle.batch( init_program = paddle.static.Program()
paddle.dataset.uci_housing.train(), batch_size=1) program = paddle.static.Program()
feeder = fluid.DataFeeder(place=place, feed_list=[x, y]) block = program.global_block()
exe = fluid.Executor(place) mul_x = block.create_parameter(
exe.run(fluid.default_startup_program()) dtype="float32",
for data in train_reader(): shape=[5, 10],
exe.run(main, feed=feeder.feed(data), fetch_list=fetch_list) lod_level=0,
name="mul.x",
optimize_attr=optimizer_attr)
mul_y = block.create_var(
dtype="float32", shape=[10, 8], lod_level=0, name="mul.y")
mul_out = block.create_var(
dtype="float32", shape=[5, 8], lod_level=0, name="mul.out")
mean_out = block.create_var(
dtype="float32", shape=[1], lod_level=0, name="mean.out")
block.append_op(
type="mul",
inputs={"X": mul_x,
"Y": mul_y},
outputs={"Out": mul_out},
attrs={"x_num_col_dims": 1})
block.append_op(
type="mean", inputs={"X": mul_out}, outputs={"Out": mean_out})
sgd_optimizer = paddle.optimizer.SGD(learning_rate=0.01)
opts, _ = sgd_optimizer.minimize(mean_out, init_program)
return opts
opts = check_sgd_optimizer({'learning_rate': 1.1})
self.assertEqual(len(opts), 2)
self.assertEqual([op.type for op in opts], ["scale", "sgd"])
opts = check_sgd_optimizer({'learning_rate': 1.0})
self.assertEqual(len(opts), 1)
self.assertEqual([op.type for op in opts], ["sgd"])
def test_raise_error(self): def test_raise_error(self):
self.assertRaises(ValueError, paddle.optimizer.SGD, learning_rate=None) self.assertRaises(ValueError, paddle.optimizer.SGD, learning_rate=None)
def test_sgd_group_dygraph(self):
class TestSGDV2Group(TestSGDV2):
def test_sgd_dygraph(self):
paddle.disable_static() paddle.disable_static()
value = np.arange(26).reshape(2, 13).astype("float32") value = np.arange(26).reshape(2, 13).astype("float32")
a = paddle.to_tensor(value) a = paddle.to_tensor(value)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册