提交 75fec82b 编写于 作者: K kingfo

resolve pynative operator issue

上级 5ed799d7
...@@ -125,7 +125,7 @@ def list_len(x): ...@@ -125,7 +125,7 @@ def list_len(x):
return len(x) return len(x)
# only used in PyNative modes # only used in PyNative mode
def partial(*args): def partial(*args):
"""Implement `partial`.""" """Implement `partial`."""
func = args[0].__call__ func = args[0].__call__
...@@ -133,10 +133,14 @@ def partial(*args): ...@@ -133,10 +133,14 @@ def partial(*args):
return partial_func return partial_func
# only used in PyNative modes # only used in PyNative mode
def depend(value, expr): def depend(value, expr):
return value return value
# only used in PyNative mode
def make_ref(key, value, ref):
return value
def scalar_cast(x, t): def scalar_cast(x, t):
"""Implement scalar_cast.""" """Implement scalar_cast."""
......
...@@ -616,18 +616,20 @@ py::object ExecutorPy::Run(const py::tuple& args, const py::object& phase) { ...@@ -616,18 +616,20 @@ py::object ExecutorPy::Run(const py::tuple& args, const py::object& phase) {
return ExecDFGraph(info_, args, phase_s); return ExecDFGraph(info_, args, phase_s);
} }
#else #else
if (backend == "ge") { if (backend == "ms" || backend == "ge") {
std::shared_ptr<py::object> ret_val = std::make_shared<py::object>(); auto ret_val = std::make_shared<py::object>();
if (info_.count(phase_s) != 0 && info_[phase_s]->func_graph != nullptr) { if (info_.count(phase_s) != 0 && info_[phase_s]->func_graph != nullptr) {
if (IsGraphOutputValueNodeOrParameter(info_[phase_s]->func_graph->output(), args, ret_val)) { if (IsGraphOutputValueNodeOrParameter(info_[phase_s]->func_graph->output(), args, ret_val)) {
return *ret_val; return *ret_val;
} }
} }
if (backend == "ge") {
if (args.size() > 0) { if (args.size() > 0) {
return args[0]; return args[0];
} }
return args; return args;
} }
}
#endif #endif
std::size_t full_arg_size = ArgListSize(phase_s); std::size_t full_arg_size = ArgListSize(phase_s);
if (size > full_arg_size) { if (size > full_arg_size) {
......
...@@ -20,11 +20,13 @@ ...@@ -20,11 +20,13 @@
#include <map> #include <map>
#include <set> #include <set>
#include <unordered_set> #include <unordered_set>
#include <algorithm>
#include "utils/any.h" #include "utils/any.h"
#include "utils/utils.h" #include "utils/utils.h"
#include "utils/context/ms_context.h" #include "utils/context/ms_context.h"
#include "operator/ops.h" #include "operator/ops.h"
#include "operator/composite/do_signature.h"
#include "pipeline/parse/data_converter.h" #include "pipeline/parse/data_converter.h"
#include "pipeline/static_analysis/prim.h" #include "pipeline/static_analysis/prim.h"
#include "session/session_factory.h" #include "session/session_factory.h"
...@@ -50,6 +52,57 @@ inline ValuePtr PyAttrValue(const py::object& obj) { ...@@ -50,6 +52,57 @@ inline ValuePtr PyAttrValue(const py::object& obj) {
return converted_ret; return converted_ret;
} }
py::tuple ConvertInputs(const PrimitivePyPtr& prim, const py::tuple& py_args) {
auto signature = prim->signatures();
std::vector<SignatureEnumDType> dtypes;
(void)std::transform(signature.begin(), signature.end(), std::back_inserter(dtypes),
[](const Signature& sig) { return sig.dtype; });
int empty_dtype_count = std::count(dtypes.begin(), dtypes.end(), SignatureEnumDType::kDTypeEmptyDefaultValue);
if (dtypes.size() == 0 || static_cast<int>(dtypes.size()) == empty_dtype_count) {
return py_args;
}
std::map<SignatureEnumDType, std::vector<size_t>> type_indexs;
for (size_t i = 0; i < dtypes.size(); ++i) {
auto it = type_indexs.find(dtypes[i]);
if (it == type_indexs.end()) {
(void)type_indexs.insert(std::make_pair(dtypes[i], std::vector<size_t>{i}));
} else {
it->second.push_back(i);
}
}
std::map<SignatureEnumDType, size_t> dst_type;
for (auto it = type_indexs.begin(); it != type_indexs.end(); (void)++it) {
auto type = it->first;
auto indexs = it->second;
if (indexs.size() < 2) {
continue;
}
size_t m_index = indexs[0];
for (size_t i = 1; i < indexs.size(); ++i) {
if (py::isinstance<tensor::Tensor>(py_args[indexs[i]])) {
m_index = indexs[i];
}
}
(void)dst_type.insert(std::make_pair(type, m_index));
}
py::tuple py_inputs(py_args.size());
for (size_t i = 0; i < py_args.size(); ++i) {
auto it = dst_type.find(dtypes[i]);
if (it != dst_type.end() && it->second != i &&
(py::isinstance<py::int_>(py_args[i]) || py::isinstance<py::float_>(py_args[i]))) {
auto tensor_ptr = py::cast<tensor::TensorPtr>(py_args[it->second]);
if (py::isinstance<py::int_>(py_args[i])) {
py_inputs[i] = std::make_shared<tensor::Tensor>(py::cast<py::int_>(py_args[i]), tensor_ptr->Dtype());
} else {
py_inputs[i] = std::make_shared<tensor::Tensor>(py::cast<py::float_>(py_args[i]), tensor_ptr->Dtype());
}
continue;
}
py_inputs[i] = py_args[i];
}
return py_inputs;
}
void PynativeInfer(const PrimitivePyPtr& prim, const py::tuple& py_args, OpExecInfo* const op_exec_info) { void PynativeInfer(const PrimitivePyPtr& prim, const py::tuple& py_args, OpExecInfo* const op_exec_info) {
size_t size = py_args.size(); size_t size = py_args.size();
AbstractBasePtrList args_spec_list; AbstractBasePtrList args_spec_list;
...@@ -73,30 +126,22 @@ OpExecInfoPtr GenerateOpExecInfo(const py::args& args) { ...@@ -73,30 +126,22 @@ OpExecInfoPtr GenerateOpExecInfo(const py::args& args) {
auto op_exec_info = std::make_shared<OpExecInfo>(); auto op_exec_info = std::make_shared<OpExecInfo>();
MS_EXCEPTION_IF_NULL(op_exec_info); MS_EXCEPTION_IF_NULL(op_exec_info);
op_exec_info->op_name = py::cast<std::string>(args[PY_NAME]); op_exec_info->op_name = py::cast<std::string>(args[PY_NAME]);
if (py::isinstance<py::none>(args[PY_PRIM])) { auto prim = py::cast<PrimitivePyPtr>(args[PY_PRIM]);
py::module ops_mod = py::module::import("mindspore.ops.operations");
py::object py_primitive = ops_mod.attr(op_exec_info->op_name.c_str())();
op_exec_info->py_primitive = py::cast<PrimitivePyPtr>(py_primitive);
py::dict none_attrs = py::dict();
op_exec_info->op_attrs = none_attrs;
} else {
PrimitivePyPtr prim = py::cast<PrimitivePyPtr>(args[PY_PRIM]);
auto pyobj = prim->GetPyObj(); auto pyobj = prim->GetPyObj();
if (pyobj == nullptr) { if (pyobj == nullptr) {
MS_LOG(EXCEPTION) << "pyobj is empty"; MS_LOG(EXCEPTION) << "pyobj is empty";
} }
py::tuple py_args = args[PY_INPUTS]; py::tuple py_args = ConvertInputs(prim, args[PY_INPUTS]);
// use python infer method // use python infer method
if (ignore_infer_prim.find(op_exec_info->op_name) == ignore_infer_prim.end()) { if (ignore_infer_prim.find(op_exec_info->op_name) == ignore_infer_prim.end()) {
PynativeInfer(prim, py_args, op_exec_info.get()); PynativeInfer(prim, py_args, op_exec_info.get());
} }
op_exec_info->py_primitive = prim; op_exec_info->py_primitive = prim;
op_exec_info->op_attrs = py::getattr(args[PY_PRIM], "attrs"); op_exec_info->op_attrs = py::getattr(args[PY_PRIM], "attrs");
} op_exec_info->op_inputs = py_args;
op_exec_info->op_inputs = args[PY_INPUTS];
op_exec_info->inputs_mask = args[PY_INPUT_MASK]; op_exec_info->inputs_mask = args[PY_INPUT_MASK];
if (op_exec_info->op_inputs.size() != op_exec_info->inputs_mask.size()) { if (op_exec_info->op_inputs.size() != op_exec_info->inputs_mask.size()) {
MS_LOG(ERROR) << "" << op_exec_info->op_name << " op_inputs size not equal op_mask"; MS_LOG(ERROR) << "op:" << op_exec_info->op_name << " inputs size not equal op_mask";
return nullptr; return nullptr;
} }
return op_exec_info; return op_exec_info;
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# ============================================================================ # ============================================================================
"""Parameter for cell.""" """Parameter for cell."""
from copy import copy from copy import copy, deepcopy
import numpy as np import numpy as np
from .initializer import initializer from .initializer import initializer
from .tensor import Tensor from .tensor import Tensor
...@@ -156,16 +156,24 @@ class Parameter: ...@@ -156,16 +156,24 @@ class Parameter:
return self.default_input return self.default_input
def __add__(self, other): def __add__(self, other):
return self.default_input + other res = deepcopy(self)
res.default_input = res.default_input + other
return res
def __sub__(self, other): def __sub__(self, other):
return self.default_input - other res = deepcopy(self)
res.default_input = res.default_input - other
return res
def __mul__(self, other): def __mul__(self, other):
return self.default_input * other res = deepcopy(self)
res.default_input = res.default_input * other
return res
def __truediv__(self, other): def __truediv__(self, other):
return self.default_input / other res = deepcopy(self)
res.default_input = res.default_input / other
return res
def set_parameter_data(self, data): def set_parameter_data(self, data):
if isinstance(data, (Tensor, list, int, float, if isinstance(data, (Tensor, list, int, float,
......
...@@ -70,45 +70,60 @@ class Tensor(Tensor_): ...@@ -70,45 +70,60 @@ class Tensor(Tensor_):
return str(self.__str__()) return str(self.__str__())
def __add__(self, other): def __add__(self, other):
if not isinstance(other, Tensor): check_type('tensor input_data', other, (Tensor, float, int))
raise TypeError("input_data must be a tensor")
out = tensor_operator_registry.get('__add__')(self, other) out = tensor_operator_registry.get('__add__')(self, other)
return out return out
def __mul__(self, other): def __mul__(self, other):
if not isinstance(other, Tensor): check_type('tensor input_data', other, (Tensor, float, int))
raise TypeError("input_data must be a tensor")
out = tensor_operator_registry.get('__mul__')(self, other) out = tensor_operator_registry.get('__mul__')(self, other)
return out return out
def __neg__(self):
return Tensor(-self.asnumpy())
def __iadd__(self, other): def __iadd__(self, other):
out = self.__add__(other) out = self.__add__(other)
return out return out
def __radd__(self, other):
check_type('tensor operation input', other, (Tensor, float, int))
out = tensor_operator_registry.get('__add__')(other, self)
return out
def __imul__(self, other): def __imul__(self, other):
out = self.__mul__(other) out = self.__mul__(other)
return out return out
def __rmul__(self, other):
check_type('tensor operation input', other, (Tensor, float, int))
out = tensor_operator_registry.get('__mul__')(other, self)
return out
def __truediv__(self, other): def __truediv__(self, other):
if isinstance(other, (int, float)): check_type('tensor operation input', other, (Tensor, float, int))
other_tensor = Tensor(other, self.dtype()) out = tensor_operator_registry.get('__div__')(self, other)
elif isinstance(other, Tensor): return out
other_tensor = other
else: def __rtruediv__(self, other):
raise TypeError("unsupported type for div operation") check_type('tensor operation input', other, (Tensor, float, int))
out = tensor_operator_registry.get('__div__')(self, other_tensor) out = tensor_operator_registry.get('__div__')(other, self)
return out return out
def __sub__(self, other): def __sub__(self, other):
if not isinstance(other, Tensor): check_type('tensor operation input', other, (Tensor, float, int))
raise TypeError("input_data must be a tensor") out = self.__add__(-other)
out = self.__add__(Tensor(-other.asnumpy()))
return out return out
def __isub__(self, other): def __isub__(self, other):
out = self.__sub__(other) out = self.__sub__(other)
return out return out
def __rsub__(self, other):
check_type('tensor operation input', other, (Tensor, float, int))
out = tensor_operator_registry.get('__add__')(other, Tensor(-self.asnumpy()))
return out
def __str__(self): def __str__(self):
if self.dtype() == mstype.type_none: if self.dtype() == mstype.type_none:
return "Unknown Tensor type!" return "Unknown Tensor type!"
......
...@@ -191,7 +191,7 @@ def get_bprop_concat(self): ...@@ -191,7 +191,7 @@ def get_bprop_concat(self):
def bprop(x, out, dout): def bprop(x, out, dout):
dx = () dx = ()
out_offset = P.ConcatOffset(F.tuple_len(x), axis)(x) out_offset = G.ConcatOffset(F.tuple_len(x), axis)(x)
for i in range(F.tuple_len(x)): for i in range(F.tuple_len(x)):
slice_out = P.Slice()(dout, out_offset[i], shape_op(x[i])) slice_out = P.Slice()(dout, out_offset[i], shape_op(x[i]))
dx = dx + (slice_out,) dx = dx + (slice_out,)
......
...@@ -14,6 +14,6 @@ ...@@ -14,6 +14,6 @@
# ============================================================================ # ============================================================================
"""ops utils.""" """ops utils."""
from .broadcast import _get_broadcast_shape from .utils import _get_broadcast_shape, _get_concat_offset
__all__ = ['_get_broadcast_shape'] __all__ = ['_get_broadcast_shape', '_get_concat_offset']
...@@ -13,8 +13,11 @@ ...@@ -13,8 +13,11 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""broadcast""" """utils for operator"""
from ..._checkparam import ParamValidator as validator
from ..._checkparam import Rel
from ...common import dtype as mstype
def _get_broadcast_shape(x_shape, y_shape, prim_name): def _get_broadcast_shape(x_shape, y_shape, prim_name):
""" """
...@@ -57,3 +60,27 @@ def _get_broadcast_shape(x_shape, y_shape, prim_name): ...@@ -57,3 +60,27 @@ def _get_broadcast_shape(x_shape, y_shape, prim_name):
broadcast_shape_front = y_shape[0: y_len - length] if length == x_len else x_shape[0: x_len - length] broadcast_shape_front = y_shape[0: y_len - length] if length == x_len else x_shape[0: x_len - length]
broadcast_shape = broadcast_shape_front + broadcast_shape_back broadcast_shape = broadcast_shape_front + broadcast_shape_back
return broadcast_shape return broadcast_shape
def _get_concat_offset(x_shp, x_type, axis):
"""for concat and concatoffset check args and compute offset"""
validator.check_type("shape", x_shp, [tuple])
validator.check_integer("len of input_x shape", len(x_shp), 0, Rel.GT)
validator.check_subclass("shape0", x_type[0], mstype.tensor)
validator.check_integer("len of input_x0 shape", len(x_shp[0]), 0, Rel.GT)
rank_base = len(x_shp[0])
validator.check_int_range('axis', axis, -rank_base - 1, rank_base, Rel.INC_BOTH)
if axis < 0:
axis = axis + rank_base
all_shp = x_shp[0][axis]
offset = [0,]
for i in range(1, len(x_shp)):
v = x_shp[i]
validator.check('len of x_shp[%d]' % i, len(v), 'len of base', len(x_shp[0]))
validator.check('x_type[%d]' % i, x_type[i], 'base', x_type[0])
for j in range(rank_base):
if j != axis and v[j] != x_shp[0][j]:
raise ValueError("Concat evaluator element %d shape in input can not concat with first element" % i)
offset.append(all_shp)
all_shp += v[axis]
return offset, all_shp, axis
...@@ -19,7 +19,7 @@ Primitive operator classes. ...@@ -19,7 +19,7 @@ Primitive operator classes.
A collection of operators to build nerual networks or computing functions. A collection of operators to build nerual networks or computing functions.
""" """
from .array_ops import (Argmax, Argmin, Cast, ConcatOffset, Concat, Pack, Unpack, from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
Diag, DiagPart, DType, ExpandDims, Eye, Diag, DiagPart, DType, ExpandDims, Eye,
Fill, GatherNd, GatherV2, InvertPermutation, Fill, GatherNd, GatherV2, InvertPermutation,
IsInstance, IsSubClass, ArgMaxWithValue, OnesLike, ZerosLike, IsInstance, IsSubClass, ArgMaxWithValue, OnesLike, ZerosLike,
...@@ -200,7 +200,6 @@ __all__ = [ ...@@ -200,7 +200,6 @@ __all__ = [
'LogicalOr', 'LogicalOr',
'Size', 'Size',
'DepthwiseConv2dNative', 'DepthwiseConv2dNative',
'ConcatOffset',
'UnsortedSegmentSum', 'UnsortedSegmentSum',
"AllGather", "AllGather",
"AllReduce", "AllReduce",
......
...@@ -20,6 +20,7 @@ from ..._c_expression import signature_kind as sig_kind ...@@ -20,6 +20,7 @@ from ..._c_expression import signature_kind as sig_kind
from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register
from ..._checkparam import ParamValidator as validator from ..._checkparam import ParamValidator as validator
from ..._checkparam import Rel, check_int_positive, check_bool from ..._checkparam import Rel, check_int_positive, check_bool
from .._utils import _get_concat_offset
from ...common import dtype as mstype from ...common import dtype as mstype
...@@ -107,6 +108,33 @@ class BinaryCrossEntropyGrad(PrimitiveWithInfer): ...@@ -107,6 +108,33 @@ class BinaryCrossEntropyGrad(PrimitiveWithInfer):
validator.check_two_types_same('x_type', x_type, 'weight_type', weight_type) validator.check_two_types_same('x_type', x_type, 'weight_type', weight_type)
return x_type return x_type
class ConcatOffset(PrimitiveWithInfer):
"""primitive for computing Concat's gradient."""
@prim_attr_register
def __init__(self, N=2, axis=0):
"""init ConcatOffset"""
def __infer__(self, input_x):
axis = self.axis
x_shp = input_x['shape']
x_type = input_x['dtype']
offset, _, axis = _get_concat_offset(x_shp, x_type, axis)
self.add_prim_attr('T', x_type[0].element_type())
offset_values = []
for i in range(len(x_shp)):
values = []
for j in range(len(x_shp[0])):
value = 0
if j == axis:
value = offset[i]
values.append(value)
offset_values.append(tuple(values))
out = {'shape': None,
'dtype': None,
'value': tuple(offset_values)}
return out
class Conv2DBackpropFilter(PrimitiveWithInfer): class Conv2DBackpropFilter(PrimitiveWithInfer):
""" """
......
...@@ -29,6 +29,7 @@ from ..._checkparam import Rel ...@@ -29,6 +29,7 @@ from ..._checkparam import Rel
from ...common import dtype as mstype from ...common import dtype as mstype
from ...common.tensor import Tensor from ...common.tensor import Tensor
from ..operations.math_ops import _infer_shape_reduce from ..operations.math_ops import _infer_shape_reduce
from .._utils import _get_concat_offset
from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register
def _check_infer_attr_reduce(axis, keep_dims): def _check_infer_attr_reduce(axis, keep_dims):
...@@ -1275,30 +1276,6 @@ class UnsortedSegmentSum(PrimitiveWithInfer): ...@@ -1275,30 +1276,6 @@ class UnsortedSegmentSum(PrimitiveWithInfer):
return out return out
def _get_concat_offset(x_shp, x_type, axis):
"""for concat and concatoffset check args and compute offset"""
validator.check_type("shape", x_shp, [tuple])
validator.check_integer("len of input_x shape", len(x_shp), 0, Rel.GT)
validator.check_subclass("shape0", x_type[0], mstype.tensor)
validator.check_integer("len of input_x0 shape", len(x_shp[0]), 0, Rel.GT)
rank_base = len(x_shp[0])
validator.check_int_range('axis', axis, -rank_base - 1, rank_base, Rel.INC_BOTH)
if axis < 0:
axis = axis + rank_base
all_shp = x_shp[0][axis]
offset = [0,]
for i in range(1, len(x_shp)):
v = x_shp[i]
validator.check('len of x_shp[%d]' % i, len(v), 'len of base', len(x_shp[0]))
validator.check('x_type[%d]' % i, x_type[i], 'base', x_type[0])
for j in range(rank_base):
if j != axis and v[j] != x_shp[0][j]:
raise ValueError("Concat evaluator element %d shape in input can not concat with first element" % i)
offset.append(all_shp)
all_shp += v[axis]
return offset, all_shp, axis
class Concat(PrimitiveWithInfer): class Concat(PrimitiveWithInfer):
r""" r"""
Concat tensor in specified axis. Concat tensor in specified axis.
...@@ -1531,34 +1508,6 @@ class Slice(PrimitiveWithInfer): ...@@ -1531,34 +1508,6 @@ class Slice(PrimitiveWithInfer):
'value': None} 'value': None}
class ConcatOffset(PrimitiveWithInfer):
"""primitive for computing Concat's gradient."""
@prim_attr_register
def __init__(self, N=2, axis=0):
"""init ConcatOffset"""
def __infer__(self, input_x):
axis = self.axis
x_shp = input_x['shape']
x_type = input_x['dtype']
offset, _, axis = _get_concat_offset(x_shp, x_type, axis)
self.add_prim_attr('T', x_type[0].element_type())
offset_values = []
for i in range(len(x_shp)):
values = []
for j in range(len(x_shp[0])):
value = 0
if j == axis:
value = offset[i]
values.append(value)
offset_values.append(tuple(values))
out = {'shape': None,
'dtype': None,
'value': tuple(offset_values)}
return out
class Select(PrimitiveWithInfer): class Select(PrimitiveWithInfer):
r""" r"""
......
...@@ -271,3 +271,6 @@ class MakeRefKey(Primitive): ...@@ -271,3 +271,6 @@ class MakeRefKey(Primitive):
@prim_attr_register @prim_attr_register
def __init__(self, tag): def __init__(self, tag):
validator.check_type('tag', tag, (str,)) validator.check_type('tag', tag, (str,))
def __call__(self):
pass
...@@ -24,6 +24,7 @@ import pytest ...@@ -24,6 +24,7 @@ import pytest
import mindspore as ms import mindspore as ms
import mindspore.common.api as me import mindspore.common.api as me
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common.parameter import Parameter from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer from mindspore.common.initializer import initializer
from ..ut_filter import non_graph_engine from ..ut_filter import non_graph_engine
...@@ -396,3 +397,24 @@ def test_tensor_dtype_fp32_to_bool(): ...@@ -396,3 +397,24 @@ def test_tensor_dtype_fp32_to_bool():
input = ms.Tensor(input) input = ms.Tensor(input)
input_me = ms.Tensor(input, dtype=ms.bool_) input_me = ms.Tensor(input, dtype=ms.bool_)
def test_tensor_operation():
x = Tensor(np.ones((3,3)) * 4)
res = x + 1
assert np.all(res.asnumpy() == np.ones((3, 3)) * 5)
res = 1 + x
assert np.all(res.asnumpy() == np.ones((3, 3)) * 5)
res = x - 2
assert np.all(res.asnumpy() == np.ones((3, 3)) * 2)
res = 6 - x
assert np.all(res.asnumpy() == np.ones((3, 3)) * 2)
res = x * 3
assert np.all(res.asnumpy() == np.ones((3, 3)) * 12)
res = 3 * x
assert np.all(res.asnumpy() == np.ones((3, 3)) * 12)
res = x / 2
assert np.all(res.asnumpy() == np.ones((3, 3)) * 2)
res = 8 / x
assert np.all(res.asnumpy() == np.ones((3, 3)) * 2)
with pytest.raises(TypeError):
res = x * (2, 3)
...@@ -190,7 +190,7 @@ def vm_impl_slice(self): ...@@ -190,7 +190,7 @@ def vm_impl_slice(self):
return vm_impl return vm_impl
@vm_impl_getters.register(P.ConcatOffset) @vm_impl_getters.register(P._grad_ops.ConcatOffset)
def vm_impl_concatOffset(self): def vm_impl_concatOffset(self):
"""Generate vm_impl function for ConcatOffset""" """Generate vm_impl function for ConcatOffset"""
def vm_impl(x): def vm_impl(x):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册