未验证 提交 0b597754 编写于 作者: A Allen Guo 提交者: GitHub

add ipu uts (#40205)

上级 fe765cb3
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
import paddle.static
from paddle.fluid.tests.unittests.ipu.op_test_ipu import IPUOpTest, ExecutionMode
@unittest.skipIf(not paddle.is_compiled_with_ipu(),
"core is not compiled with IPU")
class TestBase(IPUOpTest):
def setUp(self):
self.set_atol()
self.set_training()
self.set_data_feed()
self.set_feed_attr()
self.set_op_attrs()
@property
def fp16_enabled(self):
return True
def set_data_feed(self):
data = np.random.uniform(size=[2, 2, 4, 6])
self.feed_fp32 = {"in_0": data.astype(np.float32)}
self.feed_fp16 = {"in_0": data.astype(np.float16)}
def set_feed_attr(self):
self.feed_shape = [x.shape for x in self.feed_fp32.values()]
self.feed_list = list(self.feed_fp32.keys())
def set_op_attrs(self):
self.attrs = {}
self.attrs['axis'] = 1
def _test_base(self, exec_mode):
scope = paddle.static.Scope()
main_prog = paddle.static.Program()
startup_prog = paddle.static.Program()
main_prog.random_seed = self.SEED
startup_prog.random_seed = self.SEED
with paddle.static.scope_guard(scope):
with paddle.static.program_guard(main_prog, startup_prog):
x = paddle.static.data(
name=self.feed_list[0],
shape=self.feed_shape[0],
dtype='float32')
out = paddle.fluid.layers.flatten(x=x, **self.attrs)
fetch_list = [out.name]
if exec_mode == ExecutionMode.CPU_FP32:
place = paddle.CPUPlace()
else:
place = paddle.IPUPlace()
exe = paddle.static.Executor(place)
exe.run(startup_prog)
if exec_mode != ExecutionMode.CPU_FP32:
feed_list = self.feed_list
ipu_strategy = paddle.static.IpuStrategy()
ipu_strategy.set_graph_config(is_training=self.is_training)
if exec_mode == ExecutionMode.IPU_POPART_FP16:
ipu_strategy.set_precision_config(enable_fp16=True)
program = paddle.static.IpuCompiledProgram(
main_prog,
ipu_strategy=ipu_strategy).compile(feed_list, fetch_list)
else:
program = main_prog
feed = self.feed_fp32
if exec_mode > ExecutionMode.IPU_FP32:
feed = self.feed_fp16
result = exe.run(program, feed=feed, fetch_list=fetch_list)
return result[0]
def test_base(self):
output_dict = {}
for mode in ExecutionMode:
if mode > ExecutionMode.IPU_FP32 and not self.fp16_enabled:
break
output_dict[mode] = self._test_base(mode)
self.check(output_dict, check_shape=True)
class TestCase1(TestBase):
def set_op_attrs(self):
self.attrs = {}
self.attrs['axis'] = 0
class TestCase2(TestBase):
def set_op_attrs(self):
self.attrs = {}
self.attrs['axis'] = 2
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import unittest
import paddle
import paddle.static
from paddle.fluid.tests.unittests.ipu.op_test_ipu import IPUOpTest
@unittest.skipIf(not paddle.is_compiled_with_ipu(),
"core is not compiled with IPU")
class TestBase(IPUOpTest):
def setUp(self):
self.set_atol()
self.set_data_feed()
self.set_feed_attr()
self.set_attrs()
def set_atol(self):
self.atol = 1e-6
def set_data_feed(self):
self.feed = {
"image": np.random.uniform(size=[1, 3, 10, 10]).astype('float32'),
}
def set_feed_attr(self):
self.feed_shape = [x.shape for x in self.feed.values()]
self.feed_list = list(self.feed.keys())
self.feed_dtype = [x.dtype for x in self.feed.values()]
def set_attrs(self):
self.attrs = {
"optimizer": 'sgd',
"weight_decay": 0.0,
"loss_scaling": 1.0,
}
def _test_optimizer(self, run_ipu=True):
scope = paddle.static.Scope()
main_prog = paddle.static.Program()
startup_prog = paddle.static.Program()
main_prog.random_seed = self.SEED
startup_prog.random_seed = self.SEED
np.random.seed(self.SEED)
with paddle.static.scope_guard(scope):
with paddle.static.program_guard(main_prog, startup_prog):
image = paddle.static.data(
name='image', shape=[1, 3, 10, 10], dtype='float32')
conv1 = paddle.static.nn.conv2d(
image, num_filters=3, filter_size=3, bias_attr=False)
loss = paddle.mean(conv1)
weight_decay = self.attrs['weight_decay']
opt = paddle.optimizer.SGD(learning_rate=1e-1,
weight_decay=weight_decay)
if self.attrs['optimizer'] == 'adam':
opt = paddle.optimizer.Adam(
learning_rate=1e-1, weight_decay=weight_decay)
elif self.attrs['optimizer'] == 'lamb':
opt = paddle.optimizer.Lamb(
learning_rate=1e-1, lamb_weight_decay=weight_decay)
opt.minimize(loss)
if run_ipu:
place = paddle.IPUPlace()
else:
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
exe.run(startup_prog)
if run_ipu:
feed_list = [image.name]
fetch_list = [loss.name]
ipu_strategy = paddle.static.IpuStrategy()
ipu_strategy.set_graph_config(is_training=True)
ipu_strategy.loss_scaling = self.attrs["loss_scaling"]
program = paddle.static.IpuCompiledProgram(
main_prog, ipu_strategy=ipu_strategy).compile(feed_list,
fetch_list)
else:
program = main_prog
result = []
for epoch in range(100):
loss_res = exe.run(program, feed=self.feed, fetch_list=[loss])
result.append(loss_res)
return np.array(result)
def test(self):
# cpu and ipu dimenstion mismatch, cpu:(100, 1, 1), ipu:(100, 1)
ipu_loss = self._test_optimizer(True).flatten()
cpu_loss = self._test_optimizer(False).flatten()
self.assertTrue(np.allclose(ipu_loss, cpu_loss, atol=self.atol))
@unittest.skip('do not support L2 regularization')
class TestSGD(TestBase):
def set_attrs(self):
self.attrs = {
"optimizer": 'sgd',
"weight_decay": 0.1,
"loss_scaling": 2.0,
}
@unittest.skip('do not support L2 regularization')
class TestAdamCase1(TestBase):
def set_attrs(self):
self.attrs = {
"optimizer": 'adam',
"weight_decay": 0.1,
"loss_scaling": 3.0,
}
class TestAdamCase2(TestBase):
def set_attrs(self):
self.attrs = {
"optimizer": 'adam',
"weight_decay": 0.0,
"loss_scaling": 4.0,
}
@unittest.skip('seems cpu output wrong')
class TestLambCase1(TestBase):
def set_attrs(self):
self.attrs = {
"optimizer": 'lamb',
"weight_decay": 0.0,
"loss_scaling": 5.0,
}
@unittest.skip('seems cpu output wrong')
class TestLamb(TestBase):
def set_attrs(self):
self.attrs = {
"optimizer": 'lamb',
"weight_decay": 0.1,
"loss_scaling": 6.0,
}
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册