提交 1d6c76f3 编写于 作者: W Wei Luning

board tensor for pynative infer

上级 c1c30a44
......@@ -285,12 +285,12 @@ void ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, const OpExe
void PynativeInfer(const PrimitivePyPtr &prim, const py::list &py_args, OpExecInfo *const op_exec_info,
const abstract::AbstractBasePtrList &args_spec_list) {
MS_LOG(DEBUG) << "prim " << prim->name() << "input infer" << mindspore::ToString(args_spec_list);
MS_LOG(DEBUG) << "prim " << prim->name() << " input infer " << mindspore::ToString(args_spec_list);
prim->BeginRecordAddAttr();
AbstractBasePtr infer_res = EvalOnePrim(prim, args_spec_list)->abstract();
prim->EndRecordAddAttr();
op_exec_info->abstract = infer_res;
MS_LOG(DEBUG) << "prim " << prim->name() << "infer result " << op_exec_info->abstract->ToString();
MS_LOG(DEBUG) << "prim " << prim->name() << " infer result " << op_exec_info->abstract->ToString();
}
OpExecInfoPtr GenerateOpExecInfo(const py::args &args) {
......@@ -632,7 +632,8 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v
auto obj = op_exec_info->op_inputs[i];
bool op_mask = py::hasattr(obj, "__parameter__");
(*op_masks).push_back(op_mask);
MS_LOG(DEBUG) << "gen args i " << i << op_exec_info->op_name << " op mask" << op_mask << "grad_flag_" << grad_flag_;
MS_LOG(DEBUG) << "gen " << op_exec_info->op_name << " arg " << i << ": op mask " << op_mask << " grad_flag_ "
<< grad_flag_;
AnfNodePtr node = nullptr;
abstract::AbstractBasePtr abs = nullptr;
......@@ -646,11 +647,17 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v
if (node != nullptr && node->abstract() != nullptr) {
abs = node->abstract();
}
MS_LOG(DEBUG) << prim->ToString() << " abs is nullptr " << (abs == nullptr) << " is_const_value "
<< prim->is_const_value();
if (abs == nullptr || prim->is_const_value()) {
MS_LOG(DEBUG) << "MakeCnode get node no in map" << id;
ValuePtr input_value = PyAttrValue(obj);
bool broaden = !prim->is_const_value() && input_value->isa<tensor::Tensor>();
abs = abstract::FromValueInside(input_value, broaden);
abs = input_value->ToAbstract();
if (!prim->is_const_value()) {
auto config = abstract::AbstractBase::kBroadenTensorOnly;
abs = abs->Broaden(config);
MS_LOG(DEBUG) << "broaden for " << prim->ToString() << " " << config;
}
node_abs_map_[id] = abs;
}
(*args_spec_list).push_back(abs);
......
......@@ -66,9 +66,12 @@ ValuePtr AbstractBase::BuildValue() const {
return value_;
}
AbstractBasePtr AbstractBase::Broaden() const {
AbstractBasePtr AbstractBase::Broaden(uint8_t config) const {
AbstractBasePtr clone = Clone();
clone->set_value(kAnyValue);
auto not_broaden = config & (kBroadenTensorOnly | kBroadenParameterOnly);
if (not_broaden == 0) {
clone->set_value(kAnyValue);
}
return clone;
}
......@@ -85,7 +88,7 @@ std::string AbstractBase::ToString() const {
return buffer.str();
}
AbstractBasePtr AbstractScalar::Broaden() const { return AbstractBase::Broaden(); }
AbstractBasePtr AbstractScalar::Broaden(uint8_t config) const { return AbstractBase::Broaden(config); }
AbstractBasePtr AbstractScalar::Join(const AbstractBasePtr &other) {
MS_EXCEPTION_IF_NULL(other);
......@@ -224,11 +227,11 @@ AbstractBasePtrList AbstractSequeue::ElementsClone() const {
return ele_list;
}
AbstractBasePtrList AbstractSequeue::ElementsBroaden() const {
AbstractBasePtrList AbstractSequeue::ElementsBroaden(uint8_t config) const {
AbstractBasePtrList ele_list;
for (const auto &ele : elements_) {
MS_EXCEPTION_IF_NULL(ele);
AbstractBasePtr broadend = ele->Broaden();
AbstractBasePtr broadend = ele->Broaden(config);
ele_list.push_back(broadend);
}
return ele_list;
......@@ -376,13 +379,13 @@ AbstractBasePtr AbstractSlice::Clone() const {
return std::make_shared<AbstractSlice>(start, stop, step);
}
AbstractBasePtr AbstractSlice::Broaden() const {
AbstractBasePtr AbstractSlice::Broaden(uint8_t config) const {
MS_EXCEPTION_IF_NULL(start_);
MS_EXCEPTION_IF_NULL(stop_);
MS_EXCEPTION_IF_NULL(step_);
AbstractBasePtr start = start_->Broaden();
AbstractBasePtr stop = stop_->Broaden();
AbstractBasePtr step = step_->Broaden();
AbstractBasePtr start = start_->Broaden(config);
AbstractBasePtr stop = stop_->Broaden(config);
AbstractBasePtr step = step_->Broaden(config);
return std::make_shared<AbstractSlice>(start, stop, step);
}
......@@ -506,12 +509,15 @@ AbstractBasePtr AbstractTensor::Clone() const {
return clone;
}
AbstractBasePtr AbstractTensor::Broaden() const {
AbstractBasePtr AbstractTensor::Broaden(uint8_t config) const {
MS_EXCEPTION_IF_NULL(element_);
auto broaden = std::make_shared<AbstractTensor>(element_->Broaden());
auto shp = shape();
broaden->set_shape(shp->Clone());
broaden->set_value(kAnyValue);
auto not_broaden = config & kBroadenParameterOnly;
if (not_broaden == 0) {
broaden->set_value(kAnyValue);
}
return broaden;
}
......@@ -585,12 +591,12 @@ AbstractBasePtr AbstractDictionary::Clone() const {
return std::make_shared<AbstractDictionary>(kv);
}
AbstractBasePtr AbstractDictionary::Broaden() const {
AbstractBasePtr AbstractDictionary::Broaden(uint8_t config) const {
std::vector<AbstractAttribute> kv;
(void)std::transform(key_values_.begin(), key_values_.end(), std::back_inserter(kv),
[](const AbstractAttribute &item) {
[config](const AbstractAttribute &item) {
MS_EXCEPTION_IF_NULL(item.second);
return std::make_pair(item.first, item.second->Broaden());
return std::make_pair(item.first, item.second->Broaden(config));
});
return std::make_shared<AbstractDictionary>(kv);
}
......@@ -711,11 +717,11 @@ AbstractBasePtr AbstractClass::Clone() const {
return std::make_shared<AbstractClass>(tag_, attributes_clone, methods_);
}
AbstractBasePtr AbstractClass::Broaden() const {
AbstractBasePtr AbstractClass::Broaden(uint8_t config) const {
std::vector<AbstractAttribute> attributes_clone;
for (auto attr : attributes_) {
MS_EXCEPTION_IF_NULL(attr.second);
AbstractBasePtr clone = attr.second->Broaden();
AbstractBasePtr clone = attr.second->Broaden(config);
AbstractAttribute elem(attr.first, clone);
attributes_clone.push_back(elem);
}
......@@ -843,9 +849,8 @@ TypePtr AbstractRef::BuildType() const {
}
bool AbstractRef::operator==(const AbstractRef &other) const {
return (*ref_ == *other.ref_) && (need_cast_ == other.need_cast_) &&
return (*ref_ == *other.ref_) && (need_cast_ == other.need_cast_) && (*ref_key_ == *other.ref_key_) &&
(!need_cast_ || (*target_type_ == *other.target_type_));
// not compare the key for reuse the graph (*ref_key_ == *other.ref_key_);
}
bool AbstractRef::operator==(const AbstractBase &other) const {
......@@ -921,9 +926,12 @@ std::string AbstractNone::ToString() const {
ValuePtr AbstractNone::RealBuildValue() const { return kNone; }
AbstractBasePtr AbstractRefKey::Broaden() const {
AbstractBasePtr AbstractRefKey::Broaden(uint8_t config) const {
auto refkey = std::make_shared<AbstractRefKey>();
refkey->set_value(kAnyValue);
auto not_broaden = config & (kBroadenTensorOnly | kBroadenParameterOnly);
if (not_broaden == 0) {
refkey->set_value(kAnyValue);
}
return refkey;
}
......@@ -1016,9 +1024,9 @@ AbstractBasePtr AbstractKeywordArg::Clone() const {
return std::make_shared<AbstractKeywordArg>(arg_name_, arg_value_->Clone());
}
AbstractBasePtr AbstractKeywordArg::Broaden() const {
AbstractBasePtr AbstractKeywordArg::Broaden(uint8_t config) const {
MS_EXCEPTION_IF_NULL(arg_value_);
return std::make_shared<AbstractKeywordArg>(arg_name_, arg_value_->Broaden());
return std::make_shared<AbstractKeywordArg>(arg_name_, arg_value_->Broaden(config));
}
std::size_t AbstractKeywordArg::hash() const {
......@@ -1123,7 +1131,7 @@ AbstractBasePtr AbstractRowTensor::Clone() const {
return clone;
}
AbstractBasePtr AbstractRowTensor::Broaden() const {
AbstractBasePtr AbstractRowTensor::Broaden(uint8_t config) const {
MS_EXCEPTION_IF_NULL(element());
auto broaden = std::make_shared<AbstractRowTensor>(element()->Broaden());
auto shp = shape();
......@@ -1182,7 +1190,7 @@ AbstractBasePtr AbstractSparseTensor::Clone() const {
return clone;
}
AbstractBasePtr AbstractSparseTensor::Broaden() const {
AbstractBasePtr AbstractSparseTensor::Broaden(uint8_t config) const {
MS_EXCEPTION_IF_NULL(element());
auto broaden = std::make_shared<AbstractSparseTensor>(element()->Broaden());
auto shp = shape();
......
......@@ -69,7 +69,14 @@ class AbstractBase : public Base {
virtual TypePtr BuildType() const = 0;
virtual BaseShapePtr BuildShape() const { return kNoShape; }
virtual AbstractBasePtr Clone() const = 0;
virtual AbstractBasePtr Broaden() const;
// mask for Broaden config
inline static const uint8_t kBroadenTensorOnly = 1;
inline static const uint8_t kBroadenParameterOnly = 2;
// Each bit for on config.
// 00000001 -> 1: only boarden tensor
// 00000010 -> 2: only boarden parameter
virtual AbstractBasePtr Broaden(uint8_t config = 0) const;
virtual AbstractBasePtr Join(const AbstractBasePtr &) { return shared_from_base<AbstractBase>(); }
friend std::ostream &operator<<(std::ostream &os, const std::shared_ptr<AbstractBase> &a) {
......@@ -108,7 +115,7 @@ class AbstractScalar : public AbstractBase {
AbstractBasePtr Clone() const override {
return std::make_shared<AbstractScalar>(GetValueTrack(), GetTypeTrack()->Clone());
}
AbstractBasePtr Broaden() const override;
AbstractBasePtr Broaden(uint8_t config = 0) const override;
AbstractBasePtr Join(const AbstractBasePtr &other) override;
};
using AbstractScalarPtr = std::shared_ptr<AbstractScalar>;
......@@ -128,7 +135,7 @@ class AbstractType : public AbstractBase {
TypePtr BuildType() const override { return std::make_shared<TypeType>(); }
AbstractBasePtr Clone() const override;
AbstractBasePtr Broaden() const override { return Clone(); }
AbstractBasePtr Broaden(uint8_t config = 0) const override { return Clone(); }
};
using AbstractTypePtr = std::shared_ptr<AbstractType>;
......@@ -143,7 +150,7 @@ class AbstractError : public AbstractBase {
MS_DECLARE_PARENT(AbstractError, AbstractBase)
TypePtr BuildType() const override { return std::make_shared<Problem>(); }
AbstractBasePtr Broaden() const override { return Clone(); }
AbstractBasePtr Broaden(uint8_t config = 0) const override { return Clone(); }
AbstractBasePtr Clone() const override {
return std::make_shared<AbstractError>(GetValueTrack()->cast<StringImmPtr>(), node_);
......@@ -180,7 +187,7 @@ class AbstractFunction : public AbstractBase {
TypePtr BuildType() const override { return std::make_shared<Function>(); }
AbstractBasePtr Clone() const override { return Copy(); }
// For Function, no need to broaden.
AbstractBasePtr Broaden() const override {
AbstractBasePtr Broaden(uint8_t config = 0) const override {
return const_cast<AbstractFunction *>(this)->shared_from_base<AbstractFunction>();
}
virtual AbstractFunctionPtr Copy() const = 0;
......@@ -209,7 +216,7 @@ class AbstractKeywordArg : public AbstractBase {
TypePtr BuildType() const override;
AbstractBasePtr Clone() const override;
AbstractBasePtr Broaden() const override;
AbstractBasePtr Broaden(uint8_t config = 0) const override;
std::size_t hash() const override;
bool operator==(const AbstractKeywordArg &other) const;
......@@ -275,7 +282,7 @@ class AbstractTensor : public AbstractUndetermined {
TypePtr BuildType() const override;
BaseShapePtr BuildShape() const override;
AbstractBasePtr Clone() const override;
AbstractBasePtr Broaden() const override;
AbstractBasePtr Broaden(uint8_t config = 0) const override;
AbstractBasePtr BroadenWithShape() const;
AbstractBasePtr Join(const AbstractBasePtr &other) final;
int format() const { return this->format_; }
......@@ -312,7 +319,7 @@ class AbstractSequeue : public AbstractBase {
TypePtrList ElementsType() const;
BaseShapePtrList ElementsShape() const;
AbstractBasePtrList ElementsClone() const;
AbstractBasePtrList ElementsBroaden() const;
AbstractBasePtrList ElementsBroaden(uint8_t config = 0) const;
template <typename T>
ValuePtr ElementsBuildValue() const;
......@@ -345,7 +352,9 @@ class AbstractTuple : public AbstractSequeue {
AbstractBasePtr Clone() const override { return std::make_shared<AbstractTuple>(ElementsClone()); }
AbstractBasePtr Broaden() const override { return std::make_shared<AbstractTuple>(ElementsBroaden()); }
AbstractBasePtr Broaden(uint8_t config = 0) const override {
return std::make_shared<AbstractTuple>(ElementsBroaden(config));
}
AbstractBasePtr Join(const AbstractBasePtr &other) override { return ElementsJoin<AbstractTuple>(other); }
......@@ -372,7 +381,9 @@ class AbstractList : public AbstractSequeue {
AbstractBasePtr Clone() const override { return std::make_shared<AbstractList>(ElementsClone()); }
AbstractBasePtr Broaden() const override { return std::make_shared<AbstractList>(ElementsBroaden()); }
AbstractBasePtr Broaden(uint8_t config = 0) const override {
return std::make_shared<AbstractList>(ElementsBroaden(config));
}
AbstractBasePtr Join(const AbstractBasePtr &other) override { return ElementsJoin<AbstractList>(other); }
......@@ -403,7 +414,7 @@ class AbstractClass : public AbstractBase {
AbstractBasePtr GetAttribute(const std::string &name);
ValuePtr GetMethod(const std::string &name);
AbstractBasePtr Clone() const override;
AbstractBasePtr Broaden() const override;
AbstractBasePtr Broaden(uint8_t config = 0) const override;
std::string ToString() const override;
Named tag() const { return tag_; }
std::size_t hash() const override;
......@@ -428,7 +439,7 @@ class AbstractDictionary : public AbstractBase {
bool operator==(const AbstractDictionary &other) const;
bool operator==(const AbstractBase &other) const override;
AbstractBasePtr Clone() const override;
AbstractBasePtr Broaden() const override;
AbstractBasePtr Broaden(uint8_t config = 0) const override;
std::string ToString() const override;
std::size_t hash() const override;
std::size_t size() const { return key_values_.size(); }
......@@ -452,7 +463,7 @@ class AbstractSlice : public AbstractBase {
bool operator==(const AbstractSlice &other) const;
bool operator==(const AbstractBase &other) const override;
AbstractBasePtr Clone() const override;
AbstractBasePtr Broaden() const override;
AbstractBasePtr Broaden(uint8_t config = 0) const override;
std::string ToString() const override;
std::size_t hash() const override;
AbstractBasePtr start() const { return start_; }
......@@ -478,7 +489,9 @@ class AbstractJTagged : public AbstractBase {
TypePtr BuildType() const override;
AbstractBasePtr Clone() const override { return std::make_shared<AbstractJTagged>(element_->Clone()); }
AbstractBasePtr Broaden() const override { return std::make_shared<AbstractJTagged>(element_->Broaden()); }
AbstractBasePtr Broaden(uint8_t config = 0) const override {
return std::make_shared<AbstractJTagged>(element_->Broaden(config));
}
AbstractBasePtr Join(const AbstractBasePtr &other) override;
bool operator==(const AbstractJTagged &other) const;
......@@ -558,7 +571,7 @@ class AbstractRefKey : public AbstractBase {
}
RefKeyPtr ref_key_value() const { return ref_key_value_; }
AbstractBasePtr Join(const AbstractBasePtr &other) override;
AbstractBasePtr Broaden() const override;
AbstractBasePtr Broaden(uint8_t config = 0) const override;
std::string ToString() const override;
private:
......@@ -588,8 +601,9 @@ class AbstractRef : public AbstractBase {
inline RefKeyPtr ref_key_value() const { return ref_key_value_; }
inline TypePtr target_type() const { return target_type_; }
inline bool need_cast() const { return need_cast_; }
AbstractBasePtr Broaden() const override {
return std::make_shared<AbstractRef>(ref_key_->Broaden(), ref_->Broaden(), need_cast_, target_type_);
AbstractBasePtr Broaden(uint8_t config = 0) const override {
// always broaden for ref
return std::make_shared<AbstractRef>(ref_key_->Broaden(config), ref_->Broaden(), need_cast_, target_type_);
}
AbstractBasePtr Join(const AbstractBasePtr &other) override;
std::size_t hash() const override {
......@@ -636,7 +650,7 @@ class AbstractRowTensor : public AbstractUndetermined {
void set_dense_shape(const AbstractTuplePtr &dense_shape) { dense_shape_ = dense_shape; }
TypePtr BuildType() const override;
AbstractBasePtr Clone() const override;
AbstractBasePtr Broaden() const override;
AbstractBasePtr Broaden(uint8_t config = 0) const override;
AbstractBasePtr BroadenWithShape() const;
std::string ToString() const override;
......@@ -665,7 +679,7 @@ class AbstractSparseTensor : public AbstractUndetermined {
void set_dense_shape(const AbstractTuplePtr &dense_shape) { dense_shape_ = dense_shape; }
TypePtr BuildType() const override;
AbstractBasePtr Clone() const override;
AbstractBasePtr Broaden() const override;
AbstractBasePtr Broaden(uint8_t config = 0) const override;
AbstractBasePtr BroadenWithShape() const;
std::string ToString() const override;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册