未验证 提交 930b209e 编写于 作者: H HongyuJia 提交者: GitHub

[Custom Extension] Add xpu backward testcase (#49027)

* add xpu backward testcase

* polish code

* fix self.custom_op error
上级 1ca86fc6
......@@ -31,6 +31,28 @@ void relu_cpu_forward_kernel(const data_t* x_data,
}
}
template <typename data_t>
void relu_cpu_backward_kernel(const data_t* grad_out_data,
const data_t* out_data,
data_t* grad_x_data,
int64_t out_numel) {
for (int64_t i = 0; i < out_numel; ++i) {
grad_x_data[i] =
grad_out_data[i] * (out_data[i] > static_cast<data_t>(0) ? 1. : 0.);
}
}
template <typename data_t>
void relu_cpu_double_backward_kernel(const data_t* out_data,
const data_t* ddx_data,
data_t* ddout_data,
int64_t ddout_numel) {
for (int64_t i = 0; i < ddout_numel; ++i) {
ddout_data[i] =
ddx_data[i] * (out_data[i] > static_cast<data_t>(0) ? 1. : 0.);
}
}
std::vector<paddle::Tensor> relu_cpu_forward(const paddle::Tensor& x) {
CHECK_CPU_INPUT(x);
auto out = paddle::empty_like(x);
......@@ -44,12 +66,81 @@ std::vector<paddle::Tensor> relu_cpu_forward(const paddle::Tensor& x) {
return {out};
}
std::vector<paddle::Tensor> relu_cpu_backward(const paddle::Tensor& x,
const paddle::Tensor& out,
const paddle::Tensor& grad_out) {
auto grad_x = paddle::empty_like(x);
PD_DISPATCH_FLOATING_TYPES(out.type(), "relu_cpu_backward", ([&] {
relu_cpu_backward_kernel<data_t>(
grad_out.data<data_t>(),
out.data<data_t>(),
grad_x.data<data_t>(),
out.size());
}));
return {grad_x};
}
std::vector<paddle::Tensor> relu_cpu_double_backward(
const paddle::Tensor& out, const paddle::Tensor& ddx) {
CHECK_CPU_INPUT(out);
CHECK_CPU_INPUT(ddx);
auto ddout = paddle::empty(out.shape(), out.dtype(), out.place());
PD_DISPATCH_FLOATING_TYPES(out.type(), "relu_cpu_double_backward", ([&] {
relu_cpu_double_backward_kernel<data_t>(
out.data<data_t>(),
ddx.data<data_t>(),
ddout.mutable_data<data_t>(out.place()),
ddout.size());
}));
std::cout << "Debug info: run relu cpu double backward success." << std::endl;
return {ddout};
}
std::vector<paddle::Tensor> relu_xpu_forward(const paddle::Tensor& x) {
CHECK_XPU_INPUT(x);
auto out = paddle::relu(x);
return {out};
}
std::vector<paddle::Tensor> relu_xpu_backward(const paddle::Tensor& x,
const paddle::Tensor& out,
const paddle::Tensor& grad_out) {
CHECK_XPU_INPUT(x);
CHECK_XPU_INPUT(out);
CHECK_XPU_INPUT(grad_out);
auto grad_x = paddle::empty_like(x, x.dtype(), x.place());
auto ones = paddle::experimental::full_like(x, 1.0, x.dtype(), x.place());
auto zeros = paddle::experimental::full_like(x, 0.0, x.dtype(), x.place());
auto condition = paddle::experimental::greater_than(x, zeros);
grad_x = paddle::multiply(grad_out, paddle::where(condition, ones, zeros));
return {grad_x};
}
std::vector<paddle::Tensor> relu_xpu_double_backward(
const paddle::Tensor& out, const paddle::Tensor& ddx) {
CHECK_XPU_INPUT(out);
CHECK_XPU_INPUT(ddx);
auto ddout = paddle::empty(out.shape(), out.dtype(), out.place());
auto ones =
paddle::experimental::full_like(out, 1.0, out.dtype(), out.place());
auto zeros =
paddle::experimental::full_like(out, 0.0, out.dtype(), out.place());
auto condition = paddle::experimental::greater_than(out, zeros);
ddout = paddle::multiply(ddx, paddle::where(condition, ones, zeros));
std::cout << "Debug info: run relu cpu double backward success." << std::endl;
return {ddout};
}
std::vector<paddle::Tensor> ReluForward(const paddle::Tensor& x) {
if (x.is_cpu()) {
return relu_cpu_forward(x);
......@@ -60,7 +151,47 @@ std::vector<paddle::Tensor> ReluForward(const paddle::Tensor& x) {
}
}
std::vector<paddle::Tensor> ReluBackward(const paddle::Tensor& x,
const paddle::Tensor& out,
const paddle::Tensor& grad_out) {
if (x.is_cpu()) {
return relu_cpu_backward(x, out, grad_out);
} else if (x.is_xpu()) {
return relu_xpu_backward(x, out, grad_out);
} else {
PD_THROW("Not implemented.");
}
}
std::vector<paddle::Tensor> ReluDoubleBackward(const paddle::Tensor& out,
const paddle::Tensor& ddx) {
if (out.place() == paddle::PlaceType::kCPU) {
return relu_cpu_double_backward(out, ddx);
} else if (out.place().GetType() == phi::AllocationType::XPU) {
return relu_xpu_double_backward(out, ddx);
} else {
PD_THROW("Not implemented.");
}
}
std::vector<std::vector<int64_t>> ReluDoubleBackwardInferShape(
const std::vector<int64_t>& out_shape,
const std::vector<int64_t>& ddx_shape) {
return {out_shape};
}
PD_BUILD_OP(custom_relu)
.Inputs({"X"})
.Outputs({"Out"})
.SetKernelFn(PD_KERNEL(ReluForward));
PD_BUILD_GRAD_OP(custom_relu)
.Inputs({"X", "Out", paddle::Grad("Out")})
.Outputs({paddle::Grad("X")})
.SetKernelFn(PD_KERNEL(ReluBackward));
PD_BUILD_DOUBLE_GRAD_OP(custom_relu)
.Inputs({"Out", paddle::Grad(paddle::Grad("X"))})
.Outputs({paddle::Grad(paddle::Grad("Out"))})
.SetKernelFn(PD_KERNEL(ReluDoubleBackward))
.SetInferShapeFn(PD_INFER_SHAPE(ReluDoubleBackwardInferShape));
......@@ -23,15 +23,24 @@ import paddle
import paddle.static as static
from paddle.fluid.framework import _test_eager_guard
from paddle.utils.cpp_extension.extension_utils import run_cmd
from paddle.vision.transforms import Compose, Normalize
def custom_relu_dynamic(func, device, dtype, np_x, use_func=True):
paddle.set_device(device)
t = paddle.to_tensor(np_x, dtype=dtype)
t.stop_gradient = False
out = func(t) if use_func else paddle.nn.functional.relu(t)
out.stop_gradient = False
out.backward()
return out.numpy()
if t.grad is None:
return out.numpy(), t.grad
else:
return out.numpy(), t.grad.numpy()
def custom_relu_static(
......@@ -43,7 +52,9 @@ def custom_relu_static(
with static.scope_guard(static.Scope()):
with static.program_guard(static.Program()):
x = static.data(name='X', shape=[None, 8], dtype=dtype)
x.stop_gradient = False
out = func(x) if use_func else paddle.nn.functional.relu(x)
static.append_backward(out)
exe = static.Executor()
exe.run(static.default_startup_program())
......@@ -58,6 +69,97 @@ def custom_relu_static(
return out_v
def custom_relu_static_pe(func, device, dtype, np_x, use_func=True):
paddle.enable_static()
paddle.set_device(device)
places = static.xpu_places()
with static.scope_guard(static.Scope()):
with static.program_guard(static.Program()):
x = static.data(name='X', shape=[None, 8], dtype=dtype)
x.stop_gradient = False
out = func(x) if use_func else paddle.nn.functional.relu(x)
static.append_backward(out)
exe = static.Executor()
exe.run(static.default_startup_program())
# in static mode, x data has been covered by out
compiled_prog = static.CompiledProgram(
static.default_main_program()
).with_data_parallel(loss_name=out.name, places=places)
out_v = exe.run(
compiled_prog, feed={'X': np_x}, fetch_list=[out.name]
)
paddle.disable_static()
return out_v
def custom_relu_static_inference(func, device, np_data, np_label, path_prefix):
paddle.set_device(device)
with static.scope_guard(static.Scope()):
with static.program_guard(static.Program()):
# simple module
data = static.data(
name='data', shape=[None, 1, 28, 28], dtype='float32'
)
label = static.data(name='label', shape=[None, 1], dtype='int64')
hidden = static.nn.fc(data, size=128)
hidden = func(hidden)
hidden = static.nn.fc(hidden, size=128)
predict = static.nn.fc(hidden, size=10, activation='softmax')
loss = paddle.nn.functional.cross_entropy(input=hidden, label=label)
avg_loss = paddle.mean(loss)
opt = paddle.optimizer.SGD(learning_rate=0.1)
opt.minimize(avg_loss)
# run start up model
exe = static.Executor()
exe.run(static.default_startup_program())
# train
for _ in range(4):
exe.run(
static.default_main_program(),
feed={'data': np_data, 'label': np_label},
fetch_list=[avg_loss],
)
# save inference model
static.save_inference_model(path_prefix, [data], [predict], exe)
# get train predict value
predict_v = exe.run(
static.default_main_program(),
feed={'data': np_data, 'label': np_label},
fetch_list=[predict],
)
return predict_v
def custom_relu_double_grad_dynamic(func, device, dtype, np_x, use_func=True):
paddle.set_device(device)
t = paddle.to_tensor(np_x, dtype=dtype, stop_gradient=False)
out = func(t) if use_func else paddle.nn.functional.relu(t)
out.stop_gradient = False
dx = paddle.grad(
outputs=[out], inputs=[t], create_graph=True, retain_graph=True
)
dx[0].backward()
assert dx[0].grad is not None
return dx[0].numpy(), dx[0].grad.numpy()
class TestNewCustomOpSetUpInstall(unittest.TestCase):
def setUp(self):
cur_dir = os.path.dirname(os.path.abspath(__file__))
......@@ -110,12 +212,30 @@ class TestNewCustomOpSetUpInstall(unittest.TestCase):
),
)
def test_static_pe(self):
for device in self.devices:
for dtype in self.dtypes:
x = np.random.uniform(-1, 1, [4, 8]).astype(dtype)
out = custom_relu_static_pe(self.custom_op, device, dtype, x)
pd_out = custom_relu_static_pe(
self.custom_op, device, dtype, x, False
)
np.testing.assert_array_equal(
out,
pd_out,
err_msg='custom op out: {},\n paddle api out: {}'.format(
out, pd_out
),
)
def func_dynamic(self):
for device in self.devices:
for dtype in self.dtypes:
x = np.random.uniform(-1, 1, [4, 8]).astype(dtype)
out = custom_relu_dynamic(self.custom_op, device, dtype, x)
pd_out = custom_relu_dynamic(
out, x_grad = custom_relu_dynamic(
self.custom_op, device, dtype, x
)
pd_out, pd_x_grad = custom_relu_dynamic(
self.custom_op, device, dtype, x, False
)
np.testing.assert_array_equal(
......@@ -125,12 +245,141 @@ class TestNewCustomOpSetUpInstall(unittest.TestCase):
out, pd_out
),
)
np.testing.assert_array_equal(
x_grad,
pd_x_grad,
err_msg='custom op x grad: {},\n paddle api x grad: {}'.format(
x_grad, pd_x_grad
),
)
def test_dynamic(self):
with _test_eager_guard():
self.func_dynamic()
self.func_dynamic()
def test_static_save_and_load_inference_model(self):
paddle.enable_static()
np_data = np.random.random((1, 1, 28, 28)).astype("float32")
np_label = np.random.random((1, 1)).astype("int64")
path_prefix = "self.custom_op_inference/custom_relu"
for device in self.devices:
predict = custom_relu_static_inference(
self.custom_op, device, np_data, np_label, path_prefix
)
# load inference model
with static.scope_guard(static.Scope()):
exe = static.Executor()
[
inference_program,
feed_target_names,
fetch_targets,
] = static.load_inference_model(path_prefix, exe)
predict_infer = exe.run(
inference_program,
feed={feed_target_names[0]: np_data},
fetch_list=fetch_targets,
)
np.testing.assert_array_equal(
predict,
predict_infer,
err_msg='custom op predict: {},\n custom op infer predict: {}'.format(
predict, predict_infer
),
)
paddle.disable_static()
def test_static_save_and_run_inference_predictor(self):
paddle.enable_static()
np_data = np.random.random((1, 1, 28, 28)).astype("float32")
np_label = np.random.random((1, 1)).astype("int64")
path_prefix = "self.custom_op_inference/custom_relu"
from paddle.inference import Config, create_predictor
for device in self.devices:
predict = custom_relu_static_inference(
self.custom_op, device, np_data, np_label, path_prefix
)
# load inference model
config = Config(
path_prefix + ".pdmodel", path_prefix + ".pdiparams"
)
predictor = create_predictor(config)
input_tensor = predictor.get_input_handle(
predictor.get_input_names()[0]
)
input_tensor.reshape(np_data.shape)
input_tensor.copy_from_cpu(np_data.copy())
predictor.run()
output_tensor = predictor.get_output_handle(
predictor.get_output_names()[0]
)
predict_infer = output_tensor.copy_to_cpu()
self.assertTrue(
np.isclose(predict, predict_infer, rtol=5e-5).any(),
"custom op predict: {},\n custom op infer predict: {}".format(
predict, predict_infer
),
)
paddle.disable_static()
def test_func_double_grad_dynamic(self):
for device in self.devices:
for dtype in self.dtypes:
x = np.random.uniform(-1, 1, [4, 8]).astype(dtype)
out, dx_grad = custom_relu_double_grad_dynamic(
self.custom_op, device, dtype, x
)
pd_out, pd_dx_grad = custom_relu_double_grad_dynamic(
self.custom_op, device, dtype, x, False
)
np.testing.assert_array_equal(
out,
pd_out,
err_msg='custom op out: {},\n paddle api out: {}'.format(
out, pd_out
),
)
np.testing.assert_array_equal(
dx_grad,
pd_dx_grad,
err_msg='custom op dx grad: {},\n paddle api dx grad: {}'.format(
dx_grad, pd_dx_grad
),
)
def test_with_dataloader(self):
for device in self.devices:
paddle.set_device(device)
# data loader
transform = Compose(
[Normalize(mean=[127.5], std=[127.5], data_format='CHW')]
)
train_dataset = paddle.vision.datasets.MNIST(
mode='train', transform=transform
)
train_loader = paddle.io.DataLoader(
train_dataset,
batch_size=64,
shuffle=True,
drop_last=True,
num_workers=0,
)
for batch_id, (image, _) in enumerate(train_loader()):
out = self.custom_op(image)
pd_out = paddle.nn.functional.relu(image)
np.testing.assert_array_equal(
out,
pd_out,
err_msg='custom op out: {},\n paddle api out: {}'.format(
out, pd_out
),
)
if batch_id == 5:
break
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册