提交 75fec82b 编写于 作者: K kingfo

resolve pynative operator issue

上级 5ed799d7
......@@ -125,7 +125,7 @@ def list_len(x):
return len(x)
# only used in PyNative modes
# only used in PyNative mode
def partial(*args):
"""Implement `partial`."""
func = args[0].__call__
......@@ -133,10 +133,14 @@ def partial(*args):
return partial_func
# only used in PyNative modes
# only used in PyNative mode
def depend(value, expr):
return value
# only used in PyNative mode
def make_ref(key, value, ref):
return value
def scalar_cast(x, t):
"""Implement scalar_cast."""
......
......@@ -616,17 +616,19 @@ py::object ExecutorPy::Run(const py::tuple& args, const py::object& phase) {
return ExecDFGraph(info_, args, phase_s);
}
#else
if (backend == "ge") {
std::shared_ptr<py::object> ret_val = std::make_shared<py::object>();
if (backend == "ms" || backend == "ge") {
auto ret_val = std::make_shared<py::object>();
if (info_.count(phase_s) != 0 && info_[phase_s]->func_graph != nullptr) {
if (IsGraphOutputValueNodeOrParameter(info_[phase_s]->func_graph->output(), args, ret_val)) {
return *ret_val;
}
}
if (args.size() > 0) {
return args[0];
if (backend == "ge") {
if (args.size() > 0) {
return args[0];
}
return args;
}
return args;
}
#endif
std::size_t full_arg_size = ArgListSize(phase_s);
......
......@@ -20,11 +20,13 @@
#include <map>
#include <set>
#include <unordered_set>
#include <algorithm>
#include "utils/any.h"
#include "utils/utils.h"
#include "utils/context/ms_context.h"
#include "operator/ops.h"
#include "operator/composite/do_signature.h"
#include "pipeline/parse/data_converter.h"
#include "pipeline/static_analysis/prim.h"
#include "session/session_factory.h"
......@@ -50,6 +52,57 @@ inline ValuePtr PyAttrValue(const py::object& obj) {
return converted_ret;
}
py::tuple ConvertInputs(const PrimitivePyPtr& prim, const py::tuple& py_args) {
auto signature = prim->signatures();
std::vector<SignatureEnumDType> dtypes;
(void)std::transform(signature.begin(), signature.end(), std::back_inserter(dtypes),
[](const Signature& sig) { return sig.dtype; });
int empty_dtype_count = std::count(dtypes.begin(), dtypes.end(), SignatureEnumDType::kDTypeEmptyDefaultValue);
if (dtypes.size() == 0 || static_cast<int>(dtypes.size()) == empty_dtype_count) {
return py_args;
}
std::map<SignatureEnumDType, std::vector<size_t>> type_indexs;
for (size_t i = 0; i < dtypes.size(); ++i) {
auto it = type_indexs.find(dtypes[i]);
if (it == type_indexs.end()) {
(void)type_indexs.insert(std::make_pair(dtypes[i], std::vector<size_t>{i}));
} else {
it->second.push_back(i);
}
}
std::map<SignatureEnumDType, size_t> dst_type;
for (auto it = type_indexs.begin(); it != type_indexs.end(); (void)++it) {
auto type = it->first;
auto indexs = it->second;
if (indexs.size() < 2) {
continue;
}
size_t m_index = indexs[0];
for (size_t i = 1; i < indexs.size(); ++i) {
if (py::isinstance<tensor::Tensor>(py_args[indexs[i]])) {
m_index = indexs[i];
}
}
(void)dst_type.insert(std::make_pair(type, m_index));
}
py::tuple py_inputs(py_args.size());
for (size_t i = 0; i < py_args.size(); ++i) {
auto it = dst_type.find(dtypes[i]);
if (it != dst_type.end() && it->second != i &&
(py::isinstance<py::int_>(py_args[i]) || py::isinstance<py::float_>(py_args[i]))) {
auto tensor_ptr = py::cast<tensor::TensorPtr>(py_args[it->second]);
if (py::isinstance<py::int_>(py_args[i])) {
py_inputs[i] = std::make_shared<tensor::Tensor>(py::cast<py::int_>(py_args[i]), tensor_ptr->Dtype());
} else {
py_inputs[i] = std::make_shared<tensor::Tensor>(py::cast<py::float_>(py_args[i]), tensor_ptr->Dtype());
}
continue;
}
py_inputs[i] = py_args[i];
}
return py_inputs;
}
void PynativeInfer(const PrimitivePyPtr& prim, const py::tuple& py_args, OpExecInfo* const op_exec_info) {
size_t size = py_args.size();
AbstractBasePtrList args_spec_list;
......@@ -73,30 +126,22 @@ OpExecInfoPtr GenerateOpExecInfo(const py::args& args) {
auto op_exec_info = std::make_shared<OpExecInfo>();
MS_EXCEPTION_IF_NULL(op_exec_info);
op_exec_info->op_name = py::cast<std::string>(args[PY_NAME]);
if (py::isinstance<py::none>(args[PY_PRIM])) {
py::module ops_mod = py::module::import("mindspore.ops.operations");
py::object py_primitive = ops_mod.attr(op_exec_info->op_name.c_str())();
op_exec_info->py_primitive = py::cast<PrimitivePyPtr>(py_primitive);
py::dict none_attrs = py::dict();
op_exec_info->op_attrs = none_attrs;
} else {
PrimitivePyPtr prim = py::cast<PrimitivePyPtr>(args[PY_PRIM]);
auto pyobj = prim->GetPyObj();
if (pyobj == nullptr) {
MS_LOG(EXCEPTION) << "pyobj is empty";
}
py::tuple py_args = args[PY_INPUTS];
// use python infer method
if (ignore_infer_prim.find(op_exec_info->op_name) == ignore_infer_prim.end()) {
PynativeInfer(prim, py_args, op_exec_info.get());
}
op_exec_info->py_primitive = prim;
op_exec_info->op_attrs = py::getattr(args[PY_PRIM], "attrs");
auto prim = py::cast<PrimitivePyPtr>(args[PY_PRIM]);
auto pyobj = prim->GetPyObj();
if (pyobj == nullptr) {
MS_LOG(EXCEPTION) << "pyobj is empty";
}
py::tuple py_args = ConvertInputs(prim, args[PY_INPUTS]);
// use python infer method
if (ignore_infer_prim.find(op_exec_info->op_name) == ignore_infer_prim.end()) {
PynativeInfer(prim, py_args, op_exec_info.get());
}
op_exec_info->op_inputs = args[PY_INPUTS];
op_exec_info->py_primitive = prim;
op_exec_info->op_attrs = py::getattr(args[PY_PRIM], "attrs");
op_exec_info->op_inputs = py_args;
op_exec_info->inputs_mask = args[PY_INPUT_MASK];
if (op_exec_info->op_inputs.size() != op_exec_info->inputs_mask.size()) {
MS_LOG(ERROR) << "" << op_exec_info->op_name << " op_inputs size not equal op_mask";
MS_LOG(ERROR) << "op:" << op_exec_info->op_name << " inputs size not equal op_mask";
return nullptr;
}
return op_exec_info;
......
......@@ -14,7 +14,7 @@
# ============================================================================
"""Parameter for cell."""
from copy import copy
from copy import copy, deepcopy
import numpy as np
from .initializer import initializer
from .tensor import Tensor
......@@ -156,16 +156,24 @@ class Parameter:
return self.default_input
def __add__(self, other):
return self.default_input + other
res = deepcopy(self)
res.default_input = res.default_input + other
return res
def __sub__(self, other):
return self.default_input - other
res = deepcopy(self)
res.default_input = res.default_input - other
return res
def __mul__(self, other):
return self.default_input * other
res = deepcopy(self)
res.default_input = res.default_input * other
return res
def __truediv__(self, other):
return self.default_input / other
res = deepcopy(self)
res.default_input = res.default_input / other
return res
def set_parameter_data(self, data):
if isinstance(data, (Tensor, list, int, float,
......
......@@ -70,45 +70,60 @@ class Tensor(Tensor_):
return str(self.__str__())
def __add__(self, other):
if not isinstance(other, Tensor):
raise TypeError("input_data must be a tensor")
check_type('tensor input_data', other, (Tensor, float, int))
out = tensor_operator_registry.get('__add__')(self, other)
return out
def __mul__(self, other):
if not isinstance(other, Tensor):
raise TypeError("input_data must be a tensor")
check_type('tensor input_data', other, (Tensor, float, int))
out = tensor_operator_registry.get('__mul__')(self, other)
return out
def __neg__(self):
return Tensor(-self.asnumpy())
def __iadd__(self, other):
out = self.__add__(other)
return out
def __radd__(self, other):
check_type('tensor operation input', other, (Tensor, float, int))
out = tensor_operator_registry.get('__add__')(other, self)
return out
def __imul__(self, other):
out = self.__mul__(other)
return out
def __rmul__(self, other):
check_type('tensor operation input', other, (Tensor, float, int))
out = tensor_operator_registry.get('__mul__')(other, self)
return out
def __truediv__(self, other):
if isinstance(other, (int, float)):
other_tensor = Tensor(other, self.dtype())
elif isinstance(other, Tensor):
other_tensor = other
else:
raise TypeError("unsupported type for div operation")
out = tensor_operator_registry.get('__div__')(self, other_tensor)
check_type('tensor operation input', other, (Tensor, float, int))
out = tensor_operator_registry.get('__div__')(self, other)
return out
def __rtruediv__(self, other):
check_type('tensor operation input', other, (Tensor, float, int))
out = tensor_operator_registry.get('__div__')(other, self)
return out
def __sub__(self, other):
if not isinstance(other, Tensor):
raise TypeError("input_data must be a tensor")
out = self.__add__(Tensor(-other.asnumpy()))
check_type('tensor operation input', other, (Tensor, float, int))
out = self.__add__(-other)
return out
def __isub__(self, other):
out = self.__sub__(other)
return out
def __rsub__(self, other):
check_type('tensor operation input', other, (Tensor, float, int))
out = tensor_operator_registry.get('__add__')(other, Tensor(-self.asnumpy()))
return out
def __str__(self):
if self.dtype() == mstype.type_none:
return "Unknown Tensor type!"
......
......@@ -191,7 +191,7 @@ def get_bprop_concat(self):
def bprop(x, out, dout):
dx = ()
out_offset = P.ConcatOffset(F.tuple_len(x), axis)(x)
out_offset = G.ConcatOffset(F.tuple_len(x), axis)(x)
for i in range(F.tuple_len(x)):
slice_out = P.Slice()(dout, out_offset[i], shape_op(x[i]))
dx = dx + (slice_out,)
......
......@@ -14,6 +14,6 @@
# ============================================================================
"""ops utils."""
from .broadcast import _get_broadcast_shape
from .utils import _get_broadcast_shape, _get_concat_offset
__all__ = ['_get_broadcast_shape']
__all__ = ['_get_broadcast_shape', '_get_concat_offset']
......@@ -13,8 +13,11 @@
# limitations under the License.
# ============================================================================
"""broadcast"""
"""utils for operator"""
from ..._checkparam import ParamValidator as validator
from ..._checkparam import Rel
from ...common import dtype as mstype
def _get_broadcast_shape(x_shape, y_shape, prim_name):
"""
......@@ -57,3 +60,27 @@ def _get_broadcast_shape(x_shape, y_shape, prim_name):
broadcast_shape_front = y_shape[0: y_len - length] if length == x_len else x_shape[0: x_len - length]
broadcast_shape = broadcast_shape_front + broadcast_shape_back
return broadcast_shape
def _get_concat_offset(x_shp, x_type, axis):
"""for concat and concatoffset check args and compute offset"""
validator.check_type("shape", x_shp, [tuple])
validator.check_integer("len of input_x shape", len(x_shp), 0, Rel.GT)
validator.check_subclass("shape0", x_type[0], mstype.tensor)
validator.check_integer("len of input_x0 shape", len(x_shp[0]), 0, Rel.GT)
rank_base = len(x_shp[0])
validator.check_int_range('axis', axis, -rank_base - 1, rank_base, Rel.INC_BOTH)
if axis < 0:
axis = axis + rank_base
all_shp = x_shp[0][axis]
offset = [0,]
for i in range(1, len(x_shp)):
v = x_shp[i]
validator.check('len of x_shp[%d]' % i, len(v), 'len of base', len(x_shp[0]))
validator.check('x_type[%d]' % i, x_type[i], 'base', x_type[0])
for j in range(rank_base):
if j != axis and v[j] != x_shp[0][j]:
raise ValueError("Concat evaluator element %d shape in input can not concat with first element" % i)
offset.append(all_shp)
all_shp += v[axis]
return offset, all_shp, axis
......@@ -19,7 +19,7 @@ Primitive operator classes.
A collection of operators to build nerual networks or computing functions.
"""
from .array_ops import (Argmax, Argmin, Cast, ConcatOffset, Concat, Pack, Unpack,
from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
Diag, DiagPart, DType, ExpandDims, Eye,
Fill, GatherNd, GatherV2, InvertPermutation,
IsInstance, IsSubClass, ArgMaxWithValue, OnesLike, ZerosLike,
......@@ -200,7 +200,6 @@ __all__ = [
'LogicalOr',
'Size',
'DepthwiseConv2dNative',
'ConcatOffset',
'UnsortedSegmentSum',
"AllGather",
"AllReduce",
......
......@@ -20,6 +20,7 @@ from ..._c_expression import signature_kind as sig_kind
from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register
from ..._checkparam import ParamValidator as validator
from ..._checkparam import Rel, check_int_positive, check_bool
from .._utils import _get_concat_offset
from ...common import dtype as mstype
......@@ -107,6 +108,33 @@ class BinaryCrossEntropyGrad(PrimitiveWithInfer):
validator.check_two_types_same('x_type', x_type, 'weight_type', weight_type)
return x_type
class ConcatOffset(PrimitiveWithInfer):
"""primitive for computing Concat's gradient."""
@prim_attr_register
def __init__(self, N=2, axis=0):
"""init ConcatOffset"""
def __infer__(self, input_x):
axis = self.axis
x_shp = input_x['shape']
x_type = input_x['dtype']
offset, _, axis = _get_concat_offset(x_shp, x_type, axis)
self.add_prim_attr('T', x_type[0].element_type())
offset_values = []
for i in range(len(x_shp)):
values = []
for j in range(len(x_shp[0])):
value = 0
if j == axis:
value = offset[i]
values.append(value)
offset_values.append(tuple(values))
out = {'shape': None,
'dtype': None,
'value': tuple(offset_values)}
return out
class Conv2DBackpropFilter(PrimitiveWithInfer):
"""
......
......@@ -29,6 +29,7 @@ from ..._checkparam import Rel
from ...common import dtype as mstype
from ...common.tensor import Tensor
from ..operations.math_ops import _infer_shape_reduce
from .._utils import _get_concat_offset
from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register
def _check_infer_attr_reduce(axis, keep_dims):
......@@ -1275,30 +1276,6 @@ class UnsortedSegmentSum(PrimitiveWithInfer):
return out
def _get_concat_offset(x_shp, x_type, axis):
"""for concat and concatoffset check args and compute offset"""
validator.check_type("shape", x_shp, [tuple])
validator.check_integer("len of input_x shape", len(x_shp), 0, Rel.GT)
validator.check_subclass("shape0", x_type[0], mstype.tensor)
validator.check_integer("len of input_x0 shape", len(x_shp[0]), 0, Rel.GT)
rank_base = len(x_shp[0])
validator.check_int_range('axis', axis, -rank_base - 1, rank_base, Rel.INC_BOTH)
if axis < 0:
axis = axis + rank_base
all_shp = x_shp[0][axis]
offset = [0,]
for i in range(1, len(x_shp)):
v = x_shp[i]
validator.check('len of x_shp[%d]' % i, len(v), 'len of base', len(x_shp[0]))
validator.check('x_type[%d]' % i, x_type[i], 'base', x_type[0])
for j in range(rank_base):
if j != axis and v[j] != x_shp[0][j]:
raise ValueError("Concat evaluator element %d shape in input can not concat with first element" % i)
offset.append(all_shp)
all_shp += v[axis]
return offset, all_shp, axis
class Concat(PrimitiveWithInfer):
r"""
Concat tensor in specified axis.
......@@ -1531,34 +1508,6 @@ class Slice(PrimitiveWithInfer):
'value': None}
class ConcatOffset(PrimitiveWithInfer):
"""primitive for computing Concat's gradient."""
@prim_attr_register
def __init__(self, N=2, axis=0):
"""init ConcatOffset"""
def __infer__(self, input_x):
axis = self.axis
x_shp = input_x['shape']
x_type = input_x['dtype']
offset, _, axis = _get_concat_offset(x_shp, x_type, axis)
self.add_prim_attr('T', x_type[0].element_type())
offset_values = []
for i in range(len(x_shp)):
values = []
for j in range(len(x_shp[0])):
value = 0
if j == axis:
value = offset[i]
values.append(value)
offset_values.append(tuple(values))
out = {'shape': None,
'dtype': None,
'value': tuple(offset_values)}
return out
class Select(PrimitiveWithInfer):
r"""
......
......@@ -271,3 +271,6 @@ class MakeRefKey(Primitive):
@prim_attr_register
def __init__(self, tag):
validator.check_type('tag', tag, (str,))
def __call__(self):
pass
......@@ -24,6 +24,7 @@ import pytest
import mindspore as ms
import mindspore.common.api as me
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer
from ..ut_filter import non_graph_engine
......@@ -396,3 +397,24 @@ def test_tensor_dtype_fp32_to_bool():
input = ms.Tensor(input)
input_me = ms.Tensor(input, dtype=ms.bool_)
def test_tensor_operation():
x = Tensor(np.ones((3,3)) * 4)
res = x + 1
assert np.all(res.asnumpy() == np.ones((3, 3)) * 5)
res = 1 + x
assert np.all(res.asnumpy() == np.ones((3, 3)) * 5)
res = x - 2
assert np.all(res.asnumpy() == np.ones((3, 3)) * 2)
res = 6 - x
assert np.all(res.asnumpy() == np.ones((3, 3)) * 2)
res = x * 3
assert np.all(res.asnumpy() == np.ones((3, 3)) * 12)
res = 3 * x
assert np.all(res.asnumpy() == np.ones((3, 3)) * 12)
res = x / 2
assert np.all(res.asnumpy() == np.ones((3, 3)) * 2)
res = 8 / x
assert np.all(res.asnumpy() == np.ones((3, 3)) * 2)
with pytest.raises(TypeError):
res = x * (2, 3)
......@@ -190,7 +190,7 @@ def vm_impl_slice(self):
return vm_impl
@vm_impl_getters.register(P.ConcatOffset)
@vm_impl_getters.register(P._grad_ops.ConcatOffset)
def vm_impl_concatOffset(self):
"""Generate vm_impl function for ConcatOffset"""
def vm_impl(x):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册