提交 b6f8a799 编写于 作者: L leaves-zwx 提交者: GitHub

slice_update op (#3544)

* op & kernel

* api & test

* test grad

* format

* modify by comment

* rm unsupport sbp
Co-authored-by: Noneflow-bot <69100618+oneflow-bot@users.noreply.github.com>
Former-commit-id: 8d02bbcb
上级 4b8942ba
......@@ -295,25 +295,8 @@ def slice(
return slice_v2(x, slice_tup_list, name=name)
@oneflow_export("slice_v2")
def slice_v2(
x: remote_blob_util.BlobDef,
slice_tup_list: Sequence[Tuple[int, int, int]],
name: Optional[str] = None,
) -> remote_blob_util.BlobDef:
r"""Extracts a slice from a tensor.
Args:
x: A `Blob`.
slice_tup_list: A list of slice tuple, indicate each dimension slice (start, stop, step).
name: A name for the operation (optional).
"""
name = name or id_util.UniqueStr("Slice_")
if not isinstance(name, str):
raise ValueError("name must be a string")
ndim = len(x.shape)
def _check_slice_tup_list(slice_tup_list, shape):
ndim = len(shape)
if not isinstance(slice_tup_list, (list, tuple)) or len(slice_tup_list) > ndim:
raise ValueError(
"slice_tup_list must be a list or tuple with length "
......@@ -330,7 +313,7 @@ def slice_v2(
stop_list = []
step_list = []
for slice_tup, dim_size in zip(slice_tup_list, x.shape):
for slice_tup, dim_size in zip(slice_tup_list, shape):
if not isinstance(slice_tup, (tuple, list)) or len(slice_tup) != 3:
raise ValueError(
"element of slice_tup_list must be a list or tuple with form (start, stop, step)"
......@@ -360,14 +343,75 @@ def slice_v2(
stop_list.append(stop)
step_list.append(step)
return start_list, stop_list, step_list
@oneflow_export("slice_v2")
def slice_v2(
x: remote_blob_util.BlobDef,
slice_tup_list: Sequence[Tuple[int, int, int]],
name: Optional[str] = None,
) -> remote_blob_util.BlobDef:
r"""Extracts a slice from a tensor.
Args:
x: A `Blob`.
slice_tup_list: A list of slice tuple, indicate each dimension slice (start, stop, step).
name: A name for the operation (optional).
"""
name = name or id_util.UniqueStr("Slice_")
if not isinstance(name, str):
raise ValueError("name must be a string")
start, stop, step = _check_slice_tup_list(slice_tup_list, x.shape)
op = (
flow.user_op_builder(name)
.Op("slice")
.Input("x", [x])
.Output("y")
.Attr("start", start_list)
.Attr("stop", stop_list)
.Attr("step", step_list)
.Attr("start", start)
.Attr("stop", stop)
.Attr("step", step)
.Build()
)
return op.InferAndTryRun().SoleOutputBlob()
@oneflow_export("slice_update")
def api_slice_update(
x: remote_blob_util.BlobDef,
update: remote_blob_util.BlobDef,
slice_tup_list: Sequence[Tuple[int, int, int]],
name: Optional[str] = None,
) -> remote_blob_util.BlobDef:
r"""Update a slice of tensor `x`.
Args:
x: A `Blob`, whose slice will be updated.
update: A `Blob`, indicate the update content.
slice_tup_list: A list of slice tuple, indicate each dimension slice (start, stop, step).
name: A name for the operation (optional).
"""
if name is None:
name = id_util.UniqueStr("SliceUpdate_")
if not isinstance(name, str):
raise ValueError("name must be a string")
start, stop, step = _check_slice_tup_list(slice_tup_list, x.shape)
op = (
flow.user_op_builder(name)
.Op("slice_update")
.Input("x", [x])
.Input("update", [update])
.Output("y")
.Attr("start", start)
.Attr("stop", stop)
.Attr("step", step)
.Build()
)
return op.InferAndTryRun().SoleOutputBlob()
......
......@@ -99,6 +99,59 @@ def _make_slice_with_grad_func(
return slice_with_grad_job
def _make_slice_update_func(
slice_tup_list, input_shape, update_shape, dtype=flow.float32, func_cfg=None
):
@flow.global_function(type="predict", function_config=func_cfg)
def slice_update_job(
x: otp.Numpy.Placeholder(shape=input_shape, dtype=dtype),
update: otp.Numpy.Placeholder(shape=update_shape, dtype=dtype),
) -> otp.Numpy:
return flow.slice_update(x, update, slice_tup_list)
return slice_update_job
def _make_slice_update_grad_func(
slice_tup_list,
input_shape,
update_shape,
diff_watcher_maker=None,
dtype=flow.float32,
func_cfg=None,
):
@flow.global_function(type="train", function_config=func_cfg)
def slice_update_train_job(
x: otp.Numpy.Placeholder(shape=input_shape, dtype=dtype),
update: otp.Numpy.Placeholder(shape=update_shape, dtype=dtype),
) -> otp.Numpy:
x_var = flow.get_variable(
shape=input_shape,
dtype=dtype,
initializer=flow.constant_initializer(0.0),
name="x",
)
update_var = flow.get_variable(
shape=update_shape,
dtype=dtype,
initializer=flow.constant_initializer(0.0),
name="update",
)
x = x + x_var
update = update + update_var
if callable(diff_watcher_maker):
flow.watch_diff(x, diff_watcher_maker(input_shape))
flow.watch_diff(update, diff_watcher_maker(update_shape))
y = flow.slice_update(x, update, slice_tup_list)
flow.optimizer.SGD(
flow.optimizer.PiecewiseConstantScheduler([], [1e-3]), momentum=0
).minimize(y)
return y
return slice_update_train_job
def _test_slice(
test_case,
input,
......@@ -201,33 +254,84 @@ def _test_slice_with_grad(
test_case.assertTrue(np.array_equal(output, of_output))
# This test case will raise fatal error, error infomation is like below:
# F0808 00:20:19.768465 23960 user_kernel.cpp:451] Check failed: shape_view.elem_cnt() <= static_shape.elem_cnt() (12 vs. 9)
# InferShape of OpKernel (op_type_name: slice, op_name: SliceDynamic_0) raise error,
# output arg's (name: y, index: 0) runtime shape (2,6) surpass the limit of static shape (3,3)
# *** Check failure stack trace: ***
# ...
# The reason is the dismatch between static slice (for memory) and dynamic slice (real slice)
# The result shape of slice [:, 3:-1] for static shape (3, 7) is (3, 3)
# which indicate that blob has prod(3, 3) memory limit,
# and the result shape of slice [:, 3:-1] for dynamic shape (2, 10) is (2, 6)
# which will cause blob to be out of memory limit.
# def test_slice_dynamic_dismatch(test_case):
# input = np.random.rand(2, 10)
# slice_args = [[(None, None, None), (3, -1, None)]]
# outputs = [input[:, 3:-1]]
# _test_slice_dynamic(test_case, input, slice_args, outputs, static_shape=(3, 7))
# static shape after slice is (5, 4)
# dynamic shape after slice is (4, 5)
# static shape after slice is (5, 3)
# dynamic shape after slice is (4, 4)
# def test_slice_dynamic_anomaly_failed(test_case):
# input = np.random.rand(4, 7)
# slice_args = [[(None, None, None), (3, None, None)]]
# outputs = [input[:, 3:]]
# _test_slice_dynamic(test_case, input, slice_args, outputs, static_shape=(5, 6))
def _test_slice_update(
test_case,
input,
update,
slice_args,
output,
dtype=flow.float32,
device_tag=DEFAULT_DEVICE_TAG,
verbose=False,
):
input = input.astype(flow.convert_oneflow_dtype_to_numpy_dtype(dtype))
update = update.astype(flow.convert_oneflow_dtype_to_numpy_dtype(dtype))
output = output.astype(flow.convert_oneflow_dtype_to_numpy_dtype(dtype))
flow.clear_default_session()
func_cfg = flow.FunctionConfig()
func_cfg.default_data_type(dtype)
func_cfg.default_placement_scope(flow.scope.placement(device_tag, "0:0"))
slice_func = _make_slice_update_func(
slice_args, input.shape, update.shape, dtype, func_cfg
)
of_output = slice_func(input, update)
if verbose:
print("input:\n{}".format(input))
print("update:\n{}".format(update))
print("slice_args:", slice_args)
print("output:\n{}".format(output))
print("dtype:", dtype)
print("device_tag:", device_tag)
print("of_output:\n{}".format(of_output))
test_case.assertTrue(np.array_equal(output, of_output))
def _test_slice_update_grad(
test_case,
input,
update,
slice_args,
output,
input_diff,
update_diff,
dtype=flow.float32,
device_tag=DEFAULT_DEVICE_TAG,
verbose=False,
):
input = input.astype(flow.convert_oneflow_dtype_to_numpy_dtype(dtype))
update = update.astype(flow.convert_oneflow_dtype_to_numpy_dtype(dtype))
output = output.astype(flow.convert_oneflow_dtype_to_numpy_dtype(dtype))
input_diff = input_diff.astype(flow.convert_oneflow_dtype_to_numpy_dtype(dtype))
update_diff = update_diff.astype(flow.convert_oneflow_dtype_to_numpy_dtype(dtype))
if verbose:
print("dtype: {}".format(dtype))
print("device_tag: {}".format(device_tag))
print("input: {}\n{}\n".format(input.shape, input))
print("output: {}\n{}\n".format(output.shape, output))
def _make_diff_watcher(shape):
def _watch_diff(diff: otp.Numpy):
if shape == input_diff.shape:
test_case.assertTrue(np.array_equal(diff, input_diff))
elif shape == update_diff.shape:
test_case.assertTrue(np.array_equal(diff, update_diff))
return _watch_diff
flow.clear_default_session()
func_cfg = flow.FunctionConfig()
func_cfg.default_data_type(dtype)
func_cfg.default_placement_scope(flow.scope.placement(device_tag, "0:0"))
slice_func = _make_slice_update_grad_func(
slice_args, input.shape, update.shape, _make_diff_watcher, dtype, func_cfg
)
ret = slice_func(input, update)
test_case.assertTrue(np.array_equal(ret, output))
@flow.unittest.skip_unless_1n1d()
......@@ -396,6 +500,36 @@ class TestSliceV2(flow.unittest.TestCase):
outputs = [input[:, 2:]]
_test_slice_dynamic(test_case, input, slice_args, outputs, static_shape=(5, 6))
"""This test case will raise fatal error, error infomation is like below:
F0808 00:20:19.768465 23960 user_kernel.cpp:451] Check failed: shape_view.elem_cnt() <= static_shape.elem_cnt() (12 vs. 9)
InferShape of OpKernel (op_type_name: slice, op_name: SliceDynamic_0) raise error,
output arg's (name: y, index: 0) runtime shape (2,6) surpass the limit of static shape (3,3)
*** Check failure stack trace: ***
...
The reason is the dismatch between static slice (for memory) and dynamic slice (real slice)
The result shape of slice [:, 3:-1] for static shape (3, 7) is (3, 3)
which indicate that blob has prod(3, 3) memory limit,
and the result shape of slice [:, 3:-1] for dynamic shape (2, 10) is (2, 6)
which will cause blob to be out of memory limit.
"""
# def test_slice_dynamic_dismatch(test_case):
# input = np.random.rand(2, 10)
# slice_args = [[(None, None, None), (3, -1, None)]]
# outputs = [input[:, 3:-1]]
# _test_slice_dynamic(test_case, input, slice_args, outputs, static_shape=(3, 7))
"""
static shape after slice is (5, 4)
dynamic shape after slice is (4, 5)
static shape after slice is (5, 3)
dynamic shape after slice is (4, 4)
"""
# def test_slice_dynamic_anomaly_failed(test_case):
# input = np.random.rand(4, 7)
# slice_args = [[(None, None, None), (3, None, None)]]
# outputs = [input[:, 3:]]
# _test_slice_dynamic(test_case, input, slice_args, outputs, static_shape=(5, 6))
def test_slice_with_grad(test_case):
input = np.random.rand(2, 5, 4)
slice_tup_list = [(None, None, None), (2, -2, None)]
......@@ -412,6 +546,50 @@ class TestSliceV2(flow.unittest.TestCase):
test_case, input, slice_tup_list, output, diff, **kwarg
)
def test_slice_update(test_case):
input = np.random.rand(10, 5, 4)
update = input[5:, :-1, ::2]
update = np.random.rand(*update.shape)
output = np.copy(input)
output[5:, :-1, ::2] = update
slice_tup_list = [(5, None, None), (None, -1, None), (None, None, 2)]
arg_dict = collections.OrderedDict()
arg_dict["dtype"] = [flow.float32, flow.float64]
arg_dict["device_tag"] = ["cpu", "gpu"]
arg_dict["verbose"] = [False]
for kwarg in test_util.GenArgDict(arg_dict):
_test_slice_update(
test_case, input, update, slice_tup_list, output, **kwarg
)
def test_slice_update_grad(test_case):
input = np.random.rand(2, 7)
update = input[:, 1:4]
update = np.random.rand(*update.shape)
update_diff = np.ones(update.shape)
input_diff = np.ones(input.shape)
input_diff[:, 1:4] = 0
output = np.copy(input)
output[:, 1:4] = update
slice_tup_list = [(None, None, None), (1, 4, None)]
arg_dict = collections.OrderedDict()
arg_dict["dtype"] = [flow.float32, flow.float64]
arg_dict["device_tag"] = ["cpu", "gpu"]
arg_dict["verbose"] = [False]
for kwarg in test_util.GenArgDict(arg_dict):
_test_slice_update_grad(
test_case,
input,
update,
slice_tup_list,
output,
input_diff,
update_diff,
**kwarg
)
if __name__ == "__main__":
unittest.main()
......@@ -94,14 +94,44 @@ class SliceGradKernel final : public user_op::OpKernel {
bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
};
#define REGISTER_SLICE_KERNELS(device, dtype) \
REGISTER_USER_KERNEL("slice").SetCreateFn<SliceKernel<device, dtype>>().SetIsMatchedHob( \
(user_op::HobDeviceTag() == device) \
& (user_op::HobDataType("y", 0) == GetDataType<dtype>::value)); \
REGISTER_USER_KERNEL("slice_grad") \
.SetCreateFn<SliceGradKernel<device, dtype>>() \
.SetIsMatchedHob((user_op::HobDeviceTag() == device) \
& (user_op::HobDataType("dx", 0) == GetDataType<dtype>::value));
template<DeviceType device_type, typename T>
class SliceUpdateKernel final : public user_op::OpKernel {
public:
SliceUpdateKernel() = default;
~SliceUpdateKernel() = default;
private:
void Compute(user_op::KernelComputeContext* ctx) const override {
const user_op::Tensor* x_tensor = ctx->Tensor4ArgNameAndIndex("x", 0);
const user_op::Tensor* update_tensor = ctx->Tensor4ArgNameAndIndex("update", 0);
user_op::Tensor* y_tensor = ctx->Tensor4ArgNameAndIndex("y", 0);
Memcpy<device_type>(ctx->device_ctx(), y_tensor->mut_dptr<T>(), x_tensor->dptr<T>(),
y_tensor->shape().elem_cnt() * sizeof(T));
SliceParams params = ConstructSliceParams(ctx, y_tensor, update_tensor);
SliceKernelUtil<device_type, T>::Backward(ctx->device_ctx(), params, update_tensor->dptr<T>(),
y_tensor->mut_dptr<T>());
}
bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
};
#define REGISTER_SLICE_KERNELS(device, dtype) \
REGISTER_USER_KERNEL("slice").SetCreateFn<SliceKernel<device, dtype>>().SetIsMatchedHob( \
(user_op::HobDeviceTag() == device) \
& (user_op::HobDataType("y", 0) == GetDataType<dtype>::value)); \
REGISTER_USER_KERNEL("slice_grad") \
.SetCreateFn<SliceGradKernel<device, dtype>>() \
.SetIsMatchedHob((user_op::HobDeviceTag() == device) \
& (user_op::HobDataType("dx", 0) == GetDataType<dtype>::value)); \
REGISTER_USER_KERNEL("slice_update") \
.SetCreateFn<SliceUpdateKernel<device, dtype>>() \
.SetIsMatchedHob((user_op::HobDeviceTag() == device) \
& (user_op::HobDataType("x", 0) == GetDataType<dtype>::value) \
& (user_op::HobDataType("update", 0) == GetDataType<dtype>::value)) \
.SetInplaceProposalFn([](const user_op::InferContext&, \
user_op::AddInplaceArgPair AddInplaceArgPairFn) -> Maybe<void> { \
OF_RETURN_IF_ERROR(AddInplaceArgPairFn("y", 0, "x", 0, true)); \
return Maybe<void>::Ok(); \
});
#define REGISTER_SLICE_KERNELS_WITH_DEVICE(device) \
REGISTER_SLICE_KERNELS(device, float) \
......
......@@ -156,6 +156,73 @@ void InferSliceGradInputArgModifier(user_op::GetInputArgModifier GetInputArgModi
like_modifier->set_requires_grad(false);
}
Maybe<void> InferSliceUpdateOpTensorDesc(user_op::InferContext* ctx) {
const auto* x_desc = ctx->TensorDesc4ArgNameAndIndex("x", 0);
const int64_t ndim = x_desc->shape().NumAxes();
const auto* update_desc = ctx->TensorDesc4ArgNameAndIndex("update", 0);
CHECK_EQ_OR_RETURN(update_desc->shape().NumAxes(), ndim);
CHECK_EQ_OR_RETURN(update_desc->data_type(), x_desc->data_type());
const auto& start_vec = ctx->Attr<std::vector<int64_t>>("start");
const auto& stop_vec = ctx->Attr<std::vector<int64_t>>("stop");
const auto& step_vec = ctx->Attr<std::vector<int64_t>>("step");
CHECK_EQ_OR_RETURN(start_vec.size(), ndim);
CHECK_EQ_OR_RETURN(stop_vec.size(), ndim);
CHECK_EQ_OR_RETURN(step_vec.size(), ndim);
// validate update shape and start, stop, step attributes
FOR_RANGE(int, i, 0, ndim) {
const int64_t dim_size = x_desc->shape().At(i);
const int64_t step = step_vec.at(i);
CHECK_NE_OR_RETURN(step, 0) << "slice step cannot be 0";
int64_t start = RegulateSliceStart(start_vec.at(i), dim_size);
int64_t stop = RegulateSliceStop(stop_vec.at(i), dim_size);
if (step > 0) {
CHECK_LT_OR_RETURN(start, stop) << "slice start must be less than stop when step > 0"
", otherwise empty result will be outputted.";
} else {
CHECK_GT_OR_RETURN(start, stop) << "slice start must be more than stop when step < 0"
", otherwise empty result will be outputted.";
}
const int64_t diff = (step > 0) ? (stop - start - 1) : (stop - start + 1);
const int64_t sliced_dim_size = diff / step + 1;
CHECK_EQ_OR_RETURN(sliced_dim_size, update_desc->shape().At(i))
<< "sliced dim size " << sliced_dim_size << " at axis " << i
<< " not equal to the update shape " << update_desc->shape().ToString();
}
// the split axis can't be sliced
const SbpParallel& x_sbp = ctx->SbpParallel4ArgNameAndIndex("x", 0);
if (ctx->parallel_ctx().parallel_num() != 1 && x_sbp.has_split_parallel()) {
const int64_t split_axis = x_sbp.split_parallel().axis();
CHECK_GE_OR_RETURN(split_axis, 0);
CHECK_LT_OR_RETURN(split_axis, ndim);
CHECK_OR_RETURN(IsFullSlice(start_vec.at(split_axis), stop_vec.at(split_axis),
step_vec.at(split_axis), x_desc->shape().At(split_axis)));
}
auto* y_desc = ctx->TensorDesc4ArgNameAndIndex("y", 0);
*y_desc = *x_desc;
return Maybe<void>::Ok();
}
Maybe<void> GetSliceUpdateOpSbpSignature(user_op::SbpContext* ctx) {
const Shape& x_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0).shape();
const int64_t ndim = x_shape.NumAxes();
const auto& start_vec = ctx->Attr<std::vector<int64_t>>("start");
const auto& stop_vec = ctx->Attr<std::vector<int64_t>>("stop");
const auto& step_vec = ctx->Attr<std::vector<int64_t>>("step");
CHECK_EQ_OR_RETURN(start_vec.size(), ndim);
CHECK_EQ_OR_RETURN(stop_vec.size(), ndim);
CHECK_EQ_OR_RETURN(step_vec.size(), ndim);
FOR_RANGE(int, i, 0, ndim) {
if (IsFullSlice(start_vec.at(i), stop_vec.at(i), step_vec.at(i), x_shape.At(i))) {
ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build();
}
}
ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build();
return Maybe<void>::Ok();
}
void GenSliceGradOp(const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) {
if (op.NeedGenGradTensor4OpInput("x", 0)) {
user_op::UserOpConfWrapperBuilder builder(op.op_name() + "_grad");
......@@ -172,6 +239,44 @@ void GenSliceGradOp(const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) {
}
}
void GenSliceUpdateGradOp(user_op::BackwardOpConfContext* ctx) {
const std::string update_grad_op_name = ctx->FwOp().op_name() + "_update_grad";
ctx->DefineOp(update_grad_op_name, [&](user_op::BackwardOpBuilder& builder) {
return builder.OpTypeName("slice")
.InputBind("x", ctx->FwOp().output_grad("y", 0))
.Attr("start", ctx->FwOp().attr<std::vector<int64_t>>("start"))
.Attr("stop", ctx->FwOp().attr<std::vector<int64_t>>("stop"))
.Attr("step", ctx->FwOp().attr<std::vector<int64_t>>("step"))
.Output("y")
.Build();
});
ctx->FwOp().InputGradBind(user_op::OpArg("update", 0), [&]() -> const std::string& {
return ctx->GetOp(update_grad_op_name).output("y", 0);
});
const std::string zero_grad_op_name = ctx->FwOp().op_name() + "_zero_grad";
ctx->DefineOp(zero_grad_op_name, [&](user_op::BackwardOpBuilder& builder) {
return builder.OpTypeName("zero_like")
.InputBind("like", ctx->FwOp().input("update", 0))
.Output("out")
.Build();
});
const std::string x_grad_op_name = ctx->FwOp().op_name() + "_x_grad";
ctx->DefineOp(x_grad_op_name, [&](user_op::BackwardOpBuilder& builder) {
return builder.OpTypeName("slice_update")
.InputBind("x", ctx->FwOp().output_grad("y", 0))
.InputBind("update", ctx->GetOp(zero_grad_op_name).output("out", 0))
.Attr("start", ctx->FwOp().attr<std::vector<int64_t>>("start"))
.Attr("stop", ctx->FwOp().attr<std::vector<int64_t>>("stop"))
.Attr("step", ctx->FwOp().attr<std::vector<int64_t>>("step"))
.Output("y")
.Build();
});
ctx->FwOp().InputGradBind(user_op::OpArg("x", 0), [&]() -> const std::string& {
return ctx->GetOp(x_grad_op_name).output("y", 0);
});
}
} // namespace
REGISTER_USER_OP("slice")
......@@ -196,4 +301,16 @@ REGISTER_USER_OP("slice_grad")
REGISTER_USER_OP_GRAD("slice").SetGenBackwardOpConfFn(GenSliceGradOp);
REGISTER_USER_OP("slice_update")
.Input("x")
.Input("update")
.Output("y")
.Attr("start", UserOpAttrType::kAtListInt64)
.Attr("stop", UserOpAttrType::kAtListInt64)
.Attr("step", UserOpAttrType::kAtListInt64)
.SetTensorDescInferFn(InferSliceUpdateOpTensorDesc)
.SetGetSbpFn(GetSliceUpdateOpSbpSignature);
REGISTER_USER_OP_GRAD("slice_update").SetBackwardOpConfGenFn(GenSliceUpdateGradOp);
} // namespace oneflow
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册