未验证 提交 bc9956cc 编写于 作者: H HongyuJia 提交者: GitHub

[CustomOP unittest] Polish unit test, phi->custom (#52670)

* [CustomOP unittest] Polish unit test, phi->custom

* Change phi->custom in custom_linear_op.cc
上级 1ad943dd
......@@ -17,16 +17,17 @@ limitations under the License. */
#include "paddle/extension.h"
// The linear implemented here must be passed in bias
std::vector<paddle::Tensor> PhiLinearForward(const paddle::Tensor& x,
const paddle::Tensor& weight,
const paddle::Tensor& bias) {
std::vector<paddle::Tensor> CustomLinearForward(const paddle::Tensor& x,
const paddle::Tensor& weight,
const paddle::Tensor& bias) {
return {paddle::add(paddle::matmul(x, weight), bias)};
}
std::vector<paddle::Tensor> PhiLinearBackward(const paddle::Tensor& x,
const paddle::Tensor& weight,
const paddle::Tensor& bias,
const paddle::Tensor& out_grad) {
std::vector<paddle::Tensor> CustomLinearBackward(
const paddle::Tensor& x,
const paddle::Tensor& weight,
const paddle::Tensor& bias,
const paddle::Tensor& out_grad) {
auto x_grad = paddle::matmul(out_grad, weight, false, true);
auto weight_grad = paddle::matmul(x, out_grad, true, false);
auto bias_grad = paddle::experimental::sum(out_grad, {0});
......@@ -96,14 +97,14 @@ std::vector<paddle::DataType> LinearInferDtype(
return {x_dtype};
}
PD_BUILD_OP(phi_linear)
PD_BUILD_OP(custom_linear)
.Inputs({"X", "Weight", "Bias"})
.Outputs({"Out"})
.SetKernelFn(PD_KERNEL(PhiLinearForward))
.SetKernelFn(PD_KERNEL(CustomLinearForward))
.SetInferShapeFn(PD_INFER_SHAPE(LinearInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(LinearInferDtype));
PD_BUILD_GRAD_OP(phi_linear)
PD_BUILD_GRAD_OP(custom_linear)
.Inputs({"X", "Weight", "Bias", paddle::Grad("Out")})
.Outputs({paddle::Grad("X"), paddle::Grad("Weight"), paddle::Grad("Bias")})
.SetKernelFn(PD_KERNEL(PhiLinearBackward));
.SetKernelFn(PD_KERNEL(CustomLinearBackward));
......@@ -41,11 +41,11 @@ custom_inplace = load(
)
def inplace_dynamic_add(phi_func, device, dtype, np_x, np_y):
def inplace_dynamic_add(custom_func, device, dtype, np_x, np_y):
paddle.set_device(device)
x = paddle.to_tensor(np_x, dtype=dtype, stop_gradient=True)
y = paddle.to_tensor(np_y, dtype=dtype, stop_gradient=False)
if phi_func:
if custom_func:
out = custom_inplace.custom_add(x, y)
else:
out = x.add_(y)
......@@ -88,14 +88,14 @@ def inplace_static_add(func, device, dtype, np_x, np_y):
return x_v, out_v, x_grad_v, y_grad_v, out_grad_v
def inplace_dynamic_add_vector(phi_func, device, dtype, np_inputs, np_y):
def inplace_dynamic_add_vector(custom_func, device, dtype, np_inputs, np_y):
paddle.set_device(device)
inputs = [
paddle.to_tensor(np_input, dtype=dtype, stop_gradient=True)
for np_input in np_inputs
]
y = paddle.to_tensor(np_y, dtype=dtype, stop_gradient=False)
if phi_func:
if custom_func:
out = custom_inplace.custom_add_vec(inputs, y)
else:
out = [x.add_(y) for x in inputs]
......@@ -111,7 +111,7 @@ def inplace_dynamic_add_vector(phi_func, device, dtype, np_inputs, np_y):
)
def inplace_static_add_vector(phi_func, device, dtype, np_inputs, np_y):
def inplace_static_add_vector(custom_func, device, dtype, np_inputs, np_y):
paddle.enable_static()
paddle.set_device(device)
with static.scope_guard(static.Scope()):
......@@ -126,7 +126,7 @@ def inplace_static_add_vector(phi_func, device, dtype, np_inputs, np_y):
x1.stop_gradient = False
x2.stop_gradient = False
y.stop_gradient = False
if phi_func:
if custom_func:
out = custom_inplace.custom_add_vec([x1, x2], y)
else:
out = [paddle.add(x1, y), paddle.add(x2, y)]
......@@ -170,13 +170,13 @@ def inplace_static_add_vector(phi_func, device, dtype, np_inputs, np_y):
)
def inplace_dynamic_relu_net(phi_func, device, dtype, np_x, np_y, np_z):
def inplace_dynamic_relu_net(custom_func, device, dtype, np_x, np_y, np_z):
paddle.set_device(device)
x = paddle.to_tensor(np_x, dtype=dtype, stop_gradient=False)
y = paddle.to_tensor(np_y, dtype=dtype, stop_gradient=False)
z = paddle.to_tensor(np_z, dtype=dtype, stop_gradient=False)
out_xy = x + y
if phi_func:
if custom_func:
out_xy = custom_inplace.custom_relu_inplace(out_xy)
out_xyz = out_xy + z
out = custom_inplace.custom_relu_inplace(out_xyz)
......@@ -229,13 +229,13 @@ def inplace_static_relu_net(func, device, dtype, np_x, np_y, np_z):
return x_v, y_v, out_v, x_grad_v, y_grad_v
def dynamic_multi_inplace(phi_func, device, dtype, np_x, np_y, np_a, np_b):
def dynamic_multi_inplace(custom_func, device, dtype, np_x, np_y, np_a, np_b):
paddle.set_device(device)
x = paddle.to_tensor(np_x, dtype=dtype, stop_gradient=True)
y = paddle.to_tensor(np_y, dtype=dtype, stop_gradient=False)
a = paddle.to_tensor(np_a, dtype=dtype, stop_gradient=True)
b = paddle.to_tensor(np_b, dtype=dtype, stop_gradient=False)
if phi_func:
if custom_func:
out_xy, out_ab = custom_inplace.custom_multi_inplace(x, y, a, b)
else:
out_xy = x.add_(y)
......@@ -257,7 +257,7 @@ def dynamic_multi_inplace(phi_func, device, dtype, np_x, np_y, np_a, np_b):
)
def static_multi_inplace(phi_func, device, dtype, np_x, np_y, np_a, np_b):
def static_multi_inplace(custom_func, device, dtype, np_x, np_y, np_a, np_b):
paddle.enable_static()
paddle.set_device(device)
with static.scope_guard(static.Scope()):
......@@ -270,7 +270,7 @@ def static_multi_inplace(phi_func, device, dtype, np_x, np_y, np_a, np_b):
y.stop_gradient = False
a.stop_gradient = False
b.stop_gradient = False
if phi_func:
if custom_func:
out_xy, out_ab = custom_inplace.custom_multi_inplace(x, y, a, b)
else:
out_xy = paddle.add(x, y)
......@@ -379,11 +379,11 @@ class TestCustomInplaceJit(unittest.TestCase):
self.np_y,
)
(
phi_x,
phi_out,
phi_x_grad,
phi_y_grad,
phi_out_grad,
custom_x,
custom_out,
custom_x_grad,
custom_y_grad,
custom_out_grad,
) = inplace_static_add(
custom_inplace.custom_add,
device,
......@@ -391,15 +391,15 @@ class TestCustomInplaceJit(unittest.TestCase):
self.np_x,
self.np_y,
)
self.check_output(phi_x, phi_out, "inplace_phi_x")
self.check_output(custom_x, custom_out, "inplace_custom_x")
self.check_output(
phi_x_grad, phi_out_grad, "inplace_phi_x_grad"
custom_x_grad, custom_out_grad, "inplace_custom_x_grad"
)
self.check_output(phi_out, pd_out, "out")
self.check_output(phi_x_grad, pd_x_grad, "x_grad")
self.check_output(phi_y_grad, pd_y_grad, "y_grad")
self.check_output(phi_out_grad, pd_out_grad, "out_grad")
self.check_output(custom_out, pd_out, "out")
self.check_output(custom_x_grad, pd_x_grad, "x_grad")
self.check_output(custom_y_grad, pd_y_grad, "y_grad")
self.check_output(custom_out_grad, pd_out_grad, "out_grad")
def test_dynamic_add(self):
for device in self.devices:
......@@ -418,11 +418,11 @@ class TestCustomInplaceJit(unittest.TestCase):
self.np_y,
)
(
phi_x,
phi_y,
phi_out,
phi_x_grad,
phi_y_grad,
custom_x,
custom_y,
custom_out,
custom_x_grad,
custom_y_grad,
) = inplace_dynamic_add(
True,
device,
......@@ -431,14 +431,14 @@ class TestCustomInplaceJit(unittest.TestCase):
self.np_y,
)
self.check_output(phi_x, phi_out, "inplace_phi_x")
self.check_output(custom_x, custom_out, "inplace_custom_x")
self.check_output(pd_x, pd_out, "inplace_pd_x")
self.check_output(phi_x, pd_x, "x")
self.check_output(phi_y, pd_y, "y")
self.check_output(phi_out, pd_out, "out")
self.check_output(phi_x_grad, pd_x_grad, "x_grad")
self.check_output(phi_y_grad, pd_y_grad, "y_grad")
self.check_output(custom_x, pd_x, "x")
self.check_output(custom_y, pd_y, "y")
self.check_output(custom_out, pd_out, "out")
self.check_output(custom_x_grad, pd_x_grad, "x_grad")
self.check_output(custom_y_grad, pd_y_grad, "y_grad")
def test_static_add_vector(self):
for device in self.devices:
......@@ -456,10 +456,10 @@ class TestCustomInplaceJit(unittest.TestCase):
self.np_y,
)
(
phi_out,
phi_x_grad,
phi_y_grad,
phi_out_grad,
custom_out,
custom_x_grad,
custom_y_grad,
custom_out_grad,
) = inplace_static_add_vector(
False,
device,
......@@ -468,10 +468,10 @@ class TestCustomInplaceJit(unittest.TestCase):
self.np_y,
)
self.check_output(phi_out, pd_out, "out")
self.check_output(phi_x_grad, pd_x_grad, "x_grad")
self.check_output(phi_y_grad, pd_y_grad, "y_grad")
self.check_output(phi_out_grad, pd_out_grad, "out_grad")
self.check_output(custom_out, pd_out, "out")
self.check_output(custom_x_grad, pd_x_grad, "x_grad")
self.check_output(custom_y_grad, pd_y_grad, "y_grad")
self.check_output(custom_out_grad, pd_out_grad, "out_grad")
def test_dynamic_add_vector(self):
for device in self.devices:
......@@ -490,11 +490,11 @@ class TestCustomInplaceJit(unittest.TestCase):
self.np_y,
)
(
phi_x,
phi_y,
phi_out,
phi_x_grad,
phi_y_grad,
custom_x,
custom_y,
custom_out,
custom_x_grad,
custom_y_grad,
) = inplace_dynamic_add_vector(
False,
device,
......@@ -503,14 +503,14 @@ class TestCustomInplaceJit(unittest.TestCase):
self.np_y,
)
self.check_output(phi_x, phi_out, "inplace_phi_x")
self.check_output(custom_x, custom_out, "inplace_custom_x")
self.check_output(pd_x, pd_out, "inplace_pd_x")
self.check_output(phi_x, pd_x, "x")
self.check_output(phi_y, pd_y, "y")
self.check_output(phi_out, pd_out, "out")
self.check_output(phi_x_grad, pd_x_grad, "x_grad")
self.check_output(phi_y_grad, pd_y_grad, "y_grad")
self.check_output(custom_x, pd_x, "x")
self.check_output(custom_y, pd_y, "y")
self.check_output(custom_out, pd_out, "out")
self.check_output(custom_x_grad, pd_x_grad, "x_grad")
self.check_output(custom_y_grad, pd_y_grad, "y_grad")
def test_static_relu_net(self):
for device in self.devices:
......@@ -530,11 +530,11 @@ class TestCustomInplaceJit(unittest.TestCase):
self.np_z,
)
(
phi_x,
phi_y,
phi_out,
phi_x_grad,
phi_y_grad,
custom_x,
custom_y,
custom_out,
custom_x_grad,
custom_y_grad,
) = inplace_static_relu_net(
custom_inplace.custom_relu_inplace,
device,
......@@ -543,11 +543,11 @@ class TestCustomInplaceJit(unittest.TestCase):
self.np_y,
self.np_z,
)
self.check_output_allclose(phi_x, pd_x, "x")
self.check_output_allclose(phi_y, pd_y, "y")
self.check_output_allclose(phi_out, pd_out, "out")
self.check_output_allclose(phi_x_grad, pd_x_grad, "x_grad")
self.check_output_allclose(phi_y_grad, pd_y_grad, "y_grad")
self.check_output_allclose(custom_x, pd_x, "x")
self.check_output_allclose(custom_y, pd_y, "y")
self.check_output_allclose(custom_out, pd_out, "out")
self.check_output_allclose(custom_x_grad, pd_x_grad, "x_grad")
self.check_output_allclose(custom_y_grad, pd_y_grad, "y_grad")
def test_dynamic_relu_net(self):
for device in self.devices:
......@@ -567,11 +567,11 @@ class TestCustomInplaceJit(unittest.TestCase):
self.np_z,
)
(
phi_x,
phi_y,
phi_out,
phi_x_grad,
phi_y_grad,
custom_x,
custom_y,
custom_out,
custom_x_grad,
custom_y_grad,
) = inplace_dynamic_relu_net(
True,
device,
......@@ -581,11 +581,11 @@ class TestCustomInplaceJit(unittest.TestCase):
self.np_z,
)
self.check_output(phi_x, pd_x, "x")
self.check_output(phi_y, pd_y, "y")
self.check_output(phi_out, pd_out, "out")
self.check_output(phi_x_grad, pd_x_grad, "x_grad")
self.check_output(phi_y_grad, pd_y_grad, "y_grad")
self.check_output(custom_x, pd_x, "x")
self.check_output(custom_y, pd_y, "y")
self.check_output(custom_out, pd_out, "out")
self.check_output(custom_x_grad, pd_x_grad, "x_grad")
self.check_output(custom_y_grad, pd_y_grad, "y_grad")
def test_static_multi_inplace(self):
for device in self.devices:
......@@ -611,16 +611,16 @@ class TestCustomInplaceJit(unittest.TestCase):
self.np_b,
)
(
phi_x,
phi_out_xy,
phi_x_grad,
phi_y_grad,
phi_out_xy_grad,
phi_a,
phi_out_ab,
phi_a_grad,
phi_b_grad,
phi_out_ab_grad,
custom_x,
custom_out_xy,
custom_x_grad,
custom_y_grad,
custom_out_xy_grad,
custom_a,
custom_out_ab,
custom_a_grad,
custom_b_grad,
custom_out_ab_grad,
) = static_multi_inplace(
True,
device,
......@@ -630,23 +630,27 @@ class TestCustomInplaceJit(unittest.TestCase):
self.np_a,
self.np_b,
)
self.check_output(phi_x, pd_out_xy, "inplace_phi_x")
self.check_output(custom_x, pd_out_xy, "inplace_custom_x")
self.check_output(
phi_x_grad, phi_out_xy_grad, "inplace_phi_x_grad"
custom_x_grad, custom_out_xy_grad, "inplace_custom_x_grad"
)
self.check_output(phi_a, pd_out_ab, "inplace_phi_a")
self.check_output(custom_a, pd_out_ab, "inplace_custom_a")
self.check_output(
phi_a_grad, phi_out_ab_grad, "inplace_phi_a_grad"
custom_a_grad, custom_out_ab_grad, "inplace_custom_a_grad"
)
self.check_output(phi_out_xy, pd_out_xy, "outxy")
self.check_output(phi_x_grad, pd_x_grad, "x_grad")
self.check_output(phi_y_grad, pd_y_grad, "y_grad")
self.check_output(phi_out_xy_grad, pd_out_xy_grad, "outxy_grad")
self.check_output(phi_out_ab, pd_out_ab, "outab")
self.check_output(phi_a_grad, pd_a_grad, "a_grad")
self.check_output(phi_b_grad, pd_b_grad, "b_grad")
self.check_output(phi_out_ab_grad, pd_out_ab_grad, "outab_grad")
self.check_output(custom_out_xy, pd_out_xy, "outxy")
self.check_output(custom_x_grad, pd_x_grad, "x_grad")
self.check_output(custom_y_grad, pd_y_grad, "y_grad")
self.check_output(
custom_out_xy_grad, pd_out_xy_grad, "outxy_grad"
)
self.check_output(custom_out_ab, pd_out_ab, "outab")
self.check_output(custom_a_grad, pd_a_grad, "a_grad")
self.check_output(custom_b_grad, pd_b_grad, "b_grad")
self.check_output(
custom_out_ab_grad, pd_out_ab_grad, "outab_grad"
)
def test_dynamic_multi_inplace(self):
for device in self.devices:
......@@ -672,16 +676,16 @@ class TestCustomInplaceJit(unittest.TestCase):
self.np_b,
)
(
phi_x,
phi_y,
phi_out_xy,
phi_x_grad,
phi_y_grad,
phi_a,
phi_b,
phi_out_ab,
phi_a_grad,
phi_b_grad,
custom_x,
custom_y,
custom_out_xy,
custom_x_grad,
custom_y_grad,
custom_a,
custom_b,
custom_out_ab,
custom_a_grad,
custom_b_grad,
) = dynamic_multi_inplace(
True,
device,
......@@ -692,21 +696,21 @@ class TestCustomInplaceJit(unittest.TestCase):
self.np_b,
)
self.check_output(phi_x, phi_out_xy, "inplace_phi_x")
self.check_output(custom_x, custom_out_xy, "inplace_custom_x")
self.check_output(pd_x, pd_out_xy, "inplace_pd_x")
self.check_output(phi_a, phi_out_ab, "inplace_phi_a")
self.check_output(custom_a, custom_out_ab, "inplace_custom_a")
self.check_output(pd_a, pd_out_ab, "inplace_pd_a")
self.check_output(phi_x, pd_x, "x")
self.check_output(phi_y, pd_y, "y")
self.check_output(phi_out_xy, pd_out_xy, "outxy")
self.check_output(phi_x_grad, pd_x_grad, "x_grad")
self.check_output(phi_y_grad, pd_y_grad, "y_grad")
self.check_output(phi_a, pd_a, "a")
self.check_output(phi_b, pd_b, "b")
self.check_output(phi_out_ab, pd_out_ab, "outab")
self.check_output(phi_a_grad, pd_a_grad, "a_grad")
self.check_output(phi_b_grad, pd_b_grad, "b_grad")
self.check_output(custom_x, pd_x, "x")
self.check_output(custom_y, pd_y, "y")
self.check_output(custom_out_xy, pd_out_xy, "outxy")
self.check_output(custom_x_grad, pd_x_grad, "x_grad")
self.check_output(custom_y_grad, pd_y_grad, "y_grad")
self.check_output(custom_a, pd_a, "a")
self.check_output(custom_b, pd_b, "b")
self.check_output(custom_out_ab, pd_out_ab, "outab")
self.check_output(custom_a_grad, pd_a_grad, "a_grad")
self.check_output(custom_b_grad, pd_b_grad, "b_grad")
if __name__ == "__main__":
......
......@@ -112,12 +112,12 @@ class TestCustomLinearJit(unittest.TestCase):
for device in self.devices:
for dtype in self.dtypes:
(
phi_out,
phi_x_grad,
phi_weight_grad,
phi_bias_grad,
custom_out,
custom_x_grad,
custom_weight_grad,
custom_bias_grad,
) = linear_static(
custom_ops.phi_linear,
custom_ops.custom_linear,
device,
dtype,
self.np_x,
......@@ -132,23 +132,23 @@ class TestCustomLinearJit(unittest.TestCase):
self.np_weight,
self.np_bias,
)
self.check_output(phi_out, pd_out, "out")
self.check_output(phi_x_grad, pd_x_grad, "x_grad")
self.check_output(custom_out, pd_out, "out")
self.check_output(custom_x_grad, pd_x_grad, "x_grad")
self.check_output(
phi_weight_grad, pd_weight_grad, "weight_grad"
custom_weight_grad, pd_weight_grad, "weight_grad"
)
self.check_output(phi_bias_grad, pd_bias_grad, "bias_grad")
self.check_output(custom_bias_grad, pd_bias_grad, "bias_grad")
def test_dynamic(self):
for device in self.devices:
for dtype in self.dtypes:
(
phi_out,
phi_x_grad,
phi_weight_grad,
phi_bias_grad,
custom_out,
custom_x_grad,
custom_weight_grad,
custom_bias_grad,
) = linear_dynamic(
custom_ops.phi_linear,
custom_ops.custom_linear,
device,
dtype,
self.np_x,
......@@ -168,12 +168,12 @@ class TestCustomLinearJit(unittest.TestCase):
self.np_weight,
self.np_bias,
)
self.check_output(phi_out, pd_out, "phi_out")
self.check_output(phi_x_grad, pd_x_grad, "x_grad")
self.check_output(custom_out, pd_out, "custom_out")
self.check_output(custom_x_grad, pd_x_grad, "x_grad")
self.check_output(
phi_weight_grad, pd_weight_grad, "weight_grad"
custom_weight_grad, pd_weight_grad, "weight_grad"
)
self.check_output(phi_bias_grad, pd_bias_grad, "bias_grad")
self.check_output(custom_bias_grad, pd_bias_grad, "bias_grad")
if __name__ == "__main__":
......
......@@ -40,13 +40,13 @@ multi_out_module = load(
)
def discrete_out_dynamic(use_phi, device, dtype, np_w, np_x, np_y, np_z):
def discrete_out_dynamic(use_custom, device, dtype, np_w, np_x, np_y, np_z):
paddle.set_device(device)
w = paddle.to_tensor(np_w, dtype=dtype, stop_gradient=False)
x = paddle.to_tensor(np_x, dtype=dtype, stop_gradient=False)
y = paddle.to_tensor(np_y, dtype=dtype, stop_gradient=False)
z = paddle.to_tensor(np_z, dtype=dtype, stop_gradient=False)
if use_phi:
if use_custom:
out = multi_out_module.discrete_out(w, x, y, z)
else:
out = w * 1 + x * 2 + y * 3 + z * 4
......@@ -55,7 +55,7 @@ def discrete_out_dynamic(use_phi, device, dtype, np_w, np_x, np_y, np_z):
return out.numpy(), w.grad.numpy(), y.grad.numpy()
def discrete_out_static(use_phi, device, dtype, np_w, np_x, np_y, np_z):
def discrete_out_static(use_custom, device, dtype, np_w, np_x, np_y, np_z):
paddle.enable_static()
paddle.set_device(device)
with static.scope_guard(static.Scope()):
......@@ -68,7 +68,7 @@ def discrete_out_static(use_phi, device, dtype, np_w, np_x, np_y, np_z):
x.stop_gradient = False
y.stop_gradient = False
z.stop_gradient = False
if use_phi:
if use_custom:
out = multi_out_module.discrete_out(w, x, y, z)
else:
out = w * 1 + x * 2 + y * 3 + z * 4
......@@ -180,7 +180,11 @@ class TestMultiOutputDtypes(unittest.TestCase):
self.np_y,
self.np_z,
)
(phi_out, phi_w_grad, phi_y_grad,) = discrete_out_static(
(
custom_out,
custom_w_grad,
custom_y_grad,
) = discrete_out_static(
True,
device,
dtype,
......@@ -189,10 +193,10 @@ class TestMultiOutputDtypes(unittest.TestCase):
self.np_y,
self.np_z,
)
self.check_output(phi_out, pd_out, "out")
self.check_output(custom_out, pd_out, "out")
# NOTE: In static mode, the output gradient of custom operator has been optimized to shape=[1]. However, native paddle op's output shape = [4, 8], hence we need to fetch pd_w_grad[0][0] (By the way, something wrong with native paddle's gradient, the outputs with other indexes instead of pd_w_grad[0][0] is undefined in this unittest.)
self.check_output(phi_w_grad, pd_w_grad[0][0], "w_grad")
self.check_output(phi_y_grad, pd_y_grad[0][0], "y_grad")
self.check_output(custom_w_grad, pd_w_grad[0][0], "w_grad")
self.check_output(custom_y_grad, pd_y_grad[0][0], "y_grad")
def test_discrete_out_dynamic(self):
for device in self.devices:
......@@ -206,7 +210,11 @@ class TestMultiOutputDtypes(unittest.TestCase):
self.np_y,
self.np_z,
)
(phi_out, phi_w_grad, phi_y_grad,) = discrete_out_dynamic(
(
custom_out,
custom_w_grad,
custom_y_grad,
) = discrete_out_dynamic(
True,
device,
dtype,
......@@ -215,9 +223,9 @@ class TestMultiOutputDtypes(unittest.TestCase):
self.np_y,
self.np_z,
)
self.check_output(phi_out, pd_out, "out")
self.check_output(phi_w_grad, pd_w_grad, "w_grad")
self.check_output(phi_y_grad, pd_y_grad, "y_grad")
self.check_output(custom_out, pd_out, "out")
self.check_output(custom_w_grad, pd_w_grad, "w_grad")
self.check_output(custom_y_grad, pd_y_grad, "y_grad")
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册