提交 c90b66a0 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!777 fix bugs and dock ops

Merge pull request !777 from zhangbuxue/fix_bug_and_dock_ops
......@@ -134,7 +134,7 @@ class DebugInfo : public Base {
explicit DebugInfo(const LocationPtr &loc);
virtual ~DebugInfo() = default;
~DebugInfo() override = default;
MS_DECLARE_PARENT(DebugInfo, Base);
int64_t debug_id();
int64_t unique_id() const { return unique_id_; }
......
......@@ -231,10 +231,10 @@ std::string AnalyzedFuncGraphExporter::GetNodeType(const AnfNodePtr &node) {
auto engine = node_cfg_->engine();
auto cfg = engine->MakeConfig(node, ctx);
auto abs = engine->cache().GetValue(cfg);
if (abs == nullptr) {
return "Undefined";
}
auto dtype = abs->BuildType();
auto shape = abs->BuildShape();
std::ostringstream oss;
......
......@@ -321,7 +321,7 @@ class TraceTransform : public TraceInfo {
std::string full_name() override { return full_name_ + transform_name_; }
MS_DECLARE_PARENT(TraceTransform, TraceInfo);
virtual std::string symbol() {
std::string symbol() override {
if (transform_name_.empty()) {
return "";
}
......
......@@ -87,6 +87,12 @@ const char *MetaIdLabel(const TypeId &v) {
return "kMetaTypeExternal";
case kMetaTypeNone:
return "kMetaTypeNone";
case kMetaTypeNull:
return "kMetaTypeNull";
case kMetaTypeEllipsis:
return "kMetaTypeEllipsis";
case kMetaTypeEnd:
return "kMetaTypeEnd";
default:
return "[Unknown Type Id]";
}
......
......@@ -1084,6 +1084,7 @@ int GenerateStridedSliceParametersFromTuple(const AbstractTuplePtr &slice_tuple,
std::vector<unsigned int> shrink;
auto slice_tuple_eles = slice_tuple->elements();
size_t ellipsis_num = 0;
for (size_t index = 0; index < slice_tuple_size; index++) {
if (slice_tuple_eles[index]->isa<AbstractSlice>()) {
AbstractSlicePtr slice = dyn_cast<AbstractSlice>(slice_tuple_eles[index]);
......@@ -1118,12 +1119,13 @@ int GenerateStridedSliceParametersFromTuple(const AbstractTuplePtr &slice_tuple,
<< slice_tuple_eles[index]->ToString();
}
for (size_t index = slice_tuple_size; index < shape_size; index++) {
begin->push_back(0);
end->push_back(shape[index]);
strides->push_back(1);
if (ellipsis_num == 0) {
for (size_t index = slice_tuple_size; index < shape_size; index++) {
begin->push_back(0);
end->push_back(shape[index]);
strides->push_back(1);
}
}
return ConvertBinaryToDecimal(shrink);
}
......@@ -1199,6 +1201,7 @@ FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec
if (scalar_ptr->BuildValue()->cast<BoolImmPtr>()->value()) {
return ExpandADim(ret_graph, tensor_node);
}
MS_LOG(EXCEPTION) << "TensorSlice not support the index is False.";
}
shrink_axis_mask = GenerateStridedSliceParametersFromNumber(scalar_ptr, shape, &begin, &end, &strides);
} else if (args_spec_list[1]->isa<AbstractEllipsis>()) {
......
......@@ -133,7 +133,6 @@ ResolveIRPassLib::ResolveIRPassLib() {
InferenceOptPrepareLib::InferenceOptPrepareLib() {
grad_var_prepare_ = MakeSubstitution(GradVarPrepare(), "grad_var_prepare", IsCNode);
}
} // namespace irpass
} // namespace opt
} // namespace mindspore
......@@ -159,7 +159,6 @@ inline bool IsCNodeDup(const AnfNodePtr &node) {
}
return false;
}
} // namespace irpass
} // namespace opt
} // namespace mindspore
......
......@@ -31,7 +31,6 @@
namespace mindspore {
namespace opt {
namespace irpass {
static AnfNodePtr GenerateUnpackGraphNode(std::vector<AnfNodePtr> inputs_y, FuncGraphPtr func_graph,
AnfNodePtr func_node, bool is_unpack, bool sens_param) {
MS_EXCEPTION_IF_NULL(func_graph);
......
......@@ -33,7 +33,6 @@
namespace mindspore {
namespace opt {
namespace irpass {
// {{GradOperation, g, w}, Ys}
// {UnPackCall, {GradOperation, g, w}, Ys}
class GradVarPrepare : public AnfVisitor {
......
......@@ -28,13 +28,11 @@
namespace mindspore {
namespace pipeline {
struct ExecutorInfo {
FuncGraphPtr func_graph;
ResourcePtr resource;
std::size_t arg_list_size;
};
using ExecutorInfoPtr = std::shared_ptr<ExecutorInfo>;
inline std::string GetPhasePrefix(const std::string &phase) {
......
......@@ -101,7 +101,7 @@ py::tuple GenerateKey(const std::string &name, const std::unordered_map<std::str
MS_LOG(INFO) << "Start new args and compile key:" << key;
g_args_cache[args_spec] = key++;
}
py::tuple argSpec = py::tuple(2);
auto argSpec = py::tuple(2);
argSpec[0] = name;
argSpec[1] = g_args_cache[args_spec];
return argSpec;
......
......@@ -52,11 +52,11 @@ void DoExecNonInputGraph(const std::string &phase) {
transform::RunOptions run_options;
run_options.name = phase;
auto graph_runner = DfGraphManager::GetInstance().GetGraphRunner();
if (graph_runner == nullptr) {
MS_LOG(ERROR) << "Can not found GraphRunner";
return;
}
{
// Release GIL before calling into (potentially long-running) C++ code
py::gil_scoped_release release;
......@@ -181,7 +181,6 @@ bool AddDFGraph(const std::map<std::string, ExecutorInfoPtr> &info, const py::di
size_t pos = phase.find('.');
std::string net_id = ((pos == std::string::npos || pos == phase.size() - 1) ? phase : phase.substr(pos + 1));
std::string phase_prefix = phase.substr(0, pos);
if (phase_prefix == "export") {
MS_LOG(INFO) << "Set DfGraphConvertor training : false";
convertor.set_training(false);
......@@ -319,19 +318,24 @@ void RunGEInitGraph(const py::dict &init_params, const std::string &phase) {
py::object ExtractGeneralCnodeRet(const AbstractBasePtr &cnode_data, const py::tuple &data, size_t *count) {
MS_EXCEPTION_IF_NULL(cnode_data);
if (*count >= data.size()) {
MS_LOG(EXCEPTION) << "The number of elements in the outputs : " << data.size()
<< " less than the number of elements required. ";
}
if (cnode_data->isa<AbstractTensor>()) {
if (*count >= data.size()) {
MS_LOG(EXCEPTION) << "The number of elements in the outputs : " << data.size()
<< " less than the number of elements required. ";
}
BaseShapePtr shape = cnode_data->BuildShape();
auto shape_act = shape->cast<abstract::ShapePtr>()->shape();
Tensor tensor_exp = py::cast<Tensor>(data[*count]);
if (shape_act != tensor_exp.shape()) {
MS_LOG(EXCEPTION) << "The shape of the tensor returned from GE is not the same as "
"the shape of the tensor derived from ME.";
if (!shape->isa<abstract::Shape>()) {
MS_LOG(EXCEPTION) << "The shape of the tensor derived is not Shape, is " << shape->ToString();
}
auto shape_me = shape->cast<abstract::ShapePtr>()->shape();
auto shape_ge = py::cast<Tensor>(data[*count]).shape();
if (shape_ge != shape_me) {
MS_LOG(EXCEPTION) << "The shape of the " << *count << "th tensor returned: " << shape_ge
<< " is not the same as the shape of the tensor derived: " << shape_me;
}
return data[(*count)++];
}
......@@ -343,7 +347,7 @@ py::object ExtractGeneralCnodeRet(const AbstractBasePtr &cnode_data, const py::t
auto data_tp = cnode_data->cast<AbstractTuplePtr>();
auto elements = data_tp->elements();
size_t size = data_tp->size();
py::tuple tp = py::tuple(size);
auto tp = py::tuple(size);
for (size_t i = 0; i < size; i++) {
tp[i] = ExtractGeneralCnodeRet(elements[i], data, count);
}
......@@ -357,11 +361,11 @@ py::object StructureOutput(const AnfNodePtr &output_node, const py::tuple &data,
return ValuePtrToPyData(GetValueNode(output_node));
}
if (*count >= data.size()) {
MS_LOG(EXCEPTION) << "The number of elements in the outputs : " << data.size()
<< " less than the number of elements required. ";
}
if (output_node->isa<Parameter>()) {
if (*count >= data.size()) {
MS_LOG(EXCEPTION) << "The number of elements in the outputs : " << data.size()
<< " less than the number of elements required. ";
}
return data[(*count)++];
}
......@@ -374,7 +378,7 @@ py::object StructureOutput(const AnfNodePtr &output_node, const py::tuple &data,
if (output_c->IsApply(prim::kPrimMakeTuple)) {
auto input_list = output_c->inputs();
size_t size = input_list.size();
py::tuple tp = py::tuple(size - 1);
auto tp = py::tuple(size - 1);
for (size_t i = 1; i < size; i++) {
tp[i - 1] = StructureOutput(input_list[i], data, count);
}
......@@ -396,11 +400,8 @@ std::shared_ptr<py::object> DoExecGraph(const FuncGraphPtr &graph, const std::ve
std::vector<GeTensorPtr> ge_outputs;
transform::RunOptions run_options;
run_options.name = phase;
auto graph_runner = DfGraphManager::GetInstance().GetGraphRunner();
if (graph_runner == nullptr) {
MS_LOG(EXCEPTION) << "Can not found GraphRunner.";
}
......@@ -473,7 +474,6 @@ void ProcessGeArg(const std::map<std::string, ExecutorInfoPtr> &info, const py::
py::object ExecDFGraph(const std::map<std::string, ExecutorInfoPtr> &info, const py::tuple &args,
const std::string &phase) {
std::string phase_prefix = GetPhasePrefix(phase);
if (phase_prefix == "save") {
DoExecNonInputGraph(phase);
ConfigManager::GetInstance().ResetConfig();
......@@ -483,7 +483,6 @@ py::object ExecDFGraph(const std::map<std::string, ExecutorInfoPtr> &info, const
if (info.count(phase) == 0) {
MS_LOG(EXCEPTION) << "There is no phase:" << phase;
}
FuncGraphPtr anf_graph = info.at(phase)->func_graph;
#ifdef ENABLE_INFER
......
......@@ -31,7 +31,6 @@
namespace mindspore {
namespace pipeline {
namespace py = pybind11;
void SetGeOption(const std::map<std::string, std::string> &options);
......@@ -50,7 +49,6 @@ bool InitExecDatasetGe(const std::string &queue_name, int64_t size, int64_t batc
const std::vector<int64_t> &input_indexes, const std::string &phase);
void ExportDFGraph(const std::string &file_name, const std::string &phase);
} // namespace pipeline
} // namespace mindspore
......
......@@ -41,7 +41,7 @@ class AbstractFuncAtom : public AbstractFunction {
AbstractFunctionPtr Join(const AbstractFunctionPtr &other) final;
void Visit(std::function<void(const AbstractFuncAtomPtr &)>) const final;
bool operator==(const AbstractFunction &other) const;
bool operator==(const AbstractFunction &other) const override;
std::size_t hash() const override { return tid(); }
};
......@@ -270,7 +270,7 @@ class TypedPrimitiveAbstractClosure : public AbstractFuncAtom {
class DummyAbstractClosure : public AbstractFuncAtom {
public:
DummyAbstractClosure() = default;
~DummyAbstractClosure() = default;
~DummyAbstractClosure() override = default;
MS_DECLARE_PARENT(DummyAbstractClosure, AbstractFuncAtom)
EvaluatorPtr GetEvaluator(AnalysisEnginePtr) override { MS_LOG(EXCEPTION) << "A dummy function cannot eval."; }
......
......@@ -295,7 +295,6 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) {
dic["shape"] = shape;
dic["dtype"] = arg_slice->BuildType();
dic["value"] = BuildValue(arg_slice->BuildValue());
} else if (abs_base->isa<AbstractTuple>()) {
auto arg_tuple = dyn_cast<AbstractTuple>(abs_base);
size_t len = arg_tuple->size();
......
......@@ -171,20 +171,17 @@ GeTensorPtr TransformUtil::ConvertTensor(const MeTensorPtr &tensor, const std::s
MS_LOG(ERROR) << "The Me Tensor data type size is wrong, type size is: " << type_size;
return nullptr;
}
// get tensor buff size
size_t data_buff_size = 0;
size_t elements_num = IntToSize(tensor->ElementsNum());
if (elements_num > 0 && type_size > 0 && UINT_MAX / type_size >= elements_num) {
data_buff_size = elements_num * type_size;
if (UINT_MAX / type_size < elements_num) {
MS_LOG(ERROR) << "The required Me Tensor data buff size " << elements_num << " x " << type_size
<< " overflowed UINT_MAX: " << UINT_MAX << ".";
return nullptr;
}
// get tensor buff size
size_t data_buff_size = elements_num * type_size;
if (data_buff_size == 0) {
if (elements_num > 0 && type_size > 0 && UINT_MAX / type_size < elements_num) {
MS_LOG(ERROR) << "The required Me Tensor data buff size " << elements_num << " x " << type_size
<< " overflowed UINT_MAX: " << UINT_MAX << ".";
} else {
MS_LOG(ERROR) << "The Me Tensor data buff size is 0.";
}
return nullptr;
MS_LOG(INFO) << "The Me Tensor data buff size is 0.";
}
// create ge tensor
auto desc = GetGeTensorDesc(tensor->shape_c(), tensor->data_type(), format);
......
......@@ -56,7 +56,7 @@ class Momentum(Optimizer):
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
Outputs:
Tensor[bool], the value is True.
tuple[bool], all elements are True.
Raises:
ValueError: If the momentum is less than 0.0.
......
......@@ -142,8 +142,12 @@ from .smooth_l1_loss_grad import _smooth_l1_loss_grad_tbe
from .fused_mul_add import _fused_mul_add_tbe
from .fused_mul_add_n import _fused_mul_add_n_tbe
from .fused_mul_apply_momentum import _fused_mul_apply_momentum_tbe
from .fill_d import _fill_d_op_tbe
from .fill import _fill_op_tbe
from .erf import _erf_op_tbe
from .depthwise_conv2d import _depthwise_conv2d_tbe
from .depthwise_conv2d_backprop_filter import _depthwise_conv2d_backprop_filter_tbe
from .depthwise_conv2d_backprop_input import _depthwise_conv2d_backprop_input_tbe
from .greater_equal import _greater_equal_tbe
from .not_equal import _not_equal_tbe
from .floor_mod import _floor_mod_tbe
from .scatter_nd_update import _scatter_nd_update_tbe
......@@ -16,7 +16,7 @@
"""FillD op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
fill_d_op_info = TBERegOp("FillD") \
fill_d_op_info = TBERegOp("Fill") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("fill_d.so") \
......@@ -50,6 +50,6 @@ fill_d_op_info = TBERegOp("FillD") \
@op_info_register(fill_d_op_info)
def _fill_d_op_tbe():
def _fill_op_tbe():
"""FillD TBE register"""
return
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""FloorMod op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
floor_mod_op_info = TBERegOp("FloorMod") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("floor_mod.so") \
.compute_cost(10) \
.kernel_name("floor_mod") \
.partial_flag(True) \
.input(0, "x1", False, "required", "all") \
.input(1, "x2", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.get_op_info()
@op_info_register(floor_mod_op_info)
def _floor_mod_tbe():
"""FloorMod TBE register"""
return
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""GreaterEqual op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
greater_equal_op_info = TBERegOp("GreaterEqual") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("greater_equal.so") \
.compute_cost(10) \
.kernel_name("greater_equal") \
.partial_flag(True) \
.input(0, "x1", False, "required", "all") \
.input(1, "x2", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.BOOL_Default) \
.dtype_format(DataType.I8_5HD, DataType.I8_5HD, DataType.BOOL_5HD) \
.dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.BOOL_Default) \
.dtype_format(DataType.U8_5HD, DataType.U8_5HD, DataType.BOOL_5HD) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.BOOL_Default) \
.dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.BOOL_5HD) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.BOOL_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.BOOL_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.BOOL_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.BOOL_5HD) \
.get_op_info()
@op_info_register(greater_equal_op_info)
def _greater_equal_tbe():
"""Greater TBE register"""
return
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""NotEqual op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
not_equal_op_info = TBERegOp("NotEqual") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("not_equal.so") \
.compute_cost(10) \
.kernel_name("not_equal") \
.partial_flag(True) \
.input(0, "x1", False, "required", "all") \
.input(1, "x2", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.BOOL_Default) \
.dtype_format(DataType.I8_5HD, DataType.I8_5HD, DataType.BOOL_5HD) \
.dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.BOOL_Default) \
.dtype_format(DataType.U8_5HD, DataType.U8_5HD, DataType.BOOL_5HD) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.BOOL_Default) \
.dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.BOOL_5HD) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.BOOL_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.BOOL_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.BOOL_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.BOOL_5HD) \
.get_op_info()
@op_info_register(not_equal_op_info)
def _not_equal_tbe():
"""Equal TBE register"""
return
......@@ -37,5 +37,5 @@ scatter_nd_op_info = TBERegOp("ScatterNd") \
@op_info_register(scatter_nd_op_info)
def _scatter_nd_tbe():
"""Conv2D TBE register"""
"""ScatterNd TBE register"""
return
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ScatterNdUpdate op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
scatter_nd_update_op_info = TBERegOp("ScatterNdUpdate") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("scatter_nd_update.so") \
.compute_cost(10) \
.kernel_name("scatter_nd_update") \
.partial_flag(True) \
.attr("use_locking", "optional", "bool", "all") \
.input(0, "var", False, "required", "all") \
.input(1, "indices", False, "required", "all") \
.input(1, "updates", False, "required", "all") \
.output(0, "var", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default, DataType.U8_Default,) \
.dtype_format(DataType.BOOL_Default, DataType.I32_Default, DataType.BOOL_Default, DataType.BOOL_Default) \
.get_op_info()
@op_info_register(scatter_nd_update_op_info)
def _scatter_nd_update_tbe():
"""ScatterNdUpdate TBE register"""
return
......@@ -147,6 +147,21 @@ def _tensor_getitem_by_number(data, number_index):
return _tensor_slice(data, number_index)
@getitem.register("Tensor", "None")
def _tensor_getitem_by_none(data, index):
"""
Getting item of tensor by None.
Inputs:
data (Tensor): A tensor.
index (None): None.
Outputs:
Tensor, element type is as same as the element type of data.
"""
return _tensor_slice(data, index)
@getitem.register("Tensor", "Slice")
def _tensor_getitem_by_slice(data, slice_index):
"""
......
......@@ -633,15 +633,15 @@ class TruncatedNormal(PrimitiveWithInfer):
dtype (:class:`mindspore.dtype`): Data type. Default: mindspore.float32.
Inputs:
- **shape** (Tensor) - Shape of output tensor. The shape is a 1-D tensor, and type is int.
- **shape** (tuple[int]) - Shape of output tensor, is a tuple of positive int.
Outputs:
Tensor, type of output tensor is same as attribute `dtype`.
Examples:
>>> input_shape = Tensor(np.array([1, 2, 3]))
>>> shape = (1, 2, 3)
>>> truncated_normal = P.TruncatedNormal()
>>> output = truncated_normal(input_shape)
>>> output = truncated_normal(shape)
"""
@prim_attr_register
......@@ -651,16 +651,12 @@ class TruncatedNormal(PrimitiveWithInfer):
validator.check_typename('dtype', dtype, mstype.number_type)
def __infer__(self, shape):
shape_t = shape['value']
validator.check_subclass("shape", shape['dtype'], mstype.tensor)
shape_n = shape_t.asnumpy()
if shape_n.ndim != 1:
raise ValueError('The rank of input shape must be 1.')
if shape_n.dtype not in (np.int32, np.int64):
raise TypeError('The type of input shape must be int32 or int64.')
for i, item in enumerate(shape_n):
validator.check_integer(f"shape[{i}]", item.item(), 0, Rel.GT)
out = {'shape': tuple(shape_n),
shape_value = shape['value']
validator.check_const_input("shape", shape_value)
validator.check_type("shape", shape_value, [tuple])
for i, value in enumerate(shape_value):
validator.check_integer(f'{i}th value of shape', value, 0, Rel.GT)
out = {'shape': shape_value,
'dtype': mstype.tensor_type(self.dtype),
'value': None}
return out
......@@ -1648,20 +1644,20 @@ class StridedSlice(PrimitiveWithInfer):
validator.check_type('shrink_axis_mask', shrink_axis_mask, [int])
def __infer__(self, x, begin, end, strides):
begin_shape, end_shape, strides_shape = begin['shape'], end['shape'], strides['shape']
if begin_shape != strides_shape or end_shape != strides_shape:
raise ValueError("The shape of begin, end and strides in 'StridedSlice' must be equal.")
validator.check_const_input("begin", begin['value'])
validator.check_const_input("end", end['value'])
validator.check_const_input("strides", strides['value'])
validator.check_type("begin", begin['value'], [tuple])
validator.check_type("end", end['value'], [tuple])
validator.check_type("strides", strides['value'], [tuple])
begin_v, end_v, strides_v = begin['value'], end['value'], strides['value']
validator.check_const_input("begin", begin_v)
validator.check_const_input("end", end_v)
validator.check_const_input("strides", strides_v)
validator.check_type("begin", begin_v, [tuple])
validator.check_type("end", end_v, [tuple])
validator.check_type("strides", strides_v, [tuple])
x_shape = x['shape']
x_shp_len = len(x_shape)
begin_v, end_v, strides_v = begin['value'], end['value'], strides['value']
if len(begin_v) != x_shp_len or len(end_v) != x_shp_len or len(strides_v) != x_shp_len:
raise ValueError(f"The length of begin index{begin_v}, end index{end_v} and strides{strides_v} "
f"must be equal to the dims({x_shp_len}) of input.")
ret_shape = []
append_dimensions = []
shrink_pos = bin(self.shrink_axis_mask)[::-1]
......@@ -1914,6 +1910,11 @@ class ResizeNearestNeighbor(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, size, align_corners=False):
"""Init ResizeNearestNeighbor"""
validator.check_type("size", size, [tuple, list])
validator.check_type("align_corners", align_corners, [bool])
validator.check_integer("length of size", len(size), 2, Rel.EQ)
for i, value in enumerate(size):
validator.check_integer(f'{i}th value of size', value, 0, Rel.GE)
self.init_prim_io_names(inputs=['image_in'], outputs=['image_out'])
def infer_shape(self, x):
......
......@@ -1251,7 +1251,8 @@ class Acosh(PrimitiveWithInfer):
Compute inverse hyperbolic cosine of x element-wise.
Inputs:
- **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
- **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`,
and the data type of 'input_x' is number, the element in 'input_x' should be greater than or equal to 1.
Outputs:
Tensor, has the same shape as `input_x`.
......
......@@ -753,8 +753,15 @@ class DepthwiseConv2dNative(PrimitiveWithInfer):
self.init_prim_io_names(inputs=['x', 'w'], outputs=['output'])
self.kernel_size = _check_positive_int_or_tuple('kernel_size', kernel_size, self.name)
self.stride = _check_positive_int_or_tuple('stride', stride, self.name)
if self.stride[0] != self.stride[1]:
raise ValueError("The height and width of stride should be equal,"
f"but got height:{self.stride[0]}, width:{self.stride[1]}")
self.add_prim_attr('stride', (1, 1, self.stride[0], self.stride[1]))
self.dilation = _check_positive_int_or_tuple('dilation', dilation, self.name)
if self.dilation[0] != self.dilation[1]:
raise ValueError("The height and width of dilation should be equal,"
f"but got height:{self.dilation[0]}, width:{self.dilation[1]}")
self.add_prim_attr('dilation', (1, 1, self.dilation[0], self.dilation[1]))
validator.check_value_type('pad', pad, (int,), self.name)
self.pad_mode = validator.check_string('pad_mode', pad_mode, ['valid', 'same', 'pad'], self.name)
......@@ -771,13 +778,11 @@ class DepthwiseConv2dNative(PrimitiveWithInfer):
validator.check("x_shape[1]", x_shape[1], "w_shape[1]", w_shape[1], Rel.EQ, self.name)
validator.check('kernel_size', self.kernel_size, 'w_shape[2:4]', tuple(w_shape[2:4]), Rel.EQ, self.name)
kernel_size_h = w_shape[2]
kernel_size_w = w_shape[3]
stride_h = self.stride[2]
stride_w = self.stride[3]
dilation_h = self.dilation[2]
dilation_w = self.dilation[3]
kernel_size_n, _, kernel_size_h, kernel_size_w = w_shape
_, _, stride_h, stride_w = self.stride
_, _, dilation_h, dilation_w = self.dilation
if kernel_size_n != 1:
raise ValueError(f"The batch of input weight should be 1, but got {kernel_size_n}")
if self.pad_mode == "valid":
h_out = math.ceil((x_shape[2] - dilation_h * (kernel_size_h - 1)) / stride_h)
w_out = math.ceil((x_shape[3] - dilation_w * (kernel_size_w - 1)) / stride_w)
......@@ -1198,8 +1203,8 @@ class TopK(PrimitiveWithInfer):
>>> input_x = Tensor([1, 2, 3, 4, 5], mindspore.float16)
>>> k = 3
>>> values, indices = topk(input_x, k)
>>> assert values == Tensor(np.array([5, 4, 3]))
>>> assert indices == Tensor(np.array([4, 3, 2]))
>>> assert values == Tensor(np.array([5, 4, 3]), mstype.float16)
>>> assert indices == Tensor(np.array([4, 3, 2]), mstype.int32)
"""
@prim_attr_register
......
......@@ -372,7 +372,7 @@ test_case_math_ops = [
'desc_bprop': [[3]]}),
('TruncatedNormal', {
'block': P.TruncatedNormal(),
'desc_const': [Tensor(np.array([1, 2, 3]))],
'desc_const': [(1, 2, 3)],
'desc_inputs': [],
'skip': ['backward'],
'add_fake_input': True}),
......@@ -793,8 +793,8 @@ test_case_nn_ops = [
'desc_bprop': [[5, 5]]}),
('DepthwiseConv2dNative_1', {
'block': P.DepthwiseConv2dNative(3, (3, 3), pad_mode="pad", pad=1, stride=2),
'desc_inputs': [[10, 32, 32, 32], [3, 32, 3, 3]],
'desc_bprop': [[10, 30, 16, 16]]}),
'desc_inputs': [[10, 32, 32, 32], [1, 32, 3, 3]],
'desc_bprop': [[10, 32, 16, 16]]}),
('DepthwiseConv2dNative_2', {
'block': P.DepthwiseConv2dNative(1, (3, 3), pad_mode="same", pad=0, stride=1),
'desc_inputs': [[2592, 2048, 4, 4], [1, 2048, 3, 3]],
......
......@@ -52,8 +52,9 @@ class NetWorkSliceEllipsis(Cell):
def construct(self, tensor):
ret0 = tensor[0:4:2, ..., 1] + self.tensor_ret0
ret1 = tensor[...] + self.tensor_ret1
ret2 = tensor[True] + self.tensor_ret2
return ret0, ret1, ret2
ret2 = tensor[None] + self.tensor_ret2
ret3 = tensor[True] + self.tensor_ret2
return ret0, ret1, ret2, ret3
class NetWorkReduceDimension(Cell):
......@@ -305,7 +306,7 @@ test_cases = [
'block': NetWorkReduceToScalar(),
'desc_inputs': [Tensor(np.ones([6, 8, 10], np.int32))],
}),
('NetWorkSliceEllipsis', {
('TensorSliceEllipsis', {
'block': NetWorkSliceEllipsis(),
'desc_inputs': [Tensor(np.ones([6, 7, 8, 9], np.int32))],
}),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册