未验证 提交 9f5afa62 编写于 作者: H HongyuJia 提交者: GitHub

[Custom Extension] Polish xpu testcase (#49158)

* clean custom_xpu testcase test_static_pe

* use assert_allclose to solve precision error

* adjust precision

* flatten tensor

* fix flatten
上级 72973d5a
...@@ -150,26 +150,29 @@ def custom_relu_double_grad_dynamic(func, device, dtype, np_x, use_func=True): ...@@ -150,26 +150,29 @@ def custom_relu_double_grad_dynamic(func, device, dtype, np_x, use_func=True):
t = paddle.to_tensor(np_x, dtype=dtype, stop_gradient=False) t = paddle.to_tensor(np_x, dtype=dtype, stop_gradient=False)
out = func(t) if use_func else paddle.nn.functional.relu(t) out = func(t) if use_func else paddle.nn.functional.relu(t)
out.stop_gradient = False
dx = paddle.grad( dx = paddle.grad(
outputs=[out], inputs=[t], create_graph=True, retain_graph=True outputs=out,
inputs=t,
grad_outputs=paddle.ones_like(t),
create_graph=True,
retain_graph=True,
) )
dx[0].backward() ddout = paddle.grad(
outputs=dx[0],
assert dx[0].grad is not None inputs=out.grad,
return dx[0].numpy(), dx[0].grad.numpy() grad_outputs=paddle.ones_like(t),
create_graph=False,
)
fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": False}) fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": False})
assert ddout[0].numpy() is not None
return dx[0].numpy(), ddout[0].numpy()
class TestNewCustomOpSetUpInstall(unittest.TestCase):
class TestNewCustomOpXpuSetUpInstall(unittest.TestCase):
def setUp(self): def setUp(self):
cur_dir = os.path.dirname(os.path.abspath(__file__)) cur_dir = os.path.dirname(os.path.abspath(__file__))
# compile, install the custom op egg into site-packages under background
# Currently custom XPU op does not support Windows
if os.name == 'nt':
return
cmd = 'cd {} && {} custom_relu_xpu_setup.py install'.format( cmd = 'cd {} && {} custom_relu_xpu_setup.py install'.format(
cur_dir, sys.executable cur_dir, sys.executable
) )
...@@ -192,7 +195,7 @@ class TestNewCustomOpSetUpInstall(unittest.TestCase): ...@@ -192,7 +195,7 @@ class TestNewCustomOpSetUpInstall(unittest.TestCase):
self.custom_op = custom_relu_xpu_module_setup.custom_relu self.custom_op = custom_relu_xpu_module_setup.custom_relu
self.dtypes = ['float32', 'float64'] self.dtypes = ['float32', 'float64']
self.devices = ['xpu'] self.device = 'xpu'
# config seed # config seed
SEED = 2021 SEED = 2021
...@@ -200,91 +203,90 @@ class TestNewCustomOpSetUpInstall(unittest.TestCase): ...@@ -200,91 +203,90 @@ class TestNewCustomOpSetUpInstall(unittest.TestCase):
paddle.framework.random._manual_program_seed(SEED) paddle.framework.random._manual_program_seed(SEED)
def test_static(self): def test_static(self):
for device in self.devices: for dtype in self.dtypes:
for dtype in self.dtypes: x = np.random.uniform(-1, 1, [4, 8]).astype(dtype)
x = np.random.uniform(-1, 1, [4, 8]).astype(dtype) out = custom_relu_static(self.custom_op, self.device, dtype, x)
out = custom_relu_static(self.custom_op, device, dtype, x) pd_out = custom_relu_static(
pd_out = custom_relu_static( self.custom_op, self.device, dtype, x, False
self.custom_op, device, dtype, x, False )
) np.testing.assert_array_equal(
np.testing.assert_array_equal( out,
out, pd_out,
pd_out, err_msg='custom op out: {},\n paddle api out: {}'.format(
err_msg='custom op out: {},\n paddle api out: {}'.format( out, pd_out
out, pd_out ),
), )
)
def test_static_pe(self): def test_static_pe(self):
for device in self.devices: for dtype in self.dtypes:
for dtype in self.dtypes: x = np.random.uniform(-1, 1, [4, 8]).astype(dtype)
x = np.random.uniform(-1, 1, [4, 8]).astype(dtype) out = custom_relu_static_pe(self.custom_op, self.device, dtype, x)
out = custom_relu_static_pe(self.custom_op, device, dtype, x) pd_out = custom_relu_static_pe(
pd_out = custom_relu_static_pe( self.custom_op, self.device, dtype, x, False
self.custom_op, device, dtype, x, False )
) np.testing.assert_allclose(
np.testing.assert_array_equal( out,
out, pd_out,
pd_out, atol=1e-2,
err_msg='custom op out: {},\n paddle api out: {}'.format( err_msg='custom op out: {},\n paddle api out: {}'.format(
out, pd_out out, pd_out
), ),
) )
def test_dynamic(self): def test_dynamic(self):
for device in self.devices: for dtype in self.dtypes:
for dtype in self.dtypes: x = np.random.uniform(-1, 1, [4, 8]).astype(dtype)
x = np.random.uniform(-1, 1, [4, 8]).astype(dtype) out, x_grad = custom_relu_dynamic(
out, x_grad = custom_relu_dynamic( self.custom_op, self.device, dtype, x
self.custom_op, device, dtype, x )
) pd_out, pd_x_grad = custom_relu_dynamic(
pd_out, pd_x_grad = custom_relu_dynamic( self.custom_op, self.device, dtype, x, False
self.custom_op, device, dtype, x, False )
) np.testing.assert_array_equal(
np.testing.assert_array_equal( out,
out, pd_out,
pd_out, err_msg='custom op out: {},\n paddle api out: {}'.format(
err_msg='custom op out: {},\n paddle api out: {}'.format( out, pd_out
out, pd_out ),
), )
) np.testing.assert_array_equal(
np.testing.assert_array_equal( x_grad,
x_grad, pd_x_grad,
pd_x_grad, err_msg='custom op x grad: {},\n paddle api x grad: {}'.format(
err_msg='custom op x grad: {},\n paddle api x grad: {}'.format( x_grad, pd_x_grad
x_grad, pd_x_grad ),
), )
)
def test_static_save_and_load_inference_model(self): def test_static_save_and_load_inference_model(self):
paddle.enable_static() paddle.enable_static()
np_data = np.random.random((1, 1, 28, 28)).astype("float32") np_data = np.random.random((1, 1, 28, 28)).astype("float32")
np_label = np.random.random((1, 1)).astype("int64") np_label = np.random.random((1, 1)).astype("int64")
path_prefix = "self.custom_op_inference/custom_relu" path_prefix = "self.custom_op_inference/custom_relu"
for device in self.devices:
predict = custom_relu_static_inference( predict = custom_relu_static_inference(
self.custom_op, device, np_data, np_label, path_prefix self.custom_op, self.device, np_data, np_label, path_prefix
)
# load inference model
with static.scope_guard(static.Scope()):
exe = static.Executor()
[
inference_program,
feed_target_names,
fetch_targets,
] = static.load_inference_model(path_prefix, exe)
predict_infer = exe.run(
inference_program,
feed={feed_target_names[0]: np_data},
fetch_list=fetch_targets,
)
np.testing.assert_allclose(
predict,
predict_infer,
atol=1e-2,
err_msg='custom op predict: {},\n custom op infer predict: {}'.format(
predict, predict_infer
),
) )
# load inference model
with static.scope_guard(static.Scope()):
exe = static.Executor()
[
inference_program,
feed_target_names,
fetch_targets,
] = static.load_inference_model(path_prefix, exe)
predict_infer = exe.run(
inference_program,
feed={feed_target_names[0]: np_data},
fetch_list=fetch_targets,
)
np.testing.assert_array_equal(
predict,
predict_infer,
err_msg='custom op predict: {},\n custom op infer predict: {}'.format(
predict, predict_infer
),
)
paddle.disable_static() paddle.disable_static()
def test_static_save_and_run_inference_predictor(self): def test_static_save_and_run_inference_predictor(self):
...@@ -294,92 +296,97 @@ class TestNewCustomOpSetUpInstall(unittest.TestCase): ...@@ -294,92 +296,97 @@ class TestNewCustomOpSetUpInstall(unittest.TestCase):
path_prefix = "self.custom_op_inference/custom_relu" path_prefix = "self.custom_op_inference/custom_relu"
from paddle.inference import Config, create_predictor from paddle.inference import Config, create_predictor
for device in self.devices: predict = custom_relu_static_inference(
predict = custom_relu_static_inference( self.custom_op, self.device, np_data, np_label, path_prefix
self.custom_op, device, np_data, np_label, path_prefix )
) # load inference model
# load inference model config = Config(path_prefix + ".pdmodel", path_prefix + ".pdiparams")
config = Config( predictor = create_predictor(config)
path_prefix + ".pdmodel", path_prefix + ".pdiparams" input_tensor = predictor.get_input_handle(
predictor.get_input_names()[0]
)
input_tensor.reshape(np_data.shape)
input_tensor.copy_from_cpu(np_data.copy())
predictor.run()
output_tensor = predictor.get_output_handle(
predictor.get_output_names()[0]
)
predict_infer = output_tensor.copy_to_cpu()
predict = np.array(predict).flatten()
predict_infer = np.array(predict_infer).flatten()
np.testing.assert_allclose(
predict,
predict_infer,
rtol=5e-5,
atol=1e-2,
err_msg="custom op predict: {},\n custom op infer predict: {}".format(
predict, predict_infer
),
)
paddle.disable_static()
def test_func_double_grad_dynamic(self):
for dtype in self.dtypes:
x = np.random.uniform(-1, 1, [4, 8]).astype(dtype)
out, dx_grad = custom_relu_double_grad_dynamic(
self.custom_op, self.device, dtype, x
) )
predictor = create_predictor(config) pd_out, pd_dx_grad = custom_relu_double_grad_dynamic(
input_tensor = predictor.get_input_handle( self.custom_op, self.device, dtype, x, False
predictor.get_input_names()[0]
) )
input_tensor.reshape(np_data.shape) np.testing.assert_array_equal(
input_tensor.copy_from_cpu(np_data.copy()) out,
predictor.run() pd_out,
output_tensor = predictor.get_output_handle( err_msg='custom op out: {},\n paddle api out: {}'.format(
predictor.get_output_names()[0] out, pd_out
),
) )
predict_infer = output_tensor.copy_to_cpu() np.testing.assert_array_equal(
self.assertTrue( dx_grad,
np.isclose(predict, predict_infer, rtol=5e-5).any(), pd_dx_grad,
"custom op predict: {},\n custom op infer predict: {}".format( err_msg='custom op dx grad: {},\n paddle api dx grad: {}'.format(
predict, predict_infer dx_grad, pd_dx_grad
), ),
) )
paddle.disable_static()
def test_func_double_grad_dynamic(self):
for device in self.devices:
for dtype in self.dtypes:
x = np.random.uniform(-1, 1, [4, 8]).astype(dtype)
out, dx_grad = custom_relu_double_grad_dynamic(
self.custom_op, device, dtype, x
)
pd_out, pd_dx_grad = custom_relu_double_grad_dynamic(
self.custom_op, device, dtype, x, False
)
np.testing.assert_array_equal(
out,
pd_out,
err_msg='custom op out: {},\n paddle api out: {}'.format(
out, pd_out
),
)
np.testing.assert_array_equal(
dx_grad,
pd_dx_grad,
err_msg='custom op dx grad: {},\n paddle api dx grad: {}'.format(
dx_grad, pd_dx_grad
),
)
def test_with_dataloader(self): def test_with_dataloader(self):
paddle.disable_static() paddle.disable_static()
for device in self.devices: paddle.set_device(self.device)
paddle.set_device(device) # data loader
# data loader transform = Compose(
transform = Compose( [Normalize(mean=[127.5], std=[127.5], data_format='CHW')]
[Normalize(mean=[127.5], std=[127.5], data_format='CHW')] )
) train_dataset = paddle.vision.datasets.MNIST(
train_dataset = paddle.vision.datasets.MNIST( mode='train', transform=transform
mode='train', transform=transform )
) train_loader = paddle.io.DataLoader(
train_loader = paddle.io.DataLoader( train_dataset,
train_dataset, batch_size=64,
batch_size=64, shuffle=True,
shuffle=True, drop_last=True,
drop_last=True, num_workers=0,
num_workers=0, )
)
for batch_id, (image, _) in enumerate(train_loader()): for batch_id, (image, _) in enumerate(train_loader()):
out = self.custom_op(image) out = self.custom_op(image)
pd_out = paddle.nn.functional.relu(image) pd_out = paddle.nn.functional.relu(image)
np.testing.assert_array_equal( np.testing.assert_allclose(
out, out,
pd_out, pd_out,
err_msg='custom op out: {},\n paddle api out: {}'.format( atol=1e-2,
out, pd_out err_msg='custom op out: {},\n paddle api out: {}'.format(
), out, pd_out
) ),
)
if batch_id == 5: if batch_id == 5:
break break
paddle.enable_static() paddle.enable_static()
if __name__ == '__main__': if __name__ == '__main__':
# compile, install the custom op egg into site-packages under background
# Currently custom XPU op does not support Windows
if os.name == 'nt':
exit()
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册