未验证 提交 d0390839 编写于 作者: L Liang Depeng 提交者: GitHub

add flatten op implementation (#3789)

* add flatten op

* make changes according to reviews
Co-authored-by: Noneflow-bot <69100618+oneflow-bot@users.noreply.github.com>
上级 1fe7b296
......@@ -156,6 +156,59 @@ def gather(
)
@oneflow_export("flatten")
def flatten(
input: remote_blob_util.BlobDef,
start_dim: int = 0,
end_dim: int = -1,
name: Optional[str] = None,
) -> remote_blob_util.BlobDef:
r"""Flattens a contiguous range of dims in a Blob.
Args:
input: A `Blob`.
start_dim: The first dim to flatten.
end_dim: The last dim to flatten.
name: A name for the operation (optional).
Returns:
A `Blob`, has the same type as `input`.
For example:
.. code-block:: python
import oneflow as flow
import numpy as np
import oneflow.typing as tp
@flow.global_function()
def flatten_Job(input: tp.Numpy.Placeholder(shape=(4, 4, 3, 2), dtype=flow.float32)
) -> tp.Numpy:
flatten_blob = flow.flatten(input, start_dim=1, end_dim=-1)
return flatten_blob
input = np.zeros((4, 4, 3, 2)).astype(np.float32)
out = flatten_Job(input)
# out.shape (4, 24)
"""
if name is None:
name = id_util.UniqueStr("Flatten_")
return (
flow.user_op_builder(name)
.Op("flatten")
.Input("in", [input])
.Output("out")
.Attr("start_dim", start_dim)
.Attr("end_dim", end_dim)
.Build()
.InferAndTryRun()
.RemoteBlobList()[0]
)
def infer_shape(x, shape):
dim_index_need_infer = shape.index(-1) if shape.count(-1) == 1 else None
in_elem_cnt = reduce(operator.mul, x.shape, 1)
......
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
from collections import OrderedDict
import numpy as np
import oneflow as flow
from test_util import GenArgList
import test_global_storage
def compare_with_numpy(test_case, device_type, input_shape, start_end_dim):
assert device_type in ["gpu", "cpu"]
flow.clear_default_session()
func_config = flow.FunctionConfig()
func_config.default_data_type(flow.float)
start_dim = start_end_dim[0]
end_dim = start_end_dim[1]
@flow.global_function(type="train", function_config=func_config)
def FlattenJob() -> flow.typing.Numpy:
with flow.scope.placement(device_type, "0:0"):
x = flow.get_variable(
"in",
shape=input_shape,
dtype=flow.float,
initializer=flow.random_uniform_initializer(minval=2, maxval=5),
trainable=True,
)
loss = flow.flatten(x, start_dim=start_dim, end_dim=end_dim)
flow.optimizer.SGD(
flow.optimizer.PiecewiseConstantScheduler([], [1e-4]), momentum=0
).minimize(loss)
flow.watch(x, test_global_storage.Setter("x"))
flow.watch_diff(x, test_global_storage.Setter("x_diff"))
return loss
# OneFlow
check_point = flow.train.CheckPoint()
check_point.init()
of_out = FlattenJob()
# Numpy
of_x = test_global_storage.Get("x")
of_x_shape = of_x.shape
of_x_diff = test_global_storage.Get("x_diff")
true_end_dim = end_dim + len(of_x_shape) if end_dim < 0 else end_dim
new_shape = []
for i in range(0, start_dim):
new_shape.append(of_x_shape[i])
flatten_dim = 1
for i in range(start_dim, true_end_dim + 1):
flatten_dim *= of_x_shape[i]
new_shape.append(flatten_dim)
for i in range(true_end_dim + 1, len(of_x_shape)):
new_shape.append(of_x_shape[i])
np_out = np.reshape(of_x, tuple(new_shape))
test_case.assertTrue(of_out.shape == np_out.shape)
test_case.assertTrue(np.allclose(of_out, np_out, rtol=1e-5, atol=1e-5))
test_case.assertTrue(
np.allclose(of_x_diff, np.ones(of_x_diff.shape), rtol=1e-5, atol=1e-5)
)
@flow.unittest.skip_unless_1n1d()
class TestFlatten(flow.unittest.TestCase):
def test_flatten(test_case):
arg_dict = OrderedDict()
arg_dict["test_case"] = [test_case]
arg_dict["device_type"] = ["gpu", "cpu"]
arg_dict["input_shape"] = [(2, 3, 4, 5)]
arg_dict["start_end_dim"] = [(0, -1), (1, 3), (2, -2)]
for arg in GenArgList(arg_dict):
compare_with_numpy(*arg)
if __name__ == "__main__":
unittest.main()
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/framework.h"
#include "oneflow/user/kernels/copy_data_content_kernel.h"
namespace oneflow {
#define REGISTER_FLATTEN_KERNEL(device) \
REGISTER_USER_KERNEL("flatten") \
.SetCreateFn<CopyDataContentKernel<device>>() \
.SetIsMatchedHob(user_op::HobDeviceTag() == device) \
.SetInplaceProposalFn([](const user_op::InferContext&, \
user_op::AddInplaceArgPair AddInplaceArgPairFn) -> Maybe<void> { \
OF_RETURN_IF_ERROR(AddInplaceArgPairFn("out", 0, "in", 0, false)); \
return Maybe<void>::Ok(); \
});
REGISTER_FLATTEN_KERNEL(DeviceType::kCPU)
#ifdef WITH_CUDA
REGISTER_FLATTEN_KERNEL(DeviceType::kGPU)
#endif
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/framework.h"
namespace oneflow {
namespace {
Maybe<void> GetSbpFn(user_op::SbpContext* ctx) {
const auto& in_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0).shape();
const int32_t start_dim = ctx->Attr<int32_t>("start_dim");
const int32_t end_dim = ctx->Attr<int32_t>("end_dim");
CHECK_GE_OR_RETURN(start_dim, 0);
CHECK_LT_OR_RETURN(start_dim, in_shape.NumAxes());
const int32_t true_end_dim = end_dim < 0 ? end_dim + in_shape.NumAxes() : end_dim;
CHECK_GE_OR_RETURN(true_end_dim, 0);
CHECK_LT_OR_RETURN(true_end_dim, in_shape.NumAxes());
CHECK_LE_OR_RETURN(start_dim, true_end_dim);
for (int i = 0; i <= start_dim; ++i) {
ctx->NewBuilder().Split(user_op::OpArg("in", 0), i).Split(user_op::OpArg("out", 0), i).Build();
}
const int32_t diff = true_end_dim - start_dim;
for (int i = true_end_dim + 1; i < in_shape.NumAxes(); ++i) {
ctx->NewBuilder()
.Split(user_op::OpArg("in", 0), i)
.Split(user_op::OpArg("out", 0), i - diff)
.Build();
}
ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build();
return Maybe<void>::Ok();
}
Maybe<void> TensorDescInferFn(user_op::InferContext* ctx) {
const int32_t start_dim = ctx->Attr<int32_t>("start_dim");
const int32_t end_dim = ctx->Attr<int32_t>("end_dim");
const user_op::TensorDesc* in_tensor_desc = ctx->TensorDesc4ArgNameAndIndex("in", 0);
user_op::TensorDesc* out_tensor_desc = ctx->TensorDesc4ArgNameAndIndex("out", 0);
const Shape& in_shape = in_tensor_desc->shape();
CHECK_GE_OR_RETURN(start_dim, 0);
CHECK_LT_OR_RETURN(start_dim, in_shape.NumAxes());
const int32_t true_end_dim = end_dim < 0 ? end_dim + in_shape.NumAxes() : end_dim;
CHECK_GE_OR_RETURN(true_end_dim, 0);
CHECK_LT_OR_RETURN(true_end_dim, in_shape.NumAxes());
CHECK_LE_OR_RETURN(start_dim, true_end_dim);
*out_tensor_desc = *in_tensor_desc;
Shape* out_shape = out_tensor_desc->mut_shape();
DimVector dim_vec;
for (int i = 0; i < start_dim; ++i) { dim_vec.push_back(in_shape.At(i)); }
int64_t flatten_dim = 1;
for (int i = start_dim; i <= true_end_dim; ++i) { flatten_dim *= in_shape.At(i); }
dim_vec.push_back(flatten_dim);
for (int i = true_end_dim + 1; i < in_shape.NumAxes(); ++i) { dim_vec.push_back(in_shape.At(i)); }
*out_shape = Shape(dim_vec);
CHECK_EQ_OR_RETURN(out_shape->elem_cnt(), in_shape.elem_cnt());
return Maybe<void>::Ok();
}
Maybe<void> GetBatchAxisInferFn(user_op::BatchAxisContext* ctx) {
const int32_t start_dim = ctx->Attr<int32_t>("start_dim");
const int32_t end_dim = ctx->Attr<int32_t>("end_dim");
const auto& in_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0).shape();
CHECK_GE_OR_RETURN(start_dim, 0);
CHECK_LT_OR_RETURN(start_dim, in_shape.NumAxes());
const int32_t true_end_dim = end_dim < 0 ? end_dim + in_shape.NumAxes() : end_dim;
CHECK_GE_OR_RETURN(true_end_dim, 0);
CHECK_LT_OR_RETURN(true_end_dim, in_shape.NumAxes());
CHECK_LE_OR_RETURN(start_dim, true_end_dim);
const int64_t input_batch_axis = (*ctx->BatchAxis4ArgNameAndIndex("in", 0)).value();
OptInt64 output_batch_axis;
if (input_batch_axis < start_dim) {
output_batch_axis.set_value(input_batch_axis);
} else if (input_batch_axis >= start_dim && input_batch_axis <= true_end_dim) {
output_batch_axis.set_value(start_dim);
} else if (input_batch_axis > true_end_dim) {
output_batch_axis.set_value(input_batch_axis - (true_end_dim - start_dim));
}
*ctx->BatchAxis4ArgNameAndIndex("out", 0) = output_batch_axis;
return Maybe<void>::Ok();
}
REGISTER_USER_OP("flatten")
.Input("in")
.Output("out")
.Attr<int32_t>("start_dim", 0)
.Attr<int32_t>("end_dim", -1)
.SetTensorDescInferFn(TensorDescInferFn)
.SetGetSbpFn(GetSbpFn)
.SetBatchAxisInferFn(GetBatchAxisInferFn);
REGISTER_USER_OP_GRAD("flatten").SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op,
user_op::AddOpFn AddOp) {
if (op.NeedGenGradTensor4OpInput("in", 0)) {
user_op::UserOpConfWrapperBuilder builder(op.op_name() + "_grad");
user_op::UserOpConfWrapper reshape_grad_op =
builder.Op("reshape_like")
.Input("in", op.GetGradTensorWithOpOutput("out", 0))
.Input("like", op.input("in", 0))
.Output("out")
.Build();
op.BindGradTensorWithOpInput(reshape_grad_op.output("out", 0), "in", 0);
AddOp(reshape_grad_op);
}
});
} // namespace
} // namespace oneflow
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册