未验证 提交 bc9956cc 编写于 作者: H HongyuJia 提交者: GitHub

[CustomOP unittest] Polish unit test, phi->custom (#52670)

* [CustomOP unittest] Polish unit test, phi->custom

* Change phi->custom in custom_linear_op.cc
上级 1ad943dd
...@@ -17,16 +17,17 @@ limitations under the License. */ ...@@ -17,16 +17,17 @@ limitations under the License. */
#include "paddle/extension.h" #include "paddle/extension.h"
// The linear implemented here must be passed in bias // The linear implemented here must be passed in bias
std::vector<paddle::Tensor> PhiLinearForward(const paddle::Tensor& x, std::vector<paddle::Tensor> CustomLinearForward(const paddle::Tensor& x,
const paddle::Tensor& weight, const paddle::Tensor& weight,
const paddle::Tensor& bias) { const paddle::Tensor& bias) {
return {paddle::add(paddle::matmul(x, weight), bias)}; return {paddle::add(paddle::matmul(x, weight), bias)};
} }
std::vector<paddle::Tensor> PhiLinearBackward(const paddle::Tensor& x, std::vector<paddle::Tensor> CustomLinearBackward(
const paddle::Tensor& weight, const paddle::Tensor& x,
const paddle::Tensor& bias, const paddle::Tensor& weight,
const paddle::Tensor& out_grad) { const paddle::Tensor& bias,
const paddle::Tensor& out_grad) {
auto x_grad = paddle::matmul(out_grad, weight, false, true); auto x_grad = paddle::matmul(out_grad, weight, false, true);
auto weight_grad = paddle::matmul(x, out_grad, true, false); auto weight_grad = paddle::matmul(x, out_grad, true, false);
auto bias_grad = paddle::experimental::sum(out_grad, {0}); auto bias_grad = paddle::experimental::sum(out_grad, {0});
...@@ -96,14 +97,14 @@ std::vector<paddle::DataType> LinearInferDtype( ...@@ -96,14 +97,14 @@ std::vector<paddle::DataType> LinearInferDtype(
return {x_dtype}; return {x_dtype};
} }
PD_BUILD_OP(phi_linear) PD_BUILD_OP(custom_linear)
.Inputs({"X", "Weight", "Bias"}) .Inputs({"X", "Weight", "Bias"})
.Outputs({"Out"}) .Outputs({"Out"})
.SetKernelFn(PD_KERNEL(PhiLinearForward)) .SetKernelFn(PD_KERNEL(CustomLinearForward))
.SetInferShapeFn(PD_INFER_SHAPE(LinearInferShape)) .SetInferShapeFn(PD_INFER_SHAPE(LinearInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(LinearInferDtype)); .SetInferDtypeFn(PD_INFER_DTYPE(LinearInferDtype));
PD_BUILD_GRAD_OP(phi_linear) PD_BUILD_GRAD_OP(custom_linear)
.Inputs({"X", "Weight", "Bias", paddle::Grad("Out")}) .Inputs({"X", "Weight", "Bias", paddle::Grad("Out")})
.Outputs({paddle::Grad("X"), paddle::Grad("Weight"), paddle::Grad("Bias")}) .Outputs({paddle::Grad("X"), paddle::Grad("Weight"), paddle::Grad("Bias")})
.SetKernelFn(PD_KERNEL(PhiLinearBackward)); .SetKernelFn(PD_KERNEL(CustomLinearBackward));
...@@ -41,11 +41,11 @@ custom_inplace = load( ...@@ -41,11 +41,11 @@ custom_inplace = load(
) )
def inplace_dynamic_add(phi_func, device, dtype, np_x, np_y): def inplace_dynamic_add(custom_func, device, dtype, np_x, np_y):
paddle.set_device(device) paddle.set_device(device)
x = paddle.to_tensor(np_x, dtype=dtype, stop_gradient=True) x = paddle.to_tensor(np_x, dtype=dtype, stop_gradient=True)
y = paddle.to_tensor(np_y, dtype=dtype, stop_gradient=False) y = paddle.to_tensor(np_y, dtype=dtype, stop_gradient=False)
if phi_func: if custom_func:
out = custom_inplace.custom_add(x, y) out = custom_inplace.custom_add(x, y)
else: else:
out = x.add_(y) out = x.add_(y)
...@@ -88,14 +88,14 @@ def inplace_static_add(func, device, dtype, np_x, np_y): ...@@ -88,14 +88,14 @@ def inplace_static_add(func, device, dtype, np_x, np_y):
return x_v, out_v, x_grad_v, y_grad_v, out_grad_v return x_v, out_v, x_grad_v, y_grad_v, out_grad_v
def inplace_dynamic_add_vector(phi_func, device, dtype, np_inputs, np_y): def inplace_dynamic_add_vector(custom_func, device, dtype, np_inputs, np_y):
paddle.set_device(device) paddle.set_device(device)
inputs = [ inputs = [
paddle.to_tensor(np_input, dtype=dtype, stop_gradient=True) paddle.to_tensor(np_input, dtype=dtype, stop_gradient=True)
for np_input in np_inputs for np_input in np_inputs
] ]
y = paddle.to_tensor(np_y, dtype=dtype, stop_gradient=False) y = paddle.to_tensor(np_y, dtype=dtype, stop_gradient=False)
if phi_func: if custom_func:
out = custom_inplace.custom_add_vec(inputs, y) out = custom_inplace.custom_add_vec(inputs, y)
else: else:
out = [x.add_(y) for x in inputs] out = [x.add_(y) for x in inputs]
...@@ -111,7 +111,7 @@ def inplace_dynamic_add_vector(phi_func, device, dtype, np_inputs, np_y): ...@@ -111,7 +111,7 @@ def inplace_dynamic_add_vector(phi_func, device, dtype, np_inputs, np_y):
) )
def inplace_static_add_vector(phi_func, device, dtype, np_inputs, np_y): def inplace_static_add_vector(custom_func, device, dtype, np_inputs, np_y):
paddle.enable_static() paddle.enable_static()
paddle.set_device(device) paddle.set_device(device)
with static.scope_guard(static.Scope()): with static.scope_guard(static.Scope()):
...@@ -126,7 +126,7 @@ def inplace_static_add_vector(phi_func, device, dtype, np_inputs, np_y): ...@@ -126,7 +126,7 @@ def inplace_static_add_vector(phi_func, device, dtype, np_inputs, np_y):
x1.stop_gradient = False x1.stop_gradient = False
x2.stop_gradient = False x2.stop_gradient = False
y.stop_gradient = False y.stop_gradient = False
if phi_func: if custom_func:
out = custom_inplace.custom_add_vec([x1, x2], y) out = custom_inplace.custom_add_vec([x1, x2], y)
else: else:
out = [paddle.add(x1, y), paddle.add(x2, y)] out = [paddle.add(x1, y), paddle.add(x2, y)]
...@@ -170,13 +170,13 @@ def inplace_static_add_vector(phi_func, device, dtype, np_inputs, np_y): ...@@ -170,13 +170,13 @@ def inplace_static_add_vector(phi_func, device, dtype, np_inputs, np_y):
) )
def inplace_dynamic_relu_net(phi_func, device, dtype, np_x, np_y, np_z): def inplace_dynamic_relu_net(custom_func, device, dtype, np_x, np_y, np_z):
paddle.set_device(device) paddle.set_device(device)
x = paddle.to_tensor(np_x, dtype=dtype, stop_gradient=False) x = paddle.to_tensor(np_x, dtype=dtype, stop_gradient=False)
y = paddle.to_tensor(np_y, dtype=dtype, stop_gradient=False) y = paddle.to_tensor(np_y, dtype=dtype, stop_gradient=False)
z = paddle.to_tensor(np_z, dtype=dtype, stop_gradient=False) z = paddle.to_tensor(np_z, dtype=dtype, stop_gradient=False)
out_xy = x + y out_xy = x + y
if phi_func: if custom_func:
out_xy = custom_inplace.custom_relu_inplace(out_xy) out_xy = custom_inplace.custom_relu_inplace(out_xy)
out_xyz = out_xy + z out_xyz = out_xy + z
out = custom_inplace.custom_relu_inplace(out_xyz) out = custom_inplace.custom_relu_inplace(out_xyz)
...@@ -229,13 +229,13 @@ def inplace_static_relu_net(func, device, dtype, np_x, np_y, np_z): ...@@ -229,13 +229,13 @@ def inplace_static_relu_net(func, device, dtype, np_x, np_y, np_z):
return x_v, y_v, out_v, x_grad_v, y_grad_v return x_v, y_v, out_v, x_grad_v, y_grad_v
def dynamic_multi_inplace(phi_func, device, dtype, np_x, np_y, np_a, np_b): def dynamic_multi_inplace(custom_func, device, dtype, np_x, np_y, np_a, np_b):
paddle.set_device(device) paddle.set_device(device)
x = paddle.to_tensor(np_x, dtype=dtype, stop_gradient=True) x = paddle.to_tensor(np_x, dtype=dtype, stop_gradient=True)
y = paddle.to_tensor(np_y, dtype=dtype, stop_gradient=False) y = paddle.to_tensor(np_y, dtype=dtype, stop_gradient=False)
a = paddle.to_tensor(np_a, dtype=dtype, stop_gradient=True) a = paddle.to_tensor(np_a, dtype=dtype, stop_gradient=True)
b = paddle.to_tensor(np_b, dtype=dtype, stop_gradient=False) b = paddle.to_tensor(np_b, dtype=dtype, stop_gradient=False)
if phi_func: if custom_func:
out_xy, out_ab = custom_inplace.custom_multi_inplace(x, y, a, b) out_xy, out_ab = custom_inplace.custom_multi_inplace(x, y, a, b)
else: else:
out_xy = x.add_(y) out_xy = x.add_(y)
...@@ -257,7 +257,7 @@ def dynamic_multi_inplace(phi_func, device, dtype, np_x, np_y, np_a, np_b): ...@@ -257,7 +257,7 @@ def dynamic_multi_inplace(phi_func, device, dtype, np_x, np_y, np_a, np_b):
) )
def static_multi_inplace(phi_func, device, dtype, np_x, np_y, np_a, np_b): def static_multi_inplace(custom_func, device, dtype, np_x, np_y, np_a, np_b):
paddle.enable_static() paddle.enable_static()
paddle.set_device(device) paddle.set_device(device)
with static.scope_guard(static.Scope()): with static.scope_guard(static.Scope()):
...@@ -270,7 +270,7 @@ def static_multi_inplace(phi_func, device, dtype, np_x, np_y, np_a, np_b): ...@@ -270,7 +270,7 @@ def static_multi_inplace(phi_func, device, dtype, np_x, np_y, np_a, np_b):
y.stop_gradient = False y.stop_gradient = False
a.stop_gradient = False a.stop_gradient = False
b.stop_gradient = False b.stop_gradient = False
if phi_func: if custom_func:
out_xy, out_ab = custom_inplace.custom_multi_inplace(x, y, a, b) out_xy, out_ab = custom_inplace.custom_multi_inplace(x, y, a, b)
else: else:
out_xy = paddle.add(x, y) out_xy = paddle.add(x, y)
...@@ -379,11 +379,11 @@ class TestCustomInplaceJit(unittest.TestCase): ...@@ -379,11 +379,11 @@ class TestCustomInplaceJit(unittest.TestCase):
self.np_y, self.np_y,
) )
( (
phi_x, custom_x,
phi_out, custom_out,
phi_x_grad, custom_x_grad,
phi_y_grad, custom_y_grad,
phi_out_grad, custom_out_grad,
) = inplace_static_add( ) = inplace_static_add(
custom_inplace.custom_add, custom_inplace.custom_add,
device, device,
...@@ -391,15 +391,15 @@ class TestCustomInplaceJit(unittest.TestCase): ...@@ -391,15 +391,15 @@ class TestCustomInplaceJit(unittest.TestCase):
self.np_x, self.np_x,
self.np_y, self.np_y,
) )
self.check_output(phi_x, phi_out, "inplace_phi_x") self.check_output(custom_x, custom_out, "inplace_custom_x")
self.check_output( self.check_output(
phi_x_grad, phi_out_grad, "inplace_phi_x_grad" custom_x_grad, custom_out_grad, "inplace_custom_x_grad"
) )
self.check_output(phi_out, pd_out, "out") self.check_output(custom_out, pd_out, "out")
self.check_output(phi_x_grad, pd_x_grad, "x_grad") self.check_output(custom_x_grad, pd_x_grad, "x_grad")
self.check_output(phi_y_grad, pd_y_grad, "y_grad") self.check_output(custom_y_grad, pd_y_grad, "y_grad")
self.check_output(phi_out_grad, pd_out_grad, "out_grad") self.check_output(custom_out_grad, pd_out_grad, "out_grad")
def test_dynamic_add(self): def test_dynamic_add(self):
for device in self.devices: for device in self.devices:
...@@ -418,11 +418,11 @@ class TestCustomInplaceJit(unittest.TestCase): ...@@ -418,11 +418,11 @@ class TestCustomInplaceJit(unittest.TestCase):
self.np_y, self.np_y,
) )
( (
phi_x, custom_x,
phi_y, custom_y,
phi_out, custom_out,
phi_x_grad, custom_x_grad,
phi_y_grad, custom_y_grad,
) = inplace_dynamic_add( ) = inplace_dynamic_add(
True, True,
device, device,
...@@ -431,14 +431,14 @@ class TestCustomInplaceJit(unittest.TestCase): ...@@ -431,14 +431,14 @@ class TestCustomInplaceJit(unittest.TestCase):
self.np_y, self.np_y,
) )
self.check_output(phi_x, phi_out, "inplace_phi_x") self.check_output(custom_x, custom_out, "inplace_custom_x")
self.check_output(pd_x, pd_out, "inplace_pd_x") self.check_output(pd_x, pd_out, "inplace_pd_x")
self.check_output(phi_x, pd_x, "x") self.check_output(custom_x, pd_x, "x")
self.check_output(phi_y, pd_y, "y") self.check_output(custom_y, pd_y, "y")
self.check_output(phi_out, pd_out, "out") self.check_output(custom_out, pd_out, "out")
self.check_output(phi_x_grad, pd_x_grad, "x_grad") self.check_output(custom_x_grad, pd_x_grad, "x_grad")
self.check_output(phi_y_grad, pd_y_grad, "y_grad") self.check_output(custom_y_grad, pd_y_grad, "y_grad")
def test_static_add_vector(self): def test_static_add_vector(self):
for device in self.devices: for device in self.devices:
...@@ -456,10 +456,10 @@ class TestCustomInplaceJit(unittest.TestCase): ...@@ -456,10 +456,10 @@ class TestCustomInplaceJit(unittest.TestCase):
self.np_y, self.np_y,
) )
( (
phi_out, custom_out,
phi_x_grad, custom_x_grad,
phi_y_grad, custom_y_grad,
phi_out_grad, custom_out_grad,
) = inplace_static_add_vector( ) = inplace_static_add_vector(
False, False,
device, device,
...@@ -468,10 +468,10 @@ class TestCustomInplaceJit(unittest.TestCase): ...@@ -468,10 +468,10 @@ class TestCustomInplaceJit(unittest.TestCase):
self.np_y, self.np_y,
) )
self.check_output(phi_out, pd_out, "out") self.check_output(custom_out, pd_out, "out")
self.check_output(phi_x_grad, pd_x_grad, "x_grad") self.check_output(custom_x_grad, pd_x_grad, "x_grad")
self.check_output(phi_y_grad, pd_y_grad, "y_grad") self.check_output(custom_y_grad, pd_y_grad, "y_grad")
self.check_output(phi_out_grad, pd_out_grad, "out_grad") self.check_output(custom_out_grad, pd_out_grad, "out_grad")
def test_dynamic_add_vector(self): def test_dynamic_add_vector(self):
for device in self.devices: for device in self.devices:
...@@ -490,11 +490,11 @@ class TestCustomInplaceJit(unittest.TestCase): ...@@ -490,11 +490,11 @@ class TestCustomInplaceJit(unittest.TestCase):
self.np_y, self.np_y,
) )
( (
phi_x, custom_x,
phi_y, custom_y,
phi_out, custom_out,
phi_x_grad, custom_x_grad,
phi_y_grad, custom_y_grad,
) = inplace_dynamic_add_vector( ) = inplace_dynamic_add_vector(
False, False,
device, device,
...@@ -503,14 +503,14 @@ class TestCustomInplaceJit(unittest.TestCase): ...@@ -503,14 +503,14 @@ class TestCustomInplaceJit(unittest.TestCase):
self.np_y, self.np_y,
) )
self.check_output(phi_x, phi_out, "inplace_phi_x") self.check_output(custom_x, custom_out, "inplace_custom_x")
self.check_output(pd_x, pd_out, "inplace_pd_x") self.check_output(pd_x, pd_out, "inplace_pd_x")
self.check_output(phi_x, pd_x, "x") self.check_output(custom_x, pd_x, "x")
self.check_output(phi_y, pd_y, "y") self.check_output(custom_y, pd_y, "y")
self.check_output(phi_out, pd_out, "out") self.check_output(custom_out, pd_out, "out")
self.check_output(phi_x_grad, pd_x_grad, "x_grad") self.check_output(custom_x_grad, pd_x_grad, "x_grad")
self.check_output(phi_y_grad, pd_y_grad, "y_grad") self.check_output(custom_y_grad, pd_y_grad, "y_grad")
def test_static_relu_net(self): def test_static_relu_net(self):
for device in self.devices: for device in self.devices:
...@@ -530,11 +530,11 @@ class TestCustomInplaceJit(unittest.TestCase): ...@@ -530,11 +530,11 @@ class TestCustomInplaceJit(unittest.TestCase):
self.np_z, self.np_z,
) )
( (
phi_x, custom_x,
phi_y, custom_y,
phi_out, custom_out,
phi_x_grad, custom_x_grad,
phi_y_grad, custom_y_grad,
) = inplace_static_relu_net( ) = inplace_static_relu_net(
custom_inplace.custom_relu_inplace, custom_inplace.custom_relu_inplace,
device, device,
...@@ -543,11 +543,11 @@ class TestCustomInplaceJit(unittest.TestCase): ...@@ -543,11 +543,11 @@ class TestCustomInplaceJit(unittest.TestCase):
self.np_y, self.np_y,
self.np_z, self.np_z,
) )
self.check_output_allclose(phi_x, pd_x, "x") self.check_output_allclose(custom_x, pd_x, "x")
self.check_output_allclose(phi_y, pd_y, "y") self.check_output_allclose(custom_y, pd_y, "y")
self.check_output_allclose(phi_out, pd_out, "out") self.check_output_allclose(custom_out, pd_out, "out")
self.check_output_allclose(phi_x_grad, pd_x_grad, "x_grad") self.check_output_allclose(custom_x_grad, pd_x_grad, "x_grad")
self.check_output_allclose(phi_y_grad, pd_y_grad, "y_grad") self.check_output_allclose(custom_y_grad, pd_y_grad, "y_grad")
def test_dynamic_relu_net(self): def test_dynamic_relu_net(self):
for device in self.devices: for device in self.devices:
...@@ -567,11 +567,11 @@ class TestCustomInplaceJit(unittest.TestCase): ...@@ -567,11 +567,11 @@ class TestCustomInplaceJit(unittest.TestCase):
self.np_z, self.np_z,
) )
( (
phi_x, custom_x,
phi_y, custom_y,
phi_out, custom_out,
phi_x_grad, custom_x_grad,
phi_y_grad, custom_y_grad,
) = inplace_dynamic_relu_net( ) = inplace_dynamic_relu_net(
True, True,
device, device,
...@@ -581,11 +581,11 @@ class TestCustomInplaceJit(unittest.TestCase): ...@@ -581,11 +581,11 @@ class TestCustomInplaceJit(unittest.TestCase):
self.np_z, self.np_z,
) )
self.check_output(phi_x, pd_x, "x") self.check_output(custom_x, pd_x, "x")
self.check_output(phi_y, pd_y, "y") self.check_output(custom_y, pd_y, "y")
self.check_output(phi_out, pd_out, "out") self.check_output(custom_out, pd_out, "out")
self.check_output(phi_x_grad, pd_x_grad, "x_grad") self.check_output(custom_x_grad, pd_x_grad, "x_grad")
self.check_output(phi_y_grad, pd_y_grad, "y_grad") self.check_output(custom_y_grad, pd_y_grad, "y_grad")
def test_static_multi_inplace(self): def test_static_multi_inplace(self):
for device in self.devices: for device in self.devices:
...@@ -611,16 +611,16 @@ class TestCustomInplaceJit(unittest.TestCase): ...@@ -611,16 +611,16 @@ class TestCustomInplaceJit(unittest.TestCase):
self.np_b, self.np_b,
) )
( (
phi_x, custom_x,
phi_out_xy, custom_out_xy,
phi_x_grad, custom_x_grad,
phi_y_grad, custom_y_grad,
phi_out_xy_grad, custom_out_xy_grad,
phi_a, custom_a,
phi_out_ab, custom_out_ab,
phi_a_grad, custom_a_grad,
phi_b_grad, custom_b_grad,
phi_out_ab_grad, custom_out_ab_grad,
) = static_multi_inplace( ) = static_multi_inplace(
True, True,
device, device,
...@@ -630,23 +630,27 @@ class TestCustomInplaceJit(unittest.TestCase): ...@@ -630,23 +630,27 @@ class TestCustomInplaceJit(unittest.TestCase):
self.np_a, self.np_a,
self.np_b, self.np_b,
) )
self.check_output(phi_x, pd_out_xy, "inplace_phi_x") self.check_output(custom_x, pd_out_xy, "inplace_custom_x")
self.check_output( self.check_output(
phi_x_grad, phi_out_xy_grad, "inplace_phi_x_grad" custom_x_grad, custom_out_xy_grad, "inplace_custom_x_grad"
) )
self.check_output(phi_a, pd_out_ab, "inplace_phi_a") self.check_output(custom_a, pd_out_ab, "inplace_custom_a")
self.check_output( self.check_output(
phi_a_grad, phi_out_ab_grad, "inplace_phi_a_grad" custom_a_grad, custom_out_ab_grad, "inplace_custom_a_grad"
) )
self.check_output(phi_out_xy, pd_out_xy, "outxy") self.check_output(custom_out_xy, pd_out_xy, "outxy")
self.check_output(phi_x_grad, pd_x_grad, "x_grad") self.check_output(custom_x_grad, pd_x_grad, "x_grad")
self.check_output(phi_y_grad, pd_y_grad, "y_grad") self.check_output(custom_y_grad, pd_y_grad, "y_grad")
self.check_output(phi_out_xy_grad, pd_out_xy_grad, "outxy_grad") self.check_output(
self.check_output(phi_out_ab, pd_out_ab, "outab") custom_out_xy_grad, pd_out_xy_grad, "outxy_grad"
self.check_output(phi_a_grad, pd_a_grad, "a_grad") )
self.check_output(phi_b_grad, pd_b_grad, "b_grad") self.check_output(custom_out_ab, pd_out_ab, "outab")
self.check_output(phi_out_ab_grad, pd_out_ab_grad, "outab_grad") self.check_output(custom_a_grad, pd_a_grad, "a_grad")
self.check_output(custom_b_grad, pd_b_grad, "b_grad")
self.check_output(
custom_out_ab_grad, pd_out_ab_grad, "outab_grad"
)
def test_dynamic_multi_inplace(self): def test_dynamic_multi_inplace(self):
for device in self.devices: for device in self.devices:
...@@ -672,16 +676,16 @@ class TestCustomInplaceJit(unittest.TestCase): ...@@ -672,16 +676,16 @@ class TestCustomInplaceJit(unittest.TestCase):
self.np_b, self.np_b,
) )
( (
phi_x, custom_x,
phi_y, custom_y,
phi_out_xy, custom_out_xy,
phi_x_grad, custom_x_grad,
phi_y_grad, custom_y_grad,
phi_a, custom_a,
phi_b, custom_b,
phi_out_ab, custom_out_ab,
phi_a_grad, custom_a_grad,
phi_b_grad, custom_b_grad,
) = dynamic_multi_inplace( ) = dynamic_multi_inplace(
True, True,
device, device,
...@@ -692,21 +696,21 @@ class TestCustomInplaceJit(unittest.TestCase): ...@@ -692,21 +696,21 @@ class TestCustomInplaceJit(unittest.TestCase):
self.np_b, self.np_b,
) )
self.check_output(phi_x, phi_out_xy, "inplace_phi_x") self.check_output(custom_x, custom_out_xy, "inplace_custom_x")
self.check_output(pd_x, pd_out_xy, "inplace_pd_x") self.check_output(pd_x, pd_out_xy, "inplace_pd_x")
self.check_output(phi_a, phi_out_ab, "inplace_phi_a") self.check_output(custom_a, custom_out_ab, "inplace_custom_a")
self.check_output(pd_a, pd_out_ab, "inplace_pd_a") self.check_output(pd_a, pd_out_ab, "inplace_pd_a")
self.check_output(phi_x, pd_x, "x") self.check_output(custom_x, pd_x, "x")
self.check_output(phi_y, pd_y, "y") self.check_output(custom_y, pd_y, "y")
self.check_output(phi_out_xy, pd_out_xy, "outxy") self.check_output(custom_out_xy, pd_out_xy, "outxy")
self.check_output(phi_x_grad, pd_x_grad, "x_grad") self.check_output(custom_x_grad, pd_x_grad, "x_grad")
self.check_output(phi_y_grad, pd_y_grad, "y_grad") self.check_output(custom_y_grad, pd_y_grad, "y_grad")
self.check_output(phi_a, pd_a, "a") self.check_output(custom_a, pd_a, "a")
self.check_output(phi_b, pd_b, "b") self.check_output(custom_b, pd_b, "b")
self.check_output(phi_out_ab, pd_out_ab, "outab") self.check_output(custom_out_ab, pd_out_ab, "outab")
self.check_output(phi_a_grad, pd_a_grad, "a_grad") self.check_output(custom_a_grad, pd_a_grad, "a_grad")
self.check_output(phi_b_grad, pd_b_grad, "b_grad") self.check_output(custom_b_grad, pd_b_grad, "b_grad")
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -112,12 +112,12 @@ class TestCustomLinearJit(unittest.TestCase): ...@@ -112,12 +112,12 @@ class TestCustomLinearJit(unittest.TestCase):
for device in self.devices: for device in self.devices:
for dtype in self.dtypes: for dtype in self.dtypes:
( (
phi_out, custom_out,
phi_x_grad, custom_x_grad,
phi_weight_grad, custom_weight_grad,
phi_bias_grad, custom_bias_grad,
) = linear_static( ) = linear_static(
custom_ops.phi_linear, custom_ops.custom_linear,
device, device,
dtype, dtype,
self.np_x, self.np_x,
...@@ -132,23 +132,23 @@ class TestCustomLinearJit(unittest.TestCase): ...@@ -132,23 +132,23 @@ class TestCustomLinearJit(unittest.TestCase):
self.np_weight, self.np_weight,
self.np_bias, self.np_bias,
) )
self.check_output(phi_out, pd_out, "out") self.check_output(custom_out, pd_out, "out")
self.check_output(phi_x_grad, pd_x_grad, "x_grad") self.check_output(custom_x_grad, pd_x_grad, "x_grad")
self.check_output( self.check_output(
phi_weight_grad, pd_weight_grad, "weight_grad" custom_weight_grad, pd_weight_grad, "weight_grad"
) )
self.check_output(phi_bias_grad, pd_bias_grad, "bias_grad") self.check_output(custom_bias_grad, pd_bias_grad, "bias_grad")
def test_dynamic(self): def test_dynamic(self):
for device in self.devices: for device in self.devices:
for dtype in self.dtypes: for dtype in self.dtypes:
( (
phi_out, custom_out,
phi_x_grad, custom_x_grad,
phi_weight_grad, custom_weight_grad,
phi_bias_grad, custom_bias_grad,
) = linear_dynamic( ) = linear_dynamic(
custom_ops.phi_linear, custom_ops.custom_linear,
device, device,
dtype, dtype,
self.np_x, self.np_x,
...@@ -168,12 +168,12 @@ class TestCustomLinearJit(unittest.TestCase): ...@@ -168,12 +168,12 @@ class TestCustomLinearJit(unittest.TestCase):
self.np_weight, self.np_weight,
self.np_bias, self.np_bias,
) )
self.check_output(phi_out, pd_out, "phi_out") self.check_output(custom_out, pd_out, "custom_out")
self.check_output(phi_x_grad, pd_x_grad, "x_grad") self.check_output(custom_x_grad, pd_x_grad, "x_grad")
self.check_output( self.check_output(
phi_weight_grad, pd_weight_grad, "weight_grad" custom_weight_grad, pd_weight_grad, "weight_grad"
) )
self.check_output(phi_bias_grad, pd_bias_grad, "bias_grad") self.check_output(custom_bias_grad, pd_bias_grad, "bias_grad")
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -40,13 +40,13 @@ multi_out_module = load( ...@@ -40,13 +40,13 @@ multi_out_module = load(
) )
def discrete_out_dynamic(use_phi, device, dtype, np_w, np_x, np_y, np_z): def discrete_out_dynamic(use_custom, device, dtype, np_w, np_x, np_y, np_z):
paddle.set_device(device) paddle.set_device(device)
w = paddle.to_tensor(np_w, dtype=dtype, stop_gradient=False) w = paddle.to_tensor(np_w, dtype=dtype, stop_gradient=False)
x = paddle.to_tensor(np_x, dtype=dtype, stop_gradient=False) x = paddle.to_tensor(np_x, dtype=dtype, stop_gradient=False)
y = paddle.to_tensor(np_y, dtype=dtype, stop_gradient=False) y = paddle.to_tensor(np_y, dtype=dtype, stop_gradient=False)
z = paddle.to_tensor(np_z, dtype=dtype, stop_gradient=False) z = paddle.to_tensor(np_z, dtype=dtype, stop_gradient=False)
if use_phi: if use_custom:
out = multi_out_module.discrete_out(w, x, y, z) out = multi_out_module.discrete_out(w, x, y, z)
else: else:
out = w * 1 + x * 2 + y * 3 + z * 4 out = w * 1 + x * 2 + y * 3 + z * 4
...@@ -55,7 +55,7 @@ def discrete_out_dynamic(use_phi, device, dtype, np_w, np_x, np_y, np_z): ...@@ -55,7 +55,7 @@ def discrete_out_dynamic(use_phi, device, dtype, np_w, np_x, np_y, np_z):
return out.numpy(), w.grad.numpy(), y.grad.numpy() return out.numpy(), w.grad.numpy(), y.grad.numpy()
def discrete_out_static(use_phi, device, dtype, np_w, np_x, np_y, np_z): def discrete_out_static(use_custom, device, dtype, np_w, np_x, np_y, np_z):
paddle.enable_static() paddle.enable_static()
paddle.set_device(device) paddle.set_device(device)
with static.scope_guard(static.Scope()): with static.scope_guard(static.Scope()):
...@@ -68,7 +68,7 @@ def discrete_out_static(use_phi, device, dtype, np_w, np_x, np_y, np_z): ...@@ -68,7 +68,7 @@ def discrete_out_static(use_phi, device, dtype, np_w, np_x, np_y, np_z):
x.stop_gradient = False x.stop_gradient = False
y.stop_gradient = False y.stop_gradient = False
z.stop_gradient = False z.stop_gradient = False
if use_phi: if use_custom:
out = multi_out_module.discrete_out(w, x, y, z) out = multi_out_module.discrete_out(w, x, y, z)
else: else:
out = w * 1 + x * 2 + y * 3 + z * 4 out = w * 1 + x * 2 + y * 3 + z * 4
...@@ -180,7 +180,11 @@ class TestMultiOutputDtypes(unittest.TestCase): ...@@ -180,7 +180,11 @@ class TestMultiOutputDtypes(unittest.TestCase):
self.np_y, self.np_y,
self.np_z, self.np_z,
) )
(phi_out, phi_w_grad, phi_y_grad,) = discrete_out_static( (
custom_out,
custom_w_grad,
custom_y_grad,
) = discrete_out_static(
True, True,
device, device,
dtype, dtype,
...@@ -189,10 +193,10 @@ class TestMultiOutputDtypes(unittest.TestCase): ...@@ -189,10 +193,10 @@ class TestMultiOutputDtypes(unittest.TestCase):
self.np_y, self.np_y,
self.np_z, self.np_z,
) )
self.check_output(phi_out, pd_out, "out") self.check_output(custom_out, pd_out, "out")
# NOTE: In static mode, the output gradient of custom operator has been optimized to shape=[1]. However, native paddle op's output shape = [4, 8], hence we need to fetch pd_w_grad[0][0] (By the way, something wrong with native paddle's gradient, the outputs with other indexes instead of pd_w_grad[0][0] is undefined in this unittest.) # NOTE: In static mode, the output gradient of custom operator has been optimized to shape=[1]. However, native paddle op's output shape = [4, 8], hence we need to fetch pd_w_grad[0][0] (By the way, something wrong with native paddle's gradient, the outputs with other indexes instead of pd_w_grad[0][0] is undefined in this unittest.)
self.check_output(phi_w_grad, pd_w_grad[0][0], "w_grad") self.check_output(custom_w_grad, pd_w_grad[0][0], "w_grad")
self.check_output(phi_y_grad, pd_y_grad[0][0], "y_grad") self.check_output(custom_y_grad, pd_y_grad[0][0], "y_grad")
def test_discrete_out_dynamic(self): def test_discrete_out_dynamic(self):
for device in self.devices: for device in self.devices:
...@@ -206,7 +210,11 @@ class TestMultiOutputDtypes(unittest.TestCase): ...@@ -206,7 +210,11 @@ class TestMultiOutputDtypes(unittest.TestCase):
self.np_y, self.np_y,
self.np_z, self.np_z,
) )
(phi_out, phi_w_grad, phi_y_grad,) = discrete_out_dynamic( (
custom_out,
custom_w_grad,
custom_y_grad,
) = discrete_out_dynamic(
True, True,
device, device,
dtype, dtype,
...@@ -215,9 +223,9 @@ class TestMultiOutputDtypes(unittest.TestCase): ...@@ -215,9 +223,9 @@ class TestMultiOutputDtypes(unittest.TestCase):
self.np_y, self.np_y,
self.np_z, self.np_z,
) )
self.check_output(phi_out, pd_out, "out") self.check_output(custom_out, pd_out, "out")
self.check_output(phi_w_grad, pd_w_grad, "w_grad") self.check_output(custom_w_grad, pd_w_grad, "w_grad")
self.check_output(phi_y_grad, pd_y_grad, "y_grad") self.check_output(custom_y_grad, pd_y_grad, "y_grad")
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册