提交 1f5441d7 编写于 作者: W WilliamLian

remove inferfunction to core

上级 01158763
......@@ -393,40 +393,5 @@ ValuePtr BoolEq(const ValuePtrList &list) {
MS_LOG(EXCEPTION) << "Unsported Value for BoolEq, x: " << x->ToString() << ".";
}
std::vector<int> BroadcastShape_(std::vector<int> shpx, std::vector<int> shpy) {
int dlen = SizeToInt(shpx.size()) - SizeToInt(shpy.size());
if (dlen < 0) {
for (int i = 0; i < -dlen; ++i) {
(void)shpx.insert(shpx.begin(), 1);
}
} else if (dlen > 0) {
for (int i = 0; i < dlen; i++) {
(void)shpy.insert(shpy.begin(), 1);
}
}
if (shpx.size() != shpy.size()) {
MS_LOG(EXCEPTION) << "Failure: shpx.size() != shpy.size().";
}
std::vector<int> shp;
for (size_t i = 0; i < shpx.size(); i++) {
auto a = shpx[i];
auto b = shpy[i];
if (a == 1) {
shp.push_back(b);
} else if (b == 1) {
shp.push_back(a);
} else if (a == -1) {
shp.push_back(b);
} else if (b == -1) {
shp.push_back(a);
} else if (a == b) {
shp.push_back(a);
} else {
return std::vector<int>();
}
}
return shp;
}
} // namespace prim
} // namespace mindspore
......@@ -52,7 +52,6 @@ ValuePtr BoolNot(const ValuePtrList &list);
ValuePtr BoolAnd(const ValuePtrList &list);
ValuePtr BoolOr(const ValuePtrList &list);
ValuePtr BoolEq(const ValuePtrList &list);
std::vector<int> BroadcastShape_(std::vector<int> s1, std::vector<int> s2);
} // namespace prim
} // namespace mindspore
......
......@@ -42,28 +42,13 @@ inline const PrimitivePtr kPrimRefToEmbed = std::make_shared<Primitive>("RefToEm
inline const PrimitivePtr kPrimCreateInstance = std::make_shared<Primitive>("create_instance");
// Other miscellaneous
inline const PrimitivePtr kPrimEnvSetItem = std::make_shared<Primitive>("env_setitem");
inline const PrimitivePtr kPrimEnvGetItem = std::make_shared<Primitive>("env_getitem");
inline const PrimitivePtr kPrimEnvAdd = std::make_shared<Primitive>("env_add");
inline const PrimitivePtr kPrimMakeRefKey = std::make_shared<Primitive>("MakeRefKey");
inline const PrimitivePtr kPrimGetRefKey = std::make_shared<Primitive>("get_ref_key");
inline const PrimitivePtr kPrimGetRefValue = std::make_shared<Primitive>("get_ref_value");
inline const PrimitivePtr kPrimGetRefOrigin = std::make_shared<Primitive>("get_ref_origin");
inline const PrimitivePtr kPrimInsertGradientOf = std::make_shared<Primitive>("InsertGradientOf");
inline const PrimitivePtr kPrimCheckBprop = std::make_shared<Primitive>("CheckBprop");
inline const PrimitivePtr kPrimMakeRef = std::make_shared<Primitive>("make_ref");
inline const PrimitivePtr kPrimMixedPrecisionCast = std::make_shared<Primitive>("mixed_precision_cast");
inline const PrimitivePtr kPrimMakeRecord = std::make_shared<Primitive>("make_record");
// Structures
inline const PrimitivePtr kPrimMakeList = std::make_shared<Primitive>("make_list");
inline const PrimitivePtr kPrimMakeKeywordArg = std::make_shared<Primitive>("make_keyword_arg");
inline const PrimitivePtr kPrimListGetItem = std::make_shared<Primitive>("list_getitem");
inline const PrimitivePtr kPrimListSetItem = std::make_shared<Primitive>("list_setitem");
inline const PrimitivePtr kPrimDictGetItem = std::make_shared<Primitive>("dict_getitem");
inline const PrimitivePtr kPrimDictSetItem = std::make_shared<Primitive>("dict_setitem");
inline const PrimitivePtr kPrimListAppend = std::make_shared<Primitive>("list_append");
inline const PrimitivePtr kPrimListLen = std::make_shared<Primitive>("list_len");
inline const PrimitivePtr kPrimListMap = std::make_shared<Primitive>("list_map");
inline const PrimitivePtr kPrimListReduce = std::make_shared<Primitive>("list_reduce");
......
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_FRONTEND_OPERATE_OPS_FRONT_INFER_FUNCTION_H_
#define MINDSPORE_CCSRC_FRONTEND_OPERATE_OPS_FRONT_INFER_FUNCTION_H_
#include "abstract/abstract_value.h"
#include "abstract/primitive_infer_map.h"
namespace mindspore {
namespace abstract {
AbstractBasePtr InferImplTypeof(const AnalysisEnginePtr &, const PrimitivePtr &,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplHasType(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplBroadcastGradientArgs(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplListMap(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplListReduce(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplTupleReversed(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplReduceShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplTupleDiv(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplTuple2Array(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplShapeMul(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplTupleEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplListEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplMakeRange(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplStopGradient(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplStringEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplStringConcat(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplDictLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplJ(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplFakeBprop(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplMakeRecord(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
class RegisterFrontendPrimitiveEvalHelper {
public:
RegisterFrontendPrimitiveEvalHelper(const PrimitivePtr &primitive, const StandardPrimitiveEvalImpl &impl) {
const StandardPrimitiveImplReg impl_reg{impl, false};
RegisterStandardPrimitiveImpl(primitive, impl_reg);
}
~RegisterFrontendPrimitiveEvalHelper() = default;
};
#define REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(name, primitive, impl) \
static auto helper_##name = RegisterFrontendPrimitiveEvalHelper(primitive, impl)
} // namespace abstract
} // namespace mindspore
#endif // MINDSPORE_CCSRC_FRONTEND_OPERATE_OPS_FRONT_INFER_FUNCTION_H_
......@@ -36,115 +36,12 @@
#include "utils/convert_utils.h"
#include "utils/ms_context.h"
#include "pipeline/jit/parse/data_converter.h"
#include "abstract/primitive_infer_map.h"
#include "abstract/param_validator.h"
#include "utils/ms_utils.h"
namespace mindspore {
namespace abstract {
PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
static PrimitiveEvalImplMap prim_eval_implement_map = {
// Statements
{prim::kPrimReturn, {InferImplReturn, true}},
{prim::kPrimTypeOf, {InferImplTypeof, false}},
{prim::kPrimHasType, {InferImplHasType, false}},
{prim::kPrimDot, {InferImplDot, true}},
{prim::kPrimSwitch, {InferImplSwitch, true}},
{prim::kPrimSwitchLayer, {InferImplSwitchLayer, true}},
{prim::kPrimIs_, {InferImplIs_, true}},
{prim::kPrimIsNot, {InferImplIsNot, true}},
{prim::kPrimInDict, {InferImplInDict, true}},
{prim::kPrimNotInDict, {InferImplNotInDict, true}},
{prim::kPrimIsConsant, {InferImplIsConstant, true}},
// Maths
{prim::kPrimMaximumGrad, {InferImplMinOrMaxGrad, true}},
{prim::kPrimMinimumGrad, {InferImplMinOrMaxGrad, true}},
// Array
{prim::kPrimScalarToArray, {InferImplScalarToArray, true}},
{prim::kPrimArrayToScalar, {InferImplArrayToScalar, true}},
{prim::kPrimBroadcastShape, {InferImplBroadCastShape, true}},
{prim::kPrimPack, {InferImplPack, true}},
{prim::kPrimUnique, {InferImplUnique, true}},
{prim::kPrimUniqueGrad, {InferImplUniqueGrad, true}},
// Structure
{prim::kPrimMakeTuple, {InferImplMakeTuple, true}},
{prim::kPrimMakeList, {InferImplMakeList, true}},
{prim::kPrimMakeDict, {InferImplMakeDict, true}},
{prim::kPrimMakeSlice, {InferImplMakeSlice, true}},
{prim::kPrimMakeKeywordArg, {InferImplMakeKwarg, true}},
{prim::kPrimExtractKeywordArg, {InferImplExtractKwarg, true}},
{prim::kPrimMakeRecord, {InferImplMakeRecord, false}},
{prim::kPrimTupleGetItem, {InferImplTupleGetItem, true}},
{prim::kPrimListGetItem, {InferImplListGetItem, true}},
{prim::kPrimTupleSetItem, {InferImplTupleSetItem, true}},
{prim::kPrimListSetItem, {InferImplListSetItem, true}},
{prim::kPrimDictGetItem, {InferImplDictGetItem, true}},
{prim::kPrimDictSetItem, {InferImplDictSetItem, true}},
{prim::kPrimListAppend, {InferImplListAppend, true}},
{prim::kPrimTupleLen, {InferImplTupleLen, true}},
{prim::kPrimListLen, {InferImplListLen, true}},
{prim::kPrimArrayLen, {InferImplArrayLen, true}},
{prim::kPrimListMap, {InferImplListMap, false}},
{prim::kPrimListReduce, {InferImplListReduce, false}},
{prim::kPrimTupleReversed, {InferImplTupleReversed, false}},
{prim::kPrimReducedShape, {InferImplReduceShape, false}},
{prim::kPrimTupleDiv, {InferImplTupleDiv, false}},
{prim::kPrimTupleToArray, {InferImplTuple2Array, false}},
{prim::kPrimShapeMul, {InferImplShapeMul, false}},
{prim::kPrimTupleEqual, {InferImplTupleEqual, false}},
{prim::kPrimListEqual, {InferImplListEqual, false}},
{prim::kPrimMakeRange, {InferImplMakeRange, false}},
{prim::kPrimStopGradient, {InferImplStopGradient, false}},
{prim::kPrimStringEqual, {InferImplStringEqual, false}},
{prim::kPrimStringConcat, {InferImplStringConcat, false}},
{prim::kPrimDictLen, {InferImplDictLen, false}},
// NN
{prim::kPrimPooling, {InferImplPooling, true}},
{prim::kPrimPoolingGrad, {InferImplPoolingGrad, true}},
{prim::kPrimFusedBatchNorm, {InferImplFusedBatchNorm, true}},
{prim::kPrimFusedBatchNormGrad, {InferImplFusedBatchNormGrad, true}},
{prim::kPrimReluGrad, {InferImplReluGrad, true}},
{prim::kPrimConv2DBackpropInput, {InferImplConv2DBackpropInput, true}},
{prim::kPrimConv2DBackpropFilter, {InferImplConv2DBackpropFilter, true}},
{prim::kPrimBiasAddGrad, {InferImplBiasAddGrad, true}},
{prim::kPrimRelu, {InferImplRelu, true}},
{prim::kPrimFakeBprop, {InferImplFakeBprop, false}},
{prim::kPrimZerosLike, {InferImplZerosLike, true}},
{prim::kPrimBpropCut, {InferImplBpropCut, true}},
{prim::kPrimLayerNorm, {InferImplLayerNorm, true}},
{prim::kPrimLayerNormGrad, {InferImplLayerNormGrad, true}},
{prim::kPrimDropoutGenMask, {InferImplDropoutGenMask, true}},
// Others
{prim::kPrimIdentity, {InferImplIdentity, true}},
// Set impl to null as it will use PartialEvaluator;
{prim::kPrimPartial, {nullptr, true}},
{prim::kPrimJ, {InferImplJ, false}},
{prim::kPrimEnvGetItem, {InferImplEnvGetItem, true}},
{prim::kPrimEnvSetItem, {InferImplEnvSetItem, true}},
{prim::kPrimEnvAdd, {InferImplEnvAdd, true}},
{prim::kPrimMakeRefKey, {InferImplMakeRefKey, true}},
{prim::kPrimMakeRef, {InferImplMakeRef, true}},
{prim::kPrimGetRefKey, {InferImplGetRefKey, true}},
{prim::kPrimGetRefValue, {InferImplGetRefValue, true}},
{prim::kPrimStateSetItem, {InferImplStateSetItem, true}},
{prim::kPrimDepend, {InferImplDepend, true}},
{prim::kPrimBroadcastGradientArgs, {InferImplBroadcastGradientArgs, false}},
{prim::kPrimControlDepend, {InferImplControlDepend, true}},
// Debug
{prim::kPrimDebug, {InferImplDebug, true}},
// RowTensor
{prim::kPrimMakeRowTensor, {InferImplMakeRowTensor, true}},
{prim::kPrimRowTensorGetValues, {InferImplRowTensorGetValues, true}},
{prim::kPrimRowTensorGetIndices, {InferImplRowTensorGetIndices, true}},
{prim::kPrimRowTensorGetDenseShape, {InferImplRowTensorGetDenseShape, true}},
// SparseTensor
{prim::kPrimMakeSparseTensor, {InferImplMakeSparseTensor, true}},
{prim::kPrimSparseTensorGetValues, {InferImplSparseTensorGetValues, true}},
{prim::kPrimSparseTensorGetIndices, {InferImplSparseTensorGetIndices, true}},
{prim::kPrimSparseTensorGetDenseShape, {InferImplSparseTensorGetDenseShape, true}},
};
return prim_eval_implement_map;
}
using mindspore::parse::PyObjectWrapper;
std::unordered_set<std::string> prims_to_skip_undetermined_infer{"make_tuple", "make_list", "switch", "env_setitem",
......
......@@ -26,19 +26,10 @@
#include <vector>
#include "pipeline/jit/static_analysis/evaluator.h"
#include "abstract/primitive_infer_map.h"
namespace mindspore {
namespace abstract {
using StandardPrimitiveEvalImpl = AbstractBasePtr (*)(const AnalysisEnginePtr &, const PrimitivePtr &,
const AbstractBasePtrList &);
struct StandartPrimitiveImplReg {
StandardPrimitiveEvalImpl impl_; // Implement function of Primitive.
bool in_white_list_; // true if this Primitive in white list, else false.
};
using PrimitiveEvalImplMap =
std::unordered_map<PrimitivePtr, StandartPrimitiveImplReg, PrimitiveHasher, PrimitiveEqual>;
class StandardPrimEvaluator : public TrivialPrimEvaluator {
public:
StandardPrimEvaluator(const PrimitivePtr primitive, StandardPrimitiveEvalImpl eval_impl)
......@@ -179,191 +170,6 @@ bool IsSubtype(const AbstractBasePtr x, const TypePtr model);
void ClearPrimEvaluatorMap();
py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base);
AbstractBasePtr InferImplReturn(const AnalysisEnginePtr &, const PrimitivePtr &,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplTypeof(const AnalysisEnginePtr &, const PrimitivePtr &,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplHasType(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplDot(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplSwitch(const AnalysisEnginePtr &, const PrimitivePtr &,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplSwitchLayer(const AnalysisEnginePtr &, const PrimitivePtr &,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplIs_(const AnalysisEnginePtr &, const PrimitivePtr &,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplIsNot(const AnalysisEnginePtr &, const PrimitivePtr &,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplInDict(const AnalysisEnginePtr &, const PrimitivePtr &,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplNotInDict(const AnalysisEnginePtr &, const PrimitivePtr &,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplIsConstant(const AnalysisEnginePtr &, const PrimitivePtr &,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplPooling(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplPoolingGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplFusedBatchNorm(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplFusedBatchNormGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplReluGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplConv2DBackpropInput(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplConv2DBackpropFilter(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplBiasAddGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplRelu(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplFakeBprop(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplZerosLike(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplBpropCut(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplLayerNorm(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplLayerNormGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplDropoutGenMask(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplMinOrMaxGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplScalarToArray(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplArrayToScalar(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplBroadCastShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplPack(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplUnique(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplUniqueGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplMakeTuple(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplMakeList(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplMakeDict(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplMakeSlice(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplMakeKwarg(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplExtractKwarg(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplMakeRecord(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplTupleGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplListGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplTupleSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplListSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplDictGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplDictSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplListAppend(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplTupleLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplListLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplArrayLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplListMap(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplListReduce(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplTupleReversed(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplReduceShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplTupleDiv(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplTuple2Array(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplShapeMul(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplGenShapeIndex(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplGenInverseIndex(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplTupleEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplListEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplMakeRange(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplStopGradient(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplStringEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplStringConcat(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplDictLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplIdentity(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplJ(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplEnvGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplEnvSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplEnvAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplMakeRefKey(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplMakeRef(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplGetRefKey(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplGetRefValue(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplGetRefOrigin(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplStateSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplDepend(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplBroadcastGradientArgs(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplControlDepend(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplDebug(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplMakeRowTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplRowTensorGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplRowTensorGetIndices(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplRowTensorGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplMakeSparseTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplSparseTensorGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplSparseTensorGetIndices(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplSparseTensorGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
} // namespace abstract
} // namespace mindspore
......
/**
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
*
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_ABSTRACT_INFER_FUNCTIONS_H_
#define MINDSPORE_CORE_ABSTRACT_INFER_FUNCTIONS_H_
#include <string>
#include <memory>
#include "abstract/abstract_value.h"
#include "abstract/param_validator.h"
#include "base/core_ops.h"
namespace mindspore {
namespace abstract {
AbstractBasePtr InferImplReturn(const AnalysisEnginePtr &, const PrimitivePtr &,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplDot(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplSwitch(const AnalysisEnginePtr &, const PrimitivePtr &,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplSwitchLayer(const AnalysisEnginePtr &, const PrimitivePtr &,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplIs_(const AnalysisEnginePtr &, const PrimitivePtr &,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplIsNot(const AnalysisEnginePtr &, const PrimitivePtr &,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplInDict(const AnalysisEnginePtr &, const PrimitivePtr &,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplNotInDict(const AnalysisEnginePtr &, const PrimitivePtr &,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplIsConstant(const AnalysisEnginePtr &, const PrimitivePtr &,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplPooling(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplPoolingGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplFusedBatchNorm(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplFusedBatchNormGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplReluGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplConv2DBackpropInput(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplConv2DBackpropFilter(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplBiasAddGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplGelu(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplGeluGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplRelu(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplZerosLike(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplBpropCut(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplLayerNorm(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplLayerNormGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplDropoutGenMask(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplMinOrMaxGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplScalarToArray(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplArrayToScalar(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplBroadCastShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplPack(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplMakeTuple(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplMakeList(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplMakeDict(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplMakeSlice(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplMakeKwarg(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplExtractKwarg(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplMakeRecord(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplTupleGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplListGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplTupleSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplListSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplDictGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplDictSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplListAppend(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplTupleLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplListLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplArrayLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplGenShapeIndex(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplGenInverseIndex(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplIdentity(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplEnvGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplEnvSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplEnvAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplMakeRefKey(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplMakeRef(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplGetRefKey(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplGetRefValue(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplGetRefOrigin(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplStateSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplDepend(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplControlDepend(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplDebug(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplMakeSparseTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplSparseTensorGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplSparseTensorGetIndices(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplSparseTensorGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplMakeRowTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplRowTensorGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplRowTensorGetIndices(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplRowTensorGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplUniqueGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplUnique(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
template <typename T>
AbstractBasePtr InferTupleOrListOrDictLen(const std::string &op_name, const AbstractBasePtrList &args_spec_list) {
// Inputs: a tuple or list or dict.
CheckArgsSize(op_name, args_spec_list, 1);
auto arg = CheckArg<T>(op_name, args_spec_list, 0);
return std::make_shared<AbstractScalar>(SizeToInt(arg->size()));
}
} // namespace abstract
} // namespace mindspore
#endif // MINDSPORE_CORE_ABSTRACT_INFER_FUNCTIONS_H_
......@@ -14,13 +14,48 @@
* limitations under the License.
*/
#include "pipeline/jit/static_analysis/prim.h"
#include "abstract/infer_functions.h"
#include "abstract/utils.h"
#include "frontend/operator/cc_implementations.h"
#include "abstract/param_validator.h"
namespace mindspore {
namespace abstract {
namespace {
std::vector<int> BroadcastShape(std::vector<int> shpx, std::vector<int> shpy) {
int dlen = SizeToInt(shpx.size()) - SizeToInt(shpy.size());
if (dlen < 0) {
for (int i = 0; i < -dlen; ++i) {
(void)shpx.insert(shpx.begin(), 1);
}
} else if (dlen > 0) {
for (int i = 0; i < dlen; i++) {
(void)shpy.insert(shpy.begin(), 1);
}
}
if (shpx.size() != shpy.size()) {
MS_LOG(EXCEPTION) << "Failure: shpx.size() != shpy.size().";
}
std::vector<int> shp;
for (size_t i = 0; i < shpx.size(); i++) {
auto a = shpx[i];
auto b = shpy[i];
if (a == 1) {
shp.push_back(b);
} else if (b == 1) {
shp.push_back(a);
} else if (a == -1) {
shp.push_back(b);
} else if (b == -1) {
shp.push_back(a);
} else if (a == b) {
shp.push_back(a);
} else {
return std::vector<int>();
}
}
return shp;
}
} // namespace
AbstractBasePtr InferImplScalarToArray(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: a scalar.
......@@ -65,7 +100,7 @@ AbstractBasePtr InferImplBroadCastShape(const AnalysisEnginePtr &, const Primiti
(void)std::transform(std::begin(shp_tuple_y), std::end(shp_tuple_y), std::back_inserter(shp_y),
[](const ValuePtr &e) -> int { return GetValue<int>(e); });
std::vector<int> res = prim::BroadcastShape_(shp_x, shp_y);
std::vector<int> res = BroadcastShape(shp_x, shp_y);
if (res.empty()) {
MS_LOG(EXCEPTION) << "BroadcastShape fail: " << args_spec_list[0]->ToString() << ","
<< args_spec_list[1]->ToString();
......
......@@ -15,8 +15,7 @@
*/
#include "abstract/param_validator.h"
#include "pipeline/jit/static_analysis/prim.h"
#include "frontend/operator/ops.h"
#include "abstract/infer_functions.h"
#include "abstract/utils.h"
#include "utils/symbolic.h"
......
......@@ -14,8 +14,7 @@
* limitations under the License.
*/
#include "pipeline/jit/static_analysis/prim.h"
#include "frontend/operator/ops.h"
#include "abstract/infer_functions.h"
#include "abstract/utils.h"
#include "abstract/param_validator.h"
#include "utils/ms_utils.h"
......
......@@ -14,10 +14,12 @@
* limitations under the License.
*/
#include "pipeline/jit/static_analysis/prim.h"
#include "frontend/operator/ops.h"
#include "abstract/infer_functions.h"
#include "abstract/utils.h"
#include "abstract/param_validator.h"
#include "utils/check_convert_utils.h"
#include "c_ops/conv2d.h"
#include "abstract/primitive_infer_map.h"
namespace mindspore {
namespace abstract {
......@@ -278,13 +280,6 @@ AbstractBasePtr InferImplZerosLike(const AnalysisEnginePtr &, const PrimitivePtr
return args_spec_list[0]->Broaden();
}
AbstractBasePtr InferImplFakeBprop(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: a tensor.
CheckArgsSize(primitive->name(), args_spec_list, 1);
return args_spec_list[0]->Broaden();
}
AbstractBasePtr InferImplBpropCut(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: a tensor.
......@@ -433,5 +428,91 @@ AbstractBasePtr InferImplDropoutGenMask(const AnalysisEnginePtr &, const Primiti
return std::make_shared<AbstractTensor>(std::make_shared<AbstractScalar>(kAnyValue, kUInt8),
std::make_shared<Shape>(std::vector<int64_t>{shape_y}));
}
abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto conv_prim = primitive->cast<PrimConv2dPtr>();
MS_EXCEPTION_IF_NULL(conv_prim);
auto prim_name = conv_prim->name();
CheckAndConvertUtils::CheckInRange("Conv2d Infer", input_args.size(), kIncludeLeft, {2, 3}, prim_name);
auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShape("w_shape", input_args[0]->GetShapeTrack(), prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[1]->GetShapeTrack(), prim_name);
CheckAndConvertUtils::CheckInteger("weight rank", w_shape.size(), kEqual, 4, prim_name);
CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kEqual, 4, prim_name);
CheckAndConvertUtils::Check("x_shape[1] / group", x_shape[1] / conv_prim->GetGroup(), kEqual, "w_shape[1]",
w_shape[1], conv_prim->name());
auto out_channel = conv_prim->GetOutputChannel();
CheckAndConvertUtils::Check("out_channel", out_channel, kEqual, "w_shape[0]", w_shape[0], conv_prim->name());
std::vector<int> temp_w;
std::copy(w_shape.begin() + 2, w_shape.end(), std::back_inserter(temp_w));
CheckAndConvertUtils::Check("kernel_size", conv_prim->GetKernelSize(), kEqual, "w_shape[2:4]", temp_w,
conv_prim->name());
auto kernel_size_h = w_shape[2];
auto kernel_size_w = w_shape[3];
auto stride = conv_prim->GetStride();
auto dilation = conv_prim->GetDilation();
auto stride_h = stride[2];
auto stride_w = stride[3];
auto dilation_h = dilation[2];
auto dilation_w = dilation[3];
int h_out = -1;
int w_out = -1;
std::vector<int> pad_list(4, 0);
auto pad_mode = conv_prim->GetPadMode();
if (pad_mode == "valid") {
h_out = ceil((x_shape[2] - dilation_h * (kernel_size_h - 1)) / stride_h);
w_out = ceil((x_shape[3] - dilation_w * (kernel_size_w - 1)) / stride_w);
} else if (pad_mode == "same") {
h_out = ceil(x_shape[2] / stride_h);
w_out = ceil(x_shape[3] / stride_w);
auto pad_needed_h = std::max(0, (h_out - 1) * stride_h + dilation_h * (kernel_size_h - 1) + 1 - x_shape[2]);
pad_list.emplace_back(floor(pad_needed_h / 2));
pad_list.emplace_back(pad_needed_h / 2);
auto pad_needed_w = std::max(0, (w_out - 1) * stride_w + dilation_w * (kernel_size_w - 1) + 1 - x_shape[3]);
auto pad_left = floor(pad_needed_w / 2);
pad_list.emplace_back(pad_left);
pad_list.emplace_back(pad_needed_h - pad_left);
} else if (pad_mode == "pad") {
std::copy(conv_prim->GetPad().begin(), conv_prim->GetPad().end(), std::back_inserter(pad_list));
auto pad_top = conv_prim->GetPad()[0];
auto pad_bottom = conv_prim->GetPad()[1];
auto pad_right = conv_prim->GetPad()[2];
auto pad_left = conv_prim->GetPad()[3];
h_out = 1 + (x_shape[2] + pad_top + pad_bottom - kernel_size_h - (kernel_size_h - 1) * (dilation_h - 1)) / stride_h;
w_out = 1 + (x_shape[3] + pad_left + pad_right - kernel_size_w - (kernel_size_w - 1) * (dilation_w - 1)) / stride_w;
h_out = floor(h_out);
w_out = floor(w_out);
}
conv_prim->SetPadList(pad_list);
std::vector<int> out_shape = {x_shape[0], out_channel, h_out, w_out};
return std::make_shared<abstract::Shape>(out_shape);
}
TypePtr Conv2dInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
CheckAndConvertUtils::CheckInRange("", input_args.size(), kIncludeLeft, {2, 3}, prim->name());
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
auto x_type = CheckAndConvertUtils::ConvertTypePtrToTypeId("x_dtype", input_args[0]->GetTypeTrack(), prim->name());
const std::set<TypeId> valid_types = {kNumberTypeInt8, kNumberTypeInt32, kNumberTypeFloat16, kNumberTypeFloat32};
std::map<std::string, TypePtr> types;
types.emplace("x", input_args[0]->GetTypeTrack());
types.emplace("w", input_args[1]->GetTypeTrack());
CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
if (x_type == kNumberTypeInt8) {
return std::make_shared<TensorType>(TypeIdToType(kNumberTypeInt32));
}
return std::make_shared<TensorType>(TypeIdToType(x_type));
}
AbstractBasePtr Conv2dInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
return std::make_shared<abstract::AbstractTensor>(Conv2dInferType(primitive, input_args),
Conv2dInferShape(primitive, input_args)->shape());
}
REGISTER_PRIMITIVE_EVAL_IMPL(Conv2D, prim::kPrimConv2D, Conv2dInfer);
} // namespace abstract
} // namespace mindspore
......@@ -19,9 +19,9 @@
#include "ir/dtype.h"
#include "utils/ms_utils.h"
#include "frontend/operator/ops.h"
#include "base/core_ops.h"
#include "abstract/param_validator.h"
#include "pipeline/jit/static_analysis/prim.h"
#include "abstract/infer_functions.h"
#include "abstract/utils.h"
#include "utils/ms_context.h"
#include "utils/symbolic.h"
......@@ -35,27 +35,6 @@ AbstractBasePtr InferImplIdentity(const AnalysisEnginePtr &, const PrimitivePtr
return args_spec_list[0];
}
AbstractBasePtr InferImplJ(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// args: An object of AbstractFunction.
CheckArgsSize(primitive->name(), args_spec_list, 1);
MS_LOG(DEBUG) << "evaluate J: " << args_spec_list[0]->ToString();
AbstractFunctionPtr x = dyn_cast<AbstractFunction>(args_spec_list[0]);
if (x == nullptr) {
return std::make_shared<AbstractJTagged>(args_spec_list[0]);
}
AbstractFuncAtomPtrList jv;
auto build_jv = [&jv](const AbstractFuncAtomPtr &func) {
auto j_closure = std::make_shared<JTransformedAbstractClosure>(func);
jv.push_back(j_closure);
};
x->Visit(build_jv);
return AbstractFunction::MakeAbstractFunction(jv);
}
AbstractBasePtr InferImplEnvGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
MS_EXCEPTION_IF_NULL(primitive);
......@@ -196,125 +175,6 @@ AbstractBasePtr InferImplDepend(const AnalysisEnginePtr &, const PrimitivePtr &p
return depends;
}
bool CompareShape(const std::vector<ValuePtr> &x_shape, const std::vector<ValuePtr> &y_shape) {
if (x_shape.size() != y_shape.size()) {
return false;
}
for (size_t i = 0; i < x_shape.size(); ++i) {
if (GetValue<int>(x_shape[i]) != GetValue<int>(y_shape[i])) {
return false;
}
}
return true;
}
enum State {
SAME,
X_ONE,
Y_ONE,
};
void ComputeReduceIndex(const std::vector<int> &reverse_x, const std::vector<int> &reverse_y,
std::vector<int> *grad_x_reduce_idx, std::vector<int> *grad_y_reduce_idy) {
const size_t n = reverse_x.size();
for (size_t i = 0; i < n; ++i) {
State curr;
const int32_t x_i = reverse_x[i];
const int32_t y_i = reverse_y[i];
const int reduce_idx = SizeToInt(n - 1 - i);
if (x_i == y_i) {
curr = SAME;
} else if (x_i == 1) {
grad_x_reduce_idx->push_back(reduce_idx);
curr = X_ONE;
} else if (y_i == 1) {
grad_y_reduce_idy->push_back(reduce_idx);
curr = Y_ONE;
} else {
MS_LOG(EXCEPTION) << "not compatible shape input for BroadcastGradientArgs";
}
if (curr == SAME && x_i == 1) {
grad_x_reduce_idx->push_back(reduce_idx);
grad_y_reduce_idy->push_back(reduce_idx);
continue;
}
}
std::reverse(grad_x_reduce_idx->begin(), grad_x_reduce_idx->end());
std::reverse(grad_y_reduce_idy->begin(), grad_y_reduce_idy->end());
}
AbstractBasePtr BroadcastGradientArgsDiff(const std::vector<ValuePtr> &x_shape, const std::vector<ValuePtr> &y_shape) {
std::vector<int> reverse_x;
std::vector<int> reverse_y;
(void)std::transform(x_shape.rbegin(), x_shape.rend(), std::back_inserter(reverse_x),
[](const ValuePtr &v) { return v->cast<Int32ImmPtr>()->value(); });
(void)std::transform(y_shape.rbegin(), y_shape.rend(), std::back_inserter(reverse_y),
[](const ValuePtr &v) { return v->cast<Int32ImmPtr>()->value(); });
if (reverse_x.size() > reverse_y.size()) {
reverse_y.resize(reverse_x.size(), 1);
} else {
reverse_x.resize(reverse_y.size(), 1);
}
std::vector<int> grad_x_reduce_idx;
std::vector<int> grad_y_reduce_idy;
ComputeReduceIndex(reverse_x, reverse_y, &grad_x_reduce_idx, &grad_y_reduce_idy);
AbstractBasePtrList abs_list_x;
AbstractBasePtrList abs_list_y;
(void)std::transform(grad_x_reduce_idx.begin(), grad_x_reduce_idx.end(), std::back_inserter(abs_list_x),
[](int v) { return abstract::FromValue(v); });
(void)std::transform(grad_y_reduce_idy.begin(), grad_y_reduce_idy.end(), std::back_inserter(abs_list_y),
[](int v) { return abstract::FromValue(v); });
auto x_reduce_idx = std::make_shared<AbstractTuple>(abs_list_x);
auto y_reduce_idx = std::make_shared<AbstractTuple>(abs_list_y);
AbstractBasePtrList elem_list;
elem_list.push_back(x_reduce_idx);
elem_list.push_back(y_reduce_idx);
return std::make_shared<AbstractTuple>(elem_list);
}
AbstractBasePtr InferImplBroadcastGradientArgs(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// this primitive get the index that need to reduce
// input: x's shape and y's shape, inputs should be tuple
// output: tuple of x and y 's reduce index, reduce index should be a tuple
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 2);
auto arg_x = CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
auto arg_y = CheckArg<AbstractTuple>(op_name, args_spec_list, 1);
ValueTuplePtr arg_x_value = arg_x->BuildValue()->cast<ValueTuplePtr>();
MS_EXCEPTION_IF_NULL(arg_x_value);
ValueTuplePtr arg_y_value = arg_y->BuildValue()->cast<ValueTuplePtr>();
MS_EXCEPTION_IF_NULL(arg_y_value);
const std::vector<ValuePtr> x_shape = arg_x_value->value();
const std::vector<ValuePtr> y_shape = arg_y_value->value();
bool is_same_shape = CompareShape(x_shape, y_shape);
// if it is the same shape , do not need reduce , return empty tuple
if (is_same_shape) {
AbstractBasePtrList empty_list;
auto x_reduce_idx = std::make_shared<AbstractTuple>(empty_list);
auto y_reduce_idx = std::make_shared<AbstractTuple>(empty_list);
AbstractBasePtrList elem_list;
elem_list.push_back(x_reduce_idx);
elem_list.push_back(y_reduce_idx);
return std::make_shared<AbstractTuple>(elem_list);
}
return BroadcastGradientArgsDiff(x_shape, y_shape);
}
AbstractBasePtr InferImplControlDepend(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// args: Two objects of a subclass of AbstractBase
......
......@@ -15,8 +15,7 @@
*/
#include "abstract/param_validator.h"
#include "pipeline/jit/static_analysis/prim.h"
#include "frontend/operator/ops.h"
#include "abstract/infer_functions.h"
#include "abstract/utils.h"
#include "utils/symbolic.h"
......@@ -34,38 +33,6 @@ AbstractBasePtr InferImplReturn(const AnalysisEnginePtr &, const PrimitivePtr &,
return abs_base;
}
AbstractBasePtr InferImplTypeof(const AnalysisEnginePtr &, const PrimitivePtr &,
const AbstractBasePtrList &args_spec_list) {
// Inputs: a pointer to an AbstractBase object
if (args_spec_list.size() != 1) {
MS_LOG(EXCEPTION) << "Typeof evaluator requires 1 parameter, while the input size is " << args_spec_list.size()
<< ".";
}
AbstractBasePtr abs_base = args_spec_list[0];
MS_EXCEPTION_IF_NULL(abs_base);
TypePtr type = abs_base->BuildType();
return std::make_shared<AbstractType>(type);
}
AbstractBasePtr InferImplHasType(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: a pointer to an AbstractBase object and a pointer to a Type
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 2);
AbstractTypePtr abs_type = CheckArg<AbstractType>(op_name, args_spec_list, 1);
auto mode_v = abs_type->GetValueTrack();
MS_EXCEPTION_IF_NULL(mode_v);
if (!mode_v->isa<Type>()) {
MS_LOG(EXCEPTION) << "Get the type from AbstractType value failed.";
}
TypePtr mode_t = mode_v->cast<TypePtr>();
MS_EXCEPTION_IF_NULL(args_spec_list[0]);
bool v = IsSubtype(args_spec_list[0], mode_t);
return std::make_shared<AbstractScalar>(std::make_shared<BoolImm>(v), kBool);
}
AbstractBasePtr InferImplDot(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: two tensors.
......
/**
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
*
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "abstract/infer_functions.h"
#include "abstract/utils.h"
#include "abstract/param_validator.h"
namespace mindspore {
namespace abstract {
AbstractBasePtr InferImplMakeTuple(const AnalysisEnginePtr &, const PrimitivePtr &,
const AbstractBasePtrList &args_spec_list) {
return std::make_shared<AbstractTuple>(args_spec_list);
}
AbstractBasePtr InferImplMakeList(const AnalysisEnginePtr &, const PrimitivePtr &,
const AbstractBasePtrList &args_spec_list) {
return std::make_shared<AbstractList>(args_spec_list);
}
AbstractBasePtr InferImplMakeDict(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: two tuples.
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 2);
AbstractTuplePtr keys = CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
AbstractTuplePtr values = CheckArg<AbstractTuple>(op_name, args_spec_list, 1);
size_t keys_size = keys->size();
if (values->size() != keys_size) {
MS_LOG(EXCEPTION) << op_name << " evaluator keys' size is not equal with values' size";
}
std::vector<AbstractAttribute> key_value;
AbstractScalarPtr key;
AbstractBasePtrList key_list = keys->elements();
AbstractBasePtrList value_list = values->elements();
for (size_t index = 0; index < keys_size; index++) {
key = CheckArg<AbstractScalar>(op_name + "key", key_list, index);
ValuePtr keyPtr = key->BuildValue();
MS_EXCEPTION_IF_NULL(keyPtr);
if (!keyPtr->isa<StringImm>()) {
MS_LOG(EXCEPTION) << op_name << " evaluator keys should be string, but got " << keyPtr->ToString();
}
std::string key_string = GetValue<std::string>(keyPtr);
key_value.emplace_back(key_string, value_list[index]);
}
return std::make_shared<AbstractDictionary>(key_value);
}
AbstractBasePtr InferImplMakeKwarg(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: a string and an object of a subclass of AbstractBase.
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 2);
AbstractScalarPtr key = CheckArg<AbstractScalar>(op_name, args_spec_list, 0);
ValuePtr keyPtr = key->BuildValue();
if (!keyPtr->isa<StringImm>()) {
MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << keyPtr->ToString();
}
std::string key_string = GetValue<std::string>(keyPtr);
return std::make_shared<AbstractKeywordArg>(key_string, args_spec_list[1]);
}
AbstractBasePtr InferImplExtractKwarg(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: a string and a keyword.
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 2);
AbstractScalarPtr key = CheckArg<AbstractScalar>(op_name, args_spec_list, 0);
AbstractKeywordArgPtr kwarg = CheckArg<AbstractKeywordArg>(op_name, args_spec_list, 1);
ValuePtr key_value = key->BuildValue();
if (!key_value->isa<StringImm>()) {
MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << key_value->ToString();
}
std::string key_input = GetValue<std::string>(key_value);
std::string key_actual = kwarg->get_key();
if (key_actual != key_input) {
MS_LOG(EXCEPTION) << op_name << " evaluator input key should be same as AbstractKeywordArg' key, but input is "
<< key_input << ", AbstractKeywordArg' key is " << key_actual;
}
return kwarg->get_arg();
}
AbstractBasePtr InferImplMakeSlice(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: three scalars whose value is an int32 number.
CheckArgsSize(primitive->name(), args_spec_list, 3);
size_t args_size = args_spec_list.size();
for (size_t index = 0; index < args_size; index++) {
MS_EXCEPTION_IF_NULL(args_spec_list[index]);
if (!args_spec_list[index]->isa<AbstractScalar>() && !args_spec_list[index]->isa<AbstractNone>()) {
MS_EXCEPTION(TypeError) << "MakeSlice eval " << index << " parameter is neither AbstractScalar nor AbstractNone.";
}
if (args_spec_list[index]->isa<AbstractScalar>() &&
!dyn_cast<AbstractScalar>(args_spec_list[index])->BuildValue()->isa<Int32Imm>()) {
MS_EXCEPTION(TypeError) << "MakeSlice eval " << index
<< " parameter is an AbstractScalar, but is not an int32 number.";
}
}
// Slice: start, end, step
return std::make_shared<AbstractSlice>(args_spec_list[0], args_spec_list[1], args_spec_list[2]);
}
template <typename T>
AbstractBasePtr InferTupleOrListGetItem(const std::string &op_name, const AbstractBasePtrList &args_spec_list) {
// Inputs: a tuple or list and a scalar whose value is an int32 number.
CheckArgsSize(op_name, args_spec_list, 2);
auto queue = CheckArg<T>(op_name, args_spec_list, 0);
AbstractScalarPtr index = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
ValuePtr index_value = index->BuildValue();
if (!index_value->isa<Int32Imm>()) {
// when index_value is an AnyValue and args_spec_list[0] is a scalar, try to return the type of the first element
// and continue
if (dyn_cast<AbstractScalar>(queue->elements()[0]) != nullptr) {
return std::make_shared<AbstractScalar>(queue->elements()[0]->BuildType());
}
MS_EXCEPTION(IndexError) << op_name << " evaluator index should be an int32 number, but got "
<< index_value->ToString();
}
int idx_v = GetValue<int>(index_value);
std::size_t nelems = queue->elements().size();
if (idx_v >= SizeToInt(nelems) || idx_v < -SizeToInt(nelems)) {
MS_EXCEPTION(IndexError) << op_name << " evaluator index should be in range[-" << SizeToInt(nelems) << ", "
<< SizeToInt(nelems) << "), but got " << idx_v << ".";
}
std::size_t uidx_v = 0;
if (idx_v >= 0) {
uidx_v = IntToSize(idx_v);
} else {
uidx_v = IntToSize(idx_v + SizeToInt(nelems));
}
return queue->elements()[uidx_v];
}
template <typename T>
AbstractBasePtr InferTupleOrListSetItem(const std::string &op_name, const AbstractBasePtrList &args_spec_list) {
// Inputs: a tuple or list, a scalar whose value is an int32 number and an object of a subclass of AbstractBase.
CheckArgsSize(op_name, args_spec_list, 3);
auto queue = CheckArg<T>(op_name, args_spec_list, 0);
AbstractScalarPtr index = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
ValuePtr index_value = index->BuildValue();
if (!index_value->isa<Int32Imm>()) {
MS_EXCEPTION(IndexError) << op_name << " evaluator index should be an int32 number, but got "
<< index_value->ToString();
}
int idx_v = GetValue<int>(index_value);
if (idx_v < 0) {
MS_EXCEPTION(IndexError) << "The index of " << typeid(T).name() << " should be positive number, but got " << idx_v
<< ".";
}
size_t uidx_v = IntToSize(idx_v);
AbstractBasePtrList elements = queue->elements();
std::size_t nelems = elements.size();
if (uidx_v >= nelems) {
MS_EXCEPTION(IndexError) << op_name << " evaluator the index: " << uidx_v << " to set out of range: " << nelems - 1
<< ".";
}
elements[uidx_v] = args_spec_list[2];
return std::make_shared<T>(elements);
}
AbstractBasePtr InferImplTupleGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
return InferTupleOrListGetItem<AbstractTuple>(primitive->name(), args_spec_list);
}
AbstractBasePtr InferImplListGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
return InferTupleOrListGetItem<AbstractList>(primitive->name(), args_spec_list);
}
AbstractBasePtr InferImplTupleSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
return InferTupleOrListSetItem<AbstractTuple>(primitive->name(), args_spec_list);
}
AbstractBasePtr InferImplListSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
return InferTupleOrListSetItem<AbstractList>(primitive->name(), args_spec_list);
}
AbstractBasePtr InferImplDictGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: a dict and a scalar whose value is a string.
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 2);
AbstractDictionaryPtr dict = CheckArg<AbstractDictionary>(op_name, args_spec_list, 0);
AbstractScalarPtr key = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
ValuePtr key_value = key->BuildValue();
if (!key_value->isa<StringImm>()) {
MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << key_value->ToString();
}
auto key_str = GetValue<std::string>(key_value);
std::vector<AbstractAttribute> dict_elems = dict->elements();
auto it = std::find_if(dict_elems.begin(), dict_elems.end(),
[key_str](const AbstractAttribute &item) { return item.first == key_str; });
if (it == dict_elems.end()) {
MS_LOG(EXCEPTION) << "The key " << key_str << " does not exist in the dict:" << args_spec_list[0]->ToString();
}
return it->second;
}
AbstractBasePtr InferImplDictSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: a dict and a scalar whose value is a string and an object of a subclass of AbstractBase.
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 3);
AbstractDictionaryPtr dict = CheckArg<AbstractDictionary>(op_name, args_spec_list, 0);
AbstractScalarPtr key = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
ValuePtr key_value = key->BuildValue();
if (!key_value->isa<StringImm>()) {
MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << key_value->ToString();
}
std::string key_str = GetValue<std::string>(key_value);
std::vector<AbstractAttribute> dict_elems = dict->elements();
auto it = std::find_if(dict_elems.begin(), dict_elems.end(),
[key_str](const AbstractAttribute &item) { return item.first == key_str; });
MS_EXCEPTION_IF_NULL(args_spec_list[2]);
auto new_ele = std::make_pair(key_str, args_spec_list[2]);
if (it != dict_elems.end()) {
int index = it - dict_elems.begin();
dict_elems[IntToSize(index)] = new_ele;
} else {
dict_elems.push_back(new_ele);
}
return std::make_shared<AbstractDictionary>(dict_elems);
}
AbstractBasePtr InferImplListAppend(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: a list and an object of a subclass of AbstractBase.
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 2);
AbstractListPtr list = CheckArg<AbstractList>(op_name, args_spec_list, 0);
(void)AbstractJoin(list->elements());
return list;
}
AbstractBasePtr InferImplTupleLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
return InferTupleOrListOrDictLen<AbstractTuple>(primitive->name(), args_spec_list);
}
AbstractBasePtr InferImplListLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
return InferTupleOrListOrDictLen<AbstractList>(primitive->name(), args_spec_list);
}
AbstractBasePtr InferImplArrayLen(const AnalysisEnginePtr &, const PrimitivePtr &,
const AbstractBasePtrList &args_spec_list) {
return std::make_shared<AbstractScalar>(kAnyValue, kInt32);
}
} // namespace abstract
} // namespace mindspore
/**
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
*
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "abstract/primitive_infer_map.h"
#include "abstract/abstract_function.h"
#include "abstract/infer_functions.h"
namespace mindspore {
namespace abstract {
PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
static PrimitiveEvalImplMap prim_eval_implement_map = {
// Statements
{prim::kPrimReturn, {InferImplReturn, true}},
{prim::kPrimDot, {InferImplDot, true}},
{prim::kPrimSwitch, {InferImplSwitch, true}},
{prim::kPrimSwitchLayer, {InferImplSwitchLayer, true}},
{prim::kPrimIs_, {InferImplIs_, true}},
{prim::kPrimIsNot, {InferImplIsNot, true}},
{prim::kPrimInDict, {InferImplInDict, true}},
{prim::kPrimNotInDict, {InferImplNotInDict, true}},
{prim::kPrimIsConsant, {InferImplIsConstant, true}},
// Maths
{prim::kPrimMaximumGrad, {InferImplMinOrMaxGrad, true}},
{prim::kPrimMinimumGrad, {InferImplMinOrMaxGrad, true}},
// Array
{prim::kPrimScalarToArray, {InferImplScalarToArray, true}},
{prim::kPrimArrayToScalar, {InferImplArrayToScalar, true}},
{prim::kPrimBroadcastShape, {InferImplBroadCastShape, true}},
{prim::kPrimPack, {InferImplPack, true}},
{prim::kPrimUnique, {InferImplUnique, true}},
{prim::kPrimUniqueGrad, {InferImplUniqueGrad, true}},
// Structure
{prim::kPrimMakeTuple, {InferImplMakeTuple, true}},
{prim::kPrimMakeList, {InferImplMakeList, true}},
{prim::kPrimMakeDict, {InferImplMakeDict, true}},
{prim::kPrimMakeSlice, {InferImplMakeSlice, true}},
{prim::kPrimMakeKeywordArg, {InferImplMakeKwarg, true}},
{prim::kPrimExtractKeywordArg, {InferImplExtractKwarg, true}},
{prim::kPrimTupleGetItem, {InferImplTupleGetItem, true}},
{prim::kPrimListGetItem, {InferImplListGetItem, true}},
{prim::kPrimTupleSetItem, {InferImplTupleSetItem, true}},
{prim::kPrimListSetItem, {InferImplListSetItem, true}},
{prim::kPrimDictGetItem, {InferImplDictGetItem, true}},
{prim::kPrimDictSetItem, {InferImplDictSetItem, true}},
{prim::kPrimListAppend, {InferImplListAppend, true}},
{prim::kPrimTupleLen, {InferImplTupleLen, true}},
{prim::kPrimListLen, {InferImplListLen, true}},
{prim::kPrimArrayLen, {InferImplArrayLen, true}},
// NN
{prim::kPrimPooling, {InferImplPooling, true}},
{prim::kPrimPoolingGrad, {InferImplPoolingGrad, true}},
{prim::kPrimFusedBatchNorm, {InferImplFusedBatchNorm, true}},
{prim::kPrimFusedBatchNormGrad, {InferImplFusedBatchNormGrad, true}},
{prim::kPrimReluGrad, {InferImplReluGrad, true}},
{prim::kPrimConv2DBackpropInput, {InferImplConv2DBackpropInput, true}},
{prim::kPrimConv2DBackpropFilter, {InferImplConv2DBackpropFilter, true}},
{prim::kPrimBiasAddGrad, {InferImplBiasAddGrad, true}},
{prim::kPrimRelu, {InferImplRelu, true}},
{prim::kPrimZerosLike, {InferImplZerosLike, true}},
{prim::kPrimBpropCut, {InferImplBpropCut, true}},
{prim::kPrimLayerNorm, {InferImplLayerNorm, true}},
{prim::kPrimLayerNormGrad, {InferImplLayerNormGrad, true}},
{prim::kPrimDropoutGenMask, {InferImplDropoutGenMask, true}},
// Others
{prim::kPrimIdentity, {InferImplIdentity, true}},
// Set impl to null as it will use PartialEvaluator;
{prim::kPrimPartial, {nullptr, true}},
{prim::kPrimEnvGetItem, {InferImplEnvGetItem, true}},
{prim::kPrimEnvSetItem, {InferImplEnvSetItem, true}},
{prim::kPrimEnvAdd, {InferImplEnvAdd, true}},
{prim::kPrimMakeRefKey, {InferImplMakeRefKey, true}},
{prim::kPrimMakeRef, {InferImplMakeRef, true}},
{prim::kPrimGetRefKey, {InferImplGetRefKey, true}},
{prim::kPrimGetRefValue, {InferImplGetRefValue, true}},
{prim::kPrimStateSetItem, {InferImplStateSetItem, true}},
{prim::kPrimDepend, {InferImplDepend, true}},
{prim::kPrimControlDepend, {InferImplControlDepend, true}},
// Debug
{prim::kPrimDebug, {InferImplDebug, true}},
// SparseTensor
{prim::kPrimMakeSparseTensor, {InferImplMakeSparseTensor, true}},
{prim::kPrimSparseTensorGetValues, {InferImplSparseTensorGetValues, true}},
{prim::kPrimSparseTensorGetIndices, {InferImplSparseTensorGetIndices, true}},
{prim::kPrimSparseTensorGetDenseShape, {InferImplSparseTensorGetDenseShape, true}},
// RowTensor
{prim::kPrimMakeRowTensor, {InferImplMakeRowTensor, true}},
{prim::kPrimRowTensorGetValues, {InferImplRowTensorGetValues, true}},
{prim::kPrimRowTensorGetIndices, {InferImplRowTensorGetIndices, true}},
{prim::kPrimRowTensorGetDenseShape, {InferImplRowTensorGetDenseShape, true}},
};
return prim_eval_implement_map;
}
void RegisterStandardPrimitiveImpl(const PrimitivePtr &primitive, const StandardPrimitiveImplReg &impl_reg) {
auto &prim_eval_map = GetPrimitiveToEvalImplMap();
prim_eval_map[primitive] = impl_reg;
}
} // namespace abstract
} // namespace mindspore
/**
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
*
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_ABSTRACT_PRIMITIVE_INFER_MAP_H_
#define MINDSPORE_CORE_ABSTRACT_PRIMITIVE_INFER_MAP_H_
#include <unordered_map>
#include "ir/primitive.h"
#include "base/core_ops.h"
#include "abstract/abstract_value.h"
namespace mindspore {
namespace abstract {
using StandardPrimitiveEvalImpl = AbstractBasePtr (*)(const abstract::AnalysisEnginePtr &, const PrimitivePtr &,
const AbstractBasePtrList &);
struct StandardPrimitiveImplReg {
StandardPrimitiveEvalImpl impl_; // Implement function of Primitive.
bool in_white_list_; // true if this Primitive in white list, else false.
};
using PrimitiveEvalImplMap =
std::unordered_map<PrimitivePtr, StandardPrimitiveImplReg, PrimitiveHasher, PrimitiveEqual>;
PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap();
void RegisterStandardPrimitiveImpl(const PrimitivePtr &primitive, const StandardPrimitiveImplReg &impl_reg);
class RegisterStandardPrimitiveEvalHelper {
public:
RegisterStandardPrimitiveEvalHelper(const PrimitivePtr &primitive, const StandardPrimitiveEvalImpl &impl) {
const StandardPrimitiveImplReg impl_reg{impl, true};
RegisterStandardPrimitiveImpl(primitive, impl_reg);
}
~RegisterStandardPrimitiveEvalHelper() = default;
};
#define REGISTER_PRIMITIVE_EVAL_IMPL(name, primitive, impl) \
static auto helper_##name = RegisterStandardPrimitiveEvalHelper(primitive, impl)
} // namespace abstract
} // namespace mindspore
#endif // MINDSPORE_CORE_ABSTRACT_PRIMITIVE_INFER_MAP_H_
......@@ -246,6 +246,25 @@ inline const PrimitivePtr kPrimNotInDict = std::make_shared<Primitive>("not_in_d
inline const PrimitivePtr kPrimIsConsant = std::make_shared<Primitive>("is_constant");
inline const PrimitivePtr kPrimEquivFormat = std::make_shared<Primitive>("EquivFormat");
// Structures
inline const PrimitivePtr kPrimMakeList = std::make_shared<Primitive>("make_list");
inline const PrimitivePtr kPrimMakeKeywordArg = std::make_shared<Primitive>("make_keyword_arg");
inline const PrimitivePtr kPrimListGetItem = std::make_shared<Primitive>("list_getitem");
inline const PrimitivePtr kPrimListSetItem = std::make_shared<Primitive>("list_setitem");
inline const PrimitivePtr kPrimDictGetItem = std::make_shared<Primitive>("dict_getitem");
inline const PrimitivePtr kPrimDictSetItem = std::make_shared<Primitive>("dict_setitem");
inline const PrimitivePtr kPrimListAppend = std::make_shared<Primitive>("list_append");
inline const PrimitivePtr kPrimListLen = std::make_shared<Primitive>("list_len");
// Other miscellaneous
inline const PrimitivePtr kPrimEnvSetItem = std::make_shared<Primitive>("env_setitem");
inline const PrimitivePtr kPrimEnvGetItem = std::make_shared<Primitive>("env_getitem");
inline const PrimitivePtr kPrimEnvAdd = std::make_shared<Primitive>("env_add");
inline const PrimitivePtr kPrimMakeRefKey = std::make_shared<Primitive>("MakeRefKey");
inline const PrimitivePtr kPrimGetRefKey = std::make_shared<Primitive>("get_ref_key");
inline const PrimitivePtr kPrimMakeRef = std::make_shared<Primitive>("make_ref");
inline const PrimitivePtr kPrimGetRefValue = std::make_shared<Primitive>("get_ref_value");
// Other primitve not used by backend but used in core;
inline const PrimitivePtr kPrimStateSetItem = std::make_shared<Primitive>("state_setitem");
inline const PrimitivePtr kPrimJ = std::make_shared<Primitive>("J");
......
......@@ -26,87 +26,19 @@
namespace mindspore {
namespace {
using PrimConv2dPtr = std::shared_ptr<Conv2d>;
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto conv_prim = primitive->cast<PrimConv2dPtr>();
MS_EXCEPTION_IF_NULL(conv_prim);
auto prim_name = conv_prim->name();
CheckAndConvertUtils::CheckInRange("Conv2d Infer", input_args.size(), kIncludeLeft, {2, 3}, prim_name);
auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShape("w_shape", input_args[0]->GetShapeTrack(), prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[1]->GetShapeTrack(), prim_name);
CheckAndConvertUtils::CheckInteger("weight rank", w_shape.size(), kEqual, 4, prim_name);
CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kEqual, 4, prim_name);
CheckAndConvertUtils::Check("x_shape[1] / group", x_shape[1] / conv_prim->GetGroup(), kEqual, "w_shape[1]",
w_shape[1], conv_prim->name());
auto out_channel = conv_prim->GetOutputChannel();
CheckAndConvertUtils::Check("out_channel", out_channel, kEqual, "w_shape[0]", w_shape[0], conv_prim->name());
std::vector<int> temp_w;
std::copy(w_shape.begin() + 2, w_shape.end(), std::back_inserter(temp_w));
CheckAndConvertUtils::Check("kernel_size", conv_prim->GetKernelSize(), kEqual, "w_shape[2:4]", temp_w,
conv_prim->name());
auto kernel_size_h = w_shape[2];
auto kernel_size_w = w_shape[3];
auto stride = conv_prim->GetStride();
auto dilation = conv_prim->GetDilation();
auto stride_h = stride[2];
auto stride_w = stride[3];
auto dilation_h = dilation[2];
auto dilation_w = dilation[3];
int h_out = -1;
int w_out = -1;
std::vector<int> pad_list(4, 0);
auto pad_mode = conv_prim->GetPadMode();
if (pad_mode == "valid") {
h_out = ceil((x_shape[2] - dilation_h * (kernel_size_h - 1)) / stride_h);
w_out = ceil((x_shape[3] - dilation_w * (kernel_size_w - 1)) / stride_w);
} else if (pad_mode == "same") {
h_out = ceil(x_shape[2] / stride_h);
w_out = ceil(x_shape[3] / stride_w);
auto pad_needed_h = std::max(0, (h_out - 1) * stride_h + dilation_h * (kernel_size_h - 1) + 1 - x_shape[2]);
pad_list.emplace_back(floor(pad_needed_h / 2));
pad_list.emplace_back(pad_needed_h / 2);
auto pad_needed_w = std::max(0, (w_out - 1) * stride_w + dilation_w * (kernel_size_w - 1) + 1 - x_shape[3]);
auto pad_left = floor(pad_needed_w / 2);
pad_list.emplace_back(pad_left);
pad_list.emplace_back(pad_needed_h - pad_left);
} else if (pad_mode == "pad") {
std::copy(conv_prim->GetPad().begin(), conv_prim->GetPad().end(), std::back_inserter(pad_list));
auto pad_top = conv_prim->GetPad()[0];
auto pad_bottom = conv_prim->GetPad()[1];
auto pad_right = conv_prim->GetPad()[2];
auto pad_left = conv_prim->GetPad()[3];
h_out = 1 + (x_shape[2] + pad_top + pad_bottom - kernel_size_h - (kernel_size_h - 1) * (dilation_h - 1)) / stride_h;
w_out = 1 + (x_shape[3] + pad_left + pad_right - kernel_size_w - (kernel_size_w - 1) * (dilation_w - 1)) / stride_w;
h_out = floor(h_out);
w_out = floor(w_out);
}
conv_prim->SetPadList(pad_list);
std::vector<int> out_shape = {x_shape[0], out_channel, h_out, w_out};
return std::make_shared<abstract::Shape>(out_shape);
}
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
CheckAndConvertUtils::CheckInRange("", input_args.size(), kIncludeLeft, {2, 3}, prim->name());
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
auto x_type = CheckAndConvertUtils::ConvertTypePtrToTypeId("x_dtype", input_args[0]->GetTypeTrack(), prim->name());
const std::set<TypeId> valid_types = {kNumberTypeInt8, kNumberTypeInt32, kNumberTypeFloat16, kNumberTypeFloat32};
std::map<std::string, TypePtr> types;
types.emplace("x", input_args[0]->GetTypeTrack());
types.emplace("w", input_args[1]->GetTypeTrack());
CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
if (x_type == kNumberTypeInt8) {
return std::make_shared<TensorType>(TypeIdToType(kNumberTypeInt32));
}
return std::make_shared<TensorType>(TypeIdToType(x_type));
}
constexpr auto kKernelSize = "kernel_size";
constexpr auto kStride = "stride";
constexpr auto kDilation = "dilation";
constexpr auto kPadMode = "pad_mode";
constexpr auto kPad = "pad";
constexpr auto kMode = "mode";
constexpr auto kGroup = "group";
constexpr auto kOutputChannel = "output channel";
constexpr auto kPadList = "pad_list";
constexpr auto kConv2DName = "Conv2D";
} // namespace
Conv2d::Conv2d() : PrimitiveC(kConv2DName) { InitIOName({"x", "w"}, {"output"}); }
void Conv2d::Init(int out_channel, const std::vector<int> &kernel_size, int mode, const std::string &pad_mode,
const std::vector<int> &pad, const std::vector<int> &stride, const std::vector<int> &dilation,
int group) {
......@@ -130,10 +62,47 @@ void Conv2d::Init(int out_channel, const std::vector<int> &kernel_size, int mode
this->SetOutChannel(CheckAndConvertUtils::CheckInteger("out_channel", out_channel, kGreaterThan, 0, prim_name));
this->SetGroup(CheckAndConvertUtils::CheckInteger("group", group, kGreaterThan, 0, prim_name));
}
std::vector<int> Conv2d::GetKernelSize() const {
auto value_ptr = GetAttr(kKernelSize);
return GetValue<std::vector<int>>(value_ptr);
}
std::vector<int> Conv2d::GetStride() const {
auto value_ptr = GetAttr(kStride);
return GetValue<std::vector<int>>(value_ptr);
}
std::vector<int> Conv2d::GetDilation() const {
auto value_ptr = GetAttr(kDilation);
return GetValue<std::vector<int>>(value_ptr);
}
std::string Conv2d::GetPadMode() const {
auto value_ptr = this->GetAttr(kPadMode);
return GetValue<string>(value_ptr);
}
std::vector<int> Conv2d::GetPad() const {
auto value_ptr = this->GetAttr(kPad);
return GetValue<std::vector<int>>(value_ptr);
}
int Conv2d::GetMode() const {
auto value_ptr = this->GetAttr(kMode);
return GetValue<int>(value_ptr);
}
AbstractBasePtr Conv2dInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
InferShape(primitive, input_args)->shape());
int Conv2d::GetGroup() const {
auto value_ptr = this->GetAttr(kGroup);
return GetValue<int>(value_ptr);
}
int Conv2d::GetOutputChannel() const {
auto value_ptr = this->GetAttr(kOutputChannel);
return GetValue<int>(value_ptr);
}
void Conv2d::SetKernelSize(const std::vector<int> &kernel_size) { this->AddAttr(kKernelSize, MakeValue(kernel_size)); }
void Conv2d::SetStride(const std::vector<int> &stride) { this->AddAttr(kStride, MakeValue(stride)); }
void Conv2d::SetDilation(const std::vector<int> &dilation) { this->AddAttr(kDilation, MakeValue(dilation)); }
void Conv2d::SetPadMode(const std::string &pad_mode) { this->AddAttr(kPadMode, MakeValue(pad_mode)); }
void Conv2d::SetPad(const std::vector<int> &pad) { this->AddAttr(kPad, MakeValue(pad)); }
void Conv2d::SetMode(int mode) { this->AddAttr(kMode, MakeValue(mode)); }
void Conv2d::SetGroup(int group) { this->AddAttr(kGroup, MakeValue(group)); }
void Conv2d::SetOutChannel(int output_channel) { this->AddAttr(kOutputChannel, MakeValue(output_channel)); }
void Conv2d::SetPadList(const std::vector<int> &pad_list) { this->AddAttr(kPadList, MakeValue(pad_list)); }
} // namespace mindspore
......@@ -16,79 +16,44 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_C_OPS_CONV2D_H
#define MINDSPORE_CORE_C_OPS_CONV2D_H
#ifndef MINDSPORE_CORE_C_OPS_CONV2D_H_
#define MINDSPORE_CORE_C_OPS_CONV2D_H_
#include <map>
#include <vector>
#include <string>
#include <memory>
#include "c_ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
class Conv2d : public PrimitiveC {
public:
Conv2d() : PrimitiveC(kConv2DName) { InitIOName({"x", "w"}, {"output"}); }
Conv2d();
void Init(int out_channel, const std::vector<int> &kernel_size, int mode = 1, const std::string &pad_mode = "valid",
const std::vector<int> &pad = {0, 0, 0, 0}, const std::vector<int> &stride = {1, 1, 1, 1},
const std::vector<int> &dilation = {1, 1, 1, 1}, int group = 1);
std::vector<int> GetKernelSize() const {
auto value_ptr = this->GetAttr(kKernelSize);
return GetValue<std::vector<int>>(value_ptr);
}
std::vector<int> GetStride() const {
auto value_ptr = GetAttr(kStride);
return GetValue<std::vector<int>>(value_ptr);
}
std::vector<int> GetDilation() const {
auto value_ptr = GetAttr(kDilation);
return GetValue<std::vector<int>>(value_ptr);
}
std::string GetPadMode() const {
auto value_ptr = this->GetAttr(kPadMode);
return GetValue<string>(value_ptr);
}
std::vector<int> GetPad() const {
auto value_ptr = this->GetAttr(kPad);
return GetValue<std::vector<int>>(value_ptr);
}
int GetMode() const {
auto value_ptr = this->GetAttr(kMode);
return GetValue<int>(value_ptr);
}
int GetGroup() const {
auto value_ptr = this->GetAttr(kGroup);
return GetValue<int>(value_ptr);
}
int GetOutputChannel() const {
auto value_ptr = this->GetAttr(kOutputChannel);
return GetValue<int>(value_ptr);
}
void SetKernelSize(const std::vector<int> &kernel_size) { this->AddAttr(kKernelSize, MakeValue(kernel_size)); }
void SetStride(const std::vector<int> &stride) { this->AddAttr(kStride, MakeValue(stride)); }
void SetDilation(const std::vector<int> &dilation) { this->AddAttr(kDilation, MakeValue(dilation)); }
void SetPadMode(const std::string &pad_mode) { this->AddAttr(kPadMode, MakeValue(pad_mode)); }
void SetPad(const std::vector<int> &pad) { this->AddAttr(kPad, MakeValue(pad)); }
void SetMode(int mode) { this->AddAttr(kMode, MakeValue(mode)); }
void SetGroup(int group) { this->AddAttr(kGroup, MakeValue(group)); }
void SetOutChannel(int output_channel) { this->AddAttr(kOutputChannel, MakeValue(output_channel)); }
void SetPadList(const std::vector<int> &pad_list) { this->AddAttr(kPadList, MakeValue(pad_list)); }
private:
inline static const string kKernelSize = "kernel_size";
inline static const string kStride = "stride";
inline static const string kDilation = "dilation";
inline static const string kPadMode = "pad_mode";
inline static const string kPad = "pad";
inline static const string kMode = "mode";
inline static const string kGroup = "group";
inline static const string kOutputChannel = "output channel";
inline static const string kPadList = "pad_list";
inline static const string kConv2DName = "Conv2D";
std::vector<int> GetKernelSize() const;
std::vector<int> GetStride() const;
std::vector<int> GetDilation() const;
std::string GetPadMode() const;
std::vector<int> GetPad() const;
int GetMode() const;
int GetGroup() const;
int GetOutputChannel() const;
void SetKernelSize(const std::vector<int> &kernel_size);
void SetStride(const std::vector<int> &stride);
void SetDilation(const std::vector<int> &dilation);
void SetPadMode(const std::string &pad_mode);
void SetPad(const std::vector<int> &pad);
void SetMode(int mode);
void SetGroup(int group);
void SetOutChannel(int output_channel);
void SetPadList(const std::vector<int> &pad_list);
};
AbstractBasePtr Conv2dInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimConv2dPtr = std::shared_ptr<Conv2d>;
} // namespace mindspore
#endif // MINDSPORE_CORE_C_OPS_CONV2D_H
#endif // MINDSPORE_CORE_C_OPS_CONV2D_H_
......@@ -16,8 +16,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_C_OPS_PRIMITIVE_C_H
#define MINDSPORE_CORE_C_OPS_PRIMITIVE_C_H
#ifndef MINDSPORE_CORE_C_OPS_PRIMITIVE_C_H_
#define MINDSPORE_CORE_C_OPS_PRIMITIVE_C_H_
#include <string>
#include <vector>
#include "ir/primitive.h"
......@@ -25,7 +25,7 @@
namespace mindspore {
class PrimitiveC : public Primitive {
public:
explicit PrimitiveC(const std::string &name) : Primitive(name) { attrs_ = {}; }
explicit PrimitiveC(const std::string &name) : Primitive(name) {}
protected:
void InitIOName(const std::vector<std::string> &inputs_name, const std::vector<std::string> &outputs_name) {
......@@ -34,4 +34,4 @@ class PrimitiveC : public Primitive {
}
};
} // namespace mindspore
#endif // MINDSPORE_CORE_C_OPS_PRIMITIVE_C_H
#endif // MINDSPORE_CORE_C_OPS_PRIMITIVE_C_H_
......@@ -61,6 +61,7 @@ file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
"../../../mindspore/core/abstract/*.cc"
"../../../mindspore/core/ir/*.cc"
"../../../mindspore/core/utils/*.cc"
"../../../mindspore/core/c_ops/*.cc"
"../../../mindspore/ccsrc/common/*.cc"
"../../../mindspore/ccsrc/utils/*.cc"
"../../../mindspore/ccsrc/pipeline/jit/parse/*.cc"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册