提交 ffcbcc8d 编写于 作者: C changzherui

syn code 505

graphengine @ 63cb7293
Subproject commit 43f5d24337bf785251eefae2d810c7d5684194d6 Subproject commit 63cb729373ae8b1b14bc14176c14dac6d18d0e4d
...@@ -320,6 +320,224 @@ class Validator: ...@@ -320,6 +320,224 @@ class Validator:
raise TypeError(f"{msg_prefix} `{arg_name}` must be float.") raise TypeError(f"{msg_prefix} `{arg_name}` must be float.")
class ParamValidator:
"""Parameter validator. NOTICE: this class will be replaced by `class Validator`"""
@staticmethod
def equal(arg_name, arg_value, cond_str, cond):
"""Judging valid value."""
if not cond:
raise ValueError(f'The `{arg_name}` must be {cond_str}, but got {arg_value}.')
@staticmethod
def check(arg_name, arg_value, value_name, value, rel=Rel.EQ):
"""This method is only used for check int values, since when compare float values,
we need consider float error."""
rel_fn = Rel.get_fns(rel)
if not rel_fn(arg_value, value):
rel_str = Rel.get_strs(rel).format(f'{value_name}: {value}')
raise ValueError(f'The `{arg_name}` should be {rel_str}, but got {arg_value}.')
@staticmethod
def check_integer(arg_name, arg_value, value, rel):
"""Integer value judgment."""
rel_fn = Rel.get_fns(rel)
type_mismatch = not isinstance(arg_value, int) or isinstance(arg_value, bool)
if type_mismatch or not rel_fn(arg_value, value):
rel_str = Rel.get_strs(rel).format(value)
raise ValueError(f'The `{arg_name}` should be an int and must {rel_str}, but got {arg_value}.')
return arg_value
@staticmethod
def check_shape_length(arg_name, arg_value, value, rel):
"""Shape length judgment."""
rel_fn = Rel.get_fns(rel)
type_mismatch = not isinstance(arg_value, int)
if type_mismatch or not rel_fn(arg_value, value):
rel_str = Rel.get_strs(rel).format(value)
raise ValueError(f'The length of `{arg_name}` should be an int and must {rel_str}, but got {arg_value}')
return arg_value
@staticmethod
def check_int_range(arg_name, arg_value, lower_limit, upper_limit, rel):
"""This method is only used for check int values,
since when compare float values, we need consider float error."""
rel_fn = Rel.get_fns(rel)
type_mismatch = not isinstance(arg_value, int)
if type_mismatch or not rel_fn(arg_value, lower_limit, upper_limit):
rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit)
raise ValueError(f'The `{arg_name}` should be an int in range {rel_str}, but got {arg_value}.')
return arg_value
@staticmethod
def check_isinstance(arg_name, arg_value, classes):
"""Check arg isinstance of classes"""
if not isinstance(arg_value, classes):
raise ValueError(f'The `{arg_name}` should be isinstance of {classes}, but got {arg_value}.')
return arg_value
@staticmethod
def check_number_range(arg_name, arg_value, lower_limit, upper_limit, rel):
"""Is it necessary to consider error when comparing float values."""
rel_fn = Rel.get_fns(rel)
if not rel_fn(arg_value, lower_limit, upper_limit):
rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit)
raise ValueError(f'The `{arg_name}` should be in range {rel_str}, but got {arg_value}.')
return arg_value
@staticmethod
def check_subclass(arg_name, type_, template_type, with_type_of=True):
"""Check whether some type is subclass of another type"""
if not isinstance(template_type, Iterable):
template_type = (template_type,)
if not any([mstype.issubclass_(type_, x) for x in template_type]):
type_str = (type(type_).__name__ if isinstance(type_, (tuple, list)) else "") + str(type_)
raise TypeError(f'The {"type of" if with_type_of else ""} `{arg_name}` should be subclass'
f' of {",".join((str(x) for x in template_type))}, but got {type_str}.')
@staticmethod
def check_args_tensor(args):
"""Check whether args are all tensor."""
if not isinstance(args, dict):
raise TypeError("The args should be a dict.")
for arg, value in args.items():
ParamValidator.check_subclass(arg, value, mstype.tensor)
@staticmethod
def check_bool(arg_name, arg_value):
"""Check arg isinstance of bool"""
if not isinstance(arg_value, bool):
raise ValueError(f'The `{arg_name}` should be isinstance of bool, but got {arg_value}.')
return arg_value
@staticmethod
def check_type(arg_name, arg_value, valid_types):
"""Type checking."""
def raise_error_msg():
"""func for raising error message when check failed"""
type_names = [t.__name__ for t in valid_types]
num_types = len(valid_types)
raise TypeError(f'The type of `{arg_name}` should be {"one of " if num_types > 1 else ""}'
f'{type_names if num_types > 1 else type_names[0]}, but got {type(arg_value).__name__}.')
if isinstance(arg_value, type(mstype.tensor)):
arg_value = arg_value.element_type()
# Notice: bool is subclass of int, so `check_type('x', True, [int])` will check fail, and
# `check_type('x', True, [bool, int])` will check pass
if isinstance(arg_value, bool) and bool not in tuple(valid_types):
raise_error_msg()
if isinstance(arg_value, tuple(valid_types)):
return arg_value
raise_error_msg()
@staticmethod
def check_typename(arg_name, arg_type, valid_types):
"""Does it contain the _name_ attribute."""
def get_typename(t):
return t.__name__ if hasattr(t, '__name__') else str(t)
if isinstance(arg_type, type(mstype.tensor)):
arg_type = arg_type.element_type()
if arg_type in valid_types:
return arg_type
type_names = [get_typename(t) for t in valid_types]
if len(valid_types) == 1:
raise ValueError(f'The type of `{arg_name}` should be {type_names[0]},'
f' but got {get_typename(arg_type)}.')
raise ValueError(f'The type of `{arg_name}` should be one of {type_names},'
f' but got {get_typename(arg_type)}.')
@staticmethod
def check_string(arg_name, arg_value, valid_values):
"""String type judgment."""
if isinstance(arg_value, str) and arg_value in valid_values:
return arg_value
if len(valid_values) == 1:
raise ValueError(f'The `{arg_name}` should be str and must be {valid_values[0]},'
f' but got {arg_value}.')
raise ValueError(f'The `{arg_name}` should be str and must be one of {valid_values},'
f' but got {arg_value}.')
@staticmethod
def check_type_same(args, valid_values):
"""Determine whether the types are the same."""
name = list(args.keys())[0]
value = list(args.values())[0]
if isinstance(value, type(mstype.tensor)):
value = value.element_type()
for arg_name, arg_value in args.items():
if isinstance(arg_value, type(mstype.tensor)):
arg_value = arg_value.element_type()
if arg_value not in valid_values:
raise TypeError(f'The `{arg_name}` should be in {valid_values},'
f' but `{arg_name}` is {arg_value}.')
if arg_value != value:
raise TypeError(f'`{arg_name}` should be same as `{name}`,'
f' but `{arg_name}` is {arg_value}, `{name}` is {value}.')
@staticmethod
def check_two_types_same(arg1_name, arg1_type, arg2_name, arg2_type):
"""Determine whether the types of two variables are the same."""
if arg1_type != arg2_type:
raise TypeError(f'The type of `{arg1_name}` and `{arg2_name}` should be same.')
@staticmethod
def check_value_on_integer(arg_name, arg_value, value, rel):
"""Judging integer type."""
rel_fn = Rel.get_fns(rel)
type_match = isinstance(arg_value, int)
if type_match and (not rel_fn(arg_value, value)):
rel_str = Rel.get_strs(rel).format(value)
raise ValueError(f'The `{arg_name}` should be an int and must {rel_str}, but got {arg_value}.')
return arg_value
@staticmethod
def check_param_equal(param1_name, param1_value, param2_name, param2_value):
"""Judging the equality of parameters."""
if param1_value != param2_value:
raise ValueError(f"`{param1_name}` must equal `{param2_name}`,"
f" but got `{param1_name}` = {param1_value},"
f" `{param2_name}` = {param2_value}.")
@staticmethod
def check_const_input(arg_name, arg_value):
"""Check valid value."""
if arg_value is None:
raise ValueError(f'The `{arg_name}` must be a const input, but got {arg_value}.')
@staticmethod
def check_float_positive(arg_name, arg_value):
"""Float type judgment."""
if isinstance(arg_value, float):
if arg_value > 0:
return arg_value
raise ValueError(f"The `{arg_name}` must be positive, but got {arg_value}.")
raise TypeError(f"`{arg_name}` must be float!")
@staticmethod
def check_pad_value_by_mode(op_name, pad_mode, padding):
"""Validate value of padding according to pad_mode"""
if pad_mode != 'pad' and padding != 0:
raise ValueError(f"For op '{op_name}', padding must be zero when pad_mode is '{pad_mode}'.")
return padding
@staticmethod
def check_empty_shape_input(arg_name, arg_value):
"""Check zeros value."""
if 0 in arg_value:
raise ValueError(f"Input `{arg_name}` cannot be empty.")
@staticmethod
def check_scalar_shape_input(arg_name, arg_value):
"""Check scalar shape input."""
if arg_value != []:
raise ValueError(f"Input `{arg_name}` shape should be (). got {arg_value}")
def check_int(input_param): def check_int(input_param):
"""Int type judgment.""" """Int type judgment."""
if isinstance(input_param, int) and not isinstance(input_param, bool): if isinstance(input_param, int) and not isinstance(input_param, bool):
......
...@@ -201,6 +201,7 @@ void KernelRuntime::RunOpAssignOutputMemory(const AnfNodePtr &kernel) { ...@@ -201,6 +201,7 @@ void KernelRuntime::RunOpAssignOutputMemory(const AnfNodePtr &kernel) {
if (AnfAlgo::GetCNodeName(kernel) == "ApplyMomentum") { if (AnfAlgo::GetCNodeName(kernel) == "ApplyMomentum") {
auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, 0); auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, 0);
AnfAlgo::SetOutputAddr(device_address, 0, kernel.get()); AnfAlgo::SetOutputAddr(device_address, 0, kernel.get());
AnfAlgo::SetOutputAddr(device_address, 1, kernel.get());
return; return;
} }
......
...@@ -27,7 +27,6 @@ namespace kernel { ...@@ -27,7 +27,6 @@ namespace kernel {
constexpr auto kInitDataSetQueue = "InitDataSetQueue"; constexpr auto kInitDataSetQueue = "InitDataSetQueue";
constexpr auto kInitData = "InitData"; constexpr auto kInitData = "InitData";
constexpr auto kGetNext = "GetNext"; constexpr auto kGetNext = "GetNext";
constexpr auto kDropoutGenMask = "DropoutGenMask";
constexpr auto kPrint = "Print"; constexpr auto kPrint = "Print";
constexpr auto kOutputTypes = "output_types"; constexpr auto kOutputTypes = "output_types";
......
...@@ -55,7 +55,6 @@ MS_REG_GPU_KERNEL_ONE(BatchNorm, ...@@ -55,7 +55,6 @@ MS_REG_GPU_KERNEL_ONE(BatchNorm,
.AddOutputAttr(kNumberTypeFloat32) .AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32) .AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32) .AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32), .AddOutputAttr(kNumberTypeFloat32),
FusedBatchNormGpuKernel, float) FusedBatchNormGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(BatchNorm, MS_REG_GPU_KERNEL_ONE(BatchNorm,
...@@ -69,7 +68,6 @@ MS_REG_GPU_KERNEL_ONE(BatchNorm, ...@@ -69,7 +68,6 @@ MS_REG_GPU_KERNEL_ONE(BatchNorm,
.AddOutputAttr(kNumberTypeFloat16) .AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16) .AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16) .AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16), .AddOutputAttr(kNumberTypeFloat16),
FusedBatchNormGpuKernel, half) FusedBatchNormGpuKernel, half)
} // namespace kernel } // namespace kernel
......
...@@ -157,9 +157,6 @@ class FusedBatchNormGpuKernel : public GpuKernel { ...@@ -157,9 +157,6 @@ class FusedBatchNormGpuKernel : public GpuKernel {
output_size_list_.push_back(para_size); // running variance output_size_list_.push_back(para_size); // running variance
output_size_list_.push_back(para_size); // save mean output_size_list_.push_back(para_size); // save mean
output_size_list_.push_back(para_size); // save variance output_size_list_.push_back(para_size); // save variance
if (!is_train_) {
output_size_list_.push_back(para_size); // reserve
}
return; return;
} }
......
...@@ -30,6 +30,9 @@ namespace mindspore { ...@@ -30,6 +30,9 @@ namespace mindspore {
namespace kernel { namespace kernel {
namespace tbe { namespace tbe {
static std::map<string, string> tbe_func_adapter_map = { static std::map<string, string> tbe_func_adapter_map = {
{"softmax", "softmax_v2"},
{"log_softmax", "log_softmax_v2"},
{"apply_momentum", "apply_momentum_d"},
{"re_lu6", "relu6"}, {"re_lu6", "relu6"},
{"re_lu6_grad", "relu6_grad"}, {"re_lu6_grad", "relu6_grad"},
{"re_lu", "relu"}, {"re_lu", "relu"},
......
...@@ -344,8 +344,23 @@ bool IsNopNode(const AnfNodePtr &node) { ...@@ -344,8 +344,23 @@ bool IsNopNode(const AnfNodePtr &node) {
return true; return true;
} }
bool IsAllNopNode(session::KernelGraph *const graph) {
MS_EXCEPTION_IF_NULL(graph);
auto execution_order = graph->execution_order();
for (auto &cnode : execution_order) {
MS_EXCEPTION_IF_NULL(cnode);
if (!IsNopNode(cnode)) {
return false;
}
}
return true;
}
void HideNopNode(session::KernelGraph *const graph) { void HideNopNode(session::KernelGraph *const graph) {
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
if (IsAllNopNode(graph) == true) {
return;
}
auto execution_order = graph->execution_order(); auto execution_order = graph->execution_order();
MS_LOG(INFO) << "nop node info (Before Remove) size: " << execution_order.size(); MS_LOG(INFO) << "nop node info (Before Remove) size: " << execution_order.size();
std::vector<CNodePtr> new_nodes; std::vector<CNodePtr> new_nodes;
...@@ -361,6 +376,9 @@ void HideNopNode(session::KernelGraph *const graph) { ...@@ -361,6 +376,9 @@ void HideNopNode(session::KernelGraph *const graph) {
void RemoveNopNode(session::KernelGraph *const graph) { void RemoveNopNode(session::KernelGraph *const graph) {
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
if (IsAllNopNode(graph) == true) {
return;
}
bool changed = true; bool changed = true;
while (changed) { while (changed) {
changed = false; changed = false;
......
...@@ -177,6 +177,7 @@ const char kNameAbsGrad[] = "AbsGrad"; ...@@ -177,6 +177,7 @@ const char kNameAbsGrad[] = "AbsGrad";
const char kNameBinaryCrossEntropy[] = "BinaryCrossEntropy"; const char kNameBinaryCrossEntropy[] = "BinaryCrossEntropy";
const char kNameBinaryCrossEntropyGrad[] = "BinaryCrossEntropyGrad"; const char kNameBinaryCrossEntropyGrad[] = "BinaryCrossEntropyGrad";
const char kNameSparseApplyAdagrad[] = "SparseApplyAdagrad"; const char kNameSparseApplyAdagrad[] = "SparseApplyAdagrad";
const char kNameSparseApplyFtrlD[] = "SparseApplyFtrlD";
const char kNameAcosh[] = "Acosh"; const char kNameAcosh[] = "Acosh";
const char kNameAcoshGrad[] = "AcoshGrad"; const char kNameAcoshGrad[] = "AcoshGrad";
const char kNameFloorMod[] = "FloorMod"; const char kNameFloorMod[] = "FloorMod";
...@@ -206,7 +207,7 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma ...@@ -206,7 +207,7 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
{string(kNameMaxPool), ADPT_DESC(MaxPool)}, {string(kNameMaxPool), ADPT_DESC(MaxPool)},
{string(kNameAvgPool), ADPT_DESC(AvgPool)}, {string(kNameAvgPool), ADPT_DESC(AvgPool)},
{string(kNameMaxPoolWithArgmax), ADPT_DESC(MaxPoolWithArgmax)}, {string(kNameMaxPoolWithArgmax), ADPT_DESC(MaxPoolWithArgmax)},
{string(kNameTopK), ADPT_DESC(TopKV2)}, {string(kNameTopK), ADPT_DESC(TopK)},
{string(kNamePack), ADPT_DESC(Pack)}, {string(kNamePack), ADPT_DESC(Pack)},
{string(kNameUnpack), ADPT_DESC(Unpack)}, {string(kNameUnpack), ADPT_DESC(Unpack)},
{string(kNameSplitD), ADPT_DESC(SplitD)}, {string(kNameSplitD), ADPT_DESC(SplitD)},
...@@ -240,15 +241,15 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma ...@@ -240,15 +241,15 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
{string(kNameSquare), ADPT_DESC(Square)}, {string(kNameSquare), ADPT_DESC(Square)},
{prim::kPrimTanh->name(), ADPT_DESC(Tanh)}, {prim::kPrimTanh->name(), ADPT_DESC(Tanh)},
{prim::kPrimTanhGrad->name(), ADPT_DESC(TanhGrad)}, {prim::kPrimTanhGrad->name(), ADPT_DESC(TanhGrad)},
{string(kNameResizeNearestNeighborD), ADPT_DESC(ResizeNearestNeighborD)}, {string(kNameResizeNearestNeighborD), ADPT_DESC(ResizeNearestNeighborV2D)},
{string(kNameResizeNearestNeighborGrad), ADPT_DESC(ResizeNearestNeighborGrad)}, {string(kNameResizeNearestNeighborGrad), ADPT_DESC(ResizeNearestNeighborV2Grad)},
{string(kNameApplyAdam), ADPT_DESC(ApplyAdam)}, {string(kNameApplyAdam), ADPT_DESC(ApplyAdam)},
{string(kNameReLU6), ADPT_DESC(Relu6)}, {string(kNameReLU6), ADPT_DESC(Relu6)},
{string(kNameReLU6Grad), ADPT_DESC(Relu6Grad)}, {string(kNameReLU6Grad), ADPT_DESC(Relu6Grad)},
{string(kNameElu), ADPT_DESC(Elu)}, {string(kNameElu), ADPT_DESC(Elu)},
{string(kNameEluGrad), ADPT_DESC(EluGrad)}, {string(kNameEluGrad), ADPT_DESC(EluGrad)},
{string(kNameResizeBilinearGrad), ADPT_DESC(ResizeBilinearGrad)}, {string(kNameResizeBilinearGrad), ADPT_DESC(ResizeBilinearV2Grad)},
{string(kNameResizeBilinear), ADPT_DESC(ResizeBilinearD)}, {string(kNameResizeBilinear), ADPT_DESC(ResizeBilinearV2D)},
{string(kNameZerosLike), ADPT_DESC(ZerosLike)}, {string(kNameZerosLike), ADPT_DESC(ZerosLike)},
{string(kNameOnesLike), ADPT_DESC(OnesLike)}, {string(kNameOnesLike), ADPT_DESC(OnesLike)},
{string(kNameScatterNdUpdate), ADPT_DESC(ScatterNdUpdate)}, {string(kNameScatterNdUpdate), ADPT_DESC(ScatterNdUpdate)},
...@@ -329,7 +330,7 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma ...@@ -329,7 +330,7 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
{prim::kPrimMinimum->name(), ADPT_DESC(Minimum)}, {prim::kPrimMinimum->name(), ADPT_DESC(Minimum)},
{prim::kPrimSelect->name(), ADPT_DESC(Select)}, {prim::kPrimSelect->name(), ADPT_DESC(Select)},
{string(kNameLessEqual), ADPT_DESC(LessEqual)}, {string(kNameLessEqual), ADPT_DESC(LessEqual)},
{prim::kPrimLogSoftmax->name(), ADPT_DESC(LogSoftmax)}, {prim::kPrimLogSoftmax->name(), ADPT_DESC(LogSoftmaxV2)},
{string(kNameTruncatedNormal), ADPT_DESC(TruncatedNormal)}, {string(kNameTruncatedNormal), ADPT_DESC(TruncatedNormal)},
{string(kNameStridedSliceGrad), ADPT_DESC(StridedSliceGrad)}, {string(kNameStridedSliceGrad), ADPT_DESC(StridedSliceGrad)},
{prim::kPrimGelu->name(), ADPT_DESC(Gelu)}, {prim::kPrimGelu->name(), ADPT_DESC(Gelu)},
...@@ -363,7 +364,7 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma ...@@ -363,7 +364,7 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
{prim::kPrimMatMul->name(), ADPT_DESC(MatMul)}, {prim::kPrimMatMul->name(), ADPT_DESC(MatMul)},
{string(kNameConst), ADPT_DESC(Constant, Const)}, {string(kNameConst), ADPT_DESC(Constant, Const)},
{string(kNameSoftmax), ADPT_DESC(Softmax)}, {string(kNameSoftmax), ADPT_DESC(SoftmaxV2)},
{string(kNameSoftmaxGrad), ADPT_DESC(SoftmaxGrad)}, {string(kNameSoftmaxGrad), ADPT_DESC(SoftmaxGrad)},
{string(kNameParam), ADPT_DESC(Data)}, {string(kNameParam), ADPT_DESC(Data)},
{string(kNameROIAlign), ADPT_DESC(ROIAlign)}, {string(kNameROIAlign), ADPT_DESC(ROIAlign)},
...@@ -373,6 +374,7 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma ...@@ -373,6 +374,7 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
{string(kNameBinaryCrossEntropy), ADPT_DESC(BinaryCrossEntropy)}, {string(kNameBinaryCrossEntropy), ADPT_DESC(BinaryCrossEntropy)},
{string(kNameBinaryCrossEntropyGrad), ADPT_DESC(BinaryCrossEntropyGrad)}, {string(kNameBinaryCrossEntropyGrad), ADPT_DESC(BinaryCrossEntropyGrad)},
{string(kNameSparseApplyAdagrad), ADPT_DESC(SparseApplyAdagradD)}, {string(kNameSparseApplyAdagrad), ADPT_DESC(SparseApplyAdagradD)},
{string(kNameSparseApplyFtrlD), ADPT_DESC(SparseApplyFtrlD)},
{string(kNameAcosh), ADPT_DESC(Acosh)}, {string(kNameAcosh), ADPT_DESC(Acosh)},
{string(kNameAcoshGrad), ADPT_DESC(AcoshGrad)}, {string(kNameAcoshGrad), ADPT_DESC(AcoshGrad)},
{string(kNameFloorMod), ADPT_DESC(FloorMod)}, {string(kNameFloorMod), ADPT_DESC(FloorMod)},
...@@ -390,6 +392,7 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma ...@@ -390,6 +392,7 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
{string(kNameApplyCenteredRMSProp), ADPT_DESC(ApplyCenteredRMSProp)}}; {string(kNameApplyCenteredRMSProp), ADPT_DESC(ApplyCenteredRMSProp)}};
#ifdef ENABLE_GE #ifdef ENABLE_GE
adpt_map[string(kNamePrint)] = ADPT_DESC(Print); adpt_map[string(kNamePrint)] = ADPT_DESC(Print);
adpt_map[string(kNameApplyAdam)] = ADPT_DESC(ApplyAdamD);
#endif #endif
return adpt_map; return adpt_map;
} }
...@@ -1127,8 +1130,8 @@ void DfGraphConvertor::UpdateDataOpDesc(const AnfNodePtr &it, const OperatorPtr ...@@ -1127,8 +1130,8 @@ void DfGraphConvertor::UpdateDataOpDesc(const AnfNodePtr &it, const OperatorPtr
if (desc == nullptr) { if (desc == nullptr) {
MS_LOG(ERROR) << "Update data op descriptor failed! TensorDesc is null."; MS_LOG(ERROR) << "Update data op descriptor failed! TensorDesc is null.";
} else { } else {
(void)std::static_pointer_cast<Data>(op)->update_input_desc_data(*desc); (void)std::static_pointer_cast<Data>(op)->update_input_desc_x(*desc);
(void)std::static_pointer_cast<Data>(op)->update_output_desc_out(*desc); (void)std::static_pointer_cast<Data>(op)->update_output_desc_y(*desc);
} }
} }
......
...@@ -736,7 +736,7 @@ class OpAdapter : public BaseOpAdapter { ...@@ -736,7 +736,7 @@ class OpAdapter : public BaseOpAdapter {
return static_cast<int64_t>(GetValue<int>(value)); return static_cast<int64_t>(GetValue<int>(value));
} }
// specialization for int to Vector // specialization for int or tuple broadcast to Vector
static std::vector<int64_t> ConvertAny(const ValuePtr &value, const std::string &name, static std::vector<int64_t> ConvertAny(const ValuePtr &value, const std::string &name,
const AnyTraits<std::vector<int64_t>> anyTraitsInt) { const AnyTraits<std::vector<int64_t>> anyTraitsInt) {
return ConvertAnyUtil(value, name, anyTraitsInt); return ConvertAnyUtil(value, name, anyTraitsInt);
......
...@@ -35,15 +35,21 @@ GeTensor ConvertAnyUtil(const ValuePtr &value, const AnyTraits<mindspore::tensor ...@@ -35,15 +35,21 @@ GeTensor ConvertAnyUtil(const ValuePtr &value, const AnyTraits<mindspore::tensor
std::vector<int64_t> ConvertAnyUtil(const ValuePtr &value, const std::string &name, std::vector<int64_t> ConvertAnyUtil(const ValuePtr &value, const std::string &name,
const AnyTraits<std::vector<int64_t>>) { const AnyTraits<std::vector<int64_t>>) {
int64_t data = GetValue<int>(value); MS_EXCEPTION_IF_NULL(value);
std::vector<int64_t> list; std::vector<int64_t> list;
int size = 2; // 2 int in list
if (name == "pad") { if (name == "pad") {
size = 4; // 4 int in list if (!value->isa<ValueSequeue>()) {
list = TransformUtil::ConvertIntToList(data, size); MS_LOG(EXCEPTION) << "Value should be ValueTuple, but got" << value->type_name();
}
auto vec = value->cast<ValueSequeuePtr>();
list.resize(vec->value().size() + 2);
list[0] = 1; list[0] = 1;
list[1] = 1; list[1] = 1;
(void)std::transform(vec->value().begin(), vec->value().end(), list.begin() + 2,
[](const ValuePtr &val) { return static_cast<int64_t>(GetValue<int>(val)); });
} else { } else {
int64_t data = GetValue<int>(value);
int size = 2; // 2 int in list
list = TransformUtil::ConvertIntToList(data, size); list = TransformUtil::ConvertIntToList(data, size);
} }
......
...@@ -114,20 +114,22 @@ DECLARE_OP_ADAPTER(Reshape) ...@@ -114,20 +114,22 @@ DECLARE_OP_ADAPTER(Reshape)
DECLARE_OP_USE_OUTPUT(Reshape) DECLARE_OP_USE_OUTPUT(Reshape)
DECLARE_OP_ADAPTER(Iou) DECLARE_OP_ADAPTER(Iou)
DECLARE_OP_USE_OUTPUT(Iou) DECLARE_OP_USE_OUTPUT(Iou)
DECLARE_OP_ADAPTER(ResizeNearestNeighborD) DECLARE_OP_ADAPTER(ResizeNearestNeighborV2D)
DECLARE_OP_USE_OUTPUT(ResizeNearestNeighborD) DECLARE_OP_USE_OUTPUT(ResizeNearestNeighborV2D)
DECLARE_OP_ADAPTER(ResizeNearestNeighborGrad) DECLARE_OP_ADAPTER(ResizeNearestNeighborV2Grad)
DECLARE_OP_USE_OUTPUT(ResizeNearestNeighborGrad) DECLARE_OP_USE_OUTPUT(ResizeNearestNeighborV2Grad)
DECLARE_OP_ADAPTER(ApplyAdam) DECLARE_OP_ADAPTER(ApplyAdam)
DECLARE_OP_USE_OUTPUT(ApplyAdam) DECLARE_OP_USE_OUTPUT(ApplyAdam)
DECLARE_OP_ADAPTER(ApplyAdamD)
DECLARE_OP_USE_OUTPUT(ApplyAdamD)
DECLARE_OP_ADAPTER(Relu6) DECLARE_OP_ADAPTER(Relu6)
DECLARE_OP_USE_OUTPUT(Relu6) DECLARE_OP_USE_OUTPUT(Relu6)
DECLARE_OP_ADAPTER(Relu6Grad) DECLARE_OP_ADAPTER(Relu6Grad)
DECLARE_OP_USE_OUTPUT(Relu6Grad) DECLARE_OP_USE_OUTPUT(Relu6Grad)
DECLARE_OP_ADAPTER(ResizeBilinearD) DECLARE_OP_ADAPTER(ResizeBilinearV2D)
DECLARE_OP_USE_OUTPUT(ResizeBilinearD) DECLARE_OP_USE_OUTPUT(ResizeBilinearV2D)
DECLARE_OP_ADAPTER(ResizeBilinearGrad) DECLARE_OP_ADAPTER(ResizeBilinearV2Grad)
DECLARE_OP_USE_OUTPUT(ResizeBilinearGrad) DECLARE_OP_USE_OUTPUT(ResizeBilinearV2Grad)
DECLARE_OP_ADAPTER(ZerosLike) DECLARE_OP_ADAPTER(ZerosLike)
DECLARE_OP_USE_OUTPUT(ZerosLike) DECLARE_OP_USE_OUTPUT(ZerosLike)
DECLARE_OP_ADAPTER(OnesLike) DECLARE_OP_ADAPTER(OnesLike)
...@@ -213,8 +215,8 @@ DECLARE_OP_USE_OUTPUT(Merge) ...@@ -213,8 +215,8 @@ DECLARE_OP_USE_OUTPUT(Merge)
DECLARE_OP_ADAPTER(Switch) DECLARE_OP_ADAPTER(Switch)
DECLARE_OP_USE_OUTPUT(Switch) DECLARE_OP_USE_OUTPUT(Switch)
DECLARE_OP_ADAPTER(TopKV2) DECLARE_OP_ADAPTER(TopK)
DECLARE_OP_USE_OUTPUT(TopKV2) DECLARE_OP_USE_OUTPUT(TopK)
DECLARE_OP_ADAPTER(RealDiv) DECLARE_OP_ADAPTER(RealDiv)
DECLARE_OP_USE_OUTPUT(RealDiv) DECLARE_OP_USE_OUTPUT(RealDiv)
...@@ -264,8 +266,8 @@ DECLARE_OP_ADAPTER(Select) ...@@ -264,8 +266,8 @@ DECLARE_OP_ADAPTER(Select)
DECLARE_OP_USE_OUTPUT(Select) DECLARE_OP_USE_OUTPUT(Select)
DECLARE_OP_ADAPTER(LessEqual) DECLARE_OP_ADAPTER(LessEqual)
DECLARE_OP_USE_OUTPUT(LessEqual) DECLARE_OP_USE_OUTPUT(LessEqual)
DECLARE_OP_ADAPTER(LogSoftmax) DECLARE_OP_ADAPTER(LogSoftmaxV2)
DECLARE_OP_USE_OUTPUT(LogSoftmax) DECLARE_OP_USE_OUTPUT(LogSoftmaxV2)
DECLARE_OP_ADAPTER(TruncatedNormal) DECLARE_OP_ADAPTER(TruncatedNormal)
DECLARE_OP_USE_OUTPUT(TruncatedNormal) DECLARE_OP_USE_OUTPUT(TruncatedNormal)
DECLARE_OP_ADAPTER(StridedSliceGrad) DECLARE_OP_ADAPTER(StridedSliceGrad)
...@@ -400,8 +402,8 @@ DECLARE_OP_ADAPTER(Sigmoid) ...@@ -400,8 +402,8 @@ DECLARE_OP_ADAPTER(Sigmoid)
DECLARE_OP_USE_OUTPUT(Sigmoid) DECLARE_OP_USE_OUTPUT(Sigmoid)
DECLARE_OP_ADAPTER(SigmoidGrad) DECLARE_OP_ADAPTER(SigmoidGrad)
DECLARE_OP_USE_OUTPUT(SigmoidGrad) DECLARE_OP_USE_OUTPUT(SigmoidGrad)
DECLARE_OP_ADAPTER(Softmax) DECLARE_OP_ADAPTER(SoftmaxV2)
DECLARE_OP_USE_OUTPUT(Softmax) DECLARE_OP_USE_OUTPUT(SoftmaxV2)
DECLARE_OP_ADAPTER(SoftmaxGrad) DECLARE_OP_ADAPTER(SoftmaxGrad)
DECLARE_OP_USE_OUTPUT(SoftmaxGrad) DECLARE_OP_USE_OUTPUT(SoftmaxGrad)
DECLARE_OP_ADAPTER(Greater) DECLARE_OP_ADAPTER(Greater)
...@@ -444,6 +446,8 @@ DECLARE_OP_ADAPTER(Round) ...@@ -444,6 +446,8 @@ DECLARE_OP_ADAPTER(Round)
DECLARE_OP_USE_OUTPUT(Round) DECLARE_OP_USE_OUTPUT(Round)
DECLARE_OP_ADAPTER(ApplyFtrl) DECLARE_OP_ADAPTER(ApplyFtrl)
DECLARE_OP_USE_OUTPUT(ApplyFtrl) DECLARE_OP_USE_OUTPUT(ApplyFtrl)
DECLARE_OP_ADAPTER(SparseApplyFtrlD)
DECLARE_OP_USE_OUTPUT(SparseApplyFtrlD)
DECLARE_OP_ADAPTER(Diag) DECLARE_OP_ADAPTER(Diag)
DECLARE_OP_USE_OUTPUT(Diag) DECLARE_OP_USE_OUTPUT(Diag)
DECLARE_OP_ADAPTER(DiagPart) DECLARE_OP_ADAPTER(DiagPart)
......
...@@ -31,6 +31,7 @@ ...@@ -31,6 +31,7 @@
namespace mindspore { namespace mindspore {
const char kShapeSeperator[] = ","; const char kShapeSeperator[] = ",";
const char kShapeScalar[] = "[0]";
static std::map<std::string, TypeId> print_type_map = { static std::map<std::string, TypeId> print_type_map = {
{"int8_t", TypeId::kNumberTypeInt8}, {"uint8_t", TypeId::kNumberTypeUInt8}, {"int8_t", TypeId::kNumberTypeInt8}, {"uint8_t", TypeId::kNumberTypeUInt8},
{"int16_t", TypeId::kNumberTypeInt16}, {"uint16_t", TypeId::kNumberTypeUInt16}, {"int16_t", TypeId::kNumberTypeInt16}, {"uint16_t", TypeId::kNumberTypeUInt16},
...@@ -81,6 +82,73 @@ bool PrintTensorToString(const char *str_data_ptr, mindspore::tensor::Tensor *co ...@@ -81,6 +82,73 @@ bool PrintTensorToString(const char *str_data_ptr, mindspore::tensor::Tensor *co
return true; return true;
} }
template <typename T>
void PrintScalarToString(const char *str_data_ptr, const string &tensor_type) {
const T *data_ptr = reinterpret_cast<const T *>(str_data_ptr);
std::ostringstream buf_scalar;
buf_scalar << "Tensor shape :1 " << tensor_type;
buf_scalar << "\nval:";
buf_scalar << *data_ptr;
std::cout << buf_scalar.str() << std::endl;
}
void PrintScalarToBoolString(const char *str_data_ptr, const string &tensor_type) {
const bool *data_ptr = reinterpret_cast<const bool *>(str_data_ptr);
std::ostringstream buf_scalar;
buf_scalar << "Tensor shape :1 " << tensor_type;
buf_scalar << "\nval:";
if (*data_ptr == true) {
buf_scalar << "True";
} else {
buf_scalar << "False";
}
std::cout << buf_scalar.str() << std::endl;
}
void convertDataItem2Scalar(const char *str_data_ptr, const string &tensor_type) {
auto type_iter = print_type_map.find(tensor_type);
auto type_id = type_iter->second;
if (type_id == TypeId::kNumberTypeBool) {
PrintScalarToBoolString(str_data_ptr, tensor_type);
} else if (type_id == TypeId::kNumberTypeInt8) {
PrintScalarToString<int8_t>(str_data_ptr, tensor_type);
} else if (type_id == TypeId::kNumberTypeUInt8) {
PrintScalarToString<uint8_t>(str_data_ptr, tensor_type);
} else if (type_id == TypeId::kNumberTypeInt16) {
PrintScalarToString<int16_t>(str_data_ptr, tensor_type);
} else if (type_id == TypeId::kNumberTypeUInt16) {
PrintScalarToString<uint16_t>(str_data_ptr, tensor_type);
} else if (type_id == TypeId::kNumberTypeInt32) {
PrintScalarToString<int32_t>(str_data_ptr, tensor_type);
} else if (type_id == TypeId::kNumberTypeUInt32) {
PrintScalarToString<uint32_t>(str_data_ptr, tensor_type);
} else if (type_id == TypeId::kNumberTypeInt64) {
PrintScalarToString<int64_t>(str_data_ptr, tensor_type);
} else if (type_id == TypeId::kNumberTypeUInt64) {
PrintScalarToString<uint64_t>(str_data_ptr, tensor_type);
} else if (type_id == TypeId::kNumberTypeFloat16) {
PrintScalarToString<float16>(str_data_ptr, tensor_type);
} else if (type_id == TypeId::kNumberTypeFloat32) {
PrintScalarToString<float>(str_data_ptr, tensor_type);
} else if (type_id == TypeId::kNumberTypeFloat64) {
PrintScalarToString<double>(str_data_ptr, tensor_type);
} else {
MS_LOG(EXCEPTION) << "Cannot print scalar because of unsupport data type: " << tensor_type << ".";
}
} // namespace mindspore
bool judgeLengthValid(const size_t str_len, const string &tensor_type) {
auto type_iter = type_size_map.find(tensor_type);
if (type_iter == type_size_map.end()) {
MS_LOG(EXCEPTION) << "type of scalar to print is not support.";
}
if (str_len != type_iter->second) {
return false;
}
return true;
}
#ifndef NO_DLIB #ifndef NO_DLIB
bool ConvertDataItem2Tensor(const std::vector<tdt::DataItem> &items) { bool ConvertDataItem2Tensor(const std::vector<tdt::DataItem> &items) {
// Acquire Python GIL // Acquire Python GIL
...@@ -92,14 +160,22 @@ bool ConvertDataItem2Tensor(const std::vector<tdt::DataItem> &items) { ...@@ -92,14 +160,22 @@ bool ConvertDataItem2Tensor(const std::vector<tdt::DataItem> &items) {
ret_end_sequence = true; ret_end_sequence = true;
break; break;
} }
std::shared_ptr<std::string> str_data_ptr = std::static_pointer_cast<std::string>(item.dataPtr_);
MS_EXCEPTION_IF_NULL(str_data_ptr);
if (item.tensorShape_ == kShapeScalar) {
if (!judgeLengthValid(str_data_ptr->size(), item.tensorType_)) {
MS_LOG(EXCEPTION) << "Print op receive data length is invalid.";
}
convertDataItem2Scalar(str_data_ptr->data(), item.tensorType_);
continue;
}
std::vector<int> tensor_shape; std::vector<int> tensor_shape;
size_t totaldims = 1; size_t totaldims = 1;
if (!ParseTensorShape(item.tensorShape_, &tensor_shape, &totaldims)) { if (!ParseTensorShape(item.tensorShape_, &tensor_shape, &totaldims)) {
MS_LOG(ERROR) << "Tensor print can not parse tensor shape, receive info" << item.tensorShape_; MS_LOG(ERROR) << "Tensor print can not parse tensor shape, receive info" << item.tensorShape_;
continue; continue;
} }
std::shared_ptr<std::string> str_data_ptr = std::static_pointer_cast<std::string>(item.dataPtr_);
MS_EXCEPTION_IF_NULL(str_data_ptr);
if (item.tensorType_ == "string") { if (item.tensorType_ == "string") {
std::string data(reinterpret_cast<const char *>(str_data_ptr->c_str()), item.dataLen_); std::string data(reinterpret_cast<const char *>(str_data_ptr->c_str()), item.dataLen_);
......
...@@ -377,12 +377,10 @@ def get_bprop_batch_norm(self): ...@@ -377,12 +377,10 @@ def get_bprop_batch_norm(self):
if is_training: if is_training:
saved_reserve_1 = out[3] saved_reserve_1 = out[3]
saved_reserve_2 = out[4] saved_reserve_2 = out[4]
saved_reserve_3 = out[5]
else: else:
saved_reserve_1 = mean saved_reserve_1 = mean
saved_reserve_2 = variance saved_reserve_2 = variance
saved_reserve_3 = variance out = input_grad(dout[0], x, scale, saved_reserve_1, saved_reserve_2)
out = input_grad(dout[0], x, scale, saved_reserve_1, saved_reserve_2, saved_reserve_3)
dx = out[0] dx = out[0]
dscale = out[1] dscale = out[1]
dbias = out[2] dbias = out[2]
......
...@@ -18,3 +18,8 @@ from .dropout_genmask import _dropout_genmask_aicpu ...@@ -18,3 +18,8 @@ from .dropout_genmask import _dropout_genmask_aicpu
from .get_next import _get_next_aicpu from .get_next import _get_next_aicpu
from .print_tensor import _print_aicpu from .print_tensor import _print_aicpu
from .topk import _top_k_aicpu from .topk import _top_k_aicpu
from .is_finite import _is_finite_aicpu
from .reshape import _reshape_aicpu
from .flatten import _flatten_aicpu
from .squeeze import _squeeze_aicpu
from .expand_dims import _expand_dims_aicpu
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ExpandDims op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
expand_dims_op_info = AiCPURegOp("ExpandDims") \
.fusion_type("OPAQUE") \
.input(0, "x", "required") \
.output(0, "y", "required") \
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.I16_Default, DataType.I16_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.U16_Default, DataType.U16_Default) \
.dtype_format(DataType.U32_Default, DataType.U32_Default) \
.dtype_format(DataType.U64_Default, DataType.U64_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F64_Default, DataType.F64_Default) \
.dtype_format(DataType.BOOL_NCHW, DataType.BOOL_NCHW) \
.dtype_format(DataType.I8_NCHW, DataType.I8_NCHW) \
.dtype_format(DataType.I16_NCHW, DataType.I16_NCHW) \
.dtype_format(DataType.I32_NCHW, DataType.I32_NCHW) \
.dtype_format(DataType.I64_NCHW, DataType.I64_NCHW) \
.dtype_format(DataType.U8_NCHW, DataType.U8_NCHW) \
.dtype_format(DataType.U16_NCHW, DataType.U16_NCHW) \
.dtype_format(DataType.U32_NCHW, DataType.U32_NCHW) \
.dtype_format(DataType.U64_NCHW, DataType.U64_NCHW) \
.dtype_format(DataType.F16_NCHW, DataType.F16_NCHW) \
.dtype_format(DataType.F32_NCHW, DataType.F32_NCHW) \
.dtype_format(DataType.F64_NCHW, DataType.F64_NCHW) \
.get_op_info()
@op_info_register(expand_dims_op_info)
def _expand_dims_aicpu():
"""ExpandDims AiCPU register"""
return
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Flatten op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
flatten_op_info = AiCPURegOp("Flatten") \
.fusion_type("OPAQUE") \
.input(0, "x", "required") \
.output(0, "y", "required") \
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.I16_Default, DataType.I16_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.U16_Default, DataType.U16_Default) \
.dtype_format(DataType.U32_Default, DataType.U32_Default) \
.dtype_format(DataType.U64_Default, DataType.U64_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.I8_NCHW, DataType.I8_NCHW) \
.dtype_format(DataType.I16_NCHW, DataType.I16_NCHW) \
.dtype_format(DataType.I32_NCHW, DataType.I32_NCHW) \
.dtype_format(DataType.I64_NCHW, DataType.I64_NCHW) \
.dtype_format(DataType.U8_NCHW, DataType.U8_NCHW) \
.dtype_format(DataType.U16_NCHW, DataType.U16_NCHW) \
.dtype_format(DataType.U32_NCHW, DataType.U32_NCHW) \
.dtype_format(DataType.U64_NCHW, DataType.U64_NCHW) \
.dtype_format(DataType.F16_NCHW, DataType.F16_NCHW) \
.dtype_format(DataType.F32_NCHW, DataType.F32_NCHW) \
.get_op_info()
@op_info_register(flatten_op_info)
def _flatten_aicpu():
"""Flatten AiCPU register"""
return
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""IsFinite op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
is_finite_op_info = AiCPURegOp("IsFinite") \
.fusion_type("OPAQUE") \
.input(0, "x", "required") \
.output(0, "y", "required") \
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \
.dtype_format(DataType.I8_Default, DataType.BOOL_Default) \
.dtype_format(DataType.I16_Default, DataType.BOOL_Default) \
.dtype_format(DataType.I32_Default, DataType.BOOL_Default) \
.dtype_format(DataType.I64_Default, DataType.BOOL_Default) \
.dtype_format(DataType.U8_Default, DataType.BOOL_Default) \
.dtype_format(DataType.U16_Default, DataType.BOOL_Default) \
.dtype_format(DataType.U32_Default, DataType.BOOL_Default) \
.dtype_format(DataType.U64_Default, DataType.BOOL_Default) \
.dtype_format(DataType.F16_Default, DataType.BOOL_Default) \
.dtype_format(DataType.F32_Default, DataType.BOOL_Default) \
.dtype_format(DataType.F64_Default, DataType.BOOL_Default) \
.dtype_format(DataType.BOOL_NCHW, DataType.BOOL_NCHW) \
.dtype_format(DataType.I8_NCHW, DataType.BOOL_NCHW) \
.dtype_format(DataType.I16_NCHW, DataType.BOOL_NCHW) \
.dtype_format(DataType.I32_NCHW, DataType.BOOL_NCHW) \
.dtype_format(DataType.I64_NCHW, DataType.BOOL_NCHW) \
.dtype_format(DataType.U8_NCHW, DataType.BOOL_NCHW) \
.dtype_format(DataType.U16_NCHW, DataType.BOOL_NCHW) \
.dtype_format(DataType.U32_NCHW, DataType.BOOL_NCHW) \
.dtype_format(DataType.U64_NCHW, DataType.BOOL_NCHW) \
.dtype_format(DataType.F16_NCHW, DataType.BOOL_NCHW) \
.dtype_format(DataType.F32_NCHW, DataType.BOOL_NCHW) \
.dtype_format(DataType.F64_NCHW, DataType.BOOL_NCHW) \
.get_op_info()
@op_info_register(is_finite_op_info)
def _is_finite_aicpu():
"""IsFinite AiCPU register"""
return
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Reshape op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
reshape_op_info = AiCPURegOp("Reshape") \
.fusion_type("OPAQUE") \
.input(0, "x", "required") \
.output(0, "y", "required") \
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.I16_Default, DataType.I16_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.U16_Default, DataType.U16_Default) \
.dtype_format(DataType.U32_Default, DataType.U32_Default) \
.dtype_format(DataType.U64_Default, DataType.U64_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F64_Default, DataType.F64_Default) \
.dtype_format(DataType.BOOL_NCHW, DataType.BOOL_NCHW) \
.dtype_format(DataType.I8_NCHW, DataType.I8_NCHW) \
.dtype_format(DataType.I16_NCHW, DataType.I16_NCHW) \
.dtype_format(DataType.I32_NCHW, DataType.I32_NCHW) \
.dtype_format(DataType.I64_NCHW, DataType.I64_NCHW) \
.dtype_format(DataType.U8_NCHW, DataType.U8_NCHW) \
.dtype_format(DataType.U16_NCHW, DataType.U16_NCHW) \
.dtype_format(DataType.U32_NCHW, DataType.U32_NCHW) \
.dtype_format(DataType.U64_NCHW, DataType.U64_NCHW) \
.dtype_format(DataType.F16_NCHW, DataType.F16_NCHW) \
.dtype_format(DataType.F32_NCHW, DataType.F32_NCHW) \
.dtype_format(DataType.F64_NCHW, DataType.F64_NCHW) \
.get_op_info()
@op_info_register(reshape_op_info)
def _reshape_aicpu():
"""Rpeshape AiCPU register"""
return
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Squeeze op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
squeeze_op_info = AiCPURegOp("Squeeze") \
.fusion_type("OPAQUE") \
.input(0, "x", "required") \
.output(0, "y", "required") \
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.I16_Default, DataType.I16_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.U16_Default, DataType.U16_Default) \
.dtype_format(DataType.U32_Default, DataType.U32_Default) \
.dtype_format(DataType.U64_Default, DataType.U64_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F64_Default, DataType.F64_Default) \
.dtype_format(DataType.BOOL_NCHW, DataType.BOOL_NCHW) \
.dtype_format(DataType.I8_NCHW, DataType.I8_NCHW) \
.dtype_format(DataType.I16_NCHW, DataType.I16_NCHW) \
.dtype_format(DataType.I32_NCHW, DataType.I32_NCHW) \
.dtype_format(DataType.I64_NCHW, DataType.I64_NCHW) \
.dtype_format(DataType.U8_NCHW, DataType.U8_NCHW) \
.dtype_format(DataType.U16_NCHW, DataType.U16_NCHW) \
.dtype_format(DataType.U32_NCHW, DataType.U32_NCHW) \
.dtype_format(DataType.U64_NCHW, DataType.U64_NCHW) \
.dtype_format(DataType.F16_NCHW, DataType.F16_NCHW) \
.dtype_format(DataType.F32_NCHW, DataType.F32_NCHW) \
.dtype_format(DataType.F64_NCHW, DataType.F64_NCHW) \
.get_op_info()
@op_info_register(squeeze_op_info)
def _squeeze_aicpu():
"""Squeeze AiCPU register"""
return
...@@ -61,9 +61,6 @@ from .reduce_mean_d import _reduce_mean_d_tbe ...@@ -61,9 +61,6 @@ from .reduce_mean_d import _reduce_mean_d_tbe
from .scatter_nd import _scatter_nd_tbe from .scatter_nd import _scatter_nd_tbe
from .scatter_nd_d import _scatter_nd_d_tbe from .scatter_nd_d import _scatter_nd_d_tbe
from .reduce_mean import _reduce_mean_tbe from .reduce_mean import _reduce_mean_tbe
from .reshape import _reshape_tbe
from .expand_dims import _expand_dims_tbe
from .squeeze import _squeeze_tbe
from .tile import _tile_tbe from .tile import _tile_tbe
from .atomic_addr_clean import _atomic_addr_clean_tbe from .atomic_addr_clean import _atomic_addr_clean_tbe
from .gather_v2 import _gather_v2_tbe from .gather_v2 import _gather_v2_tbe
......
...@@ -30,22 +30,23 @@ apply_momentum_op_info = TBERegOp("ApplyMomentum") \ ...@@ -30,22 +30,23 @@ apply_momentum_op_info = TBERegOp("ApplyMomentum") \
.input(3, "grad", False, "required", "all") \ .input(3, "grad", False, "required", "all") \
.input(4, "momentum", False, "required", "all") \ .input(4, "momentum", False, "required", "all") \
.output(0, "var", False, "required", "all") \ .output(0, "var", False, "required", "all") \
.output(1, "accum", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default, DataType.F16_Default) \ DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_Default, DataType.F16_5HD, .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_Default, DataType.F16_5HD,
DataType.F16_Default, DataType.F16_5HD) \ DataType.F16_Default, DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_Default, DataType.F16_C1HWNCoC0, .dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_Default, DataType.F16_C1HWNCoC0,
DataType.F16_Default, DataType.F16_C1HWNCoC0) \ DataType.F16_Default, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \
.dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_Default, DataType.F16_FracZ, .dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_Default, DataType.F16_FracZ,
DataType.F16_Default, DataType.F16_FracZ) \ DataType.F16_Default, DataType.F16_FracZ, DataType.F16_FracZ) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default) \ DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_Default, DataType.F32_5HD, .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_Default, DataType.F32_5HD,
DataType.F32_Default, DataType.F32_5HD) \ DataType.F32_Default, DataType.F32_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_Default, DataType.F32_C1HWNCoC0, .dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_Default, DataType.F32_C1HWNCoC0,
DataType.F32_Default, DataType.F32_C1HWNCoC0) \ DataType.F32_Default, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_Default, DataType.F32_FracZ, .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_Default, DataType.F32_FracZ,
DataType.F32_Default, DataType.F32_FracZ) \ DataType.F32_Default, DataType.F32_FracZ, DataType.F32_FracZ) \
.get_op_info() .get_op_info()
......
...@@ -36,19 +36,18 @@ batch_norm_op_info = TBERegOp("BatchNorm") \ ...@@ -36,19 +36,18 @@ batch_norm_op_info = TBERegOp("BatchNorm") \
.output(2, "batch_variance", False, "required", "all") \ .output(2, "batch_variance", False, "required", "all") \
.output(3, "reserve_space_1", False, "optional", "all") \ .output(3, "reserve_space_1", False, "optional", "all") \
.output(4, "reserve_space_2", False, "optional", "all") \ .output(4, "reserve_space_2", False, "optional", "all") \
.output(5, "reserve_space_3", False, "optional", "all") \
.dtype_format(DataType.F16_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, .dtype_format(DataType.F16_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F16_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F16_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F16_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, .dtype_format(DataType.F16_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
DataType.F32_5HD, DataType.F16_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F16_5HD, DataType.F32_5HD, DataType.F32_5HD,
DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ DataType.F32_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ DataType.F32_5HD, DataType.F32_5HD) \
.get_op_info() .get_op_info()
......
...@@ -599,4 +599,13 @@ class DataType: ...@@ -599,4 +599,13 @@ class DataType:
F32_NCHW = ("float32", "NCHW") F32_NCHW = ("float32", "NCHW")
F32_NHWC = ("float32", "NHWC") F32_NHWC = ("float32", "NHWC")
F32_HWCN = ("float32", "HWCN") F32_HWCN = ("float32", "HWCN")
\ No newline at end of file F64_None = ("float64", "")
F64_Default = ("float64", "DefaultFormat")
F64_5HD = ("float64", "NC1HWC0")
F64_FracZ = ("float64", "FracZ")
F64_FracNZ = ("float64", "FRACTAL_NZ")
F64_C1HWNCoC0 = ("float64", "C1HWNCoC0")
F64_NCHW = ("float64", "NCHW")
F64_NHWC = ("float64", "NHWC")
F64_HWCN = ("float64", "HWCN")
...@@ -85,11 +85,11 @@ class BatchNormGrad(PrimitiveWithInfer): ...@@ -85,11 +85,11 @@ class BatchNormGrad(PrimitiveWithInfer):
self.epsilon = validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT, self.name) self.epsilon = validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT, self.name)
self.add_prim_attr('data_format', "NCHW") self.add_prim_attr('data_format', "NCHW")
def infer_shape(self, y_backprop_shape, x_shape, scale_shape, reserve_1_shape, reserve_2_shape, reserve_3_shape): def infer_shape(self, y_backprop_shape, x_shape, scale_shape, reserve_1_shape, reserve_2_shape):
validator.check("BatchNorm y_backprop_shape", y_backprop_shape, "BatchNorm x_shape", x_shape) validator.check("BatchNorm y_backprop_shape", y_backprop_shape, "BatchNorm x_shape", x_shape)
return (x_shape, scale_shape, scale_shape, reserve_1_shape, reserve_2_shape) return (x_shape, scale_shape, scale_shape, reserve_1_shape, reserve_2_shape)
def infer_dtype(self, y_backprop_type, x_type, scale_type, reserve_1_type, reserve_2_type, reserve_3_type): def infer_dtype(self, y_backprop_type, x_type, scale_type, reserve_1_type, reserve_2_type):
return (x_type, scale_type, scale_type, reserve_1_type, reserve_2_type) return (x_type, scale_type, scale_type, reserve_1_type, reserve_2_type)
...@@ -209,7 +209,7 @@ class Conv2DBackpropFilter(PrimitiveWithInfer): ...@@ -209,7 +209,7 @@ class Conv2DBackpropFilter(PrimitiveWithInfer):
'value': None, 'value': None,
'shape': w_size_v, 'shape': w_size_v,
'dtype': doutput['dtype'], 'dtype': doutput['dtype'],
} }
return out return out
...@@ -349,7 +349,7 @@ class FlattenGrad(PrimitiveWithInfer): ...@@ -349,7 +349,7 @@ class FlattenGrad(PrimitiveWithInfer):
'value': None, 'value': None,
'shape': args[1]['value'], 'shape': args[1]['value'],
'dtype': args[0]['dtype'], 'dtype': args[0]['dtype'],
} }
return out return out
......
...@@ -1657,6 +1657,8 @@ class IsFinite(PrimitiveWithInfer): ...@@ -1657,6 +1657,8 @@ class IsFinite(PrimitiveWithInfer):
return x_shape return x_shape
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_subclass("x", x_dtype, mstype.tensor, self.name)
validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type + (mstype.bool_,), self.name)
return mstype.bool_ return mstype.bool_
class FloatStatus(PrimitiveWithInfer): class FloatStatus(PrimitiveWithInfer):
......
...@@ -580,7 +580,7 @@ class BatchNorm(PrimitiveWithInfer): ...@@ -580,7 +580,7 @@ class BatchNorm(PrimitiveWithInfer):
>>> mean = Tensor(np.ones([64]), mindspore.float32) >>> mean = Tensor(np.ones([64]), mindspore.float32)
>>> variance = Tensor(np.ones([64]), mindspore.float32) >>> variance = Tensor(np.ones([64]), mindspore.float32)
>>> batch_norm = P.BatchNorm() >>> batch_norm = P.BatchNorm()
>>> output = batch_norm(input_x, scale, bias, mean, variance) >>> output = batch_norm(input_x, scale, bias, mean, variance
""" """
@prim_attr_register @prim_attr_register
...@@ -589,8 +589,7 @@ class BatchNorm(PrimitiveWithInfer): ...@@ -589,8 +589,7 @@ class BatchNorm(PrimitiveWithInfer):
validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT, self.name) validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT, self.name)
self.add_prim_attr('data_format', "NCHW") self.add_prim_attr('data_format', "NCHW")
self.init_prim_io_names(inputs=['x', 'scale', 'offset', 'mean', 'variance'], self.init_prim_io_names(inputs=['x', 'scale', 'offset', 'mean', 'variance'],
outputs=['y', 'batch_mean', 'batch_variance', 'reserve_space_1', 'reserve_space_2', outputs=['y', 'batch_mean', 'batch_variance', 'reserve_space_1', 'reserve_space_2'])
'reserve_space_3'])
def infer_shape(self, input_x, scale, bias, mean, variance): def infer_shape(self, input_x, scale, bias, mean, variance):
validator.check_integer("scale rank", len(scale), 1, Rel.EQ, self.name) validator.check_integer("scale rank", len(scale), 1, Rel.EQ, self.name)
...@@ -600,7 +599,7 @@ class BatchNorm(PrimitiveWithInfer): ...@@ -600,7 +599,7 @@ class BatchNorm(PrimitiveWithInfer):
validator.check_integer("mean rank", len(mean), 1, Rel.EQ, self.name) validator.check_integer("mean rank", len(mean), 1, Rel.EQ, self.name)
validator.check("mean shape", mean, "variance shape", variance, Rel.EQ, self.name) validator.check("mean shape", mean, "variance shape", variance, Rel.EQ, self.name)
validator.check("mean shape", mean, "scale shape", scale, Rel.EQ, self.name) validator.check("mean shape", mean, "scale shape", scale, Rel.EQ, self.name)
return (input_x, scale, scale, scale, scale, scale) return (input_x, scale, scale, scale, scale)
def infer_dtype(self, input_x, scale, bias, mean, variance): def infer_dtype(self, input_x, scale, bias, mean, variance):
validator.check_tensor_type_same({"input_x": input_x}, [mstype.float16, mstype.float32], self.name) validator.check_tensor_type_same({"input_x": input_x}, [mstype.float16, mstype.float32], self.name)
...@@ -613,7 +612,7 @@ class BatchNorm(PrimitiveWithInfer): ...@@ -613,7 +612,7 @@ class BatchNorm(PrimitiveWithInfer):
else: else:
args_moving = {"mean": mean, "variance": variance} args_moving = {"mean": mean, "variance": variance}
validator.check_tensor_type_same(args_moving, [mstype.float16, mstype.float32], self.name) validator.check_tensor_type_same(args_moving, [mstype.float16, mstype.float32], self.name)
return (input_x, scale, bias, input_x, input_x, input_x) return (input_x, scale, bias, input_x, input_x)
class Conv2D(PrimitiveWithInfer): class Conv2D(PrimitiveWithInfer):
...@@ -1428,8 +1427,11 @@ class ApplyMomentum(PrimitiveWithInfer): ...@@ -1428,8 +1427,11 @@ class ApplyMomentum(PrimitiveWithInfer):
def __init__(self, use_nesterov=False, use_locking=False, gradient_scale=1.0): def __init__(self, use_nesterov=False, use_locking=False, gradient_scale=1.0):
self.init_prim_io_names(inputs=['variable', 'accumulation', 'learning_rate', 'gradient', 'momentum'], self.init_prim_io_names(inputs=['variable', 'accumulation', 'learning_rate', 'gradient', 'momentum'],
outputs=['output']) outputs=['output'])
self.is_tbe = context.get_context("device_target") == "Ascend"
def infer_shape(self, v_shape, a_shape, l_shape, g_shape, m_shape): def infer_shape(self, v_shape, a_shape, l_shape, g_shape, m_shape):
if self.is_tbe:
return v_shape, v_shape
return v_shape return v_shape
def infer_dtype(self, v_dtype, a_dtype, l_dtype, g_dtype, m_dtype): def infer_dtype(self, v_dtype, a_dtype, l_dtype, g_dtype, m_dtype):
...@@ -1440,6 +1442,8 @@ class ApplyMomentum(PrimitiveWithInfer): ...@@ -1440,6 +1442,8 @@ class ApplyMomentum(PrimitiveWithInfer):
validator.check_scalar_or_tensor_type_same({"l_dtype": l_dtype}, valid_types, self.name) validator.check_scalar_or_tensor_type_same({"l_dtype": l_dtype}, valid_types, self.name)
validator.check_scalar_or_tensor_type_same({"g_dtype": g_dtype}, valid_types, self.name) validator.check_scalar_or_tensor_type_same({"g_dtype": g_dtype}, valid_types, self.name)
validator.check_scalar_or_tensor_type_same({"m_dtype": m_dtype}, valid_types, self.name) validator.check_scalar_or_tensor_type_same({"m_dtype": m_dtype}, valid_types, self.name)
if self.is_tbe:
return g_dtype, g_dtype
return g_dtype return g_dtype
...@@ -2578,13 +2582,13 @@ class SparseApplyAdagrad(PrimitiveWithInfer): ...@@ -2578,13 +2582,13 @@ class SparseApplyAdagrad(PrimitiveWithInfer):
validator.check('var_shape[1:]', var_shape[1:], 'grad_shape[1:]', grad_shape[1:], Rel.EQ, self.name) validator.check('var_shape[1:]', var_shape[1:], 'grad_shape[1:]', grad_shape[1:], Rel.EQ, self.name)
validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name) validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name)
validator.check('grad_shape[0]', grad_shape[0], 'indices_shape[0]', indices_shape[0], Rel.EQ, self.name) validator.check('grad_shape[0]', grad_shape[0], 'indices_shape[0]', indices_shape[0], Rel.EQ, self.name)
return var_shape return var_shape, accum_shape
def infer_dtype(self, var_type, accum_type, grad_type, indices_type): def infer_dtype(self, var_type, accum_type, grad_type, indices_type):
args = {'var': var_type, 'accum': accum_type, 'grad': grad_type} args = {'var': var_type, 'accum': accum_type, 'grad': grad_type}
validator.check_tensor_type_same(args, (mstype.float32,), self.name) validator.check_tensor_type_same(args, (mstype.float32,), self.name)
validator.check_tensor_type_same({'indices': indices_type}, [mstype.int32], self.name) validator.check_tensor_type_same({'indices': indices_type}, [mstype.int32], self.name)
return var_type return var_type, accum_type
class LARSUpdate(PrimitiveWithInfer): class LARSUpdate(PrimitiveWithInfer):
......
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from mindspore import Tensor
from mindspore.ops import operations as P
import mindspore.nn as nn
from mindspore.common.api import ms_function
import numpy as np
import mindspore.context as context
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.expand_dims = P.ExpandDims()
def construct(self, tensor, dim):
return self.expand_dims(tensor, dim)
def test_net_bool():
x = np.random.randn(1, 16, 1, 1).astype(np.bool)
net = Net()
output = net(Tensor(x), -1)
print(output.asnumpy())
assert(np.all(output.asnumpy() == np.expand_dims(x, -1)))
def test_net_int8():
x = np.random.randn(1, 16, 1, 1).astype(np.int8)
net = Net()
output = net(Tensor(x), -1)
print(output.asnumpy())
assert(np.all(output.asnumpy() == np.expand_dims(x, -1)))
def test_net_uint8():
x = np.random.randn(1, 16, 1, 1).astype(np.uint8)
net = Net()
output = net(Tensor(x), -1)
print(output.asnumpy())
assert(np.all(output.asnumpy() == np.expand_dims(x, -1)))
def test_net_int16():
x = np.random.randn(1, 16, 1, 1).astype(np.int16)
net = Net()
output = net(Tensor(x), -1)
print(output.asnumpy())
assert(np.all(output.asnumpy() == np.expand_dims(x, -1)))
def test_net_uint16():
x = np.random.randn(1, 16, 1, 1).astype(np.uint16)
net = Net()
output = net(Tensor(x), -1)
print(output.asnumpy())
assert(np.all(output.asnumpy() == np.expand_dims(x, -1)))
def test_net_int32():
x = np.random.randn(1, 16, 1, 1).astype(np.int32)
net = Net()
output = net(Tensor(x), -1)
print(output.asnumpy())
assert(np.all(output.asnumpy() == np.expand_dims(x, -1)))
def test_net_uint32():
x = np.random.randn(1, 16, 1, 1).astype(np.uint32)
net = Net()
output = net(Tensor(x), -1)
print(output.asnumpy())
assert(np.all(output.asnumpy() == np.expand_dims(x, -1)))
def test_net_int64():
x = np.random.randn(1, 16, 1, 1).astype(np.int64)
net = Net()
output = net(Tensor(x), -1)
print(output.asnumpy())
assert(np.all(output.asnumpy() == np.expand_dims(x, -1)))
def test_net_uint64():
x = np.random.randn(1, 16, 1, 1).astype(np.uint64)
net = Net()
output = net(Tensor(x), -1)
print(output.asnumpy())
assert(np.all(output.asnumpy() == np.expand_dims(x, -1)))
def test_net_float16():
x = np.random.randn(1, 16, 1, 1).astype(np.float16)
net = Net()
output = net(Tensor(x), -1)
print(output.asnumpy())
assert(np.all(output.asnumpy() == np.expand_dims(x, -1)))
def test_net_float32():
x = np.random.randn(1, 16, 1, 1).astype(np.float32)
net = Net()
output = net(Tensor(x), -1)
print(output.asnumpy())
assert(np.all(output.asnumpy() == np.expand_dims(x, -1)))
def test_net_float64():
x = np.random.randn(1, 16, 1, 1).astype(np.float64)
net = Net()
output = net(Tensor(x), -1)
print(output.asnumpy())
assert(np.all(output.asnumpy() == np.expand_dims(x, -1)))
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from mindspore import Tensor
from mindspore.ops import operations as P
import mindspore.nn as nn
import numpy as np
import mindspore.context as context
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.flatten = P.Flatten()
def construct(self, tensor):
return self.flatten(tensor)
def test_net_int8():
x = np.random.randn(1, 16, 1, 1).astype(np.int8)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert(np.all(output.asnumpy() == x.flatten()))
def test_net_uint8():
x = np.random.randn(1, 16, 1, 1).astype(np.uint8)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert(np.all(output.asnumpy() == x.flatten()))
def test_net_int16():
x = np.random.randn(1, 16, 1, 1).astype(np.int16)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert(np.all(output.asnumpy() == x.flatten()))
def test_net_uint16():
x = np.random.randn(1, 16, 1, 1).astype(np.uint16)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert(np.all(output.asnumpy() == x.flatten()))
def test_net_int32():
x = np.random.randn(1, 16, 1, 1).astype(np.int32)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert(np.all(output.asnumpy() == x.flatten()))
def test_net_uint32():
x = np.random.randn(1, 16, 1, 1).astype(np.uint32)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert(np.all(output.asnumpy() == x.flatten()))
def test_net_int64():
x = np.random.randn(1, 16, 1, 1).astype(np.int64)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert(np.all(output.asnumpy() == x.flatten()))
def test_net_uint64():
x = np.random.randn(1, 16, 1, 1).astype(np.uint64)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert(np.all(output.asnumpy() == x.flatten()))
def test_net_float16():
x = np.random.randn(1, 16, 1, 1).astype(np.float16)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert(np.all(output.asnumpy() == x.flatten()))
def test_net_float32():
x = np.random.randn(1, 16, 1, 1).astype(np.float32)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert(np.all(output.asnumpy() == x.flatten()))
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from mindspore import Tensor
from mindspore.ops import operations as P
import mindspore.nn as nn
from mindspore.common.api import ms_function
import numpy as np
import mindspore.context as context
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.isfinite = P.IsFinite()
def construct(self, tensor):
return self.isfinite(tensor)
def test_net_bool():
x = np.random.randn(1, 16, 1, 1).astype(np.bool)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert(np.all(output.asnumpy() == np.isfinite(x)))
def test_net_int8():
x = np.random.randn(1, 16, 1, 1).astype(np.int8)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert(np.all(output.asnumpy() == np.isfinite(x)))
def test_net_uint8():
x = np.random.randn(1, 16, 1, 1).astype(np.uint8)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert(np.all(output.asnumpy() == np.isfinite(x)))
def test_net_int16():
x = np.random.randn(1, 16, 1, 1).astype(np.int16)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert(np.all(output.asnumpy() == np.isfinite(x)))
def test_net_uint16():
x = np.random.randn(1, 16, 1, 1).astype(np.uint16)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert(np.all(output.asnumpy() == np.isfinite(x)))
def test_net_int32():
x = np.random.randn(1, 16, 1, 1).astype(np.int32)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert(np.all(output.asnumpy() == np.isfinite(x)))
def test_net_uint32():
x = np.random.randn(1, 16, 1, 1).astype(np.uint32)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert(np.all(output.asnumpy() == np.isfinite(x)))
def test_net_int64():
x = np.random.randn(1, 16, 1, 1).astype(np.int64)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert(np.all(output.asnumpy() == np.isfinite(x)))
def test_net_uint64():
x = np.random.randn(1, 16, 1, 1).astype(np.uint64)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert(np.all(output.asnumpy() == np.isfinite(x)))
def test_net_float16():
x = np.random.randn(1, 16, 1, 1).astype(np.float16)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert(np.all(output.asnumpy() == np.isfinite(x)))
def test_net_float32():
x = np.random.randn(1, 16, 1, 1).astype(np.float32)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert(np.all(output.asnumpy() == np.isfinite(x)))
def test_net_float64():
x = np.random.randn(1, 16, 1, 1).astype(np.float64)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert(np.all(output.asnumpy() == np.isfinite(x)))
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from mindspore import Tensor
from mindspore.ops import operations as P
import mindspore.nn as nn
from mindspore.common.api import ms_function
import numpy as np
import mindspore.context as context
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.reshape = P.Reshape()
def construct(self, tensor):
return self.reshape(tensor, (4,4))
def test_net_bool():
x = np.random.randn(1, 16, 1, 1).astype(np.bool)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert(np.all(output.asnumpy() == np.reshape(x, (4,4))))
def test_net_int8():
x = np.random.randn(1, 16, 1, 1).astype(np.int8)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert(np.all(output.asnumpy() == np.reshape(x, (4,4))))
def test_net_uint8():
x = np.random.randn(1, 16, 1, 1).astype(np.uint8)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert(np.all(output.asnumpy() == np.reshape(x, (4,4))))
def test_net_int16():
x = np.random.randn(1, 16, 1, 1).astype(np.int16)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert(np.all(output.asnumpy() == np.reshape(x, (4,4))))
def test_net_uint16():
x = np.random.randn(1, 16, 1, 1).astype(np.uint16)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert(np.all(output.asnumpy() == np.reshape(x, (4,4))))
def test_net_int32():
x = np.random.randn(1, 16, 1, 1).astype(np.int32)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert(np.all(output.asnumpy() == np.reshape(x, (4,4))))
def test_net_uint32():
x = np.random.randn(1, 16, 1, 1).astype(np.uint32)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert(np.all(output.asnumpy() == np.reshape(x, (4,4))))
def test_net_int64():
x = np.random.randn(1, 16, 1, 1).astype(np.int64)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert(np.all(output.asnumpy() == np.reshape(x, (4,4))))
def test_net_uint64():
x = np.random.randn(1, 16, 1, 1).astype(np.uint64)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert(np.all(output.asnumpy() == np.reshape(x, (4,4))))
def test_net_float16():
x = np.random.randn(1, 16, 1, 1).astype(np.float16)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert(np.all(output.asnumpy() == np.reshape(x, (4,4))))
def test_net_float32():
x = np.random.randn(1, 16, 1, 1).astype(np.float32)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert(np.all(output.asnumpy() == np.reshape(x, (4,4))))
def test_net_float64():
x = np.random.randn(1, 16, 1, 1).astype(np.float64)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert(np.all(output.asnumpy() == np.reshape(x, (4,4))))
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from mindspore import Tensor
from mindspore.ops import operations as P
import mindspore.nn as nn
import numpy as np
import mindspore.context as context
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.squeeze = P.Squeeze()
def construct(self, tensor):
return self.squeeze(tensor)
def test_net_bool():
x = np.random.randn(1, 16, 1, 1).astype(np.bool)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert(np.all(output.asnumpy() == x.squeeze()))
def test_net_int8():
x = np.random.randn(1, 16, 1, 1).astype(np.int8)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert(np.all(output.asnumpy() == x.squeeze()))
def test_net_uint8():
x = np.random.randn(1, 16, 1, 1).astype(np.uint8)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert(np.all(output.asnumpy() == x.squeeze()))
def test_net_int16():
x = np.random.randn(1, 16, 1, 1).astype(np.int16)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert(np.all(output.asnumpy() == x.squeeze()))
def test_net_uint16():
x = np.random.randn(1, 16, 1, 1).astype(np.uint16)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert(np.all(output.asnumpy() == x.squeeze()))
def test_net_int32():
x = np.random.randn(1, 16, 1, 1).astype(np.int32)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert(np.all(output.asnumpy() == x.squeeze()))
def test_net_uint32():
x = np.random.randn(1, 16, 1, 1).astype(np.uint32)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert(np.all(output.asnumpy() == x.squeeze()))
def test_net_int64():
x = np.random.randn(1, 16, 1, 1).astype(np.int64)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert(np.all(output.asnumpy() == x.squeeze()))
def test_net_uint64():
x = np.random.randn(1, 16, 1, 1).astype(np.uint64)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert(np.all(output.asnumpy() == x.squeeze()))
def test_net_float16():
x = np.random.randn(1, 16, 1, 1).astype(np.float16)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert(np.all(output.asnumpy() == x.squeeze()))
def test_net_float32():
x = np.random.randn(1, 16, 1, 1).astype(np.float32)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert(np.all(output.asnumpy() == x.squeeze()))
def test_net_float64():
x = np.random.randn(1, 16, 1, 1).astype(np.float64)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert(np.all(output.asnumpy() == x.squeeze()))
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/backend_common_test.h"
#include "common/py_func_graph_fetcher.h"
#include "operator/ops.h"
#include "ir/meta_tensor.h"
#include "debug/anf_ir_dump.h"
#include "utils/utils.h"
#include "pre_activate/common/optimizer.h"
#include "pre_activate/ascend/ir_fission/batch_norm_grad_split.h"
#include "session/anf_runtime_algorithm.h"
namespace mindspore {
namespace opt {
class TestHWBatchNormGradSplit : public BackendCommon {
public:
TestHWBatchNormGradSplit() : get_py_fun_("gtest_input.pre_activate.batch_norm_grad_split", true) {}
public:
UT::PyFuncGraphFetcher get_py_fun_;
};
TEST_F(TestHWBatchNormGradSplit, test_split) {
get_py_fun_.SetDoResolve(true);
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_batch_norm_grad_split", "before");
EXPECT_NE(g, nullptr);
std::vector<int> shp_x{1, 64, 112, 112};
std::vector<int> shp_b{64};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x);
auto b_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_b);
AbstractBasePtrList args_spec_list{x_abstract, x_abstract, b_abstract, b_abstract, b_abstract, b_abstract};
auto kernel_graph = GetKernelGraph(g, args_spec_list);
EXPECT_NE(kernel_graph, nullptr);
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
auto pass = std::make_shared<opt::BatchNormGradSplit>();
pm->AddPass(pass);
optimizer->AddPassManager(pm);
auto new_graph = optimizer->Optimize(kernel_graph);
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_batch_norm_grad_split", "after");
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
}
} // namespace opt
} // namespace mindspore
...@@ -189,7 +189,8 @@ TEST_F(TestConvert, TestConvertBatchNorm) { ...@@ -189,7 +189,8 @@ TEST_F(TestConvert, TestConvertBatchNorm) {
TEST_F(TestConvert, TestConvertConvBackpropInput) { TEST_F(TestConvert, TestConvertConvBackpropInput) {
auto prim = prim::kPrimConv2DBackpropInput; auto prim = prim::kPrimConv2DBackpropInput;
prim->AddAttr("stride", MakeValue(1)); const std::vector<int> list{1,1};
prim->AddAttr("stride", MakeValue(list));
prim->AddAttr("pad", MakeValue(0)); prim->AddAttr("pad", MakeValue(0));
prim->AddAttr("pad_mode", MakeValue(std::string("pad"))); prim->AddAttr("pad_mode", MakeValue(std::string("pad")));
prim->AddAttr("dilation", MakeValue(1)); prim->AddAttr("dilation", MakeValue(1));
...@@ -218,7 +219,8 @@ TEST_F(TestConvert, TestConvertConvBackpropInput) { ...@@ -218,7 +219,8 @@ TEST_F(TestConvert, TestConvertConvBackpropInput) {
TEST_F(TestConvert, TestConvertConvBackpropFilter) { TEST_F(TestConvert, TestConvertConvBackpropFilter) {
auto prim = prim::kPrimConv2DBackpropFilter; auto prim = prim::kPrimConv2DBackpropFilter;
prim->AddAttr("stride", MakeValue(1)); const std::vector<int> list{1,1};
prim->AddAttr("stride", MakeValue(list));
prim->AddAttr("pad", MakeValue(0)); prim->AddAttr("pad", MakeValue(0));
prim->AddAttr("pad_mode", MakeValue(std::string("pad"))); prim->AddAttr("pad_mode", MakeValue(std::string("pad")));
prim->AddAttr("dilation", MakeValue(1)); prim->AddAttr("dilation", MakeValue(1));
......
...@@ -38,7 +38,7 @@ def tensor_run_opt(opt, iters, learning_rate, momentum, ...@@ -38,7 +38,7 @@ def tensor_run_opt(opt, iters, learning_rate, momentum,
gradient, variable, moment): gradient, variable, moment):
""" tensor_run_opt """ """ tensor_run_opt """
success = True success = True
new_weight = opt(variable, moment, learning_rate, gradient, momentum) new_weight = opt(variable, moment, learning_rate, gradient, momentum)[0]
success = F.depend(success, F.assign(variable, new_weight)) success = F.depend(success, F.assign(variable, new_weight))
return success return success
......
...@@ -670,7 +670,7 @@ test_case_nn_ops = [ ...@@ -670,7 +670,7 @@ test_case_nn_ops = [
'skip': []}), 'skip': []}),
('BatchNormGrad', { ('BatchNormGrad', {
'block': G.BatchNormGrad(), 'block': G.BatchNormGrad(),
'desc_inputs': [[128, 64, 32, 32], [128, 64, 32, 32], [64], [64], [64], [64]], 'desc_inputs': [[128, 64, 32, 32], [128, 64, 32, 32], [64], [64], [64]],
'desc_bprop': [[128, 64, 32, 32], [64], [64], [64], [64]], 'desc_bprop': [[128, 64, 32, 32], [64], [64], [64], [64]],
'skip': ['backward']}), 'skip': ['backward']}),
('TopK', { ('TopK', {
...@@ -807,7 +807,7 @@ test_case_nn_ops = [ ...@@ -807,7 +807,7 @@ test_case_nn_ops = [
('SparseApplyAdagrad', { ('SparseApplyAdagrad', {
'block': P.SparseApplyAdagrad(0.5), 'block': P.SparseApplyAdagrad(0.5),
'desc_inputs': [[3, 3], [3, 3], [3, 3], Tensor(np.ones((3,), np.int32))], 'desc_inputs': [[3, 3], [3, 3], [3, 3], Tensor(np.ones((3,), np.int32))],
'desc_bprop': [3, 3], 'desc_bprop': [[3, 3], [3, 3]],
'skip': ['backward']}), 'skip': ['backward']}),
('Flatten_1', { ('Flatten_1', {
'block': NetForFlatten(), 'block': NetForFlatten(),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册