未验证 提交 b87d8458 编写于 作者: H Houjiang Chen 提交者: GitHub

Dev fix and align interface (#6075)

* Fix treat_args_as_inlist logic

* Align pytorch interface

* auto format by CI

* Fix empty and reformat

* fix

* auto format by CI

* fix merge
Co-authored-by: Noneflow-ci-bot <ci-bot@oneflow.org>
Co-authored-by: Noneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
上级 fc309307
......@@ -54,7 +54,7 @@ bool ParseArgs(const py::args& args, const py::kwargs& kwargs, std::vector<Pytho
if (max_pos_args == 1) {
const auto& type = function.argument_def.at(0).type;
treat_args_as_intlist = (type == kINT32_LIST || type == kUINT32_LIST || type == kINT64_LIST
|| type == kUINT64_LIST);
|| type == kUINT64_LIST || type == kSHAPE);
}
if (nargs > max_pos_args && !treat_args_as_intlist) {
if (raise_exception) {
......@@ -67,10 +67,7 @@ bool ParseArgs(const py::args& args, const py::kwargs& kwargs, std::vector<Pytho
for (int i = 0; i < function.argument_def.size(); ++i) {
const auto& param = function.argument_def.at(i);
py::object obj;
if (arg_pos == 0 && treat_args_as_intlist && !param.keyword_only) {
obj = args;
arg_pos = nargs;
} else if (arg_pos < nargs) {
if (arg_pos < nargs) {
if (param.keyword_only) {
if (raise_exception) {
THROW(TypeError) << function.name << "(): argument '" << param.name
......@@ -78,7 +75,7 @@ bool ParseArgs(const py::args& args, const py::kwargs& kwargs, std::vector<Pytho
}
return false;
}
obj = args[arg_pos++];
obj = args[arg_pos];
} else {
if (kwargs.contains(param.name.c_str())) {
obj = kwargs[param.name.c_str()];
......@@ -86,12 +83,13 @@ bool ParseArgs(const py::args& args, const py::kwargs& kwargs, std::vector<Pytho
}
}
if (!obj && !param.has_default_value) {
if (raise_exception) {
THROW(TypeError) << function.name << "(): missing required argument " << param.name;
if (obj) {
if (arg_pos == 0 && treat_args_as_intlist && !param.keyword_only && PyLong_Check(obj.ptr())) {
obj = args;
arg_pos = nargs;
} else {
arg_pos++;
}
return false;
} else if (obj) {
PythonArg arg(obj, param.size);
if ((obj == Py_None && param.optional) || PythonArgCheck(arg, param.type)) {
parsed_args->at(i) = std::move(arg);
......@@ -104,6 +102,12 @@ bool ParseArgs(const py::args& args, const py::kwargs& kwargs, std::vector<Pytho
return false;
}
} else {
if (!param.has_default_value) {
if (raise_exception) {
THROW(TypeError) << function.name << "(): missing required argument " << param.name;
}
return false;
}
parsed_args->at(i) = PythonArg(param.default_value);
}
}
......
......@@ -295,15 +295,19 @@
bind_python: False
- name: "arange"
signature:
"Tensor (Int64 start, Int64 limit, Int64 delta, DataType dtype=kInt64,
Device device=None) => Arange"
signature: [
"Tensor (Int64 start, Int64 end, Int64 step=1, *, DataType dtype=kInt64,
Device device=None) => Arange",
"Tensor (Int64 end, *, DataType dtype=kInt64, Device device=None) => Arange2",
]
bind_python: True
- name: "consistent_arange"
signature:
"Tensor (Int64 start, Int64 limit, Int64 delta, DataType dtype=kInt64,
Placement placement, SbpList sbp_tuple) => ConsistentArange"
signature: [
"Tensor (Int64 start, Int64 end, Int64 step=1, *, DataType dtype=kInt64,
Placement placement, SbpList sbp) => ConsistentArange",
"Tensor (Int64 end, *, DataType dtype=kInt64, Placement placement, SbpList sbp) => ConsistentArange2",
]
bind_python: True
- name: "flatten"
......@@ -327,19 +331,27 @@
bind_python: True
- name: "constant"
signature: "Tensor (Shape shape, Scalar value, DataType dtype, Device device=None) => Constant"
signature: [
"Tensor (Shape shape, Scalar value, *, DataType dtype, Device device=None) => Constant",
]
bind_python: True
- name: "consistent_constant"
signature: "Tensor (Shape shape, Scalar value, DataType dtype, Placement placement, SbpList sbp_tuple) => ConsistentConstant"
signature: [
"Tensor (Shape shape, Scalar value, *, DataType dtype, Placement placement, SbpList sbp) => ConsistentConstant",
]
bind_python: True
- name: "empty"
signature: "Tensor (Shape shape, DataType dtype, Device device=None) => Empty"
signature: [
"Tensor (Shape shape, *, DataType dtype, Device device=None) => Empty",
]
bind_python: True
- name: "consistent_empty"
signature: "Tensor (Shape shape, DataType dtype, Placement placement, SbpList sbp_tuple) => ConsistentEmpty"
signature: [
"Tensor (Shape shape, *, DataType dtype, Placement placement, SbpList sbp) => ConsistentEmpty",
]
bind_python: True
- name: "zeros_like"
......@@ -351,8 +363,9 @@
bind_python: True
- name: "bernoulli"
signature:
"Tensor (Tensor x, DataType dtype=kFloat, Generator generator=None) => Bernoulli"
signature: [
"Tensor (Tensor x, *, DataType dtype=kFloat, Generator generator=None) => Bernoulli",
]
bind_python: True
- name: "concat"
......@@ -1038,22 +1051,31 @@
bind_python: False
- name: "rand"
signature: "Tensor (Shape shape, DataType dtype=None, Device device=None, Generator generator=None) => Rand"
signature: [
"Tensor (Shape shape, *, DataType dtype=None, Device device=None,
Generator generator=None) => Rand",
]
bind_python: True
- name: "consistent_rand"
signature: "Tensor (Shape shape, Placement placement, SbpList sbp_tuple, DataType dtype=None,
Generator generator=None) => ConsistentRand"
signature: [
"Tensor (Shape shape, *, Placement placement, SbpList sbp, DataType dtype=None,
Generator generator=None) => ConsistentRand",
]
bind_python: True
- name: "randn"
signature: "Tensor (Shape shape, DataType dtype=None, Device device=None,
Generator generator=None) => RandN"
signature: [
"Tensor (Shape shape, *, DataType dtype=None, Device device=None,
Generator generator=None) => RandN",
]
bind_python: True
- name: "consistent_randn"
signature: "Tensor (Shape shape, Placement placement, SbpList sbp_tuple, DataType dtype=None,
Generator generator=None) => ConsistentRandN"
signature: [
"Tensor (Shape shape, *, Placement placement, SbpList sbp, DataType dtype=None,
Generator generator=None) => ConsistentRandN",
]
bind_python: True
- name: "randint"
......@@ -1094,7 +1116,15 @@
bind_python: False
- name: "randperm"
signature: "Tensor (Int32 n, Device device=None, Generator generator=None) => RandPerm"
signature: [
"Tensor (Int32 n, *, Device device=None, Generator generator=None) => RandPerm",
]
bind_python: True
- name: "consistent_randperm"
signature: [
"Tensor (Int32 n, *, Placement placement, SbpList sbp, Generator generator=None) => ConsistentRandPerm",
]
bind_python: True
- name: "fused_self_attention"
......@@ -1105,10 +1135,6 @@
signature: "Tensor (Tensor query_mul_key_grad, Tensor value_grad, Tensor hidden_states, Float alpha=1.0) => FusedSelfAttentionGrad"
bind_python: False
- name: "consistent_randperm"
signature: "Tensor (Int32 n,Placement placement, SbpList sbp_tuple, Generator generator=None) => ConsistentRandperm"
bind_python: True
- name: "fused_scale_tril"
signature: "Tensor (Tensor x, Int64 diagonal=0, Scalar fill_value=0, Scalar scale=1) => FusedScaleTril"
bind_python: True
......
......@@ -290,6 +290,14 @@ class ArangeFunctor {
std::shared_ptr<OpExpr> op_;
};
class Arange2Functor {
public:
Maybe<Tensor> operator()(const int64_t& limit, const Symbol<DType>& dtype,
const Optional<Symbol<Device>>& device) const {
return Arange(0, limit, 1, dtype, device);
}
};
class ConsistentArangeFunctor {
public:
ConsistentArangeFunctor() { op_ = CHECK_JUST(one::OpBuilder("range").Output("out").Build()); }
......@@ -319,6 +327,15 @@ class ConsistentArangeFunctor {
std::shared_ptr<OpExpr> op_;
};
class ConsistentArange2Functor {
public:
Maybe<Tensor> operator()(const int64_t& limit, const Symbol<DType>& dtype,
const Symbol<ParallelDesc>& placement,
const std::vector<Symbol<cfg::SbpParallel>>& sbp_tuple) const {
return ConsistentArange(0, limit, 1, dtype, placement, sbp_tuple);
}
};
class ArgMaxFunctor : public UnaryFunctor {
public:
ArgMaxFunctor() { op_ = CHECK_JUST(one::OpBuilder("argmax").Input("in").Output("out").Build()); }
......@@ -662,7 +679,9 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
m.add_functor<impl::ReduceMeanFunctor>("ReduceMean");
m.add_functor<impl::TransposeFunctor>("Transpose");
m.add_functor<impl::ArangeFunctor>("Arange");
m.add_functor<impl::Arange2Functor>("Arange2");
m.add_functor<impl::ConsistentArangeFunctor>("ConsistentArange");
m.add_functor<impl::ConsistentArange2Functor>("ConsistentArange2");
m.add_functor<impl::ArgMaxFunctor>("ArgMax");
m.add_functor<impl::CastFunctor>("Cast");
m.add_functor<impl::ClampFunctor>("Clamp");
......
......@@ -80,7 +80,7 @@ class RandFunctor {
if (dtype.has_value()) {
dtype_val = JUST(dtype.value())->data_type();
if (dtype_val != DataType::kFloat && dtype_val != DataType::kDouble) {
OF_UNIMPLEMENTED() << dtype_val << "not supported in rand";
OF_UNIMPLEMENTED() << "Only support float and double in rand().";
}
}
......@@ -127,7 +127,7 @@ class ConsistentRandFunctor {
if (dtype.has_value()) {
dtype_val = JUST(dtype.value())->data_type();
if (dtype_val != DataType::kFloat && dtype_val != DataType::kDouble) {
OF_UNIMPLEMENTED() << dtype_val << "not supported in rand";
OF_UNIMPLEMENTED() << "Only support float and double in rand().";
}
}
......@@ -172,7 +172,7 @@ class RandNFunctor {
dtype_val = JUST(dtype.value())->data_type();
if (dtype_val != DataType::kFloat && dtype_val != DataType::kDouble) {
OF_UNIMPLEMENTED() << dtype_val << "not supported in randn";
OF_UNIMPLEMENTED() << "Only support float and double in randn().";
}
}
......@@ -220,7 +220,7 @@ class ConsistentRandNFunctor {
dtype_val = JUST(dtype.value())->data_type();
if (dtype_val != DataType::kFloat && dtype_val != DataType::kDouble) {
OF_UNIMPLEMENTED() << dtype_val << "not supported in randn";
OF_UNIMPLEMENTED() << "Only support float and double in randn().";
}
}
......
......@@ -102,7 +102,11 @@ def _setitem(self, key, value):
if self.is_consistent:
if isinstance(value, (int, float)):
value = flow.F.consistent_constant(
[1], value, self.dtype, placement=self.placement, sbp=flow.sbp.broadcast
[1],
value,
dtype=self.dtype,
placement=self.placement,
sbp=flow.sbp.broadcast,
)
else:
if value.is_consistent:
......@@ -118,7 +122,7 @@ def _setitem(self, key, value):
value = value.to_consistent(self.placement, sbp=flow.sbp.broadcast)
else:
if isinstance(value, (int, float)):
value = flow.F.constant([1], value, self.dtype, device=self.device)
value = flow.F.constant([1], value, dtype=self.dtype, device=self.device)
else:
value = value.to(device=self.device)
......
......@@ -65,7 +65,7 @@ def arange_op(
if placement is None:
if isinstance(device, str):
device = flow.device(device)
res = flow.F.arange(start, end, step, dtype, device)
res = flow.F.arange(start, end, step, dtype=dtype, device=device)
else:
assert isinstance(
placement, flow._oneflow_internal.placement
......@@ -77,7 +77,9 @@ def arange_op(
for elem in sbp:
assert isinstance(elem, flow.sbp.sbp), "sbp: %s" % sbp
assert len(sbp) == len(placement.hierarchy)
res = flow.F.consistent_arange(start, end, step, dtype, placement, sbp)
res = flow.F.consistent_arange(
start, end, step, dtype=dtype, placement=placement, sbp=sbp
)
res.requires_grad = requires_grad
return res
......
......@@ -69,10 +69,16 @@ class _ConstantBase(Module):
def forward(self):
if self.placement is not None:
res = flow.F.consistent_constant(
self.shape, self.value, self.dtype, self.placement, self.sbp,
self.shape,
self.value,
dtype=self.dtype,
placement=self.placement,
sbp=self.sbp,
)
else:
res = flow.F.constant(self.shape, self.value, self.dtype, self.device,)
res = flow.F.constant(
self.shape, self.value, dtype=self.dtype, device=self.device
)
res.requires_grad = self.requires_grad
return res
......@@ -361,10 +367,10 @@ class NewOnes(Module):
), f"requires_grad parameter not correct, please check!"
if self.placement is not None:
res = flow.F.consistent_constant(
new_size, 1.0, new_dtype, self.placement, self.sbp
new_size, 1.0, dtype=new_dtype, placement=self.placement, sbp=self.sbp
)
else:
res = flow.F.constant(new_size, 1.0, new_dtype, new_device)
res = flow.F.constant(new_size, 1.0, dtype=new_dtype, device=new_device)
res.requires_grad = new_requires_grad
return res
......
......@@ -86,9 +86,11 @@ def empty_op(
assert sbp is None, "sbp: %s" % sbp
if placement is not None:
tensor = flow.F.consistent_empty(shape, dtype, placement, sbp)
tensor = flow.F.consistent_empty(
shape, dtype=dtype, placement=placement, sbp=sbp
)
else:
tensor = flow.F.empty(shape, dtype, device,)
tensor = flow.F.empty(shape, dtype=dtype, device=device)
tensor.requires_grad_(requires_grad)
return tensor
......
......@@ -55,7 +55,7 @@ def bernoulli(input, *, generator=None, out=None):
"""
return flow.F.bernoulli(input, flow.float32, generator)
return flow.F.bernoulli(input, dtype=flow.float32, generator=generator)
def _rand_op_common_process(
......@@ -110,10 +110,19 @@ class Rand(Module):
def forward(self):
if self.placement is not None:
res = flow.F.consistent_rand(
self.size, self.placement, self.sbp, self.dtype, self.generator
self.size,
placement=self.placement,
sbp=self.sbp,
dtype=self.dtype,
generator=self.generator,
)
else:
res = flow.F.rand(self.size, self.dtype, self.device, self.generator)
res = flow.F.rand(
self.size,
dtype=self.dtype,
device=self.device,
generator=self.generator,
)
res.requires_grad = self.requires_grad
return res
......@@ -199,10 +208,19 @@ class RandN(Module):
def forward(self):
if self.placement is not None:
res = flow.F.consistent_randn(
self.size, self.placement, self.sbp, self.dtype, self.generator
self.size,
placement=self.placement,
sbp=self.sbp,
dtype=self.dtype,
generator=self.generator,
)
else:
res = flow.F.randn(self.size, self.dtype, self.device, self.generator)
res = flow.F.randn(
self.size,
dtype=self.dtype,
device=self.device,
generator=self.generator,
)
res.requires_grad = self.requires_grad
return res
......@@ -402,10 +420,10 @@ class RandPerm(Module):
def forward(self, out=None):
if self.placement is not None:
res = flow.F.consistent_randperm(
self.n, self.placement, self.sbp, self.generator
self.n, placement=self.placement, sbp=self.sbp, generator=self.generator
)
else:
res = flow.F.randperm(self.n, self.device, self.generator)
res = flow.F.randperm(self.n, device=self.device, generator=self.generator)
res.requires_grad = self.requires_grad
return res.to(dtype=self.dtype)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册