未验证 提交 59c9d75e 编写于 作者: H HongyuJia 提交者: GitHub

[CustomOP Optional Inplace] Custom operator supports inplace optional vector Tensor input (#52421)

* [CustomOP Optional Inplace] Custom operator supports inplace optional vector Tensor input

* uncomment unittest codes
上级 0b60f28c
......@@ -833,7 +833,11 @@ static void RunInferDtypeFunc(
auto in_name = inplace_reverse_map.at(out_name);
// make sure ctx has valid inplace optional outputs
if (ctx->HasOutput(out_name)) {
ctx->SetOutputDataTypes(out_name, ctx->GetInputDataTypes(in_name));
size_t size = ctx->InputSize(in_name);
for (size_t i = 0; i < size; ++i) {
auto dtype = ctx->GetInputDataType(in_name, i);
ctx->SetOutputDataType(out_name, dtype, i);
}
} else {
PADDLE_ENFORCE(
detail::IsOptionalVar(out_name),
......
......@@ -120,6 +120,8 @@ class PADDLE_API CustomOpKernelContext {
std::vector<Tensor> InputsBetween(size_t start, size_t end) const;
Tensor& MutableInputAt(size_t idx);
paddle::optional<Tensor> OptionalInputAt(size_t idx);
paddle::optional<std::vector<Tensor>> OptionalInputsBetween(size_t start,
size_t end);
const std::vector<paddle::any>& Attrs() const { return attrs_; }
const std::vector<std::pair<size_t, size_t>>& InputRange() {
return input_range_;
......@@ -294,6 +296,34 @@ struct KernelFuncImpl<Return (*)(Args...), impl_fn> {
}
};
// Handle args for inplace vector<Tensor> case
template <typename... Tail>
struct ComputeCallHelper<std::vector<Tensor>&, Tail...> {
template <int in_idx, int attr_idx, int out_idx, typename... PreviousArgs>
static void Compute(CustomOpKernelContext* ctx, PreviousArgs&... pargs) {
auto& range = ctx->InputRangeAt(in_idx);
auto arg = ctx->InputsBetween(range.first, range.second);
ComputeCallHelper<
Tail...>::template Compute<in_idx + 1, attr_idx, out_idx>(ctx,
pargs...,
arg);
}
};
// Handle args for optional inplace vector<Tensor> case
template <typename... Tail>
struct ComputeCallHelper<paddle::optional<std::vector<Tensor>>&, Tail...> {
template <int in_idx, int attr_idx, int out_idx, typename... PreviousArgs>
static void Compute(CustomOpKernelContext* ctx, PreviousArgs&... pargs) {
auto& range = ctx->InputRangeAt(in_idx);
auto arg = ctx->OptionalInputsBetween(range.first, range.second);
ComputeCallHelper<
Tail...>::template Compute<in_idx + 1, attr_idx, out_idx>(ctx,
pargs...,
arg);
}
};
PD_SPECIALIZE_ComputeCallHelper(bool);
PD_SPECIALIZE_ComputeCallHelper(int);
PD_SPECIALIZE_ComputeCallHelper(float);
......@@ -358,20 +388,6 @@ struct KernelFuncImpl<Return (*)(Args...), impl_fn> {
}
};
// Handle args for inplace vector<Tensor> case
template <typename... Tail>
struct ComputeCallHelper<std::vector<Tensor>&, Tail...> {
template <int in_idx, int attr_idx, int out_idx, typename... PreviousArgs>
static void Compute(CustomOpKernelContext* ctx, PreviousArgs&... pargs) {
auto& range = ctx->InputRangeAt(in_idx);
auto arg = ctx->InputsBetween(range.first, range.second);
ComputeCallHelper<
Tail...>::template Compute<in_idx + 1, attr_idx, out_idx>(ctx,
pargs...,
arg);
}
};
template <int out_idx, typename T>
struct ComputeReturnHelper;
......
......@@ -110,6 +110,18 @@ paddle::optional<Tensor> CustomOpKernelContext::OptionalInputAt(size_t idx) {
return paddle::make_optional<paddle::Tensor>(inputs_.at(idx));
}
paddle::optional<std::vector<Tensor>>
CustomOpKernelContext::OptionalInputsBetween(size_t start, size_t end) {
std::vector<Tensor> rlt;
for (size_t i = start; i < end; ++i) {
if (!inputs_.at(i).is_initialized()) {
return paddle::none;
}
rlt.emplace_back(inputs_.at(i));
}
return paddle::make_optional<std::vector<Tensor>>(rlt);
}
Tensor* CustomOpKernelContext::MutableOutputAt(size_t idx) {
return &(outputs_.at(idx));
}
......
......@@ -1057,6 +1057,22 @@ def _gen_output_content(
if out_idx in inplace_reverse_idx:
in_idx = inplace_reverse_idx[out_idx]
if (
in_idx != -1
and "@VECTOR" in in_names[in_idx]
and "@OPTIONAL" in in_names[in_idx]
):
# inplace optional vector<Tensor> output case
lower_in_names = in_names[in_idx].split("@")[0].lower()
dynamic_content += f"""
{indent}if {lower_in_names} is not None:
{indent} outs['{out_name}'] = [core.eager.Tensor() for _ in range(len({lower_in_names}))]
{indent}else:
{indent} outs['{out_name}'] = core.eager.Tensor()
{indent}ctx.add_outputs(outs['{out_name}'])"""
static_content += f"""
{indent}if {lower_in_names} is not None:
{indent} outs['{out_name}'] = [helper.create_variable(dtype='float32') for _ in range(len({lower_in_names}))]"""
elif (
in_idx != -1 and "@VECTOR" in in_names[in_idx]
): # inplace vector<Tensor> output case
lower_in_names = in_names[in_idx].split("@")[0].lower()
......
......@@ -300,3 +300,116 @@ PD_BUILD_GRAD_OP(custom_optional_inplace_add)
paddle::Grad(paddle::Optional("Y"))}})
.SetKernelFn(PD_KERNEL(AddOptionalInplaceBackward))
.SetInferShapeFn(PD_INFER_SHAPE(AddOptionalInplaceBackwardInferShape));
/*
if (y) {
outX = 2 * x + y[1...n];
outY[i] = x + y[i];
} else {
outX = 2 * x;
outY = None;
}
*/
std::vector<paddle::Tensor> AddOptionalInplaceVectorForward(
const paddle::Tensor& x,
paddle::optional<std::vector<paddle::Tensor>>& y) { // NOLINT
PD_CHECK(x.place() == paddle::PlaceType::kCPU, "x must be a CPU Tensor.");
paddle::Tensor outX = paddle::zeros(x.shape(), x.dtype(), x.place());
PD_DISPATCH_FLOATING_TYPES(
x.type(), "AddOptionalInplaceVectorForward", ([&] {
add_two_pointers<data_t>(
x.data<data_t>(), x.data<data_t>(), outX.data<data_t>(), x.size());
if (y) {
for (size_t i = 0; i < y->size(); ++i) {
add_one_pointer<data_t>(
y->at(i).data<data_t>(), outX.data<data_t>(), outX.size());
add_one_pointer<data_t>(
x.data<data_t>(), y->at(i).data<data_t>(), x.size());
}
}
}));
// No need to return y, because we set it as inplace input.
return {outX};
}
std::vector<paddle::DataType> AddOptionalInplaceVectorInferDtype(
const paddle::DataType& x_dtype,
const paddle::optional<std::vector<paddle::DataType>>& y_dtype) {
return {x_dtype};
}
std::vector<std::vector<int64_t>> AddOptionalInplaceVectorInferShape(
const std::vector<int64_t>& x_shape,
const paddle::optional<std::vector<std::vector<int64_t>>>& y_shape) {
return {x_shape};
}
/*
if (outy_grad) {
x_grad = outX_grad * 2 + outY_grad[1...n];
y_grad[i] = outX_grad + outY_grad[i];
} else {
x_grad = outX_grad * 2;
y_grad = None;
}
*/
std::vector<paddle::Tensor> AddOptionalInplaceVectorBackward(
const paddle::Tensor& x,
const paddle::optional<std::vector<paddle::Tensor>>& y,
const paddle::Tensor& outx_grad,
paddle::optional<std::vector<paddle::Tensor>>& outy_grad) { // NOLINT
PD_CHECK(x.place() == paddle::PlaceType::kCPU, "x must be a CPU Tensor.");
paddle::Tensor x_grad = paddle::zeros(x.shape(), x.dtype(), x.place());
PD_DISPATCH_FLOATING_TYPES(
outx_grad.type(), "AddOptionalInplaceVectorBackward", ([&] {
add_two_pointers<data_t>(outx_grad.data<data_t>(),
outx_grad.data<data_t>(),
x_grad.data<data_t>(),
x_grad.size());
if (outy_grad) {
for (size_t i = 0; i < outy_grad->size(); ++i) {
add_one_pointer<data_t>(outy_grad->at(i).data<data_t>(),
x_grad.data<data_t>(),
x_grad.size());
add_one_pointer<data_t>(outx_grad.data<data_t>(),
outy_grad->at(i).data<data_t>(),
outx_grad.size());
}
}
}));
return {x_grad};
}
std::vector<std::vector<int64_t>> AddOptionalInplaceVectorBackwardInferShape(
const std::vector<int64_t>& x_shape,
const paddle::optional<std::vector<std::vector<int64_t>>>& y_shape,
const std::vector<int64_t>& x_grad_shape,
const paddle::optional<std::vector<std::vector<int64_t>>>& y_grad_shape) {
return {x_shape};
}
PD_BUILD_OP(custom_optional_inplace_add_vec)
.Inputs({"X", paddle::Optional(paddle::Vec("Y"))})
.Outputs({"OutX", paddle::Optional(paddle::Vec("OutY"))})
.SetInplaceMap({{paddle::Optional(paddle::Vec("Y")),
paddle::Optional(paddle::Vec("OutY"))}})
.SetKernelFn(PD_KERNEL(AddOptionalInplaceVectorForward))
.SetInferShapeFn(PD_INFER_SHAPE(AddOptionalInplaceVectorInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(AddOptionalInplaceVectorInferDtype));
PD_BUILD_GRAD_OP(custom_optional_inplace_add_vec)
.Inputs({"X",
paddle::Optional(paddle::Vec("Y")),
paddle::Grad("OutX"),
paddle::Grad(paddle::Optional(paddle::Vec("OutY")))})
.Outputs({paddle::Grad("X"),
paddle::Grad(paddle::Optional(paddle::Vec("Y")))})
.SetInplaceMap({{paddle::Grad(paddle::Optional(paddle::Vec("OutY"))),
paddle::Grad(paddle::Optional(paddle::Vec("Y")))}})
.SetKernelFn(PD_KERNEL(AddOptionalInplaceVectorBackward))
.SetInferShapeFn(
PD_INFER_SHAPE(AddOptionalInplaceVectorBackwardInferShape));
......@@ -41,7 +41,7 @@ custom_optional = load(
)
def optional_dynamic_add(phi_func, device, dtype, np_x, np_y):
def optional_dynamic_add(custom_func, device, dtype, np_x, np_y):
paddle.set_device(device)
x = paddle.to_tensor(np_x, dtype=dtype, stop_gradient=False)
......@@ -49,7 +49,7 @@ def optional_dynamic_add(phi_func, device, dtype, np_x, np_y):
y = paddle.to_tensor(np_y, dtype=dtype, stop_gradient=False)
else:
y = x
if phi_func:
if custom_func:
out = custom_optional.custom_add(x, y if np_y is not None else None)
else:
out = paddle.add(x, y)
......@@ -58,7 +58,7 @@ def optional_dynamic_add(phi_func, device, dtype, np_x, np_y):
return x.numpy(), out.numpy(), x.grad.numpy()
def optional_static_add(phi_func, device, dtype, np_x, np_y):
def optional_static_add(custom_func, device, dtype, np_x, np_y):
paddle.enable_static()
paddle.set_device(device)
with static.scope_guard(static.Scope()):
......@@ -79,7 +79,7 @@ def optional_static_add(phi_func, device, dtype, np_x, np_y):
feed_dict = {
"x": np_x.astype(dtype),
}
if phi_func:
if custom_func:
out = custom_optional.custom_add(
x, y if np_y is not None else None
)
......@@ -116,13 +116,13 @@ if (y) {
'''
def optional_inplace_dynamic_add(phi_func, device, dtype, np_x, np_y):
def optional_inplace_dynamic_add(custom_func, device, dtype, np_x, np_y):
paddle.set_device(device)
x = paddle.to_tensor(np_x, dtype=dtype, stop_gradient=False)
if np_y is not None:
y = paddle.to_tensor(np_y, dtype=dtype, stop_gradient=True)
if phi_func:
if custom_func:
outx, outy = custom_optional.custom_optional_inplace_add(x, y)
else:
# We need to accumulate y's grad here.
......@@ -133,7 +133,7 @@ def optional_inplace_dynamic_add(phi_func, device, dtype, np_x, np_y):
outy = y.add_(x)
else:
y = None
if phi_func:
if custom_func:
outx, outy = custom_optional.custom_optional_inplace_add(x, y)
else:
outx = 2 * x
......@@ -155,7 +155,7 @@ def optional_inplace_dynamic_add(phi_func, device, dtype, np_x, np_y):
)
def optional_inplace_static_add(phi_func, device, dtype, np_x, np_y):
def optional_inplace_static_add(custom_func, device, dtype, np_x, np_y):
paddle.enable_static()
paddle.set_device(device)
with static.scope_guard(static.Scope()):
......@@ -171,7 +171,7 @@ def optional_inplace_static_add(phi_func, device, dtype, np_x, np_y):
"x": np_x.astype(dtype),
"y": np_y.astype(dtype),
}
if phi_func:
if custom_func:
outx, outy = custom_optional.custom_optional_inplace_add(
x, y
)
......@@ -182,7 +182,7 @@ def optional_inplace_static_add(phi_func, device, dtype, np_x, np_y):
feed_dict = {
"x": np_x.astype(dtype),
}
if phi_func:
if custom_func:
outx, outy = custom_optional.custom_optional_inplace_add(
x, None
)
......@@ -223,7 +223,7 @@ def optional_inplace_static_add(phi_func, device, dtype, np_x, np_y):
return [x_v, out_v, x_grad_v]
def optional_vector_dynamic_add(phi_func, device, dtype, np_x, np_inputs):
def optional_vector_dynamic_add(custom_func, device, dtype, np_x, np_inputs):
paddle.set_device(device)
x = paddle.to_tensor(np_x, dtype=dtype, stop_gradient=False)
......@@ -232,14 +232,14 @@ def optional_vector_dynamic_add(phi_func, device, dtype, np_x, np_inputs):
paddle.to_tensor(np_input, dtype=dtype, stop_gradient=False)
for np_input in np_inputs
]
if phi_func:
if custom_func:
out = custom_optional.custom_add_vec(x, inputs)
else:
out = paddle.add(x, inputs[0])
for input in inputs[1:]:
out = paddle.add(out, input)
else:
if phi_func:
if custom_func:
out = custom_optional.custom_add_vec(x, None)
else:
out = paddle.add(x, x)
......@@ -248,7 +248,7 @@ def optional_vector_dynamic_add(phi_func, device, dtype, np_x, np_inputs):
return x.numpy(), out.numpy(), x.grad.numpy()
def optional_vector_static_add(phi_func, device, dtype, np_x, np_inputs):
def optional_vector_static_add(custom_func, device, dtype, np_x, np_inputs):
paddle.enable_static()
paddle.set_device(device)
with static.scope_guard(static.Scope()):
......@@ -271,13 +271,13 @@ def optional_vector_static_add(phi_func, device, dtype, np_x, np_inputs):
"y2": np_inputs[1].astype(dtype),
}
)
if phi_func:
if custom_func:
out = custom_optional.custom_add_vec(x, [y1, y2])
else:
out = paddle.add(x, y1)
out = paddle.add(out, y2)
else:
if phi_func:
if custom_func:
out = custom_optional.custom_add_vec(x, None)
else:
out = paddle.add(x, x)
......@@ -301,6 +301,159 @@ def optional_vector_static_add(phi_func, device, dtype, np_x, np_inputs):
return x_v, out_v, x_grad_v
'''
if (y) {
outX = 2 * x + y[1...n];
outY[i] = x + y[i];
} else {
outX = 2 * x;
outY = None;
}
'''
def optional_inplace_vector_dynamic_add(
custom_func, device, dtype, np_x, np_inputs
):
paddle.set_device(device)
x = paddle.to_tensor(np_x, dtype=dtype, stop_gradient=False)
if np_inputs is not None:
inputs = [
paddle.to_tensor(np_input, dtype=dtype, stop_gradient=True)
for np_input in np_inputs
]
if custom_func:
outx, outy = custom_optional.custom_optional_inplace_add_vec(
x, inputs
)
else:
outx = 2 * x
outy = []
for input in inputs:
# We need to accumulate y's grad here.
input.stop_gradient = False
outx = outx + input
# Inplace leaf Tensor's stop_gradient should be True
input.stop_gradient = True
outy.append(input.add_(x))
else:
if custom_func:
outx, outy = custom_optional.custom_optional_inplace_add_vec(
x, None
)
else:
outx = 2 * x
outy = None
assert (
outy is None
), "The output `outy` of optional_inplace_dynamic_add should be None"
if outy is not None:
out = outx
for tensor in outy:
out = out + tensor
else:
out = outx
out.backward()
return (
x.numpy(),
outx.numpy(),
[y.numpy() for y in inputs] if np_inputs is not None else None,
[t.numpy() for t in outy] if outy is not None else None,
out.numpy(),
x.grad.numpy(),
[y.grad.numpy() for y in inputs]
if np_inputs is not None and inputs[0].grad is not None
else None,
)
def optional_inplace_vector_static_add(
custom_func, device, dtype, np_x, np_inputs
):
paddle.enable_static()
paddle.set_device(device)
with static.scope_guard(static.Scope()):
with static.program_guard(static.Program()):
x = static.data(name="x", shape=[None, np_x.shape[1]], dtype=dtype)
x.stop_gradient = False
feed_dict = {
"x": np_x.astype(dtype),
}
if np_inputs is not None:
y1 = static.data(
name="y1", shape=[None, np_x.shape[1]], dtype=dtype
)
y1.stop_gradient = False
y2 = static.data(
name="y2", shape=[None, np_x.shape[1]], dtype=dtype
)
y2.stop_gradient = False
feed_dict.update(
{
"y1": np_inputs[0].astype(dtype),
"y2": np_inputs[1].astype(dtype),
}
)
if custom_func:
(
outx,
outy,
) = custom_optional.custom_optional_inplace_add_vec(
x, [y1, y2]
)
else:
outx = paddle.add(paddle.add(paddle.add(x, x), y1), y2)
# outx = 2 * x + y1 + y2
outy = [x + y1, x + y2]
else:
if custom_func:
(
outx,
outy,
) = custom_optional.custom_optional_inplace_add_vec(x, None)
else:
outx = 2 * x
outy = None
if np_inputs is not None:
out = outx + outy[0] + outy[1]
else:
out = outx
mean_out = paddle.mean(out)
static.append_backward(mean_out)
exe = static.Executor()
exe.run(static.default_startup_program())
if np_inputs is not None:
x_v, out_v, x_grad_v, y1_grad_v, y2_grad_v = exe.run(
static.default_main_program(),
feed=feed_dict,
fetch_list=[
x.name,
out.name,
x.name + "@GRAD",
y1.name + "@GRAD",
y2.name + "@GRAD",
],
)
paddle.disable_static()
return [x_v, out_v, x_grad_v, y1_grad_v, y2_grad_v]
else:
x_v, out_v, x_grad_v = exe.run(
static.default_main_program(),
feed=feed_dict,
fetch_list=[
x.name,
out.name,
x.name + "@GRAD",
],
)
paddle.disable_static()
return [x_v, out_v, x_grad_v]
class TestCustomOptionalJit(unittest.TestCase):
def setUp(self):
self.dtypes = ['float32', 'float64']
......@@ -317,6 +470,16 @@ class TestCustomOptionalJit(unittest.TestCase):
return
assert out is not None, "out value of " + name + " is None"
assert pd_out is not None, "pd_out value of " + name + " is None"
if isinstance(out, list) and isinstance(pd_out, list):
for idx in range(len(out)):
np.testing.assert_array_equal(
out[idx],
pd_out[idx],
err_msg='custom op {}: {},\n paddle api {}: {}'.format(
name, out[idx], name, pd_out[idx]
),
)
else:
np.testing.assert_array_equal(
out,
pd_out,
......@@ -351,7 +514,11 @@ class TestCustomOptionalJit(unittest.TestCase):
self.np_x,
np_y,
)
(phi_x, phi_out, phi_x_grad,) = optional_static_add(
(
custom_x,
custom_out,
custom_x_grad,
) = optional_static_add(
True,
device,
dtype,
......@@ -359,9 +526,9 @@ class TestCustomOptionalJit(unittest.TestCase):
np_y,
)
self.check_output(phi_x, pd_x, "x")
self.check_output(phi_out, pd_out, "out")
self.check_output(phi_x_grad, pd_x_grad, "x_grad")
self.check_output(custom_x, pd_x, "x")
self.check_output(custom_out, pd_out, "out")
self.check_output(custom_x_grad, pd_x_grad, "x_grad")
def test_optional_dynamic_add(self):
for device in self.devices:
......@@ -374,7 +541,11 @@ class TestCustomOptionalJit(unittest.TestCase):
self.np_x,
np_y,
)
(phi_x, phi_out, phi_x_grad,) = optional_dynamic_add(
(
custom_x,
custom_out,
custom_x_grad,
) = optional_dynamic_add(
True,
device,
dtype,
......@@ -382,9 +553,9 @@ class TestCustomOptionalJit(unittest.TestCase):
np_y,
)
self.check_output(phi_x, pd_x, "x")
self.check_output(phi_out, pd_out, "out")
self.check_output(phi_x_grad, pd_x_grad, "x_grad")
self.check_output(custom_x, pd_x, "x")
self.check_output(custom_out, pd_out, "out")
self.check_output(custom_x_grad, pd_x_grad, "x_grad")
def test_optional_inplace_static_add(self):
for device in self.devices:
......@@ -397,7 +568,7 @@ class TestCustomOptionalJit(unittest.TestCase):
self.np_x,
np_y,
)
phi_tuple = optional_inplace_static_add(
custom_tuple = optional_inplace_static_add(
True,
device,
dtype,
......@@ -405,11 +576,13 @@ class TestCustomOptionalJit(unittest.TestCase):
np_y,
)
self.check_output(phi_tuple[0], pd_tuple[0], "x")
self.check_output(phi_tuple[1], pd_tuple[1], "out")
self.check_output(phi_tuple[2], pd_tuple[2], "x_grad")
if len(phi_tuple) > 3:
self.check_output(phi_tuple[3], pd_tuple[3], "y_grad")
self.check_output(custom_tuple[0], pd_tuple[0], "x")
self.check_output(custom_tuple[1], pd_tuple[1], "out")
self.check_output(custom_tuple[2], pd_tuple[2], "x_grad")
if len(custom_tuple) > 3:
self.check_output(
custom_tuple[3], pd_tuple[3], "y_grad"
)
def test_optional_inplace_dynamic_add(self):
for device in self.devices:
......@@ -431,13 +604,13 @@ class TestCustomOptionalJit(unittest.TestCase):
np_y,
)
(
phi_x,
phi_outx,
phi_y,
phi_outy,
phi_out,
phi_x_grad,
phi_y_grad,
custom_x,
custom_outx,
custom_y,
custom_outy,
custom_out,
custom_x_grad,
custom_y_grad,
) = optional_inplace_dynamic_add(
True,
device,
......@@ -447,21 +620,25 @@ class TestCustomOptionalJit(unittest.TestCase):
)
self.check_output(pd_y, pd_outy, "inplace_pd_y")
self.check_output(phi_y, phi_outy, "inplace_phi_y")
self.check_output(custom_y, custom_outy, "inplace_custom_y")
self.check_output(phi_x, pd_x, "x")
self.check_output(phi_outx, pd_outx, "outx")
self.check_output(phi_y, pd_y, "y")
self.check_output(phi_outy, pd_outy, "outy")
self.check_output(phi_out, pd_out, "out")
self.check_output(phi_x_grad, pd_x_grad, "x_grad")
self.check_output(phi_y_grad, pd_y_grad, "y_grad")
self.check_output(custom_x, pd_x, "x")
self.check_output(custom_outx, pd_outx, "outx")
self.check_output(custom_y, pd_y, "y")
self.check_output(custom_outy, pd_outy, "outy")
self.check_output(custom_out, pd_out, "out")
self.check_output(custom_x_grad, pd_x_grad, "x_grad")
self.check_output(custom_y_grad, pd_y_grad, "y_grad")
def test_optional_vector_static_add(self):
for device in self.devices:
for dtype in self.dtypes:
for np_y in [None, self.np_inputs]:
(phi_x, phi_out, phi_x_grad,) = optional_vector_static_add(
(
custom_x,
custom_out,
custom_x_grad,
) = optional_vector_static_add(
True,
device,
dtype,
......@@ -476,15 +653,19 @@ class TestCustomOptionalJit(unittest.TestCase):
np_y,
)
self.check_output(phi_x, pd_x, "x")
self.check_output(phi_out, pd_out, "out")
self.check_output(phi_x_grad, pd_x_grad, "x_grad")
self.check_output(custom_x, pd_x, "x")
self.check_output(custom_out, pd_out, "out")
self.check_output(custom_x_grad, pd_x_grad, "x_grad")
def test_optional_vector_dynamic_add(self):
for device in self.devices:
for dtype in self.dtypes:
for np_y in [None, self.np_inputs]:
(phi_x, phi_out, phi_x_grad,) = optional_vector_dynamic_add(
(
custom_x,
custom_out,
custom_x_grad,
) = optional_vector_dynamic_add(
True,
device,
dtype,
......@@ -499,9 +680,85 @@ class TestCustomOptionalJit(unittest.TestCase):
np_y,
)
self.check_output(phi_x, pd_x, "x")
self.check_output(phi_out, pd_out, "out")
self.check_output(phi_x_grad, pd_x_grad, "x_grad")
self.check_output(custom_x, pd_x, "x")
self.check_output(custom_out, pd_out, "out")
self.check_output(custom_x_grad, pd_x_grad, "x_grad")
def test_optional_inplace_vector_static_add(self):
for device in self.devices:
for dtype in self.dtypes:
for np_y in [None, self.np_inputs]:
pd_tuple = optional_inplace_vector_static_add(
False,
device,
dtype,
self.np_x,
np_y,
)
custom_tuple = optional_inplace_vector_static_add(
True,
device,
dtype,
self.np_x,
np_y,
)
self.check_output(custom_tuple[0], pd_tuple[0], "x")
self.check_output(custom_tuple[1], pd_tuple[1], "out")
self.check_output(custom_tuple[2], pd_tuple[2], "x_grad")
if len(custom_tuple) > 3:
self.check_output(
custom_tuple[3], pd_tuple[3], "y1_grad"
)
self.check_output(
custom_tuple[4], pd_tuple[4], "y2_grad"
)
def test_optional_inplace_vector_dynamic_add(self):
for device in self.devices:
for dtype in self.dtypes:
for np_y in [None, self.np_inputs]:
(
custom_x,
custom_outx,
custom_y,
custom_outy,
custom_out,
custom_x_grad,
custom_y_grad,
) = optional_inplace_vector_dynamic_add(
True,
device,
dtype,
self.np_x,
np_y,
)
(
pd_x,
pd_outx,
pd_y,
pd_outy,
pd_out,
pd_x_grad,
pd_y_grad,
) = optional_inplace_vector_dynamic_add(
False,
device,
dtype,
self.np_x,
np_y,
)
self.check_output(pd_y, pd_outy, "inplace_pd_y")
self.check_output(custom_y, custom_outy, "inplace_custom_y")
self.check_output(custom_x, pd_x, "x")
self.check_output(custom_outx, pd_outx, "outx")
self.check_output(custom_y, pd_y, "y")
self.check_output(custom_outy, pd_outy, "outy")
self.check_output(custom_out, pd_out, "out")
self.check_output(custom_x_grad, pd_x_grad, "x_grad")
self.check_output(custom_y_grad, pd_y_grad, "y_grad")
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册