未验证 提交 6efb9f59 编写于 作者: Y yaozhixin 提交者: GitHub

update uts p1 (#39210)

上级 fd44de58
...@@ -856,16 +856,15 @@ if __name__ == "__main__": ...@@ -856,16 +856,15 @@ if __name__ == "__main__":
paddle.static.load(main_prog, "model/ernie") paddle.static.load(main_prog, "model/ernie")
if args.run_on_ipu: if args.run_on_ipu:
ipu_strategy = compiler.get_ipu_strategy() ipu_strategy = paddle.static.IpuStrategy()
ipu_strategy.num_ipus = args.num_ipus ipu_strategy.SetGraphConfig(
ipu_strategy.enable_manual_shard = args.num_ipus > 1 num_ipus=args.num_ipus,
ipu_strategy.enable_pipelining = args.enable_pipelining is_training=args.is_training,
if args.enable_pipelining: enable_manual_shard=args.num_ipus > 1)
if args.is_training: ipu_strategy.SetPipeliningConfig(
ipu_strategy.batches_per_step = args.num_ipus + 1 enable_pipelining=args.enable_pipelining,
else: batches_per_step=args.num_ipus + 1)
ipu_strategy.batches_per_step = args.num_ipus
ipu_strategy.is_training = args.is_training
ipu_compiler = compiler.IPUCompiledProgram( ipu_compiler = compiler.IPUCompiledProgram(
main_prog, ipu_strategy=ipu_strategy) main_prog, ipu_strategy=ipu_strategy)
program = ipu_compiler.compile(feed_list, fetch_list) program = ipu_compiler.compile(feed_list, fetch_list)
......
...@@ -72,8 +72,8 @@ class TestRelu(IPUOpTest): ...@@ -72,8 +72,8 @@ class TestRelu(IPUOpTest):
if run_ipu: if run_ipu:
feed_list = self.feed_list feed_list = self.feed_list
ipu_strategy = compiler.get_ipu_strategy() ipu_strategy = paddle.static.IpuStrategy()
ipu_strategy.is_training = self.is_training ipu_strategy.SetGraphConfig(is_training=self.is_training)
program = compiler.IpuCompiler( program = compiler.IpuCompiler(
main_prog, main_prog,
ipu_strategy=ipu_strategy).compile(feed_list, fetch_list) ipu_strategy=ipu_strategy).compile(feed_list, fetch_list)
......
...@@ -81,10 +81,12 @@ class TestBase(IPUOpTest): ...@@ -81,10 +81,12 @@ class TestBase(IPUOpTest):
if run_ipu: if run_ipu:
feed_list = self.feed_list feed_list = self.feed_list
ipu_strategy = compiler.get_ipu_strategy() ipu_strategy = paddle.static.IpuStrategy()
ipu_strategy.is_training = self.is_training ipu_strategy.SetGraphConfig(
# enable avg shard pass num_ipus=2,
ipu_strategy.need_avg_shard = True is_training=self.is_training,
enable_manual_shard=True,
need_avg_shard=True)
program = compiler.IPUCompiledProgram( program = compiler.IPUCompiledProgram(
main_prog, main_prog,
ipu_strategy=ipu_strategy).compile(feed_list, fetch_list) ipu_strategy=ipu_strategy).compile(feed_list, fetch_list)
......
...@@ -79,8 +79,8 @@ class TestBase(IPUOpTest): ...@@ -79,8 +79,8 @@ class TestBase(IPUOpTest):
if run_ipu: if run_ipu:
feed_list = self.feed_list feed_list = self.feed_list
ipu_strategy = compiler.get_ipu_strategy() ipu_strategy = paddle.static.IpuStrategy()
ipu_strategy.is_training = self.is_training ipu_strategy.SetGraphConfig(is_training=self.is_training)
program = compiler.IPUCompiledProgram( program = compiler.IPUCompiledProgram(
main_prog, main_prog,
ipu_strategy=ipu_strategy).compile(feed_list, fetch_list) ipu_strategy=ipu_strategy).compile(feed_list, fetch_list)
......
...@@ -81,8 +81,8 @@ class TestBase(IPUOpTest): ...@@ -81,8 +81,8 @@ class TestBase(IPUOpTest):
if run_ipu: if run_ipu:
feed_list = self.feed_list feed_list = self.feed_list
ipu_strategy = compiler.get_ipu_strategy() ipu_strategy = paddle.static.IpuStrategy()
ipu_strategy.is_training = self.is_training ipu_strategy.SetGraphConfig(is_training=self.is_training)
program = compiler.IPUCompiledProgram( program = compiler.IPUCompiledProgram(
main_prog, main_prog,
ipu_strategy=ipu_strategy).compile(feed_list, fetch_list) ipu_strategy=ipu_strategy).compile(feed_list, fetch_list)
......
...@@ -83,8 +83,8 @@ class TestBase(IPUOpTest): ...@@ -83,8 +83,8 @@ class TestBase(IPUOpTest):
if run_ipu: if run_ipu:
feed_list = self.feed_list feed_list = self.feed_list
ipu_strategy = compiler.get_ipu_strategy() ipu_strategy = paddle.static.IpuStrategy()
ipu_strategy.is_training = self.is_training ipu_strategy.SetGraphConfig(is_training=self.is_training)
program = compiler.IPUCompiledProgram( program = compiler.IPUCompiledProgram(
main_prog, main_prog,
ipu_strategy=ipu_strategy).compile(feed_list, fetch_list) ipu_strategy=ipu_strategy).compile(feed_list, fetch_list)
......
...@@ -81,8 +81,8 @@ class TestBase(IPUOpTest): ...@@ -81,8 +81,8 @@ class TestBase(IPUOpTest):
if run_ipu: if run_ipu:
feed_list = self.feed_list feed_list = self.feed_list
ipu_strategy = compiler.get_ipu_strategy() ipu_strategy = paddle.static.IpuStrategy()
ipu_strategy.is_training = self.is_training ipu_strategy.SetGraphConfig(is_training=self.is_training)
program = compiler.IPUCompiledProgram( program = compiler.IPUCompiledProgram(
main_prog, main_prog,
ipu_strategy=ipu_strategy).compile(feed_list, fetch_list) ipu_strategy=ipu_strategy).compile(feed_list, fetch_list)
......
...@@ -93,8 +93,8 @@ class TestBase(IPUOpTest): ...@@ -93,8 +93,8 @@ class TestBase(IPUOpTest):
if run_ipu: if run_ipu:
feed_list = self.feed_list feed_list = self.feed_list
ipu_strategy = compiler.get_ipu_strategy() ipu_strategy = paddle.static.IpuStrategy()
ipu_strategy.is_training = self.is_training ipu_strategy.SetGraphConfig(is_training=self.is_training)
program = compiler.IPUCompiledProgram( program = compiler.IPUCompiledProgram(
main_prog, main_prog,
ipu_strategy=ipu_strategy).compile(feed_list, fetch_list) ipu_strategy=ipu_strategy).compile(feed_list, fetch_list)
......
...@@ -83,8 +83,8 @@ class TestBase(IPUOpTest): ...@@ -83,8 +83,8 @@ class TestBase(IPUOpTest):
if run_ipu: if run_ipu:
feed_list = self.feed_list feed_list = self.feed_list
ipu_strategy = compiler.get_ipu_strategy() ipu_strategy = paddle.static.IpuStrategy()
ipu_strategy.is_training = self.is_training ipu_strategy.SetGraphConfig(is_training=self.is_training)
program = compiler.IPUCompiledProgram( program = compiler.IPUCompiledProgram(
main_prog, main_prog,
ipu_strategy=ipu_strategy).compile(feed_list, fetch_list) ipu_strategy=ipu_strategy).compile(feed_list, fetch_list)
......
...@@ -75,8 +75,8 @@ class TestMul(IPUOpTest): ...@@ -75,8 +75,8 @@ class TestMul(IPUOpTest):
if run_ipu: if run_ipu:
feed_list = self.feed_list feed_list = self.feed_list
ipu_strategy = compiler.get_ipu_strategy() ipu_strategy = paddle.static.IpuStrategy()
ipu_strategy.is_training = self.is_training ipu_strategy.SetGraphConfig(is_training=self.is_training)
program = compiler.IPUCompiledProgram( program = compiler.IPUCompiledProgram(
main_prog, main_prog,
ipu_strategy=ipu_strategy).compile(feed_list, fetch_list) ipu_strategy=ipu_strategy).compile(feed_list, fetch_list)
......
...@@ -84,8 +84,8 @@ class TestBase(IPUOpTest): ...@@ -84,8 +84,8 @@ class TestBase(IPUOpTest):
if run_ipu: if run_ipu:
feed_list = self.feed_list feed_list = self.feed_list
ipu_strategy = compiler.get_ipu_strategy() ipu_strategy = paddle.static.IpuStrategy()
ipu_strategy.is_training = self.is_training ipu_strategy.SetGraphConfig(is_training=self.is_training)
program = compiler.IPUCompiledProgram( program = compiler.IPUCompiledProgram(
main_prog, main_prog,
ipu_strategy=ipu_strategy).compile(feed_list, fetch_list) ipu_strategy=ipu_strategy).compile(feed_list, fetch_list)
......
...@@ -76,8 +76,8 @@ class TestBase(IPUOpTest): ...@@ -76,8 +76,8 @@ class TestBase(IPUOpTest):
if run_ipu: if run_ipu:
feed_list = self.feed_list feed_list = self.feed_list
ipu_strategy = compiler.get_ipu_strategy() ipu_strategy = paddle.static.IpuStrategy()
ipu_strategy.is_training = self.is_training ipu_strategy.SetGraphConfig(is_training=self.is_training)
program = compiler.IPUCompiledProgram( program = compiler.IPUCompiledProgram(
main_prog, main_prog,
ipu_strategy=ipu_strategy).compile(feed_list, fetch_list) ipu_strategy=ipu_strategy).compile(feed_list, fetch_list)
...@@ -142,8 +142,8 @@ class TestCase1(TestBase): ...@@ -142,8 +142,8 @@ class TestCase1(TestBase):
if run_ipu: if run_ipu:
feed_list = self.feed_list feed_list = self.feed_list
ipu_strategy = compiler.get_ipu_strategy() ipu_strategy = paddle.static.IpuStrategy()
ipu_strategy.is_training = self.is_training ipu_strategy.SetGraphConfig(is_training=self.is_training)
program = compiler.IPUCompiledProgram( program = compiler.IPUCompiledProgram(
main_prog, main_prog,
ipu_strategy=ipu_strategy).compile(feed_list, fetch_list) ipu_strategy=ipu_strategy).compile(feed_list, fetch_list)
......
...@@ -78,8 +78,8 @@ class TestBase(IPUOpTest): ...@@ -78,8 +78,8 @@ class TestBase(IPUOpTest):
if run_ipu: if run_ipu:
feed_list = self.feed_list feed_list = self.feed_list
ipu_strategy = compiler.get_ipu_strategy() ipu_strategy = paddle.static.IpuStrategy()
ipu_strategy.is_training = self.is_training ipu_strategy.SetGraphConfig(is_training=self.is_training)
program = compiler.IPUCompiledProgram( program = compiler.IPUCompiledProgram(
main_prog, main_prog,
ipu_strategy=ipu_strategy).compile(feed_list, fetch_list) ipu_strategy=ipu_strategy).compile(feed_list, fetch_list)
......
...@@ -83,8 +83,8 @@ class TestBase(IPUOpTest): ...@@ -83,8 +83,8 @@ class TestBase(IPUOpTest):
if run_ipu: if run_ipu:
feed_list = self.feed_list feed_list = self.feed_list
ipu_strategy = compiler.get_ipu_strategy() ipu_strategy = paddle.static.IpuStrategy()
ipu_strategy.is_training = self.is_training ipu_strategy.SetGraphConfig(is_training=self.is_training)
program = compiler.IPUCompiledProgram( program = compiler.IPUCompiledProgram(
main_prog, main_prog,
ipu_strategy=ipu_strategy).compile(feed_list, fetch_list) ipu_strategy=ipu_strategy).compile(feed_list, fetch_list)
......
...@@ -81,8 +81,8 @@ class TestBase(IPUOpTest): ...@@ -81,8 +81,8 @@ class TestBase(IPUOpTest):
if run_ipu: if run_ipu:
feed_list = self.feed_list feed_list = self.feed_list
ipu_strategy = compiler.get_ipu_strategy() ipu_strategy = paddle.static.IpuStrategy()
ipu_strategy.is_training = self.is_training ipu_strategy.SetGraphConfig(is_training=self.is_training)
program = compiler.IPUCompiledProgram( program = compiler.IPUCompiledProgram(
main_prog, main_prog,
ipu_strategy=ipu_strategy).compile(feed_list, fetch_list) ipu_strategy=ipu_strategy).compile(feed_list, fetch_list)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册