未验证 提交 c30e7b81 编写于 作者: H HongyuJia 提交者: GitHub

[CustomOP Unittest] Add unit test for outputs with discrete order (#52348)

上级 496bbeab
......@@ -65,9 +65,58 @@ std::vector<paddle::DataType> InferDtype(paddle::DataType x_dtype) {
return {x_dtype, paddle::DataType::FLOAT64, paddle::DataType::INT32};
}
// out = w * 1 + x * 2 + y * 3 + z * 4
std::vector<paddle::Tensor> DiscreteOutForward(const paddle::Tensor& w,
const paddle::Tensor& x,
const paddle::Tensor& y,
const paddle::Tensor& z) {
paddle::Tensor out = w * 1 + x * 2 + y * 3 + z * 4;
return {out};
}
std::vector<std::vector<int64_t>> DiscreteOutInferShape(
const std::vector<int64_t>& w_shape,
const std::vector<int64_t>& x_shape,
const std::vector<int64_t>& y_shape,
const std::vector<int64_t>& z_shape) {
return {w_shape};
}
std::vector<paddle::DataType> DiscreteOutInferDtype(
const paddle::DataType& w_dtype,
const paddle::DataType& x_dtype,
const paddle::DataType& y_dtype,
const paddle::DataType& z_dtype) {
return {w_dtype};
}
// w_grad = out_grad
// y_grad = out_grad * 3
std::vector<paddle::Tensor> DiscreteOutBackward(
const paddle::Tensor& w,
const paddle::Tensor& x,
const paddle::Tensor& y,
const paddle::Tensor& z,
const paddle::Tensor& out_grad) {
return {out_grad, out_grad * 3};
}
PD_BUILD_OP(multi_out)
.Inputs({"X"})
.Outputs({"Out", "Fake_float64", "ZFake_int32"})
.SetKernelFn(PD_KERNEL(MultiOutCPU))
.SetInferShapeFn(PD_INFER_SHAPE(InferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(InferDtype));
PD_BUILD_OP(discrete_out)
.Inputs({"w", "x", "y", "z"})
.Outputs({"output"})
.SetKernelFn(PD_KERNEL(DiscreteOutForward))
.SetInferShapeFn(PD_INFER_SHAPE(DiscreteOutInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(DiscreteOutInferDtype));
// Test gradient operator whose output order is discrete.
PD_BUILD_GRAD_OP(discrete_out)
.Inputs({"w", "x", "y", "z", paddle::Grad("output")})
.Outputs({paddle::Grad("w"), paddle::Grad("y")})
.SetKernelFn(PD_KERNEL(DiscreteOutBackward));
......@@ -19,6 +19,7 @@ import numpy as np
from utils import extra_cc_args, paddle_includes
import paddle
from paddle import static
from paddle.utils.cpp_extension import get_build_directory, load
from paddle.utils.cpp_extension.extension_utils import run_cmd
......@@ -39,11 +40,79 @@ multi_out_module = load(
)
def discrete_out_dynamic(use_phi, device, dtype, np_w, np_x, np_y, np_z):
paddle.set_device(device)
w = paddle.to_tensor(np_w, dtype=dtype, stop_gradient=False)
x = paddle.to_tensor(np_x, dtype=dtype, stop_gradient=False)
y = paddle.to_tensor(np_y, dtype=dtype, stop_gradient=False)
z = paddle.to_tensor(np_z, dtype=dtype, stop_gradient=False)
if use_phi:
out = multi_out_module.discrete_out(w, x, y, z)
else:
out = w * 1 + x * 2 + y * 3 + z * 4
out.backward()
return out.numpy(), w.grad.numpy(), y.grad.numpy()
def discrete_out_static(use_phi, device, dtype, np_w, np_x, np_y, np_z):
paddle.enable_static()
paddle.set_device(device)
with static.scope_guard(static.Scope()):
with static.program_guard(static.Program()):
w = static.data(name="w", shape=[None, np_x.shape[1]], dtype=dtype)
x = static.data(name="x", shape=[None, np_x.shape[1]], dtype=dtype)
y = static.data(name="y", shape=[None, np_y.shape[1]], dtype=dtype)
z = static.data(name="z", shape=[None, np_z.shape[1]], dtype=dtype)
w.stop_gradient = False
x.stop_gradient = False
y.stop_gradient = False
z.stop_gradient = False
if use_phi:
out = multi_out_module.discrete_out(w, x, y, z)
else:
out = w * 1 + x * 2 + y * 3 + z * 4
static.append_backward(out)
exe = static.Executor()
exe.run(static.default_startup_program())
out_v, w_grad_v, y_grad_v = exe.run(
static.default_main_program(),
feed={
"w": np_w.astype(dtype),
"x": np_x.astype(dtype),
"y": np_y.astype(dtype),
"z": np_z.astype(dtype),
},
fetch_list=[
out.name,
w.name + "@GRAD",
y.name + "@GRAD",
],
)
paddle.disable_static()
return out_v, w_grad_v, y_grad_v
class TestMultiOutputDtypes(unittest.TestCase):
def setUp(self):
self.custom_op = multi_out_module.multi_out
self.dtypes = ['float32', 'float64']
self.devices = ['cpu']
self.np_w = np.random.uniform(-1, 1, [4, 8]).astype("float32")
self.np_x = np.random.uniform(-1, 1, [4, 8]).astype("float32")
self.np_y = np.random.uniform(-1, 1, [4, 8]).astype("float32")
self.np_z = np.random.uniform(-1, 1, [4, 8]).astype("float32")
def check_output(self, out, pd_out, name):
np.testing.assert_array_equal(
out,
pd_out,
err_msg='custom op {}: {},\n paddle api {}: {}'.format(
name, out, name, pd_out
),
)
def run_static(self, device, dtype):
paddle.set_device(device)
......@@ -80,7 +149,7 @@ class TestMultiOutputDtypes(unittest.TestCase):
one_int32, np.ones([4, 8]).astype('int32')
)
def test_static(self):
def test_multi_out_static(self):
paddle.enable_static()
for device in self.devices:
for dtype in self.dtypes:
......@@ -88,7 +157,7 @@ class TestMultiOutputDtypes(unittest.TestCase):
self.check_multi_outputs(res)
paddle.disable_static()
def test_dynamic(self):
def test_multi_out_dynamic(self):
for device in self.devices:
for dtype in self.dtypes:
paddle.set_device(device)
......@@ -99,6 +168,57 @@ class TestMultiOutputDtypes(unittest.TestCase):
self.assertTrue(len(outs) == 3)
self.check_multi_outputs(outs, True)
def test_discrete_out_static(self):
for device in self.devices:
for dtype in self.dtypes:
(pd_out, pd_w_grad, pd_y_grad,) = discrete_out_static(
False,
device,
dtype,
self.np_w,
self.np_x,
self.np_y,
self.np_z,
)
(phi_out, phi_w_grad, phi_y_grad,) = discrete_out_static(
True,
device,
dtype,
self.np_w,
self.np_x,
self.np_y,
self.np_z,
)
self.check_output(phi_out, pd_out, "out")
# NOTE: In static mode, the output gradient of custom operator has been optimized to shape=[1]. However, native paddle op's output shape = [4, 8], hence we need to fetch pd_w_grad[0][0] (By the way, something wrong with native paddle's gradient, the outputs with other indexes instead of pd_w_grad[0][0] is undefined in this unittest.)
self.check_output(phi_w_grad, pd_w_grad[0][0], "w_grad")
self.check_output(phi_y_grad, pd_y_grad[0][0], "y_grad")
def test_discrete_out_dynamic(self):
for device in self.devices:
for dtype in self.dtypes:
(pd_out, pd_w_grad, pd_y_grad,) = discrete_out_dynamic(
False,
device,
dtype,
self.np_w,
self.np_x,
self.np_y,
self.np_z,
)
(phi_out, phi_w_grad, phi_y_grad,) = discrete_out_dynamic(
True,
device,
dtype,
self.np_w,
self.np_x,
self.np_y,
self.np_z,
)
self.check_output(phi_out, pd_out, "out")
self.check_output(phi_w_grad, pd_w_grad, "w_grad")
self.check_output(phi_y_grad, pd_y_grad, "y_grad")
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册