提交 be3d79cb 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!4204 add dynamic shape support for GatherV2 and others

Merge pull request !4204 from fary86/adapt_primitive_dynamic_shape
......@@ -49,22 +49,6 @@ using mindspore::parse::PyObjectWrapper;
std::unordered_set<std::string> prims_to_skip_undetermined_infer{"make_tuple", "make_list", "switch", "env_setitem",
"env_getitem"};
EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) {
if (prims_to_skip_undetermined_infer.find(prim_->name()) == prims_to_skip_undetermined_infer.end()) {
auto ret_abstract = AbstractEval(args);
if (ret_abstract != nullptr) {
MS_LOG(DEBUG) << "StandardPrimEvaluator eval Undetermined";
return ret_abstract;
}
}
prim_->BeginRecordAddAttr();
AbstractBasePtr abs_base = eval_impl_(engine, prim_, args);
prim_->EndRecordAddAttr();
auto added_attrs = prim_->evaluate_added_attrs();
auto infer_result = std::make_shared<EvalResult>(abs_base, std::make_shared<AttrValueMap>(added_attrs));
return infer_result;
}
EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
AnfNodeConfigPtr out_conf) {
AbstractBasePtrList args_spec_list;
......@@ -289,45 +273,45 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) {
py::dict dic;
if (abs_base->isa<AbstractTensor>()) {
auto arg_tensor = dyn_cast<AbstractTensor>(abs_base);
dic["shape"] = arg_tensor->shape()->shape();
dic[ATTR_SHAPE] = arg_tensor->shape()->shape();
if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) {
const auto &min_shape = arg_tensor->shape()->min_shape();
const auto &max_shape = arg_tensor->shape()->max_shape();
if (!min_shape.empty() && !max_shape.empty()) {
dic["min_shape"] = min_shape;
dic["max_shape"] = max_shape;
dic[ATTR_MIN_SHAPE] = min_shape;
dic[ATTR_MAX_SHAPE] = max_shape;
}
}
dic["dtype"] = arg_tensor->BuildType();
dic["value"] = BuildValue(arg_tensor->BuildValue());
dic[ATTR_DTYPE] = arg_tensor->BuildType();
dic[ATTR_VALUE] = BuildValue(arg_tensor->BuildValue());
} else if (abs_base->isa<AbstractRowTensor>()) {
auto arg = dyn_cast<AbstractRowTensor>(abs_base);
dic["shape"] = arg->shape()->shape();
dic["dtype"] = arg->BuildType();
dic["value"] = BuildValue(arg->BuildValue());
dic[ATTR_SHAPE] = arg->shape()->shape();
dic[ATTR_DTYPE] = arg->BuildType();
dic[ATTR_VALUE] = BuildValue(arg->BuildValue());
} else if (abs_base->isa<AbstractSparseTensor>()) {
auto arg = dyn_cast<AbstractSparseTensor>(abs_base);
dic["shape"] = arg->shape()->shape();
dic["dtype"] = arg->BuildType();
dic["value"] = BuildValue(arg->BuildValue());
dic[ATTR_SHAPE] = arg->shape()->shape();
dic[ATTR_DTYPE] = arg->BuildType();
dic[ATTR_VALUE] = BuildValue(arg->BuildValue());
} else if (abs_base->isa<AbstractScalar>() || abs_base->isa<AbstractType>() || abs_base->isa<AbstractRefKey>()) {
ShapeVector shape;
dic["shape"] = shape;
dic["dtype"] = abs_base->BuildType();
dic["value"] = BuildValue(abs_base->BuildValue());
dic[ATTR_SHAPE] = shape;
dic[ATTR_DTYPE] = abs_base->BuildType();
dic[ATTR_VALUE] = BuildValue(abs_base->BuildValue());
} else if (abs_base->isa<AbstractSlice>()) {
auto arg_slice = dyn_cast<AbstractSlice>(abs_base);
ShapeVector shape;
dic["shape"] = shape;
dic["dtype"] = arg_slice->BuildType();
dic["value"] = BuildValue(arg_slice->BuildValue());
dic[ATTR_SHAPE] = shape;
dic[ATTR_DTYPE] = arg_slice->BuildType();
dic[ATTR_VALUE] = BuildValue(arg_slice->BuildValue());
} else if (abs_base->isa<AbstractRef>()) {
auto value = abs_base->cast<AbstractRefPtr>()->ref();
dic = ConvertAbstractToPython(value);
} else if (abs_base->isa<AbstractEllipsis>()) {
dic["shape"] = py::none();
dic["dtype"] = py::ellipsis();
dic["value"] = py::ellipsis();
dic[ATTR_SHAPE] = py::none();
dic[ATTR_DTYPE] = py::ellipsis();
dic[ATTR_VALUE] = py::ellipsis();
} else if (abs_base->isa<AbstractTuple>()) {
auto arg_tuple = dyn_cast<AbstractTuple>(abs_base);
size_t len = arg_tuple->size();
......@@ -336,12 +320,12 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) {
for (size_t i = 0; i < len; i++) {
py::dict out = ConvertAbstractToPython(arg_tuple->elements()[i]);
shape_tuple[i] = out["shape"];
dtype_tuple[i] = out["dtype"];
shape_tuple[i] = out[ATTR_SHAPE];
dtype_tuple[i] = out[ATTR_DTYPE];
}
dic["shape"] = shape_tuple;
dic["dtype"] = dtype_tuple;
dic["value"] = BuildValue(arg_tuple->BuildValue());
dic[ATTR_SHAPE] = shape_tuple;
dic[ATTR_DTYPE] = dtype_tuple;
dic[ATTR_VALUE] = BuildValue(arg_tuple->BuildValue());
} else if (abs_base->isa<AbstractList>()) {
auto arg_list = dyn_cast<AbstractList>(abs_base);
size_t len = arg_list->size();
......@@ -350,25 +334,25 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) {
for (size_t i = 0; i < len; i++) {
py::dict out = ConvertAbstractToPython(arg_list->elements()[i]);
shape_list[i] = out["shape"];
dtype_list[i] = out["dtype"];
shape_list[i] = out[ATTR_SHAPE];
dtype_list[i] = out[ATTR_DTYPE];
}
dic["shape"] = shape_list;
dic["dtype"] = dtype_list;
dic["value"] = BuildValue(arg_list->BuildValue());
dic[ATTR_SHAPE] = shape_list;
dic[ATTR_DTYPE] = dtype_list;
dic[ATTR_VALUE] = BuildValue(arg_list->BuildValue());
} else if (abs_base->isa<AbstractNone>()) {
dic["shape"] = py::none();
dic["dtype"] = py::none();
dic["value"] = py::none();
dic[ATTR_SHAPE] = py::none();
dic[ATTR_DTYPE] = py::none();
dic[ATTR_VALUE] = py::none();
} else if (abs_base->isa<AbstractFunction>()) {
dic["shape"] = py::none();
dic["dtype"] = abs_base->BuildType();
dic["value"] = py::none();
dic[ATTR_SHAPE] = py::none();
dic[ATTR_DTYPE] = abs_base->BuildType();
dic[ATTR_VALUE] = py::none();
} else if (abs_base->isa<AbstractUndetermined>()) {
auto arg = dyn_cast<AbstractUndetermined>(abs_base);
dic["shape"] = py::none();
dic["dtype"] = arg->BuildType();
dic["value"] = py::none();
dic[ATTR_SHAPE] = py::none();
dic[ATTR_DTYPE] = arg->BuildType();
dic[ATTR_VALUE] = py::none();
} else {
auto value = abs_base->BuildValue();
if ((*value == *kAnyValue)) {
......@@ -409,18 +393,20 @@ py::tuple PreparePyInputs(const PrimitivePyPtr &prim_py, const AbstractBasePtrLi
AbstractBasePtr PyInferRes2Abstract(const PrimitivePyPtr &prim_py, const py::dict &output) {
// Convert to AbstractValue based on type and shape
auto out_dtype = output["dtype"];
if (output["value"].is_none()) {
auto out_shape = output["shape"];
py::object min_shape = output.contains("min_shape") ? (py::object)output["min_shape"] : (py::object)py::none();
py::object max_shape = output.contains("max_shape") ? (py::object)output["max_shape"] : (py::object)py::none();
auto out_dtype = output[ATTR_DTYPE];
if (output[ATTR_VALUE].is_none()) {
auto out_shape = output[ATTR_SHAPE];
py::object min_shape =
output.contains(py::str(ATTR_MIN_SHAPE)) ? (py::object)output[ATTR_MIN_SHAPE] : (py::object)py::none();
py::object max_shape =
output.contains(py::str(ATTR_MAX_SHAPE)) ? (py::object)output[ATTR_MAX_SHAPE] : (py::object)py::none();
return PyListDtype2AbstractTensor(out_shape, out_dtype, min_shape, max_shape);
}
// Convert pyobject to Value, then to AbstractValue
ValuePtr converted_ret = nullptr;
TypePtr dtype = py::isinstance<Type>(out_dtype) ? out_dtype.cast<TypePtr>() : nullptr;
bool converted = parse::ConvertData(output["value"], &converted_ret, false, dtype);
bool converted = parse::ConvertData(output[ATTR_VALUE], &converted_ret, false, dtype);
if (!converted) {
MS_LOG(EXCEPTION) << "Convert data failed";
}
......@@ -447,6 +433,73 @@ AbstractBasePtr PyInferRes2Abstract(const PrimitivePyPtr &prim_py, const py::dic
}
} // end anonymous namespace
EvalResultPtr StandardPrimEvaluator::EvalPyCheckPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) {
auto prim_py = dyn_cast<PrimitivePy>(prim_);
if (prim_py == nullptr) {
MS_LOG(EXCEPTION) << "The primitive with type 'kPrimTypePyInferCheck' should be a python primitive.";
}
// Call checking method '__check__' for subclass of 'PrimitiveWithCheck'
MS_LOG(DEBUG) << "Begin input args checking for: " << prim_py->ToString();
auto py_args = PreparePyInputs(prim_py, args);
prim_py->RunCheck(py_args);
prim_->BeginRecordAddAttr();
AbstractBasePtr abs_base = eval_impl_(engine, prim_, args);
prim_->EndRecordAddAttr();
auto added_attrs = prim_->evaluate_added_attrs();
if (!py::hasattr(prim_py->GetPyObj(), PY_PRIM_METHOD_INFER_VALUE)) {
return std::make_shared<EvalResult>(abs_base, std::make_shared<AttrValueMap>(added_attrs));
}
// Call method 'infer_value' for primitive with this method for constant propagation
py::tuple py_vals(py_args.size());
for (size_t i = 0; i < py_args.size(); ++i) {
py_vals[i] = py_args[i][ATTR_VALUE];
}
py::object py_ret = prim_py->RunInferValue(py_vals);
if (py::isinstance<py::none>(py_ret)) {
return std::make_shared<EvalResult>(abs_base, std::make_shared<AttrValueMap>(added_attrs));
}
// Convert pyobject to Value, then to AbstractValue
ValuePtr converted_ret = nullptr;
TypePtr dtype = abs_base->BuildType();
bool converted = parse::ConvertData(py_ret, &converted_ret, false, dtype);
if (!converted) {
MS_LOG(EXCEPTION) << "Convert data failed";
}
auto res_spec = FromValue(converted_ret);
MS_EXCEPTION_IF_NULL(res_spec);
if (res_spec->isa<AbstractTensor>()) {
// Replace to tensor constant node in specialize
auto res_tensor = res_spec->cast<AbstractTensorPtr>();
res_tensor->set_value(converted_ret);
}
return std::make_shared<EvalResult>(res_spec, std::make_shared<AttrValueMap>(added_attrs));
}
EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) {
if (prims_to_skip_undetermined_infer.find(prim_->name()) == prims_to_skip_undetermined_infer.end()) {
auto ret_abstract = AbstractEval(args);
if (ret_abstract != nullptr) {
MS_LOG(DEBUG) << "StandardPrimEvaluator eval Undetermined";
return ret_abstract;
}
}
if (prim_->prim_type() == PrimType::kPrimTypePyInferCheck) {
return EvalPyCheckPrim(engine, args);
}
prim_->BeginRecordAddAttr();
AbstractBasePtr abs_base = eval_impl_(engine, prim_, args);
prim_->EndRecordAddAttr();
auto added_attrs = prim_->evaluate_added_attrs();
return std::make_shared<EvalResult>(abs_base, std::make_shared<AttrValueMap>(added_attrs));
}
EvalResultPtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args) {
auto ret_abstract = AbstractEval(args);
if (ret_abstract != nullptr) {
......
......@@ -42,6 +42,8 @@ class StandardPrimEvaluator : public TrivialPrimEvaluator {
std::string ToString() const override { return identifier_ + prim_->name(); }
private:
EvalResultPtr EvalPyCheckPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args);
PrimitivePtr prim_;
const StandardPrimitiveEvalImpl eval_impl_;
};
......
......@@ -308,20 +308,18 @@ void AnalysisEngine::Clear() {
namespace {
EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr &engine) {
// Custom Primitive with python infer_shape, infer_type
EvaluatorPtr evaluator = nullptr;
MS_EXCEPTION_IF_NULL(prim);
if (prim->isa<prim::DoSignaturePrimitive>()) {
evaluator = std::make_shared<DoSignatureEvaluator>(prim);
return evaluator;
return std::make_shared<DoSignatureEvaluator>(prim);
}
if (prim->isa<prim::UnpackGraphPrimitive>()) {
evaluator = std::make_shared<UnpackGraphEvaluator>(prim);
return evaluator;
return std::make_shared<UnpackGraphEvaluator>(prim);
}
if (prim->Hash() == prim::kPrimMixedPrecisionCast->Hash() && prim->name() == prim::kPrimMixedPrecisionCast->name()) {
evaluator = std::make_shared<MixedPrecisionCastEvaluator>(prim);
return evaluator;
return std::make_shared<MixedPrecisionCastEvaluator>(prim);
}
EvaluatorPtr evaluator = nullptr;
if (prim->HasPyEvaluator()) {
auto prim_py = dyn_cast<PrimitivePy>(prim);
if (prim_py != nullptr) {
......
......@@ -55,6 +55,10 @@ void ValidateOperation(const AnfNodePtr &node) {
MS_LOG(DEBUG) << "Primitive " << prim->name() << " has python evaluator.";
return;
}
if (prim->prim_type() == PrimType::kPrimTypePyInferCheck) {
MS_LOG(DEBUG) << "Primitive " << prim->name() << " has python inference checking method.";
return;
}
if (prim->name() == "fake_bprop") {
MS_LOG(EXCEPTION) << "Illegal primitive: " << GetValue<std::string>(prim->GetAttr("info"));
}
......
......@@ -254,16 +254,33 @@ py::dict PrimitivePy::RunInfer(const py::tuple &args) {
if (!HasPyObj()) {
MS_LOG(EXCEPTION) << "[" << this->ToString() << "]: pyobj is empty";
}
auto infer_fuc = python_obj_.attr("__infer__");
auto infer_fuc = python_obj_.attr(PY_PRIM_METHOD_INFER);
return infer_fuc(*args);
}
void PrimitivePy::RunCheck(const py::tuple &args) {
if (!HasPyObj()) {
MS_LOG(EXCEPTION) << "[" << this->ToString() << "]: pyobj is empty";
}
auto check_func = python_obj_.attr(PY_PRIM_METHOD_CHECK);
(void)check_func(*args);
}
py::object PrimitivePy::RunInferValue(const py::tuple &args) {
if (!HasPyObj()) {
MS_LOG(EXCEPTION) << "[" << this->ToString() << "]: pyobj is empty";
}
auto infer_value = python_obj_.attr(PY_PRIM_METHOD_INFER_VALUE);
return infer_value(*args);
}
REGISTER_PYBIND_DEFINE(Primitive_, ([](const py::module *m) {
(void)py::enum_<PrimType>(*m, "prim_type", py::arithmetic())
.value("unknown", PrimType::kPrimTypeUnknown)
.value("builtin", PrimType::kPrimTypeBuiltIn)
.value("py_infer_shape", PrimType::kPrimTypePyInferShape)
.value("user_custom", PrimType::kPrimTypeUserCustom);
.value("user_custom", PrimType::kPrimTypeUserCustom)
.value("py_infer_check", PrimType::kPrimTypePyInferCheck);
(void)py::class_<PrimitivePy, std::shared_ptr<PrimitivePy>>(*m, "Primitive_")
.def_readonly(PYTHON_PRIMITIVE_FLAG, &PrimitivePy::parse_info_)
.def(py::init<py::str &, py::object>())
......
......@@ -62,6 +62,8 @@ class PrimitivePy : public Primitive {
const bool parse_info_ = true;
const py::object &GetPyObj() const { return python_obj_; }
py::dict RunInfer(const py::tuple &args);
void RunCheck(const py::tuple &args);
py::object RunInferValue(const py::tuple &args);
bool ObjHasAttr(const char *attr_name) { return py::hasattr(python_obj_, attr_name); }
bool HasPyObj() { return python_obj_.operator bool(); }
PrimitivePtr Clone() override;
......
......@@ -81,6 +81,9 @@ AbstractBasePtr InferImplDropoutGenMask(const AnalysisEnginePtr &, const Primiti
AbstractBasePtr InferImplMinOrMaxGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplSqrt(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplScalarToArray(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplArrayToScalar(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
......@@ -176,6 +179,14 @@ AbstractBasePtr InferImplUniqueGrad(const AnalysisEnginePtr &, const PrimitivePt
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplUnique(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplGatherV2(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplDynamicShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplSparseApplyFtrl(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplSparseApplyProximalAdagrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
template <typename T>
AbstractBasePtr InferTupleOrListOrDictLen(const std::string &op_name, const AbstractBasePtrList &args_spec_list) {
......
......@@ -14,6 +14,8 @@
* limitations under the License.
*/
#include <algorithm>
#include <iterator>
#include "abstract/infer_functions.h"
#include "abstract/utils.h"
#include "abstract/param_validator.h"
......@@ -226,5 +228,60 @@ AbstractBasePtr InferImplUniqueGrad(const AnalysisEnginePtr &, const PrimitivePt
// outputs: dx
return std::make_shared<AbstractTensor>(ids->element(), ids_idx->shape());
}
AbstractBasePtr InferImplGatherV2(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
const std::string &op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 3);
AbstractTensorPtr params = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
AbstractTensorPtr indices = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
AbstractScalarPtr axis = CheckArg<AbstractScalar>(op_name, args_spec_list, 2);
auto params_shp = params->shape()->shape();
auto indices_shp = indices->shape()->shape();
auto axis_val = GetValue<int>(axis->BuildValue());
auto params_rank = static_cast<int>(params_shp.size());
if (axis_val < 0) {
axis_val += params_rank;
}
auto calc_shape = [axis_val, &params_shp](const ShapeVector &inp_vec) -> ShapeVector {
ShapeVector out_vec;
std::copy(params_shp.begin(), params_shp.begin() + axis_val, std::back_inserter(out_vec));
copy(inp_vec.begin(), inp_vec.end(), std::back_inserter(out_vec));
copy(params_shp.begin() + axis_val + 1, params_shp.end(), std::back_inserter(out_vec));
return out_vec;
};
ShapeVector out_shape = calc_shape(indices_shp);
if (!indices->shape()->min_shape().empty() && !indices->shape()->max_shape().empty()) {
ShapeVector min_shape = calc_shape(indices->shape()->min_shape());
ShapeVector max_shape = calc_shape(indices->shape()->max_shape());
return std::make_shared<AbstractTensor>(params->element(),
std::make_shared<Shape>(out_shape, min_shape, max_shape));
}
return std::make_shared<AbstractTensor>(params->element(), std::make_shared<Shape>(out_shape));
}
AbstractBasePtr InferImplDynamicShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
const std::string &op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 1);
AbstractTensorPtr input = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
auto shape = input->shape()->shape();
bool has_dyn_shape = std::any_of(shape.begin(), shape.end(), [](int dim) { return dim == Shape::SHP_ANY; });
std::vector<int> tensor_shp({static_cast<int>(shape.size())});
if (has_dyn_shape) {
auto elem = std::make_shared<AbstractScalar>(std::make_shared<AnyValue>(), std::make_shared<Int>(32));
return std::make_shared<AbstractTensor>(elem, std::make_shared<Shape>(tensor_shp));
}
auto shp_buf_size = sizeof(int) * shape.size();
auto tensor = std::make_shared<tensor::Tensor>(kNumberTypeInt32, tensor_shp, shape.data(), shp_buf_size);
return tensor->ToAbstract();
}
} // namespace abstract
} // namespace mindspore
......@@ -37,5 +37,14 @@ AbstractBasePtr InferImplMinOrMaxGrad(const AnalysisEnginePtr &, const Primitive
return std::make_shared<AbstractTuple>(AbstractBasePtrList({dx, dy}));
}
AbstractBasePtr InferImplSqrt(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: three tensors.
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 1);
auto inp = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
return inp->Clone()->Broaden();
}
} // namespace abstract
} // namespace mindspore
......@@ -445,5 +445,25 @@ AbstractBasePtr InferImplDropoutGenMask(const AnalysisEnginePtr &, const Primiti
return std::make_shared<AbstractTensor>(std::make_shared<AbstractScalar>(kAnyValue, kUInt8),
std::make_shared<Shape>(std::vector<int64_t>{shape_y}));
}
AbstractBasePtr InferImplSparseApplyFtrl(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
CheckArgsSize(primitive->name(), args_spec_list, 5);
AbstractBasePtrList elements;
for (size_t i = 0; i < 3; ++i) {
elements.push_back(args_spec_list[i]->Clone()->Broaden());
}
return std::make_shared<AbstractTuple>(elements);
}
AbstractBasePtr InferImplSparseApplyProximalAdagrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
CheckArgsSize(primitive->name(), args_spec_list, 7);
AbstractBasePtrList elements;
for (size_t i = 0; i < 2; ++i) {
elements.push_back(args_spec_list[i]->Clone()->Broaden());
}
return std::make_shared<AbstractTuple>(elements);
}
} // namespace abstract
} // namespace mindspore
......@@ -37,6 +37,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
// Maths
{prim::kPrimMaximumGrad, {InferImplMinOrMaxGrad, true}},
{prim::kPrimMinimumGrad, {InferImplMinOrMaxGrad, true}},
{prim::kPrimSqrt, {InferImplSqrt, true}},
// Array
{prim::kPrimScalarToArray, {InferImplScalarToArray, true}},
{prim::kPrimArrayToScalar, {InferImplArrayToScalar, true}},
......@@ -44,6 +45,9 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{prim::kPrimPack, {InferImplPack, true}},
{prim::kPrimUnique, {InferImplUnique, true}},
{prim::kPrimUniqueGrad, {InferImplUniqueGrad, true}},
{prim::kPrimGatherV2, {InferImplGatherV2, true}},
{prim::kPrimSparseGatherV2, {InferImplGatherV2, true}},
{prim::kPrimDynamicShape, {InferImplDynamicShape, true}},
// Structure
{prim::kPrimMakeTuple, {InferImplMakeTuple, true}},
{prim::kPrimMakeList, {InferImplMakeList, true}},
......@@ -77,6 +81,8 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{prim::kPrimLayerNorm, {InferImplLayerNorm, true}},
{prim::kPrimLayerNormGrad, {InferImplLayerNormGrad, true}},
{prim::kPrimDropoutGenMask, {InferImplDropoutGenMask, true}},
{prim::kPrimSparseApplyFtrl, {InferImplSparseApplyFtrl, true}},
{prim::kPrimSparseApplyProximalAdagrad, {InferImplSparseApplyProximalAdagrad, true}},
// Others
{prim::kPrimIdentity, {InferImplIdentity, true}},
// Set impl to null as it will use PartialEvaluator;
......
......@@ -84,6 +84,9 @@ inline const PrimitivePtr kPrimConcat = std::make_shared<Primitive>("Concat");
inline const PrimitivePtr kPrimSqueeze = std::make_shared<Primitive>("Squeeze");
inline const PrimitivePtr kPrimTranspose = std::make_shared<Primitive>("Transpose");
inline const PrimitivePtr kPrimGatherV2 = std::make_shared<Primitive>("GatherV2");
inline const PrimitivePtr kPrimSparseGatherV2 = std::make_shared<Primitive>("SparseGatherV2");
inline const PrimitivePtr kPrimShape = std::make_shared<Primitive>("Shape");
inline const PrimitivePtr kPrimDynamicShape = std::make_shared<Primitive>("DynamicShape");
inline const PrimitivePtr kPrimEmbeddingLookup = std::make_shared<Primitive>("EmbeddingLookup");
inline const PrimitivePtr kPrimEmbeddingLookupCommGrad = std::make_shared<Primitive>("EmbeddingLookupCommGrad");
inline const PrimitivePtr kPrimSize = std::make_shared<Primitive>("Size");
......@@ -154,6 +157,8 @@ inline const PrimitivePtr kPrimBpropCut = std::make_shared<Primitive>("bprop_cut
inline const PrimitivePtr kPrimFakeQuantPerLayer = std::make_shared<Primitive>("FakeQuantPerLayer");
inline const PrimitivePtr kPrimFakeQuantPerChannel = std::make_shared<Primitive>("FakeQuantPerChannel");
inline const PrimitivePtr kPrimApplyRMSProp = std::make_shared<Primitive>("ApplyRMSProp");
inline const PrimitivePtr kPrimSparseApplyFtrl = std::make_shared<Primitive>("SparseApplyFtrl");
inline const PrimitivePtr kPrimSparseApplyProximalAdagrad = std::make_shared<Primitive>("SparseApplyProximalAdagrad");
// Comm ops
inline const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator");
......
......@@ -35,7 +35,8 @@ enum PrimType {
kPrimTypeBuiltIn, // Built-in primitive operator
kPrimTypePyInferShape, // Primitive operator defined by custom
kPrimTypePyInferTensor, // Primitive operator defined by custom
kPrimTypeUserCustom
kPrimTypeUserCustom,
kPrimTypePyInferCheck // Primitive operator with input args checking method
};
class Primitive : public Named {
......
......@@ -23,4 +23,19 @@ const char GRAPH_FLAG_HAS_EFFECT[] = "has_effect";
const char GRAPH_FLAG_EFFECT_PATIAL_ORDER[] = "_effect_patial_order";
const char GRAPH_FLAG_RANDOM_EFFECT[] = "_random_effect";
const char GRAPH_FLAG_SIDE_EFFECT[] = "_side_effect";
// method names of python primitive called from c++ source code
// 1. infer method name of class 'PrimitiveWithInfer'
const char PY_PRIM_METHOD_INFER[] = "__infer__";
// 2. check method name of class 'PrimitiveWithCheck'
const char PY_PRIM_METHOD_CHECK[] = "__check__";
// 3. method name of class 'PrimitivePy' for constant propagation
const char PY_PRIM_METHOD_INFER_VALUE[] = "infer_value";
// type inference related attributes
const char ATTR_VALUE[] = "value";
const char ATTR_DTYPE[] = "dtype";
const char ATTR_SHAPE[] = "shape";
const char ATTR_MIN_SHAPE[] = "min_shape";
const char ATTR_MAX_SHAPE[] = "max_shape";
} // namespace mindspore
......@@ -23,6 +23,16 @@ extern const char GRAPH_FLAG_HAS_EFFECT[];
extern const char GRAPH_FLAG_EFFECT_PATIAL_ORDER[];
extern const char GRAPH_FLAG_RANDOM_EFFECT[];
extern const char GRAPH_FLAG_SIDE_EFFECT[];
extern const char PY_PRIM_METHOD_INFER[];
extern const char PY_PRIM_METHOD_CHECK[];
extern const char PY_PRIM_METHOD_INFER_VALUE[];
extern const char ATTR_VALUE[];
extern const char ATTR_DTYPE[];
extern const char ATTR_SHAPE[];
extern const char ATTR_MIN_SHAPE[];
extern const char ATTR_MAX_SHAPE[];
} // namespace mindspore
#endif // MINDSPORE_CORE_UTILS_FLAGS_H
......@@ -27,7 +27,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue,
SameTypeShape, ScatterAdd, ScatterSub, ScatterMul, ScatterDiv, ScatterMax, ScatterMin,
ScatterUpdate, ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select,
Shape, Size, Slice, Split, TransShape, ParallelConcat, Padding,
Shape, DynamicShape, Size, Slice, Split, TransShape, ParallelConcat, Padding,
ScatterNdAdd, ScatterNdSub, ScatterNonAliasingAdd, ReverseV2, Rint,
Squeeze, StridedSlice, Tile, TensorScatterUpdate, EditDistance,
Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin, UnsortedSegmentProd,
......@@ -206,6 +206,7 @@ __all__ = [
'HookBackward',
'InvertPermutation',
'Shape',
'DynamicShape',
'DropoutDoMask',
'DropoutGenMask',
'DropoutGrad',
......
......@@ -27,7 +27,7 @@ import numpy as np
from .._utils import get_concat_offset
from ..operations.math_ops import _infer_shape_reduce
from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register, _run_op
from ..primitive import Primitive, PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register, _run_op
from ..._c_expression import signature_dtype as sig_dtype
from ..._c_expression import signature_kind as sig_kind
from ..._c_expression import signature_rw as sig_rw
......@@ -142,6 +142,11 @@ class ExpandDims(PrimitiveWithInfer):
out = {'shape': x_shape,
'dtype': x['dtype'],
'value': value}
if 'min_shape' in x and 'max_shape' in x:
out['min_shape'] = x['min_shape']
out['min_shape'].insert(axis_v, 1)
out['max_shape'] = x['max_shape']
out['max_shape'].insert(axis_v, 1)
return out
......@@ -277,6 +282,9 @@ class Cast(PrimitiveWithInfer):
out = {'shape': x['shape'],
'dtype': mstype.tensor_type(t['value']),
'value': value}
if 'min_shape' in x and 'max_shape' in x:
out['min_shape'] = x['min_shape']
out['max_shape'] = x['max_shape']
return out
......@@ -445,6 +453,27 @@ class Shape(PrimitiveWithInfer):
return out
class DynamicShape(Primitive):
"""
Returns the shape of input tensor.
Inputs:
- **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
Outputs:
Tensor[int], 1-dim Tensor of type int32
Examples:
>>> input_tensor = Tensor(np.ones(shape=[3, 2, 1]), mindspore.float32)
>>> shape = P.DynamicShape()
>>> output = shape(input_tensor)
"""
@prim_attr_register
def __init__(self):
"""init Shape"""
class Squeeze(PrimitiveWithInfer):
"""
Returns a tensor with the same type but dimensions of 1 being removed based on axis.
......@@ -578,7 +607,7 @@ class Unique(Primitive):
self.init_prim_io_names(inputs=['x'], outputs=['output'])
class GatherV2(PrimitiveWithInfer):
class GatherV2(PrimitiveWithCheck):
"""
Returns a slice of input tensor based on the specified indices and axis.
......@@ -605,7 +634,7 @@ class GatherV2(PrimitiveWithInfer):
"""init index_select"""
self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output'])
def __infer__(self, params, indices, axis):
def __check__(self, params, indices, axis):
validator.check_subclass("params", params['dtype'], mstype.tensor, self.name)
validator.check_tensor_type_same({"indices": indices['dtype']}, mstype.int_type, self.name)
validator.check_subclass("axis", axis['dtype'], mstype.int_, self.name)
......@@ -613,13 +642,6 @@ class GatherV2(PrimitiveWithInfer):
params_shp = params['shape']
rank = len(params_shp)
validator.check_int_range("axis", axis_v, -rank, rank, Rel.INC_LEFT, self.name)
if axis_v < 0:
axis_v += rank
out_shape = params_shp[:axis_v] + indices['shape'] + params_shp[axis_v + 1:]
out = {'shape': out_shape,
'dtype': params['dtype'],
'value': None}
return out
class SparseGatherV2(GatherV2):
......
......@@ -26,7 +26,7 @@ from ..._checkparam import Rel
from ...common import dtype as mstype
from ...common.tensor import Tensor
from .._utils import get_broadcast_shape
from ..primitive import PrimitiveWithInfer, prim_attr_register, _run_op
from ..primitive import PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register, _run_op
def _infer_shape_reduce(x, axis, keep_dims, prim_name):
......@@ -1257,7 +1257,7 @@ class Rsqrt(PrimitiveWithInfer):
return None
class Sqrt(PrimitiveWithInfer):
class Sqrt(PrimitiveWithCheck):
"""
Returns square root of a tensor element-wise.
......@@ -1279,12 +1279,8 @@ class Sqrt(PrimitiveWithInfer):
"""init Sqrt"""
self.init_prim_io_names(inputs=['x'], outputs=['output'])
def infer_shape(self, x_shape):
return x_shape
def infer_dtype(self, x_type):
def check_dtype(self, x_type):
validator.check_tensor_type_same({"x": x_type}, mstype.number_type, self.name)
return x_type
def infer_value(self, x):
if x is not None:
......
......@@ -28,7 +28,7 @@ from ..._c_expression import signature_dtype as sig_dtype
from ..._checkparam import Validator as validator
from ..._checkparam import Rel
from ...common import dtype as mstype
from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register
from ..primitive import Primitive, PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register
from ..operations.math_ops import _infer_shape_reduce
......@@ -4354,7 +4354,7 @@ class ApplyProximalAdagrad(PrimitiveWithInfer):
return var_dtype, accum_dtype
class SparseApplyProximalAdagrad(PrimitiveWithInfer):
class SparseApplyProximalAdagrad(PrimitiveWithCheck):
r"""
Update relevant entries according to the proximal adagrad algorithm. Compared with ApplyProximalAdagrad,
an additional index tensor is input.
......@@ -4433,11 +4433,10 @@ class SparseApplyProximalAdagrad(PrimitiveWithInfer):
outputs=['var', 'accum'])
self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name)
def infer_shape(self, var_shape, accum_shape, lr_shape, l1_shape, l2_shape, grad_shape, indices_shape):
def check_shape(self, var_shape, accum_shape, lr_shape, l1_shape, l2_shape, grad_shape, indices_shape):
validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name)
return var_shape, accum_shape
def infer_dtype(self, var_dtype, accum_dtype, lr_dtype, l1_dtype, l2_dtype, grad_dtype, indices_dtype):
def check_dtype(self, var_dtype, accum_dtype, lr_dtype, l1_dtype, l2_dtype, grad_dtype, indices_dtype):
args = {'var': var_dtype, 'accum': accum_dtype, 'grad': grad_dtype}
validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name)
validator.check_scalar_or_tensor_type_same({"lr": lr_dtype}, [mstype.float16, mstype.float32], self.name)
......@@ -4446,7 +4445,6 @@ class SparseApplyProximalAdagrad(PrimitiveWithInfer):
valid_types = [mstype.int16, mstype.int32, mstype.int64,
mstype.uint16, mstype.uint32, mstype.uint64]
validator.check_tensor_type_same({'indices': indices_dtype}, valid_types, self.name)
return var_dtype, accum_dtype
class ApplyAddSign(PrimitiveWithInfer):
......@@ -4978,7 +4976,7 @@ class ApplyFtrl(PrimitiveWithInfer):
return var_type
class SparseApplyFtrl(PrimitiveWithInfer):
class SparseApplyFtrl(PrimitiveWithCheck):
"""
Update relevant entries according to the FTRL-proximal scheme.
......@@ -5053,21 +5051,19 @@ class SparseApplyFtrl(PrimitiveWithInfer):
self.lr_power = validator.check_number("lr_power", lr_power, 0, Rel.LE, self.name)
self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name)
def infer_shape(self, var_shape, accum_shape, linear_shape, grad_shape, indices_shape):
def check_shape(self, var_shape, accum_shape, linear_shape, grad_shape, indices_shape):
validator.check('var shape', var_shape, 'accum shape', accum_shape, Rel.EQ, self.name)
validator.check('var shape', var_shape, 'linear shape', linear_shape, Rel.EQ, self.name)
if len(var_shape) > 1:
validator.check('var_shape[1:]', var_shape[1:], 'grad_shape[1:]', grad_shape[1:], Rel.EQ, self.name)
validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name)
validator.check('grad_shape[0]', grad_shape[0], 'indices_shape[0]', indices_shape[0], Rel.EQ, self.name)
return var_shape, accum_shape, linear_shape
def infer_dtype(self, var_dtype, accum_dtype, linear_dtype, grad_dtype, indices_dtype):
def check_dtype(self, var_dtype, accum_dtype, linear_dtype, grad_dtype, indices_dtype):
args = {"var_dtype": var_dtype, "accum_dtype": accum_dtype,
"linear_dtype": linear_dtype, "grad_dtype": grad_dtype}
validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name)
validator.check_tensor_type_same({"indices_dtype": indices_dtype}, [mstype.int32], self.name)
return var_dtype, accum_dtype, linear_dtype
class SparseApplyFtrlV2(PrimitiveWithInfer):
......
......@@ -200,6 +200,84 @@ class Primitive(Primitive_):
return self._update_parameter
class PrimitiveWithCheck(Primitive):
"""
PrimitiveWithCheck is the base class of primitives in python defines functions for checking operator input arguments
but used the infer method registed in c++ source codes.
There are three methods can be overide to define the check logic of the primitive: __check__(), check_shape(),
check_dtype(). If __check__() is defined in primitive, the __check__() has highest priority to be called.
If __check__() is not defined, infer_shape() and infer_dtype() can be defined to describe the check logic of
the shape and type.
Args:
name (str): Name of the current Primitive.
Examples:
>>> # init a Primitive class with check
>>> class Flatten(PrimitiveWithCheck):
>>> @prim_attr_register
>>> def __init__(self):
>>> pass
>>> def check_shape(self, input_x):
>>> validator.check_integer('input_x rank', len(input_x), 1, Rel.GE, self.name)
>>>
>>> def check_dtype(self, input_x):
>>> validator.check_subclass("input_x", input_x, mstype.tensor, self.name)
>>>
>>> # init a Primitive obj
>>> add = Flatten()
"""
def __init__(self, name):
Primitive.__init__(self, name)
self.set_prim_type(prim_type.py_infer_check)
def _clone(self):
"""
Deeply clones the primitive object.
Calls the __init__() method with the same arguments. This method is called in parser if the
flag self.__setattr_flag__ is True.
"""
cloned_prim = Primitive._clone(self)
return cloned_prim
def check_shape(self, *args):
"""
Check shapes of input args.
Note:
The shape of scalar is an empty tuple.
Args:
args (tuple(int)): shapes of input tensors.
Return:
None.
"""
return None
def check_dtype(self, *args):
"""
Check data types of input args.
Args:
args (:class:`mindspore.dtype`): data type of inputs.
Return:
None.
"""
return None
def __check__(self, *args):
"""Check shape, type, and value at the same time by using dictionary as arguments."""
tracks = ['dtype', 'shape']
for track in tracks:
fn = getattr(self, 'check_' + track)
fn(*(x[track] for x in args))
class PrimitiveWithInfer(Primitive):
"""
PrimitiveWithInfer is the base class of primitives in python defines functions for tracking inference in python.
......@@ -306,6 +384,18 @@ class PrimitiveWithInfer(Primitive):
if not is_graph_mode:
return out
# output does not contain dynamic shape, no need to calculate min/max shape
def has_dynamic_shape(shp):
if isinstance(shp, int):
return shp < 0
if isinstance(shp, (list, tuple)):
return any(has_dynamic_shape(e) for e in shp)
return False
if not has_dynamic_shape(out['shape']):
return out
# calculate min/max shape for output
def get_specified_shape(elems, attr):
has_specified_shape = False
ret_vals = []
......@@ -345,6 +435,8 @@ def prim_attr_register(fn):
def deco(self, *args, **kwargs):
if isinstance(self, PrimitiveWithInfer):
PrimitiveWithInfer.__init__(self, self.__class__.__name__)
elif isinstance(self, PrimitiveWithCheck):
PrimitiveWithCheck.__init__(self, self.__class__.__name__)
else:
Primitive.__init__(self, self.__class__.__name__)
bound_args = inspect.signature(fn).bind(self, *args, **kwargs)
......
......@@ -27,7 +27,7 @@ from mindspore.ops import composite as C
from mindspore.ops import functional as F
from mindspore.ops import operations as P
from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like
from mindspore.ops.primitive import constexpr
from mindspore.ops.primitive import constexpr, PrimitiveWithInfer, prim_attr_register
from mindspore.ops._grad.grad_base import bprop_getters
from mindspore import Tensor, RowTensor, context
from mindspore.common.parameter import Parameter, ParameterTuple
......@@ -105,10 +105,31 @@ def _generate_inverse_index(x_shape, axis):
perm = index[1:1 + axis] + (0,) + index[1 + axis:]
return perm
class MySparseGatherV2(P.GatherV2):
# pylint: disable=W0231
class MySparseGatherV2(PrimitiveWithInfer):
"""
For test
"""
@prim_attr_register
def __init__(self):
"""init index_select"""
self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output'])
def __infer__(self, params, indices, axis):
validator.check_subclass("params", params['dtype'], mstype.tensor, self.name)
validator.check_tensor_type_same({"indices": indices['dtype']}, mstype.int_type, self.name)
validator.check_subclass("axis", axis['dtype'], mstype.int_, self.name)
axis_v = axis['value']
params_shp = params['shape']
rank = len(params_shp)
validator.check_int_range("axis", axis_v, -rank, rank, Rel.INC_LEFT, self.name)
if axis_v < 0:
axis_v += rank
out_shape = params_shp[:axis_v] + indices['shape'] + params_shp[axis_v + 1:]
out = {'shape': out_shape,
'dtype': params['dtype'],
'value': None}
return out
@bprop_getters.register(MySparseGatherV2)
def get_bprop_sparse_gather_v2(self):
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test dynamic shape """
from mindspore import Tensor, context, nn, Parameter
from mindspore.ops import operations as P
from mindspore import dtype as mstype
import numpy as np
context.set_context(mode=context.GRAPH_MODE, save_graphs=False)
def test_sparse_apply_proximal_ada_grad():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.sparse_apply_proximal_adagrad = P.SparseApplyProximalAdagrad()
self.var = Parameter(Tensor(np.random.rand(7800, 80).astype(np.float32)), name="var")
self.accum = Parameter(Tensor(np.random.rand(7800, 80).astype(np.float32)), name="accum")
self.lr = 0.01
self.l1 = 0.0
self.l2 = 0.0
def construct(self, grad, indices):
out = self.sparse_apply_proximal_adagrad(self.var, self.accum, self.lr, self.l1, self.l2, grad, indices)
return out[0]
class NetWrapper(nn.Cell):
def __init__(self):
super(NetWrapper, self).__init__()
self.unq = P.Unique()
self.add = P.TensorAdd()
self.expand_dims = P.ExpandDims()
self.cast = P.Cast()
self.net = Net()
def construct(self, grad, inp):
ids, _ = self.unq(inp)
new_grad = self.expand_dims(ids, 1)
new_grad = self.cast(new_grad, mstype.float32) + grad
return self.net(new_grad, ids)
net = NetWrapper()
grad = Tensor(np.random.rand(1, 80).astype(np.float32))
indices = Tensor(np.ones([7800]), mstype.int32)
net(grad, indices)
def test_sparse_apply_ftrl():
class SparseApplyFtrlNet(nn.Cell):
def __init__(self):
super(SparseApplyFtrlNet, self).__init__()
self.sparse_apply_ftrl = P.SparseApplyFtrl(lr=0.01, l1=0.0, l2=0.0, lr_power=-0.5)
self.var = Parameter(Tensor(np.random.rand(7800, 80).astype(np.float32)), name="var")
self.accum = Parameter(Tensor(np.random.rand(7800, 80).astype(np.float32)), name="accum")
self.linear = Parameter(Tensor(np.random.rand(7800, 80).astype(np.float32)), name="linear")
def construct(self, grad, indices):
out = self.sparse_apply_ftrl(self.var, self.accum, self.linear, grad, indices)
return out[0]
class NetWrapper(nn.Cell):
def __init__(self):
super(NetWrapper, self).__init__()
self.unq = P.Unique()
self.add = P.TensorAdd()
self.expand_dims = P.ExpandDims()
self.cast = P.Cast()
self.net = SparseApplyFtrlNet()
def construct(self, grad, inp):
ids, _ = self.unq(inp)
new_grad = self.expand_dims(ids, 1)
new_grad = self.cast(new_grad, mstype.float32) + grad
return self.net(new_grad, ids)
net = NetWrapper()
grad = Tensor(np.random.rand(1, 80).astype(np.float32))
indices = Tensor(np.ones([7800]), mstype.int32)
net(grad, indices)
def test_gatherv2():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.unq = P.Unique()
self.gather = P.GatherV2()
def construct(self, x, y):
u, _ = self.unq(y)
z = self.gather(x, u, 0)
return z
x = Tensor(np.ones([20, 12], dtype=np.float32))
y = Tensor(np.ones([8], dtype=np.int32))
net = Net()
net(x, y)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册