提交 80978cf3 编写于 作者: B buxue

support operator ** // % for scalar and tensor, and in not in for dict, ang str concat

上级 8f6b941a
...@@ -83,9 +83,9 @@ convert_object_map = { ...@@ -83,9 +83,9 @@ convert_object_map = {
T.mul: multitype_ops.mul, T.mul: multitype_ops.mul,
T.truediv: multitype_ops.div, T.truediv: multitype_ops.div,
T.getitem: multitype_ops.getitem, T.getitem: multitype_ops.getitem,
T.floordiv: NO_IMPLEMENT, T.floordiv: multitype_ops.floordiv,
T.mod: F.scalar_mod, T.mod: multitype_ops.mod,
T.pow: F.scalar_pow, T.pow: multitype_ops.pow_,
T.matmul: F.dot, T.matmul: F.dot,
T.lshift: NO_IMPLEMENT, T.lshift: NO_IMPLEMENT,
T.rshift: NO_IMPLEMENT, T.rshift: NO_IMPLEMENT,
...@@ -104,8 +104,8 @@ convert_object_map = { ...@@ -104,8 +104,8 @@ convert_object_map = {
T.ge: multitype_ops.greater_equal, T.ge: multitype_ops.greater_equal,
T.is_: F.is_, T.is_: F.is_,
T.is_not: F.is_not, T.is_not: F.is_not,
T.contains: NO_IMPLEMENT, T.contains: F.in_dict,
T.not_contains: NO_IMPLEMENT, T.not_contains: F.not_in_dict,
# system function # system function
T.len: M.ms_len, T.len: M.ms_len,
......
...@@ -103,7 +103,7 @@ T InnerScalarMul(T x, T y) { ...@@ -103,7 +103,7 @@ T InnerScalarMul(T x, T y) {
} }
template <typename T> template <typename T>
T InnerScalarDiv(T x, T y) { float InnerScalarDiv(T x, T y) {
if (y == 0) { if (y == 0) {
MS_LOG(EXCEPTION) << "Divisor could not be zero"; MS_LOG(EXCEPTION) << "Divisor could not be zero";
} }
...@@ -111,23 +111,41 @@ T InnerScalarDiv(T x, T y) { ...@@ -111,23 +111,41 @@ T InnerScalarDiv(T x, T y) {
MS_LOG(EXCEPTION) << "Overflow of the div of two signed number x: " << std::to_string(x) MS_LOG(EXCEPTION) << "Overflow of the div of two signed number x: " << std::to_string(x)
<< ", y: " << std::to_string(y) << "."; << ", y: " << std::to_string(y) << ".";
} }
return x / y; return static_cast<float>(x) / static_cast<float>(y);
} }
int32_t InnerScalarMod(int32_t x, int32_t y) { template <typename T>
T InnerScalarFloordiv(T x, T y) {
auto ret = std::floor(InnerScalarDiv(x, y));
if (std::is_integral<T>::value) {
return static_cast<int>(ret);
}
return ret;
}
template <typename T>
T InnerScalarMod(T x, T y) {
if (y == 0) { if (y == 0) {
MS_LOG(EXCEPTION) << "Could not mod to zero."; MS_LOG(EXCEPTION) << "Could not mod to zero.";
} }
if (IsSignedIntOverflow(x, y, OpType::MOD)) { if (std::is_integral<T>::value && std::is_signed<T>::value && IsSignedIntOverflow(x, y, OpType::MOD)) {
MS_LOG(EXCEPTION) << "Overflow of the mod of two signed number x: " << std::to_string(x) MS_LOG(EXCEPTION) << "Overflow of the mod of two signed number x: " << std::to_string(x)
<< ", y: " << std::to_string(y) << "."; << ", y: " << std::to_string(y) << ".";
} }
return x % y; if (std::is_integral<T>::value) {
return static_cast<int>(x) % static_cast<int>(y);
}
float x_int = std::floor(x);
float y_int = std::ceil(y);
float max = x_int / y_int;
float ret = x - y * max;
return ret;
} }
float InnerScalarMod(float, float) { MS_LOG(EXCEPTION) << "Float does not support mod operator."; } template <typename T, typename U>
T InnerScalarPow(T x, U y) {
double InnerScalarMod(double, double) { MS_LOG(EXCEPTION) << "Double does not support mod operator."; } return std::pow(x, y);
}
template <typename T, typename U> template <typename T, typename U>
bool InnerScalarEq(T x, U y) { bool InnerScalarEq(T x, U y) {
...@@ -193,6 +211,8 @@ SCALAR_OP(Sub) ...@@ -193,6 +211,8 @@ SCALAR_OP(Sub)
SCALAR_OP(Mul) SCALAR_OP(Mul)
SCALAR_OP(Div) SCALAR_OP(Div)
SCALAR_OP(Mod) SCALAR_OP(Mod)
SCALAR_OP(Pow)
SCALAR_OP(Floordiv)
#define LOGIC_OP(op_t) \ #define LOGIC_OP(op_t) \
ValuePtr Scalar##op_t(const ValuePtrList& list) { \ ValuePtr Scalar##op_t(const ValuePtrList& list) { \
...@@ -227,6 +247,10 @@ SCALAR_OP(Mod) ...@@ -227,6 +247,10 @@ SCALAR_OP(Mod)
bool sum = InnerScalar##op_t(GetValue<float>(x), GetValue<int>(y)); \ bool sum = InnerScalar##op_t(GetValue<float>(x), GetValue<int>(y)); \
return MakeValue(sum); \ return MakeValue(sum); \
} \ } \
if (x->isa<Int32Imm>() && y->isa<FP32Imm>()) { \
bool sum = InnerScalar##op_t(GetValue<int>(x), GetValue<float>(y)); \
return MakeValue(sum); \
} \
if (x->isa<Int64Imm>() && y->isa<Int32Imm>()) { \ if (x->isa<Int64Imm>() && y->isa<Int32Imm>()) { \
bool sum = InnerScalar##op_t(GetValue<int64_t>(x), GetValue<int>(y)); \ bool sum = InnerScalar##op_t(GetValue<int64_t>(x), GetValue<int>(y)); \
return MakeValue(sum); \ return MakeValue(sum); \
......
...@@ -37,9 +37,10 @@ ValuePtr ScalarSub(const ValuePtrList& list); ...@@ -37,9 +37,10 @@ ValuePtr ScalarSub(const ValuePtrList& list);
ValuePtr ScalarMul(const ValuePtrList& list); ValuePtr ScalarMul(const ValuePtrList& list);
ValuePtr ScalarDiv(const ValuePtrList& list); ValuePtr ScalarDiv(const ValuePtrList& list);
ValuePtr ScalarMod(const ValuePtrList& list); ValuePtr ScalarMod(const ValuePtrList& list);
ValuePtr ScalarPow(const ValuePtrList& list);
ValuePtr ScalarFloordiv(const ValuePtrList& list);
ValuePtr ScalarUAdd(const ValuePtrList& list); ValuePtr ScalarUAdd(const ValuePtrList& list);
ValuePtr ScalarUSub(const ValuePtrList& list); ValuePtr ScalarUSub(const ValuePtrList& list);
ValuePtr ScalarUSub(const ValuePtrList& list);
ValuePtr ScalarLog(const ValuePtrList& list); ValuePtr ScalarLog(const ValuePtrList& list);
ValuePtr ScalarEq(const ValuePtrList& list); ValuePtr ScalarEq(const ValuePtrList& list);
ValuePtr ScalarLt(const ValuePtrList& list); ValuePtr ScalarLt(const ValuePtrList& list);
......
...@@ -88,14 +88,17 @@ std::map<SignatureEnumDType, size_t> GetMaxDtypeIndex(const std::vector<Signatur ...@@ -88,14 +88,17 @@ std::map<SignatureEnumDType, size_t> GetMaxDtypeIndex(const std::vector<Signatur
if (indexs.size() < 2) { if (indexs.size() < 2) {
continue; continue;
} }
size_t m_index = indexs[0];
for (size_t i = 1; i < indexs.size(); ++i) { for (const auto& index : indexs) {
if (args_spec_list[indexs[i]]->isa<abstract::AbstractTensor>()) { AbstractBasePtr arg_value = args_spec_list[index];
m_index = indexs[i]; if (arg_value->isa<abstract::AbstractRef>()) {
arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref();
}
if (arg_value->isa<abstract::AbstractTensor>()) {
(void)dst_type.insert(std::make_pair(type, index));
break;
} }
}
if (args_spec_list[m_index]->isa<abstract::AbstractTensor>()) {
(void)dst_type.insert(std::make_pair(type, m_index));
} }
} }
return dst_type; return dst_type;
...@@ -119,15 +122,19 @@ void DoAutoCast(const std::vector<Signature>& signature, const abstract::Abstrac ...@@ -119,15 +122,19 @@ void DoAutoCast(const std::vector<Signature>& signature, const abstract::Abstrac
(void)std::transform(signature.begin(), signature.end(), std::back_inserter(dtypes), (void)std::transform(signature.begin(), signature.end(), std::back_inserter(dtypes),
[](const Signature& sig) { return sig.dtype; }); [](const Signature& sig) { return sig.dtype; });
int empty_dtype_count = std::count(dtypes.begin(), dtypes.end(), SignatureEnumDType::kDTypeEmptyDefaultValue); int empty_dtype_count = std::count(dtypes.begin(), dtypes.end(), SignatureEnumDType::kDTypeEmptyDefaultValue);
if (dtypes.size() == 0 || static_cast<int>(dtypes.size()) == empty_dtype_count) { if (dtypes.empty() || static_cast<int>(dtypes.size()) == empty_dtype_count) {
return; return;
} }
// Stat the index of the arguments with the largest type in the same SignatureEnumDType. // Stat the index of the arguments with the largest type in the same SignatureEnumDType.
std::map<SignatureEnumDType, size_t> dst_type = GetMaxDtypeIndex(dtypes, args_spec_list); std::map<SignatureEnumDType, size_t> dst_type = GetMaxDtypeIndex(dtypes, args_spec_list);
// Identify which arg requires auto cast // Identify which arg requires auto cast
for (size_t i = 0; i < args_spec_list.size(); ++i) { for (size_t i = 0; i < args_spec_list.size(); ++i) {
AbstractBasePtr arg_value = args_spec_list[i];
if (arg_value->isa<abstract::AbstractRef>()) {
arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref();
}
auto it = dst_type.find(dtypes[i]); auto it = dst_type.find(dtypes[i]);
if (it == dst_type.end() || it->second == i || !args_spec_list[i]->isa<abstract::AbstractScalar>()) { if (it == dst_type.end() || it->second == i || !arg_value->isa<abstract::AbstractScalar>()) {
continue; continue;
} }
// get source node for cast // get source node for cast
......
...@@ -28,6 +28,7 @@ const PrimitivePtr kPrimScalarAdd = std::make_shared<Primitive>("scalar_add"); ...@@ -28,6 +28,7 @@ const PrimitivePtr kPrimScalarAdd = std::make_shared<Primitive>("scalar_add");
const PrimitivePtr kPrimScalarSub = std::make_shared<Primitive>("scalar_sub"); const PrimitivePtr kPrimScalarSub = std::make_shared<Primitive>("scalar_sub");
const PrimitivePtr kPrimScalarMul = std::make_shared<Primitive>("scalar_mul"); const PrimitivePtr kPrimScalarMul = std::make_shared<Primitive>("scalar_mul");
const PrimitivePtr kPrimScalarDiv = std::make_shared<Primitive>("scalar_div"); const PrimitivePtr kPrimScalarDiv = std::make_shared<Primitive>("scalar_div");
const PrimitivePtr kPrimScalarFloordiv = std::make_shared<Primitive>("scalar_floordiv");
const PrimitivePtr kPrimScalarMod = std::make_shared<Primitive>("scalar_mod"); const PrimitivePtr kPrimScalarMod = std::make_shared<Primitive>("scalar_mod");
const PrimitivePtr kPrimScalarPow = std::make_shared<Primitive>("scalar_pow"); const PrimitivePtr kPrimScalarPow = std::make_shared<Primitive>("scalar_pow");
const PrimitivePtr kPrimScalarTrunc = std::make_shared<Primitive>("scalar_trunc"); const PrimitivePtr kPrimScalarTrunc = std::make_shared<Primitive>("scalar_trunc");
...@@ -78,6 +79,7 @@ const PrimitivePtr kPrimCreateInstance = std::make_shared<Primitive>("create_ins ...@@ -78,6 +79,7 @@ const PrimitivePtr kPrimCreateInstance = std::make_shared<Primitive>("create_ins
// Structure // Structure
const PrimitivePtr kPrimStringEqual = std::make_shared<Primitive>("string_equal"); const PrimitivePtr kPrimStringEqual = std::make_shared<Primitive>("string_equal");
const PrimitivePtr kPrimStringConcat = std::make_shared<Primitive>("string_concat");
const PrimitivePtr kPrimMakeTuple = std::make_shared<Primitive>("make_tuple"); const PrimitivePtr kPrimMakeTuple = std::make_shared<Primitive>("make_tuple");
const PrimitivePtr kPrimMakeList = std::make_shared<Primitive>("make_list"); const PrimitivePtr kPrimMakeList = std::make_shared<Primitive>("make_list");
const PrimitivePtr kPrimMakeDict = std::make_shared<Primitive>("make_dict"); const PrimitivePtr kPrimMakeDict = std::make_shared<Primitive>("make_dict");
...@@ -221,6 +223,8 @@ const PrimitivePtr kPrimBroadcastGradientArgs = std::make_shared<Primitive>("Bro ...@@ -221,6 +223,8 @@ const PrimitivePtr kPrimBroadcastGradientArgs = std::make_shared<Primitive>("Bro
const PrimitivePtr kPrimControlDepend = std::make_shared<Primitive>("ControlDepend"); const PrimitivePtr kPrimControlDepend = std::make_shared<Primitive>("ControlDepend");
const PrimitivePtr kPrimIs_ = std::make_shared<Primitive>("is_"); const PrimitivePtr kPrimIs_ = std::make_shared<Primitive>("is_");
const PrimitivePtr kPrimIsNot = std::make_shared<Primitive>("is_not"); const PrimitivePtr kPrimIsNot = std::make_shared<Primitive>("is_not");
const PrimitivePtr kPrimInDict = std::make_shared<Primitive>("in_dict");
const PrimitivePtr kPrimNotInDict = std::make_shared<Primitive>("not_in_dict");
// Comm ops // Comm ops
const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator"); const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator");
......
...@@ -34,6 +34,7 @@ extern const PrimitivePtr kPrimScalarAdd; ...@@ -34,6 +34,7 @@ extern const PrimitivePtr kPrimScalarAdd;
extern const PrimitivePtr kPrimScalarSub; extern const PrimitivePtr kPrimScalarSub;
extern const PrimitivePtr kPrimScalarMul; extern const PrimitivePtr kPrimScalarMul;
extern const PrimitivePtr kPrimScalarDiv; extern const PrimitivePtr kPrimScalarDiv;
extern const PrimitivePtr kPrimScalarFloordiv;
extern const PrimitivePtr kPrimScalarMod; extern const PrimitivePtr kPrimScalarMod;
extern const PrimitivePtr kPrimScalarPow; extern const PrimitivePtr kPrimScalarPow;
extern const PrimitivePtr kPrimScalarTrunc; extern const PrimitivePtr kPrimScalarTrunc;
...@@ -84,6 +85,7 @@ extern const PrimitivePtr kPrimCreateInstance; ...@@ -84,6 +85,7 @@ extern const PrimitivePtr kPrimCreateInstance;
// Structure // Structure
extern const PrimitivePtr kPrimStringEqual; extern const PrimitivePtr kPrimStringEqual;
extern const PrimitivePtr kPrimStringConcat;
extern const PrimitivePtr kPrimMakeTuple; extern const PrimitivePtr kPrimMakeTuple;
extern const PrimitivePtr kPrimMakeList; extern const PrimitivePtr kPrimMakeList;
extern const PrimitivePtr kPrimMakeDict; extern const PrimitivePtr kPrimMakeDict;
...@@ -227,8 +229,8 @@ extern const PrimitivePtr kPrimBroadcastGradientArgs; ...@@ -227,8 +229,8 @@ extern const PrimitivePtr kPrimBroadcastGradientArgs;
extern const PrimitivePtr kPrimControlDepend; extern const PrimitivePtr kPrimControlDepend;
extern const PrimitivePtr kPrimIs_; extern const PrimitivePtr kPrimIs_;
extern const PrimitivePtr kPrimIsNot; extern const PrimitivePtr kPrimIsNot;
extern const PrimitivePtr kPrimMinimumGrad; extern const PrimitivePtr kPrimInDict;
extern const PrimitivePtr kPrimMaximumGrad; extern const PrimitivePtr kPrimNotInDict;
// Comm ops // Comm ops
extern const PrimitivePtr kPrimMirror; extern const PrimitivePtr kPrimMirror;
......
...@@ -114,12 +114,12 @@ void FusedBatchNormCheckDim(const PrimitivePtr &primitive, const AbstractBasePtr ...@@ -114,12 +114,12 @@ void FusedBatchNormCheckDim(const PrimitivePtr &primitive, const AbstractBasePtr
AbstractTensorPtr arg = CheckArg<AbstractTensor>(op_name, args_spec_list, i); AbstractTensorPtr arg = CheckArg<AbstractTensor>(op_name, args_spec_list, i);
ShapePtr arg_shape = dyn_cast<Shape>(arg->GetShapeTrack()); ShapePtr arg_shape = dyn_cast<Shape>(arg->GetShapeTrack());
if (arg_shape == nullptr) { if (arg_shape == nullptr) {
MS_LOG(EXCEPTION) << "" << op_name << " type of args[" << i << "] should be Shape, but " << arg->ToString(); MS_LOG(EXCEPTION) << op_name << " type of args[" << i << "] should be Shape, but " << arg->ToString();
} }
if (i == 0) { if (i == 0) {
if (arg_shape->shape().size() < 2) { if (arg_shape->shape().size() < 2) {
MS_LOG(EXCEPTION) << "" << op_name << " shape of args[" << i MS_LOG(EXCEPTION) << op_name << " shape of args[" << i
<< "] should be TensorShape with dimension greater than 1, but shape: " << "] should be TensorShape with dimension greater than 1, but shape: "
<< arg_shape->ToString(); << arg_shape->ToString();
} }
...@@ -127,7 +127,7 @@ void FusedBatchNormCheckDim(const PrimitivePtr &primitive, const AbstractBasePtr ...@@ -127,7 +127,7 @@ void FusedBatchNormCheckDim(const PrimitivePtr &primitive, const AbstractBasePtr
} }
if (arg_shape->shape().size() != 1) { if (arg_shape->shape().size() != 1) {
MS_LOG(EXCEPTION) << "" << op_name << " shape of args[" << i MS_LOG(EXCEPTION) << op_name << " shape of args[" << i
<< "] should be TensorShape with dimension: 1, but shape: " << arg_shape->ToString(); << "] should be TensorShape with dimension: 1, but shape: " << arg_shape->ToString();
} }
} }
...@@ -159,7 +159,7 @@ AbstractBasePtr InferImplFusedBatchNorm(const AnalysisEnginePtr &, const Primiti ...@@ -159,7 +159,7 @@ AbstractBasePtr InferImplFusedBatchNorm(const AnalysisEnginePtr &, const Primiti
MS_LOG(EXCEPTION) << "Arg shape size should >= 1."; MS_LOG(EXCEPTION) << "Arg shape size should >= 1.";
} }
if (arg_shape_list[0] != input_shape_list[1]) { if (arg_shape_list[0] != input_shape_list[1]) {
MS_LOG(EXCEPTION) << "" << op_name << " size of tensor param[" << i << "](which is " << arg_shape_list[0] MS_LOG(EXCEPTION) << op_name << " size of tensor param[" << i << "](which is " << arg_shape_list[0]
<< ") should match the second dimension of tensor" << ") should match the second dimension of tensor"
" param[0](which is " " param[0](which is "
<< input_shape_list[1] << ")."; << input_shape_list[1] << ").";
...@@ -378,7 +378,7 @@ AbstractBasePtr InferImplDropoutGenMask(const AnalysisEnginePtr &, const Primiti ...@@ -378,7 +378,7 @@ AbstractBasePtr InferImplDropoutGenMask(const AnalysisEnginePtr &, const Primiti
TypePtr prob_type = keep_prob->element()->BuildType(); TypePtr prob_type = keep_prob->element()->BuildType();
if ((prob_type->type_id() != kNumberTypeFloat16) && (prob_type->type_id() != kNumberTypeFloat32)) { if ((prob_type->type_id() != kNumberTypeFloat16) && (prob_type->type_id() != kNumberTypeFloat32)) {
MS_LOG(EXCEPTION) << "" << op_name << " keep_prob type should be float16 or float32, but " << prob_type->ToString() MS_LOG(EXCEPTION) << op_name << " keep_prob type should be float16 or float32, but " << prob_type->ToString()
<< "."; << ".";
} }
......
...@@ -169,5 +169,36 @@ AbstractBasePtr InferImplIsNot(const AnalysisEnginePtr &, const PrimitivePtr &pr ...@@ -169,5 +169,36 @@ AbstractBasePtr InferImplIsNot(const AnalysisEnginePtr &, const PrimitivePtr &pr
return std::make_shared<AbstractScalar>(!(*t == *x)); return std::make_shared<AbstractScalar>(!(*t == *x));
} }
bool IsInDict(const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) {
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 2);
auto key = CheckArg<AbstractScalar>(op_name, args_spec_list, 0);
auto dict = CheckArg<AbstractDictionary>(op_name, args_spec_list, 1);
ValuePtr key_value = key->BuildValue();
if (!key_value->isa<StringImm>()) {
MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << key_value->ToString();
}
auto key_str = GetValue<std::string>(key_value);
std::vector<AbstractAttribute> dict_elems = dict->elements();
auto it = std::find_if(dict_elems.begin(), dict_elems.end(),
[key_str](const AbstractAttribute &item) { return item.first == key_str; });
return it != dict_elems.end();
}
AbstractBasePtr InferImplInDict(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// statement: x in t
// Inputs: x, t
return std::make_shared<AbstractScalar>(IsInDict(primitive, args_spec_list));
}
AbstractBasePtr InferImplNotInDict(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// statement: x not in t
// Inputs: x, t
return std::make_shared<AbstractScalar>(!IsInDict(primitive, args_spec_list));
}
} // namespace abstract } // namespace abstract
} // namespace mindspore } // namespace mindspore
...@@ -36,7 +36,7 @@ AbstractBasePtr InferImplStringEqual(const AnalysisEnginePtr &, const PrimitiveP ...@@ -36,7 +36,7 @@ AbstractBasePtr InferImplStringEqual(const AnalysisEnginePtr &, const PrimitiveP
ValuePtr value_x = scalar_x->BuildValue(); ValuePtr value_x = scalar_x->BuildValue();
ValuePtr value_y = scalar_y->BuildValue(); ValuePtr value_y = scalar_y->BuildValue();
if (!value_x->isa<StringImm>() || !value_y->isa<StringImm>()) { if (!value_x->isa<StringImm>() || !value_y->isa<StringImm>()) {
MS_LOG(EXCEPTION) << "" << op_name << " requires 2 parameters are string, but got param0: " << value_x->ToString() MS_LOG(EXCEPTION) << op_name << " requires 2 parameters are string, but got param0: " << value_x->ToString()
<< ", param1: " << value_y->ToString(); << ", param1: " << value_y->ToString();
} }
...@@ -44,6 +44,25 @@ AbstractBasePtr InferImplStringEqual(const AnalysisEnginePtr &, const PrimitiveP ...@@ -44,6 +44,25 @@ AbstractBasePtr InferImplStringEqual(const AnalysisEnginePtr &, const PrimitiveP
return std::make_shared<AbstractScalar>(ret); return std::make_shared<AbstractScalar>(ret);
} }
AbstractBasePtr InferImplStringConcat(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: two scalars whose value is a string.
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 2);
AbstractScalarPtr scalar_x = CheckArg<AbstractScalar>(op_name, args_spec_list, 0);
AbstractScalarPtr scalar_y = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
ValuePtr value_x = scalar_x->BuildValue();
ValuePtr value_y = scalar_y->BuildValue();
if (!value_x->isa<StringImm>() || !value_y->isa<StringImm>()) {
MS_LOG(EXCEPTION) << op_name << " requires 2 parameters are string, but got param0: " << value_x->ToString()
<< ", param1: " << value_y->ToString();
}
std::string ret = (value_x->cast<StringImmPtr>()->value() + value_y->cast<StringImmPtr>()->value());
return std::make_shared<AbstractScalar>(ret);
}
AbstractBasePtr InferImplMakeTuple(const AnalysisEnginePtr &, const PrimitivePtr &, AbstractBasePtr InferImplMakeTuple(const AnalysisEnginePtr &, const PrimitivePtr &,
const AbstractBasePtrList &args_spec_list) { const AbstractBasePtrList &args_spec_list) {
return std::make_shared<AbstractTuple>(args_spec_list); return std::make_shared<AbstractTuple>(args_spec_list);
...@@ -64,7 +83,7 @@ AbstractBasePtr InferImplMakeDict(const AnalysisEnginePtr &, const PrimitivePtr ...@@ -64,7 +83,7 @@ AbstractBasePtr InferImplMakeDict(const AnalysisEnginePtr &, const PrimitivePtr
size_t keys_size = keys->size(); size_t keys_size = keys->size();
if (values->size() != keys_size) { if (values->size() != keys_size) {
MS_LOG(EXCEPTION) << "" << op_name << " evaluator keys' size is not equal with values' size"; MS_LOG(EXCEPTION) << op_name << " evaluator keys' size is not equal with values' size";
} }
std::vector<AbstractAttribute> key_value; std::vector<AbstractAttribute> key_value;
...@@ -76,7 +95,7 @@ AbstractBasePtr InferImplMakeDict(const AnalysisEnginePtr &, const PrimitivePtr ...@@ -76,7 +95,7 @@ AbstractBasePtr InferImplMakeDict(const AnalysisEnginePtr &, const PrimitivePtr
ValuePtr keyPtr = key->BuildValue(); ValuePtr keyPtr = key->BuildValue();
MS_EXCEPTION_IF_NULL(keyPtr); MS_EXCEPTION_IF_NULL(keyPtr);
if (!keyPtr->isa<StringImm>()) { if (!keyPtr->isa<StringImm>()) {
MS_LOG(EXCEPTION) << "" << op_name << " evaluator keys should be string, but got " << keyPtr->ToString(); MS_LOG(EXCEPTION) << op_name << " evaluator keys should be string, but got " << keyPtr->ToString();
} }
std::string key_string = GetValue<std::string>(keyPtr); std::string key_string = GetValue<std::string>(keyPtr);
key_value.emplace_back(key_string, value_list[index]); key_value.emplace_back(key_string, value_list[index]);
...@@ -93,7 +112,7 @@ AbstractBasePtr InferImplMakeKwarg(const AnalysisEnginePtr &, const PrimitivePtr ...@@ -93,7 +112,7 @@ AbstractBasePtr InferImplMakeKwarg(const AnalysisEnginePtr &, const PrimitivePtr
ValuePtr keyPtr = key->BuildValue(); ValuePtr keyPtr = key->BuildValue();
if (!keyPtr->isa<StringImm>()) { if (!keyPtr->isa<StringImm>()) {
MS_LOG(EXCEPTION) << "" << op_name << " evaluator key should be string, but got " << keyPtr->ToString(); MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << keyPtr->ToString();
} }
std::string key_string = GetValue<std::string>(keyPtr); std::string key_string = GetValue<std::string>(keyPtr);
return std::make_shared<AbstractKeywordArg>(key_string, args_spec_list[1]); return std::make_shared<AbstractKeywordArg>(key_string, args_spec_list[1]);
...@@ -109,14 +128,13 @@ AbstractBasePtr InferImplExtractKwarg(const AnalysisEnginePtr &, const Primitive ...@@ -109,14 +128,13 @@ AbstractBasePtr InferImplExtractKwarg(const AnalysisEnginePtr &, const Primitive
ValuePtr key_value = key->BuildValue(); ValuePtr key_value = key->BuildValue();
if (!key_value->isa<StringImm>()) { if (!key_value->isa<StringImm>()) {
MS_LOG(EXCEPTION) << "" << op_name << " evaluator key should be string, but got " << key_value->ToString(); MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << key_value->ToString();
} }
std::string key_input = GetValue<std::string>(key_value); std::string key_input = GetValue<std::string>(key_value);
std::string key_actual = kwarg->get_key(); std::string key_actual = kwarg->get_key();
if (key_actual != key_input) { if (key_actual != key_input) {
MS_LOG(EXCEPTION) << "" << op_name MS_LOG(EXCEPTION) << op_name << " evaluator input key should be same as AbstractKeywordArg' key, but input is "
<< " evaluator input key should be same as AbstractKeywordArg' key, but input is " << key_input << key_input << ", AbstractKeywordArg' key is " << key_actual;
<< ", AbstractKeywordArg' key is " << key_actual;
} }
return kwarg->get_arg(); return kwarg->get_arg();
} }
...@@ -187,13 +205,12 @@ AbstractBasePtr InferTupleOrListGetItem(const std::string &op_name, const Abstra ...@@ -187,13 +205,12 @@ AbstractBasePtr InferTupleOrListGetItem(const std::string &op_name, const Abstra
ValuePtr index_value = index->BuildValue(); ValuePtr index_value = index->BuildValue();
if (!index_value->isa<Int32Imm>()) { if (!index_value->isa<Int32Imm>()) {
MS_LOG(EXCEPTION) << "" << op_name << " evaluator index should be an int32 number, but got " MS_LOG(EXCEPTION) << op_name << " evaluator index should be an int32 number, but got " << index_value->ToString();
<< index_value->ToString();
} }
int idx_v = GetValue<int>(index_value); int idx_v = GetValue<int>(index_value);
std::size_t nelems = queue->elements().size(); std::size_t nelems = queue->elements().size();
if (idx_v >= SizeToInt(nelems) || idx_v < -SizeToInt(nelems)) { if (idx_v >= SizeToInt(nelems) || idx_v < -SizeToInt(nelems)) {
MS_LOG(EXCEPTION) << "" << op_name << " evaluator index should be in range[-" << SizeToInt(nelems) << ", " MS_LOG(EXCEPTION) << op_name << " evaluator index should be in range[-" << SizeToInt(nelems) << ", "
<< SizeToInt(nelems) << "), but got " << idx_v << "."; << SizeToInt(nelems) << "), but got " << idx_v << ".";
} }
...@@ -215,8 +232,7 @@ AbstractBasePtr InferTupleOrListSetItem(const std::string &op_name, const Abstra ...@@ -215,8 +232,7 @@ AbstractBasePtr InferTupleOrListSetItem(const std::string &op_name, const Abstra
ValuePtr index_value = index->BuildValue(); ValuePtr index_value = index->BuildValue();
if (!index_value->isa<Int32Imm>()) { if (!index_value->isa<Int32Imm>()) {
MS_LOG(EXCEPTION) << "" << op_name << " evaluator index should be an int32 number, but got " MS_LOG(EXCEPTION) << op_name << " evaluator index should be an int32 number, but got " << index_value->ToString();
<< index_value->ToString();
} }
int idx_v = GetValue<int>(index_value); int idx_v = GetValue<int>(index_value);
if (idx_v < 0) { if (idx_v < 0) {
...@@ -227,8 +243,7 @@ AbstractBasePtr InferTupleOrListSetItem(const std::string &op_name, const Abstra ...@@ -227,8 +243,7 @@ AbstractBasePtr InferTupleOrListSetItem(const std::string &op_name, const Abstra
AbstractBasePtrList elements = queue->elements(); AbstractBasePtrList elements = queue->elements();
std::size_t nelems = elements.size(); std::size_t nelems = elements.size();
if (uidx_v >= nelems) { if (uidx_v >= nelems) {
MS_LOG(EXCEPTION) << "" << op_name << " evaluator the index: " << uidx_v << " to set out of range: " << nelems - 1 MS_LOG(EXCEPTION) << op_name << " evaluator the index: " << uidx_v << " to set out of range: " << nelems - 1 << ".";
<< ".";
} }
elements[uidx_v] = args_spec_list[2]; elements[uidx_v] = args_spec_list[2];
return std::make_shared<T>(elements); return std::make_shared<T>(elements);
...@@ -264,12 +279,12 @@ AbstractBasePtr InferImplDictGetItem(const AnalysisEnginePtr &, const PrimitiveP ...@@ -264,12 +279,12 @@ AbstractBasePtr InferImplDictGetItem(const AnalysisEnginePtr &, const PrimitiveP
ValuePtr key_value = key->BuildValue(); ValuePtr key_value = key->BuildValue();
if (!key_value->isa<StringImm>()) { if (!key_value->isa<StringImm>()) {
MS_LOG(EXCEPTION) << "" << op_name << " evaluator key should be string, but got " << key_value->ToString(); MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << key_value->ToString();
} }
std::string key_str = GetValue<std::string>(key_value); auto key_str = GetValue<std::string>(key_value);
std::vector<AbstractAttribute> dict_elems = dict->elements(); std::vector<AbstractAttribute> dict_elems = dict->elements();
auto it = std::find_if(dict_elems.begin(), dict_elems.end(), auto it = std::find_if(dict_elems.begin(), dict_elems.end(),
[key_str](AbstractAttribute &item) { return item.first == key_str; }); [key_str](const AbstractAttribute &item) { return item.first == key_str; });
if (it == dict_elems.end()) { if (it == dict_elems.end()) {
MS_LOG(EXCEPTION) << "The key " << key_str << " does not exist in the dict:" << args_spec_list[0]->ToString(); MS_LOG(EXCEPTION) << "The key " << key_str << " does not exist in the dict:" << args_spec_list[0]->ToString();
...@@ -287,7 +302,7 @@ AbstractBasePtr InferImplDictSetItem(const AnalysisEnginePtr &, const PrimitiveP ...@@ -287,7 +302,7 @@ AbstractBasePtr InferImplDictSetItem(const AnalysisEnginePtr &, const PrimitiveP
ValuePtr key_value = key->BuildValue(); ValuePtr key_value = key->BuildValue();
if (!key_value->isa<StringImm>()) { if (!key_value->isa<StringImm>()) {
MS_LOG(EXCEPTION) << "" << op_name << " evaluator key should be string, but got " << key_value->ToString(); MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << key_value->ToString();
} }
std::string key_str = GetValue<std::string>(key_value); std::string key_str = GetValue<std::string>(key_value);
std::vector<AbstractAttribute> dict_elems = dict->elements(); std::vector<AbstractAttribute> dict_elems = dict->elements();
...@@ -446,27 +461,27 @@ AbstractBasePtr InferImplReduceShape(const AnalysisEnginePtr &, const PrimitiveP ...@@ -446,27 +461,27 @@ AbstractBasePtr InferImplReduceShape(const AnalysisEnginePtr &, const PrimitiveP
auto x_shp_value = shape_x->BuildValue(); auto x_shp_value = shape_x->BuildValue();
if (x_shp_value->isa<AnyValue>()) { if (x_shp_value->isa<AnyValue>()) {
MS_LOG(EXCEPTION) << "" << op_name MS_LOG(EXCEPTION) << op_name
<< " evaluator shape's data field can't be anything: " << args_spec_list[1]->ToString(); << " evaluator shape's data field can't be anything: " << args_spec_list[1]->ToString();
} }
// Axis can be scalar, tuple or None // Axis can be scalar, tuple or None
AbstractTuplePtr axis = nullptr; AbstractTuplePtr axis = nullptr;
if (args_spec_list[1]->isa<AbstractScalar>()) { if (args_spec_list[1]->isa<AbstractScalar>()) {
MS_LOG(DEBUG) << "" << op_name << " evaluator second parameter is scalar"; MS_LOG(DEBUG) << op_name << " evaluator second parameter is scalar";
AbstractBasePtrList axis_list = {dyn_cast<AbstractScalar>(args_spec_list[1])}; AbstractBasePtrList axis_list = {dyn_cast<AbstractScalar>(args_spec_list[1])};
axis = std::make_shared<AbstractTuple>(axis_list); axis = std::make_shared<AbstractTuple>(axis_list);
} else if (args_spec_list[1]->isa<AbstractTuple>()) { } else if (args_spec_list[1]->isa<AbstractTuple>()) {
MS_LOG(DEBUG) << "" << op_name << " evaluator second parameter is tuple"; MS_LOG(DEBUG) << op_name << " evaluator second parameter is tuple";
axis = args_spec_list[1]->cast<AbstractTuplePtr>(); axis = args_spec_list[1]->cast<AbstractTuplePtr>();
} else { } else {
MS_LOG(EXCEPTION) << "" << op_name << " evaluator second parameter should be a scalar or tuple, but got " MS_LOG(EXCEPTION) << op_name << " evaluator second parameter should be a scalar or tuple, but got "
<< args_spec_list[1]->ToString(); << args_spec_list[1]->ToString();
} }
auto axis_value = axis->BuildValue(); auto axis_value = axis->BuildValue();
if (axis_value->isa<AnyValue>()) { if (axis_value->isa<AnyValue>()) {
MS_LOG(EXCEPTION) << "" << op_name MS_LOG(EXCEPTION) << op_name
<< " evaluator shape's data field can't be anything: " << args_spec_list[1]->ToString(); << " evaluator shape's data field can't be anything: " << args_spec_list[1]->ToString();
} }
auto axis_value_ptr = axis_value->cast<ValueTuplePtr>(); auto axis_value_ptr = axis_value->cast<ValueTuplePtr>();
......
...@@ -24,36 +24,35 @@ namespace mindspore { ...@@ -24,36 +24,35 @@ namespace mindspore {
namespace prim { namespace prim {
PrimToFunction::PrimToFunction() PrimToFunction::PrimToFunction()
: prim_func_type_map_({ : prim_func_type_map_({// ONE_ARG prim
// ONE_ARG prim {"bool_not", kPrimTypeOneArg},
{"bool_not", kPrimTypeOneArg}, {"scalar_cos", kPrimTypeOneArg},
{"scalar_cos", kPrimTypeOneArg}, {"scalar_exp", kPrimTypeOneArg},
{"scalar_exp", kPrimTypeOneArg}, {"scalar_floor", kPrimTypeOneArg},
{"scalar_floor", kPrimTypeOneArg}, {"scalar_log", kPrimTypeOneArg},
{"scalar_log", kPrimTypeOneArg}, {"scalar_sin", kPrimTypeOneArg},
{"scalar_sin", kPrimTypeOneArg}, {"scalar_tan", kPrimTypeOneArg},
{"scalar_tan", kPrimTypeOneArg}, {"scalar_trunc", kPrimTypeOneArg},
{"scalar_trunc", kPrimTypeOneArg}, {"typeof", kPrimTypeOneArg},
{"typeof", kPrimTypeOneArg}, {"scalar_uadd", kPrimTypeOneArg},
{"scalar_uadd", kPrimTypeOneArg}, {"scalar_usub", kPrimTypeOneArg},
{"scalar_usub", kPrimTypeOneArg}, // TWO_ARGS prim
// TWO_ARGS prim {"scalar_add", kPrimTypeTwoArgs},
{"scalar_add", kPrimTypeTwoArgs}, {"bool_and", kPrimTypeTwoArgs},
{"bool_and", kPrimTypeTwoArgs}, {"bool_eq", kPrimTypeTwoArgs},
{"bool_eq", kPrimTypeTwoArgs}, {"bool_or", kPrimTypeTwoArgs},
{"bool_or", kPrimTypeTwoArgs}, {"scalar_div", kPrimTypeTwoArgs},
{"scalar_div", kPrimTypeTwoArgs}, {"scalar_eq", kPrimTypeTwoArgs},
{"scalar_eq", kPrimTypeTwoArgs}, {"scalar_ge", kPrimTypeTwoArgs},
{"scalar_ge", kPrimTypeTwoArgs}, {"scalar_gt", kPrimTypeTwoArgs},
{"scalar_gt", kPrimTypeTwoArgs}, {"scalar_le", kPrimTypeTwoArgs},
{"scalar_le", kPrimTypeTwoArgs}, {"scalar_lt", kPrimTypeTwoArgs},
{"scalar_lt", kPrimTypeTwoArgs}, {"scalar_ne", kPrimTypeTwoArgs},
{"scalar_ne", kPrimTypeTwoArgs}, {"scalar_mod", kPrimTypeTwoArgs},
{"scalar_mod", kPrimTypeTwoArgs}, {"scalar_mul", kPrimTypeTwoArgs},
{"scalar_mul", kPrimTypeTwoArgs}, {"scalar_pow", kPrimTypeTwoArgs},
{"scalar_pow", kPrimTypeTwoArgs}, {"scalar_sub", kPrimTypeTwoArgs},
{"scalar_sub", kPrimTypeTwoArgs}, {"scalar_floordiv", kPrimTypeTwoArgs}}) {}
}) {}
bool PrimToFunction::GetFunction(const PrimitivePtr& prim, FunctionPtr* const func) const { bool PrimToFunction::GetFunction(const PrimitivePtr& prim, FunctionPtr* const func) const {
bool result = false; bool result = false;
......
...@@ -52,6 +52,8 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { ...@@ -52,6 +52,8 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{prim::kPrimSwitch, {InferImplSwitch, true}}, {prim::kPrimSwitch, {InferImplSwitch, true}},
{prim::kPrimIs_, {InferImplIs_, true}}, {prim::kPrimIs_, {InferImplIs_, true}},
{prim::kPrimIsNot, {InferImplIsNot, true}}, {prim::kPrimIsNot, {InferImplIsNot, true}},
{prim::kPrimInDict, {InferImplInDict, true}},
{prim::kPrimNotInDict, {InferImplNotInDict, true}},
// Maths // Maths
{prim::kPrimMaximumGrad, {InferImplMinOrMaxGrad, true}}, {prim::kPrimMaximumGrad, {InferImplMinOrMaxGrad, true}},
{prim::kPrimMinimumGrad, {InferImplMinOrMaxGrad, true}}, {prim::kPrimMinimumGrad, {InferImplMinOrMaxGrad, true}},
...@@ -91,6 +93,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { ...@@ -91,6 +93,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{prim::kPrimMakeRange, {InferImplMakeRange, false}}, {prim::kPrimMakeRange, {InferImplMakeRange, false}},
{prim::kPrimStopGradient, {InferImplStopGradient, false}}, {prim::kPrimStopGradient, {InferImplStopGradient, false}},
{prim::kPrimStringEqual, {InferImplStringEqual, false}}, {prim::kPrimStringEqual, {InferImplStringEqual, false}},
{prim::kPrimStringConcat, {InferImplStringConcat, false}},
{prim::kPrimDictLen, {InferImplDictLen, false}}, {prim::kPrimDictLen, {InferImplDictLen, false}},
// NN // NN
{prim::kPrimPooling, {InferImplPooling, true}}, {prim::kPrimPooling, {InferImplPooling, true}},
...@@ -988,6 +991,8 @@ PrimitiveToImplMap &GetUniformPrimitiveToImplMap() { ...@@ -988,6 +991,8 @@ PrimitiveToImplMap &GetUniformPrimitiveToImplMap() {
{prim::kPrimScalarMul, {prim::ScalarMul, true, nullptr, true}}, {prim::kPrimScalarMul, {prim::ScalarMul, true, nullptr, true}},
{prim::kPrimScalarDiv, {prim::ScalarDiv, true, nullptr, true}}, {prim::kPrimScalarDiv, {prim::ScalarDiv, true, nullptr, true}},
{prim::kPrimScalarMod, {prim::ScalarMod, true, nullptr, true}}, {prim::kPrimScalarMod, {prim::ScalarMod, true, nullptr, true}},
{prim::kPrimScalarPow, {prim::ScalarPow, true, nullptr, true}},
{prim::kPrimScalarFloordiv, {prim::ScalarFloordiv, true, nullptr, true}},
{prim::kPrimScalarUadd, {prim::ScalarUAdd, true, nullptr, true}}, {prim::kPrimScalarUadd, {prim::ScalarUAdd, true, nullptr, true}},
{prim::kPrimScalarUsub, {prim::ScalarUSub, true, nullptr, true}}, {prim::kPrimScalarUsub, {prim::ScalarUSub, true, nullptr, true}},
{prim::kPrimScalarLog, {prim::ScalarLog, true, nullptr, true}}, {prim::kPrimScalarLog, {prim::ScalarLog, true, nullptr, true}},
......
...@@ -178,6 +178,10 @@ AbstractBasePtr InferImplIs_(const AnalysisEnginePtr &, const PrimitivePtr &, ...@@ -178,6 +178,10 @@ AbstractBasePtr InferImplIs_(const AnalysisEnginePtr &, const PrimitivePtr &,
const AbstractBasePtrList &args_spec_list); const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplIsNot(const AnalysisEnginePtr &, const PrimitivePtr &, AbstractBasePtr InferImplIsNot(const AnalysisEnginePtr &, const PrimitivePtr &,
const AbstractBasePtrList &args_spec_list); const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplInDict(const AnalysisEnginePtr &, const PrimitivePtr &,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplNotInDict(const AnalysisEnginePtr &, const PrimitivePtr &,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplPooling(const AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr InferImplPooling(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list); const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplPoolingGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr InferImplPoolingGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
...@@ -287,6 +291,8 @@ AbstractBasePtr InferImplStopGradient(const AnalysisEnginePtr &, const Primitive ...@@ -287,6 +291,8 @@ AbstractBasePtr InferImplStopGradient(const AnalysisEnginePtr &, const Primitive
const AbstractBasePtrList &args_spec_list); const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplStringEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr InferImplStringEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list); const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplStringConcat(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplDictLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr InferImplDictLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list); const AbstractBasePtrList &args_spec_list);
......
...@@ -19,6 +19,9 @@ from .add_impl import add ...@@ -19,6 +19,9 @@ from .add_impl import add
from .sub_impl import sub from .sub_impl import sub
from .mul_impl import mul from .mul_impl import mul
from .div_impl import div from .div_impl import div
from .pow_impl import pow_
from .floordiv_impl import floordiv
from .mod_impl import mod
from .getitem_impl import getitem from .getitem_impl import getitem
from .zeros_like_impl import zeros_like from .zeros_like_impl import zeros_like
from .ones_like_impl import ones_like from .ones_like_impl import ones_like
...@@ -38,6 +41,9 @@ __all__ = [ ...@@ -38,6 +41,9 @@ __all__ = [
'sub', 'sub',
'mul', 'mul',
'div', 'div',
'pow_',
'floordiv',
'mod',
'uadd', 'uadd',
'zeros_like', 'zeros_like',
'ones_like', 'ones_like',
......
...@@ -69,6 +69,21 @@ def _scalar_add_scalar(x, y): ...@@ -69,6 +69,21 @@ def _scalar_add_scalar(x, y):
return F.scalar_add(x, y) return F.scalar_add(x, y)
@add.register("String", "String")
def _string_concat_string(x, y):
"""
Concatenate the string y to the string x.
Args:
x (str): The first input string.
y (str): the second input string.
Returns:
str, concatenate the y to the x.
"""
return F.string_concat(x, y)
@add.register("Number", "Tensor") @add.register("Number", "Tensor")
def _scalar_add_tensor(x, y): def _scalar_add_tensor(x, y):
""" """
...@@ -81,8 +96,7 @@ def _scalar_add_tensor(x, y): ...@@ -81,8 +96,7 @@ def _scalar_add_tensor(x, y):
Returns: Returns:
Tensor, has the same dtype as x. Tensor, has the same dtype as x.
""" """
z = F.scalar_to_tensor(x, F.dtype(y)) return F.tensor_add(x, y)
return F.tensor_add(z, y)
@add.register("Tensor", "Number") @add.register("Tensor", "Number")
...@@ -97,8 +111,7 @@ def _tensor_add_scalar(x, y): ...@@ -97,8 +111,7 @@ def _tensor_add_scalar(x, y):
Returns: Returns:
Tensor, has the same dtype as x. Tensor, has the same dtype as x.
""" """
z = F.scalar_to_tensor(y, F.dtype(x)) return F.tensor_add(x, y)
return F.tensor_add(x, z)
@add.register("Tensor", "Tensor") @add.register("Tensor", "Tensor")
......
...@@ -68,8 +68,7 @@ def _scalar_div_tensor(x, y): ...@@ -68,8 +68,7 @@ def _scalar_div_tensor(x, y):
Returns: Returns:
Tensor, has the same dtype as x. Tensor, has the same dtype as x.
""" """
z = F.scalar_to_tensor(x, F.dtype(y)) return F.tensor_div(x, y)
return F.tensor_div(z, y)
@div.register("Tensor", "Number") @div.register("Tensor", "Number")
...@@ -84,5 +83,4 @@ def _tensor_div_scalar(x, y): ...@@ -84,5 +83,4 @@ def _tensor_div_scalar(x, y):
Returns: Returns:
Tensor, has the same dtype as x. Tensor, has the same dtype as x.
""" """
z = F.scalar_to_tensor(y, F.dtype(x)) return F.tensor_div(x, y)
return F.tensor_div(x, z)
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Implementation for internal polymorphism `floordiv` operations."""
from ...composite import base
from ... import functional as F
floordiv = base.MultitypeFuncGraph("floordiv")
"""
`floordiv` is a metafuncgraph object which will compute the floordiv of two objects
using ".register" decorator.
"""
@floordiv.register("Number", "Number")
def _floordiv_scalar(x, y):
"""Returns x // y where x and y are all scalars."""
return F.scalar_floordiv(x, y)
@floordiv.register("Tensor", "Tensor")
def _floordiv_tensor(x, y):
"""Returns x // y where x and y are all tensors and have save dtype."""
return F.tensor_floordiv(x, y)
@floordiv.register("Tensor", "Number")
def _tensor_floordiv_scalar(x, y):
"""Returns x // y where x is a tensor and y is a scalar. x and y should have same dtype."""
return F.tensor_floordiv(x, y)
@floordiv.register("Number", "Tensor")
def _scalar_floordiv_tensor(x, y):
"""Returns x // y where x is a scalar and y is a tensor. x and y should have same dtype."""
return F.tensor_floordiv(x, y)
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Implementation for internal polymorphism `mod` operations."""
from ...composite import base
from ... import functional as F
mod = base.MultitypeFuncGraph("mod")
"""
`mod` is a metafuncgraph object which will compute the mod of two objects
using ".register" decorator.
"""
@mod.register("Number", "Number")
def _mod_scalar(x, y):
"""Returns x % y where x and y are all scalars."""
return F.scalar_mod(x, y)
@mod.register("Tensor", "Tensor")
def _mod_tensor(x, y):
"""Returns x % y where x and y are all tensors and have save dtype."""
return F.tensor_mod(x, y)
@mod.register("Tensor", "Number")
def _tensor_mod_scalar(x, y):
"""Returns x % y where x is a tensor and y is a scalar. x and y should have same dtype."""
return F.tensor_mod(x, y)
@mod.register("Number", "Tensor")
def _scalar_mod_tensor(x, y):
"""Returns x % y where x is a scalar and y is a tensor. x and y should have same dtype."""
return F.tensor_mod(x, y)
...@@ -56,8 +56,7 @@ def _scalar_mul_tensor(x, y): ...@@ -56,8 +56,7 @@ def _scalar_mul_tensor(x, y):
Outputs: Outputs:
Tensor, has the same dtype as x. Tensor, has the same dtype as x.
""" """
z = F.scalar_to_tensor(x, F.dtype(y)) return F.tensor_mul(x, y)
return F.tensor_mul(z, y)
@mul.register("Tensor", "Number") @mul.register("Tensor", "Number")
...@@ -68,5 +67,4 @@ def _tensor_mul_scalar(x, y): ...@@ -68,5 +67,4 @@ def _tensor_mul_scalar(x, y):
Outputs: Outputs:
Tensor, has the same dtype as x. Tensor, has the same dtype as x.
""" """
z = F.scalar_to_tensor(y, F.dtype(x)) return F.tensor_mul(x, y)
return F.tensor_mul(x, z)
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Implementation for internal polymorphism `pow` operations."""
from ...composite import base
from ... import functional as F
pow_ = base.MultitypeFuncGraph("pow")
"""
`pow` is a metafuncgraph object which will compute the pow of two objects
using ".register" decorator.
"""
@pow_.register("Number", "Number")
def _pow_scalar(x, y):
"""Returns x ** y where x and y are all scalars."""
return F.scalar_pow(x, y)
@pow_.register("Tensor", "Tensor")
def _pow_tensor(x, y):
"""Returns x ** y where x and y are all tensors and have save dtype."""
return F.tensor_pow(x, y)
@pow_.register("Tensor", "Number")
def _tensor_pow_scalar(x, y):
"""Returns x ** y where x is a tensor and y is a scalar. x and y should have same dtype."""
return F.tensor_pow(x, y)
@pow_.register("Number", "Tensor")
def _scalar_pow_tensor(x, y):
"""Returns x ** y where x is a scalar and y is a tensor. x and y should have same dtype."""
return F.tensor_pow(x, y)
...@@ -41,12 +41,10 @@ def _sub_tensor(x, y): ...@@ -41,12 +41,10 @@ def _sub_tensor(x, y):
@sub.register("Number", "Tensor") @sub.register("Number", "Tensor")
def _scalar_sub_tensor(x, y): def _scalar_sub_tensor(x, y):
"""Returns x - y where x is a scalar and y is a tensor. x and y should have same dtype.""" """Returns x - y where x is a scalar and y is a tensor. x and y should have same dtype."""
z = F.scalar_to_tensor(x, F.dtype(y)) return F.tensor_sub(x, y)
return F.tensor_sub(z, y)
@sub.register("Tensor", "Number") @sub.register("Tensor", "Number")
def _tensor_sub_scalar(x, y): def _tensor_sub_scalar(x, y):
"""Returns x - y where x is a tensor and y is a scalar. x and y should have same dtype.""" """Returns x - y where x is a tensor and y is a scalar. x and y should have same dtype."""
z = F.scalar_to_tensor(y, F.dtype(x)) return F.tensor_sub(x, y)
return F.tensor_sub(x, z)
...@@ -48,6 +48,9 @@ tensor_ge = P.GreaterEqual() ...@@ -48,6 +48,9 @@ tensor_ge = P.GreaterEqual()
tensor_sub = P.Sub() tensor_sub = P.Sub()
tensor_mul = P.Mul() tensor_mul = P.Mul()
tensor_div = P.RealDiv() tensor_div = P.RealDiv()
tensor_floordiv = P.FloorDiv()
tensor_pow = P.Pow()
tensor_mod = P.FloorMod()
strided_slice = P.StridedSlice() strided_slice = P.StridedSlice()
same_type_shape = P.SameTypeShape() same_type_shape = P.SameTypeShape()
equal = P.Equal() equal = P.Equal()
...@@ -83,6 +86,7 @@ scalar_add = Primitive('scalar_add') ...@@ -83,6 +86,7 @@ scalar_add = Primitive('scalar_add')
scalar_mul = Primitive('scalar_mul') scalar_mul = Primitive('scalar_mul')
scalar_sub = Primitive('scalar_sub') scalar_sub = Primitive('scalar_sub')
scalar_div = Primitive('scalar_div') scalar_div = Primitive('scalar_div')
scalar_floordiv = Primitive('scalar_floordiv')
scalar_log = Primitive('scalar_log') scalar_log = Primitive('scalar_log')
scalar_pow = Primitive('scalar_pow') scalar_pow = Primitive('scalar_pow')
scalar_gt = Primitive('scalar_gt') scalar_gt = Primitive('scalar_gt')
...@@ -95,6 +99,7 @@ scalar_uadd = Primitive('scalar_uadd') ...@@ -95,6 +99,7 @@ scalar_uadd = Primitive('scalar_uadd')
scalar_usub = Primitive('scalar_usub') scalar_usub = Primitive('scalar_usub')
scalar_mod = Primitive('scalar_mod') scalar_mod = Primitive('scalar_mod')
string_eq = Primitive('string_equal') string_eq = Primitive('string_equal')
string_concat = Primitive('string_concat')
bool_not = Primitive("bool_not") bool_not = Primitive("bool_not")
bool_or = Primitive("bool_or") bool_or = Primitive("bool_or")
bool_and = Primitive("bool_and") bool_and = Primitive("bool_and")
...@@ -104,7 +109,8 @@ logical_not = P.LogicalNot() ...@@ -104,7 +109,8 @@ logical_not = P.LogicalNot()
array_to_scalar = Primitive('array_to_scalar') array_to_scalar = Primitive('array_to_scalar')
is_ = Primitive("is_") is_ = Primitive("is_")
is_not = Primitive("is_not") is_not = Primitive("is_not")
in_dict = Primitive("in_dict")
not_in_dict = Primitive("not_in_dict")
broadcast_gradient_args = Primitive('BroadcastGradientArgs') broadcast_gradient_args = Primitive('BroadcastGradientArgs')
dot = Primitive('dot') dot = Primitive('dot')
array_reduce = Primitive('array_reduce') array_reduce = Primitive('array_reduce')
......
...@@ -667,8 +667,8 @@ class AddN(PrimitiveWithInfer): ...@@ -667,8 +667,8 @@ class AddN(PrimitiveWithInfer):
>>> return self.addN(z) >>> return self.addN(z)
>>> >>>
>>> net = NetAddN() >>> net = NetAddN()
>>> input_x = Tensor(np.array([1, 2, 3]), mindspore.int32) >>> input_x = Tensor(np.array([1, 2, 3]), mindspore.float32)
>>> input_y = Tensor(np.array([4, 5, 6]), mindspore.int32) >>> input_y = Tensor(np.array([4, 5, 6]), mindspore.float32)
>>> net(input_x, input_y, input_x, input_y) >>> net(input_x, input_y, input_x, input_y)
Tensor([10, 14, 18], shape=(3,), dtype=mindspore.int32) Tensor([10, 14, 18], shape=(3,), dtype=mindspore.int32)
""" """
......
...@@ -131,3 +131,72 @@ def test_ME_arithmetic_operator_0070(): ...@@ -131,3 +131,72 @@ def test_ME_arithmetic_operator_0070():
def test_ME_logical_operator_0020(): def test_ME_logical_operator_0020():
""" test_ME_logical_operator_0020 """ """ test_ME_logical_operator_0020 """
logical_operator_base('or') logical_operator_base('or')
def test_ops():
class OpsNet(Cell):
""" OpsNet definition """
def __init__(self, x, y):
super(OpsNet, self).__init__()
self.x = x
self.y = y
self.int = 4
self.float = 3.2
self.str_a = "hello"
self.str_b = "world"
def construct(self, x, y):
h = x // y
m = x ** y
n = x % y
r = self.x // self.y
s = self.x ** self.y
t = self.x % self.y
p = h + m + n
q = r + s + t
ret_pow = p ** q + q ** p
ret_mod = p % q + q % p
ret_floor = p // q + q // p
ret = ret_pow + ret_mod + ret_floor
if self.int > self.float:
if self.str_a + self.str_b == "helloworld":
return ret
return x
net = OpsNet(9, 2)
x = Tensor(np.random.randint(low=1, high=10, size=(2, 3, 4), dtype=np.int32))
y = Tensor(np.random.randint(low=10, high=20, size=(2, 3, 4), dtype=np.int32))
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
net(x, y)
def test_in_dict():
class InDictNet(Cell):
""" InDictNet definition """
def __init__(self, key_in, key_not_in):
super(InDictNet, self).__init__()
self.key_in = key_in
self.key_not_in = key_not_in
def construct(self, x, y, z):
d = {"a": x, "b": y}
ret_in = 1
ret_not_in = 2
if self.key_in in d:
ret_in = d[self.key_in]
if self.key_not_in not in d:
ret_not_in = z
ret = ret_in + ret_not_in
return ret
net = InDictNet("a", "c")
x = Tensor(np.random.randint(low=1, high=10, size=(2, 3, 4), dtype=np.int32))
y = Tensor(np.random.randint(low=10, high=20, size=(2, 3, 4), dtype=np.int32))
z = Tensor(np.random.randint(low=20, high=30, size=(2, 3, 4), dtype=np.int32))
context.set_context(mode=context.GRAPH_MODE)
net(x, y, z)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册