提交 bb3ae206 编写于 作者: Y Yang Yang

nccl pass parallel_do test

上级 0d57ca46
...@@ -172,12 +172,18 @@ class ParallelOpTest(BaseParallelForTest): ...@@ -172,12 +172,18 @@ class ParallelOpTest(BaseParallelForTest):
loss = fluid.layers.mean(x=hidden) loss = fluid.layers.mean(x=hidden)
yield loss yield loss
def test_fc_with_tiny_data(self): def test_simple_fc(self):
self.run_test( self.run_test(
callback=self.__network__, callback=self.__network__,
feed={'img': numpy.random.random(size=(8, 784)).astype('float32')}, feed={'img': numpy.random.random(size=(8, 784)).astype('float32')},
fetch=['fc1.w@GRAD']) fetch=['fc1.w@GRAD'])
def test_fc_with_tiny_data(self):
self.run_test(
callback=self.__network__,
feed={'img': numpy.random.random(size=(1, 784)).astype('float32')},
fetch=['fc1.w@GRAD'])
class ParallelOpTestMultipleInput(BaseParallelForTest): class ParallelOpTestMultipleInput(BaseParallelForTest):
@staticmethod @staticmethod
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册