未验证 提交 6a10e604 编写于 作者: H HongyuJia 提交者: GitHub

[CustomOP Optional] CustomOP supports optional vector<Tensor> input (#51973)

上级 5754aae5
......@@ -174,7 +174,7 @@ static void RunKernelFunc(
custom_t.set_impl(std::make_shared<phi::DenseTensor>(*x));
custom_vec_in.emplace_back(custom_t);
}
} else { // optional inputs, `custom_vec_in` is empty
} else { // optional inputs.
PADDLE_ENFORCE(
detail::IsOptionalVar(in_name),
phi::errors::NotFound("Your custom operator's KernelFunc cannot "
......@@ -182,6 +182,12 @@ static void RunKernelFunc(
in_name));
VLOG(3) << "Custom Operator: KernelFunc's vector input " << in_name
<< " is optional dtype with None input";
// NOTE(HongyuJia): In dygraph mode, we can not distinguish Tensor and
// vector<Tensor> when user inputs None, so dygraph mode appends one
// un-initialized Tensor to CustomOpKernelContext. To be compatible with
// dygraph mode, `custom_vec_in` also emplace_back one un-initialized
// tensor here.
custom_vec_in.emplace_back(paddle::Tensor());
}
kernel_ctx.EmplaceBackInputs(std::move(custom_vec_in));
} else { // inputs Tensor
......
......@@ -1060,7 +1060,9 @@ PYBIND11_MODULE(libpaddle, m) {
if (PyList_Check(obj) || PyTuple_Check(obj)) {
self.EmplaceBackInputs(
std::move(CastPyArg2VectorOfTensor(obj, 1)));
} else if (obj == Py_None) { // check optional Tensor
} else if (obj == Py_None) {
// Check optional Tensor, use one un-initialized tensor to
// indicate both Tensor and vector<Tensor> inputs
self.EmplaceBackInput(std::move(paddle::Tensor()));
} else {
self.EmplaceBackInput(std::move(CastPyArg2Tensor(obj, 1)));
......
......@@ -241,6 +241,26 @@ struct KernelFuncImpl<Return (*)(Args...), impl_fn> {
}
};
template <typename... Tail>
struct ComputeCallHelper<const paddle::optional<std::vector<paddle::Tensor>>&,
Tail...> {
template <int in_idx, int attr_idx, int out_idx, typename... PreviousArgs>
static void Compute(CustomOpKernelContext* ctx, PreviousArgs&... pargs) {
auto& range = ctx->InputRangeAt(in_idx);
auto arg = ctx->InputsBetween(range.first, range.second);
if (arg.empty() || !arg[0].is_initialized()) {
ComputeCallHelper<Tail...>::
template Compute<in_idx + 1, attr_idx, out_idx>(
ctx, pargs..., paddle::none);
} else {
ComputeCallHelper<
Tail...>::template Compute<in_idx + 1, attr_idx, out_idx>(ctx,
pargs...,
arg);
}
}
};
PD_SPECIALIZE_ComputeCallHelper(bool);
PD_SPECIALIZE_ComputeCallHelper(int);
PD_SPECIALIZE_ComputeCallHelper(float);
......@@ -486,6 +506,33 @@ struct InferShapeFuncImpl<Return (*)(Args...), impl_fn> {
}
};
template <typename... Tail>
struct InferShapeCallHelper<
const paddle::optional<std::vector<std::vector<int64_t>>>&,
Tail...> {
template <int in_idx,
int vec_in_idx,
int attr_idx,
typename... PreviousArgs>
static Return InferShape(
const std::vector<std::vector<int64_t>>& input_shapes,
const std::vector<std::vector<std::vector<int64_t>>>& vec_input_shapes,
const std::vector<paddle::any>& attrs,
const PreviousArgs&... pargs) {
const std::vector<std::vector<int64_t>>& arg =
vec_input_shapes[vec_in_idx];
if (arg.empty()) {
return InferShapeCallHelper<Tail...>::
template InferShape<in_idx, vec_in_idx + 1, attr_idx>(
input_shapes, vec_input_shapes, attrs, pargs..., paddle::none);
} else {
return InferShapeCallHelper<Tail...>::
template InferShape<in_idx, vec_in_idx + 1, attr_idx>(
input_shapes, vec_input_shapes, attrs, pargs..., arg);
}
}
};
// NOTE(chenweihang): Used to be compatible with the 2.0.1 released
// interface, and will be deprecated in the future
PD_SPECIALIZE_InferShapeCallHelper_FOR_SHAPE(std::vector<int64_t>);
......@@ -593,8 +640,7 @@ struct InferDtypeFuncImpl<Return (*)(Args...), impl_fn> {
PD_SPECIALIZE_InferDtypeCallHelper_FOR_DTYPES(const std::vector<DataType>&);
template <typename... Tail>
struct InferDtypeCallHelper<const paddle::optional<paddle::DataType>&,
Tail...> {
struct InferDtypeCallHelper<const paddle::optional<DataType>&, Tail...> {
template <int in_idx, int vec_in_idx, typename... PreviousArgs>
static Return InferDtype(
const std::vector<DataType>& input_dtypes,
......@@ -613,6 +659,27 @@ struct InferDtypeFuncImpl<Return (*)(Args...), impl_fn> {
}
};
template <typename... Tail>
struct InferDtypeCallHelper<const paddle::optional<std::vector<DataType>>&,
Tail...> {
template <int in_idx, int vec_in_idx, typename... PreviousArgs>
static Return InferDtype(
const std::vector<DataType>& input_dtypes,
const std::vector<std::vector<DataType>>& vec_input_dtypes,
const PreviousArgs&... pargs) {
const std::vector<DataType>& arg = vec_input_dtypes[vec_in_idx];
if (arg.empty()) {
return InferDtypeCallHelper<Tail...>::
template InferDtype<in_idx, vec_in_idx + 1>(
input_dtypes, vec_input_dtypes, pargs..., paddle::none);
} else {
return InferDtypeCallHelper<Tail...>::
template InferDtype<in_idx, vec_in_idx + 1>(
input_dtypes, vec_input_dtypes, pargs..., arg);
}
}
};
// NOTE(chenweihang): Used to be compatible with the 2.0.1 released
// interface, and will be deprecated in the future
PD_SPECIALIZE_InferDtypeCallHelper_TO_DTYPE(DataType);
......
......@@ -19,21 +19,19 @@
#include "paddle/extension.h"
template <typename data_t>
void add_forward_kernel(const data_t* x_data,
const data_t* y_data,
data_t* out_data,
int64_t numel) {
void add_one_pointer(const data_t* x_data, data_t* out_data, int64_t numel) {
for (size_t i = 0; i < numel; ++i) {
out_data[i] = x_data[i] + y_data[i];
out_data[i] += x_data[i];
}
}
template <typename data_t>
void add_backward_kernel(data_t* x_grad_data,
const data_t* out_grad_data,
int64_t numel) {
void add_two_pointers(const data_t* x_data,
const data_t* y_data,
data_t* out_data,
int64_t numel) {
for (size_t i = 0; i < numel; ++i) {
x_grad_data[i] += out_grad_data[i];
out_data[i] = x_data[i] + y_data[i];
}
}
......@@ -53,12 +51,12 @@ std::vector<paddle::Tensor> AddForward(
PD_DISPATCH_FLOATING_TYPES(
x.type(), "AddForward", ([&] {
if (y) {
add_forward_kernel<data_t>(x.data<data_t>(),
y->data<data_t>(),
out.data<data_t>(),
x.size());
add_two_pointers<data_t>(x.data<data_t>(),
y->data<data_t>(),
out.data<data_t>(),
x.size());
} else {
add_forward_kernel<data_t>(
add_two_pointers<data_t>(
x.data<data_t>(), x.data<data_t>(), out.data<data_t>(), x.size());
}
}));
......@@ -69,7 +67,6 @@ std::vector<paddle::DataType> AddInferDtype(
const paddle::DataType& x_dtype,
const paddle::optional<paddle::DataType>& y_dtype) {
if (y_dtype) {
std::cout << "DEBUG AddInferDtype" << *y_dtype << std::endl;
return {*y_dtype};
}
return {x_dtype};
......@@ -98,18 +95,14 @@ std::vector<paddle::Tensor> AddBackward(
PD_CHECK(x.place() == paddle::PlaceType::kCPU, "x must be a CPU Tensor.");
paddle::Tensor x_grad = paddle::zeros(x.shape(), x.dtype(), x.place());
paddle::Tensor y_grad = paddle::zeros(x.shape(), x.dtype(), x.place());
PD_DISPATCH_FLOATING_TYPES(
out_grad.type(), "AddBackward", ([&] {
add_backward_kernel<data_t>(
x_grad.data<data_t>(), out_grad.data<data_t>(), out_grad.size());
if (y) {
add_backward_kernel<data_t>(
y_grad.data<data_t>(), out_grad.data<data_t>(), out_grad.size());
} else {
add_backward_kernel<data_t>(
x_grad.data<data_t>(), out_grad.data<data_t>(), out_grad.size());
add_one_pointer<data_t>(
out_grad.data<data_t>(), x_grad.data<data_t>(), out_grad.size());
if (!y) {
add_one_pointer<data_t>(
out_grad.data<data_t>(), x_grad.data<data_t>(), out_grad.size());
}
}));
......@@ -127,3 +120,91 @@ PD_BUILD_GRAD_OP(custom_add)
.Inputs({"X", paddle::Optional("Y"), paddle::Grad("Out")})
.Outputs({paddle::Grad("X")})
.SetKernelFn(PD_KERNEL(AddBackward));
/*
if (y) {
out = x + y[0] + y[1] + ...;
} else {
out = x + x;
}
*/
std::vector<paddle::Tensor> AddVectorForward(
const paddle::Tensor& x,
const paddle::optional<std::vector<paddle::Tensor>>& y) { // NOLINT
PD_CHECK(x.place() == paddle::PlaceType::kCPU, "x must be a CPU Tensor.");
paddle::Tensor out = paddle::zeros(x.shape(), x.dtype(), x.place());
PD_DISPATCH_FLOATING_TYPES(
x.type(), "AddVectorForward", ([&] {
if (y) {
add_one_pointer<data_t>(
x.data<data_t>(), out.data<data_t>(), out.size());
for (size_t i = 0; i < y->size(); ++i) {
add_one_pointer<data_t>(
y->at(i).data<data_t>(), out.data<data_t>(), out.size());
}
} else {
add_two_pointers<data_t>(
x.data<data_t>(), x.data<data_t>(), out.data<data_t>(), x.size());
}
}));
return {out};
}
std::vector<paddle::DataType> AddVectorInferDtype(
const paddle::DataType& x_dtype,
const paddle::optional<std::vector<paddle::DataType>>& y_dtype) {
if (y_dtype) {
return {y_dtype->at(0)};
}
return {x_dtype};
}
std::vector<std::vector<int64_t>> AddVectorInferShape(
const std::vector<int64_t>& x_shape,
const paddle::optional<std::vector<std::vector<int64_t>>>& y_shape) {
if (y_shape) {
return {y_shape->at(0)};
}
return {x_shape};
}
/*
if (y) {
x_grad = out_grad;
} else {
x_grad = out_grad + out_grad;
}
*/
std::vector<paddle::Tensor> AddVectorBackward(
const paddle::Tensor& x,
const paddle::optional<std::vector<paddle::Tensor>>& y,
const paddle::Tensor& out_grad) { // NOLINT
PD_CHECK(x.place() == paddle::PlaceType::kCPU, "x must be a CPU Tensor.");
paddle::Tensor x_grad = paddle::zeros(x.shape(), x.dtype(), x.place());
PD_DISPATCH_FLOATING_TYPES(
out_grad.type(), "AddVectorBackward", ([&] {
add_one_pointer<data_t>(
out_grad.data<data_t>(), x_grad.data<data_t>(), out_grad.size());
if (!y) {
add_one_pointer<data_t>(
out_grad.data<data_t>(), x_grad.data<data_t>(), out_grad.size());
}
}));
return {x_grad};
}
PD_BUILD_OP(custom_add_vec)
.Inputs({"X", paddle::Optional(paddle::Vec("Y"))})
.Outputs({"Out"})
.SetKernelFn(PD_KERNEL(AddVectorForward))
.SetInferShapeFn(PD_INFER_SHAPE(AddVectorInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(AddVectorInferDtype));
PD_BUILD_GRAD_OP(custom_add_vec)
.Inputs({"X", paddle::Optional(paddle::Vec("Y")), paddle::Grad("Out")})
.Outputs({paddle::Grad("X")})
.SetKernelFn(PD_KERNEL(AddVectorBackward));
......@@ -105,12 +105,94 @@ def optional_static_add(phi_func, device, dtype, np_x, np_y):
return x_v, out_v, x_grad_v
def optional_vector_dynamic_add(phi_func, device, dtype, np_x, np_inputs):
paddle.set_device(device)
x = paddle.to_tensor(np_x, dtype=dtype, stop_gradient=False)
if np_inputs is not None:
inputs = [
paddle.to_tensor(np_input, dtype=dtype, stop_gradient=False)
for np_input in np_inputs
]
if phi_func:
out = custom_optional.custom_add_vec(x, inputs)
else:
out = paddle.add(x, inputs[0])
for input in inputs[1:]:
out = paddle.add(out, input)
else:
if phi_func:
out = custom_optional.custom_add_vec(x, None)
else:
out = paddle.add(x, x)
out.backward()
return x.numpy(), out.numpy(), x.grad.numpy()
def optional_vector_static_add(phi_func, device, dtype, np_x, np_inputs):
paddle.enable_static()
paddle.set_device(device)
with static.scope_guard(static.Scope()):
with static.program_guard(static.Program()):
x = static.data(name="x", shape=[None, np_x.shape[1]], dtype=dtype)
x.stop_gradient = False
feed_dict = {"x": np_x.astype(dtype)}
if np_inputs is not None:
y1 = static.data(
name="y1", shape=[None, np_x.shape[1]], dtype=dtype
)
y1.stop_gradient = False
y2 = static.data(
name="y2", shape=[None, np_x.shape[1]], dtype=dtype
)
y2.stop_gradient = False
feed_dict.update(
{
"y1": np_inputs[0].astype(dtype),
"y2": np_inputs[1].astype(dtype),
}
)
if phi_func:
out = custom_optional.custom_add_vec(x, [y1, y2])
else:
out = paddle.add(x, y1)
out = paddle.add(out, y2)
else:
if phi_func:
out = custom_optional.custom_add_vec(x, None)
else:
out = paddle.add(x, x)
mean_out = paddle.mean(out)
static.append_backward(mean_out)
exe = static.Executor()
exe.run(static.default_startup_program())
x_v, out_v, x_grad_v = exe.run(
static.default_main_program(),
feed=feed_dict,
fetch_list=[
x.name,
out.name,
x.name + "@GRAD",
],
)
paddle.disable_static()
return x_v, out_v, x_grad_v
class TestCustomOptionalJit(unittest.TestCase):
def setUp(self):
self.dtypes = ['float32', 'float64']
self.devices = ['cpu']
self.np_x = np.random.random((3, 2)).astype("float32")
self.np_y = np.random.random((3, 2)).astype("float32")
self.np_inputs = [
np.random.random((3, 2)).astype("float32"),
np.random.random((3, 2)).astype("float32"),
]
def check_output(self, out, pd_out, name):
np.testing.assert_array_equal(
......@@ -132,92 +214,97 @@ class TestCustomOptionalJit(unittest.TestCase):
),
)
def test_static_add(self):
def test_optional_static_add(self):
for device in self.devices:
for dtype in self.dtypes:
(pd_x, pd_out, pd_x_grad,) = optional_static_add(
False,
device,
dtype,
self.np_x,
self.np_y,
)
(phi_x, phi_out, phi_x_grad,) = optional_static_add(
True,
device,
dtype,
self.np_x,
self.np_y,
)
self.check_output(phi_x, pd_x, "x")
self.check_output(phi_out, pd_out, "out")
self.check_output(phi_x_grad, pd_x_grad, "x_grad")
for np_y in [None, self.np_y]:
(pd_x, pd_out, pd_x_grad,) = optional_static_add(
False,
device,
dtype,
self.np_x,
np_y,
)
(phi_x, phi_out, phi_x_grad,) = optional_static_add(
True,
device,
dtype,
self.np_x,
np_y,
)
self.check_output(phi_x, pd_x, "x")
self.check_output(phi_out, pd_out, "out")
self.check_output(phi_x_grad, pd_x_grad, "x_grad")
def test_dynamic_add(self):
def test_optional_dynamic_add(self):
for device in self.devices:
for dtype in self.dtypes:
(pd_x, pd_out, pd_x_grad,) = optional_dynamic_add(
False,
device,
dtype,
self.np_x,
self.np_y,
)
(phi_x, phi_out, phi_x_grad,) = optional_dynamic_add(
True,
device,
dtype,
self.np_x,
self.np_y,
)
for np_y in [None, self.np_y]:
(pd_x, pd_out, pd_x_grad,) = optional_dynamic_add(
False,
device,
dtype,
self.np_x,
np_y,
)
(phi_x, phi_out, phi_x_grad,) = optional_dynamic_add(
True,
device,
dtype,
self.np_x,
np_y,
)
self.check_output(phi_x, pd_x, "x")
self.check_output(phi_out, pd_out, "out")
self.check_output(phi_x_grad, pd_x_grad, "x_grad")
self.check_output(phi_x, pd_x, "x")
self.check_output(phi_out, pd_out, "out")
self.check_output(phi_x_grad, pd_x_grad, "x_grad")
def test_optional_static_add(self):
def test_optional_vector_static_add(self):
for device in self.devices:
for dtype in self.dtypes:
(pd_x, pd_out, pd_x_grad,) = optional_static_add(
False,
device,
dtype,
self.np_x,
None,
)
(phi_x, phi_out, phi_x_grad,) = optional_static_add(
True,
device,
dtype,
self.np_x,
None,
)
for np_y in [None, self.np_inputs]:
(phi_x, phi_out, phi_x_grad,) = optional_vector_static_add(
True,
device,
dtype,
self.np_x,
np_y,
)
(pd_x, pd_out, pd_x_grad,) = optional_vector_static_add(
False,
device,
dtype,
self.np_x,
np_y,
)
self.check_output(phi_x, pd_x, "x")
self.check_output(phi_out, pd_out, "out")
self.check_output(phi_x_grad, pd_x_grad, "x_grad")
self.check_output(phi_x, pd_x, "x")
self.check_output(phi_out, pd_out, "out")
self.check_output(phi_x_grad, pd_x_grad, "x_grad")
def test_optional_dynamic_add(self):
def test_optional_vector_dynamic_add(self):
for device in self.devices:
for dtype in self.dtypes:
(pd_x, pd_out, pd_x_grad,) = optional_dynamic_add(
False,
device,
dtype,
self.np_x,
None,
)
(phi_x, phi_out, phi_x_grad,) = optional_dynamic_add(
True,
device,
dtype,
self.np_x,
None,
)
for np_y in [None, self.np_inputs]:
(phi_x, phi_out, phi_x_grad,) = optional_vector_dynamic_add(
True,
device,
dtype,
self.np_x,
np_y,
)
(pd_x, pd_out, pd_x_grad,) = optional_vector_dynamic_add(
False,
device,
dtype,
self.np_x,
np_y,
)
self.check_output(phi_x, pd_x, "x")
self.check_output(phi_out, pd_out, "out")
self.check_output(phi_x_grad, pd_x_grad, "x_grad")
self.check_output(phi_x, pd_x, "x")
self.check_output(phi_out, pd_out, "out")
self.check_output(phi_x_grad, pd_x_grad, "x_grad")
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册