提交 ea3ddea3 编写于 作者: W Wei Luning

remove ref origin

上级 e7df5416
......@@ -333,28 +333,28 @@ ArgsPairList HyperMap::Harmonize(const FuncGraphPtr &func_graph, const ArgsPairL
}
FuncGraphPtr HyperMap::GenerateFromTypes(const TypePtrList &args_spec_list) {
FuncGraphPtr ptrGraph = std::make_shared<FuncGraph>();
ptrGraph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
ptrGraph->set_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER, true);
ptrGraph->debug_info()->set_name("hyper_map");
FuncGraphPtr ptr_graph = std::make_shared<FuncGraph>();
ptr_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
ptr_graph->set_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER, true);
ptr_graph->debug_info()->set_name("hyper_map");
AnfNodePtr ptrFnArg = nullptr;
std::size_t i = 0;
ArgsPairList argmap;
ArgsPairList argmap2;
if (fn_leaf_ == nullptr) {
ptrFnArg = ptrGraph->add_parameter();
ptrFnArg = ptr_graph->add_parameter();
i = 1;
}
std::size_t size = args_spec_list.size();
for (; i < size; ++i) {
argmap.push_back(std::make_pair(ptrGraph->add_parameter(), args_spec_list[i]));
argmap.push_back(std::make_pair(ptr_graph->add_parameter(), args_spec_list[i]));
}
argmap2 = Harmonize(ptrGraph, argmap);
ptrGraph->set_output(Make(ptrGraph, ptrFnArg, argmap2));
return ptrGraph;
argmap2 = Harmonize(ptr_graph, argmap);
ptr_graph->set_output(Make(ptr_graph, ptrFnArg, argmap2));
return ptr_graph;
}
abstract::AbstractBasePtrList HyperMap::NormalizeArgs(const AbstractBasePtrList &args_spec_list) const {
......@@ -582,30 +582,30 @@ FuncGraphPtr GradOperation::GetGrad(AnfNodePtr node, const AnfNodePtr &weights,
inputs.push_back(opsTupleItem);
inputs.push_back(cnode);
inputs.push_back(NewValueNode(1));
AnfNodePtr ptrBprop = ret->NewCNode(inputs);
AnfNodePtr ptr_bprop = ret->NewCNode(inputs);
doGetGrad(ret, out, ptrBprop, weights_node, opsTupleItem);
doGetGrad(ret, out, ptr_bprop, weights_node, opsTupleItem);
return ret;
}
void GradOperation::doGetGrad(const FuncGraphPtr &func_graph, AnfNodePtr out, AnfNodePtr ptrBprop, AnfNodePtr weights,
void GradOperation::doGetGrad(const FuncGraphPtr &func_graph, AnfNodePtr out, AnfNodePtr ptr_bprop, AnfNodePtr weights,
ValueNodePtr opsTupleItem) {
MS_EXCEPTION_IF_NULL(func_graph);
AnfNodePtr ptrBPropArg = nullptr;
AnfNodePtr ptr_bprop_arg = nullptr;
if (sens_param_) {
ptrBPropArg = func_graph->add_parameter();
ptr_bprop_arg = func_graph->add_parameter();
} else {
auto ones_like = prim::GetPythonOps("ones_like");
ptrBPropArg = func_graph->NewCNode({NewValueNode(ones_like), out});
ptr_bprop_arg = func_graph->NewCNode({NewValueNode(ones_like), out});
}
AnfNodePtr ptrBApp = func_graph->NewCNode({ptrBprop, ptrBPropArg});
AnfNodePtr ptr_bapp = func_graph->NewCNode({ptr_bprop, ptr_bprop_arg});
CNodePtr fv_bprop = nullptr;
if (get_by_list_) {
// python code: grads = hyper_map(F.partial(env_get, env), weights)
AnfNodePtr env = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), ptrBApp, NewValueNode(0)});
AnfNodePtr env = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), ptr_bapp, NewValueNode(0)});
AnfNodePtr partial_env_get =
func_graph->NewCNode({NewValueNode(prim::kPrimPartial), NewValueNode(prim::GetPythonOps("env_get")), env});
MetaFuncGraphPtr hyper_map = std::make_shared<HyperMap>();
......@@ -614,7 +614,7 @@ void GradOperation::doGetGrad(const FuncGraphPtr &func_graph, AnfNodePtr out, An
CNodePtr inputs_bprop = nullptr;
if (get_all_) {
inputs_bprop = func_graph->NewCNode({NewValueNode(kTail), ptrBApp});
inputs_bprop = func_graph->NewCNode({NewValueNode(kTail), ptr_bapp});
}
// Gradients wrt inputs and parameters
......@@ -636,8 +636,8 @@ void GradOperation::doGetGrad(const FuncGraphPtr &func_graph, AnfNodePtr out, An
}
// Gradients wrt first input.
// ptrBApp returns (EnvInstance(grads wrt params), grads wrt input0, grads wrt input1, ...), so 1 is for first input
func_graph->set_output(func_graph->NewCNode({opsTupleItem, ptrBApp, NewValueNode(1)}));
// ptr_bapp returns (EnvInstance(grads wrt params), grads wrt input0, grads wrt input1, ...), so 1 is for first input
func_graph->set_output(func_graph->NewCNode({opsTupleItem, ptr_bapp, NewValueNode(1)}));
}
// Generate the graph.
......@@ -657,35 +657,35 @@ FuncGraphPtr GradOperation::GenerateFuncGraph(const AbstractBasePtrList &args_sp
auto real_fn = dyn_cast<FuncGraphAbstractClosure>(fn);
MS_EXCEPTION_IF_NULL(real_fn);
FuncGraphPtr ptrGraph = real_fn->func_graph();
MS_EXCEPTION_IF_NULL(ptrGraph);
TraceManager::DebugTrace(std::make_shared<TraceGradOperation>(ptrGraph->debug_info()));
FuncGraphPtr dfBuilder = std::make_shared<FuncGraph>();
FuncGraphPtr ptr_graph = real_fn->func_graph();
MS_EXCEPTION_IF_NULL(ptr_graph);
TraceManager::DebugTrace(std::make_shared<TraceGradOperation>(ptr_graph->debug_info()));
FuncGraphPtr df_builder = std::make_shared<FuncGraph>();
TraceManager::EndTrace();
auto nparam = ptrGraph->parameters().size();
auto nparam = ptr_graph->parameters().size();
std::ostringstream ss;
ss << "grad{" << nparam << "}";
dfBuilder->set_flag(FUNC_GRAPH_FLAG_CORE, true);
dfBuilder->debug_info()->set_name(ss.str());
ParameterPtr param_graph = dfBuilder->add_parameter();
df_builder->set_flag(FUNC_GRAPH_FLAG_CORE, true);
df_builder->debug_info()->set_name(ss.str());
ParameterPtr param_graph = df_builder->add_parameter();
AnfNodePtr weights = nullptr;
if (get_by_list_) {
weights = dfBuilder->add_parameter();
weights = df_builder->add_parameter();
}
std::vector<AnfNodePtr> inputs;
inputs.push_back(NewValueNode(prim::kPrimJ));
inputs.push_back(param_graph);
auto jf = dfBuilder->NewCNode(inputs);
auto jf = df_builder->NewCNode(inputs);
// df is checked in GetGrad
TraceManager::DebugTrace(std::make_shared<TraceGradOperation>(ptrGraph->debug_info()));
auto df = GetGrad(jf, weights, ptrGraph->parameters());
TraceManager::DebugTrace(std::make_shared<TraceGradOperation>(ptr_graph->debug_info()));
auto df = GetGrad(jf, weights, ptr_graph->parameters());
TraceManager::EndTrace();
dfBuilder->set_output(NewValueNode(df));
df_builder->set_output(NewValueNode(df));
return dfBuilder;
return df_builder;
}
REGISTER_PYBIND_DEFINE(GradOperation_, ([](const py::module *m) {
......
......@@ -72,10 +72,15 @@ void SetMaxType(TypeId *max_type_id, size_t *max_type_number, const TypeId type_
bool GetTensorOrScalarTypeInfo(AbstractBasePtr arg_value, bool is_write, TypeId *arg_type_id,
TypeId *arg_type = nullptr) {
if (arg_value->isa<abstract::AbstractRef>()) {
if (is_write) {
arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref_origin();
} else {
arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref();
auto ref = arg_value->cast<abstract::AbstractRefPtr>();
arg_value = ref->ref();
if (!is_write && ref->need_cast()) {
auto tensor_type = ref->target_type();
*arg_type_id = tensor_type->type_id();
if (arg_type != nullptr) {
*arg_type = kObjectTypeTensorType;
}
return true;
}
}
if (arg_value->isa<abstract::AbstractTensor>()) {
......@@ -248,6 +253,8 @@ void DoAutoCast(const std::string &func_name, const std::vector<Signature> &sign
if (arg_value->isa<abstract::AbstractTensor>() && arg_type_id == it->second) {
continue;
}
MS_LOG(DEBUG) << "do cast for inputs " << i << " " << (*op_inputs)[i + 1]->ToString() << " " << arg_type_id
<< " to " << it->second;
(*op_inputs)[i + 1] = DoCast((*op_inputs)[i + 1], it->second, graph);
}
}
......@@ -289,16 +296,23 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func
TypePtr type = args_spec_list[i]->GetTypeTrack();
if (type && type->type_id() == kObjectTypeRef) {
auto ref_abs = args_spec_list[i]->cast<abstract::AbstractRefPtr>();
if (sig == SignatureEnumRW::kRWRead) {
param = func_graph->NewCNode({NewValueNode(prim::kPrimGetRefValue), param});
param = NewCNode({NewValueNode(prim::kPrimGetRefValue), param}, func_graph);
if (ref_abs && ref_abs->need_cast()) {
auto cast = prim::GetPythonOps("cast", "mindspore.ops.functional");
param = NewCNode({NewValueNode(cast), param, NewValueNode(ref_abs->target_type())}, func_graph);
}
} else if (sig == SignatureEnumRW::kRWWrite) {
param = func_graph->NewCNode({NewValueNode(prim::kPrimGetRefOrigin), param});
param = NewCNode({NewValueNode(prim::kPrimGetRefValue), param}, func_graph);
write_indices.insert(i);
}
// If sig is SignatureEnumRW::kRWRef, not do anything.
} else if (sig == SignatureEnumRW::kRWWrite && type->type_id() != kObjectTypeRefKey) {
MS_EXCEPTION(TypeError) << "Function " << func_name << "'s input " << i << " should be a Parameter.";
}
MS_LOG(DEBUG) << "Function " << func_name << "'s input " << i << " " << param->DebugString(2) << " type "
<< args_spec_list[i]->ToString();
op_inputs.push_back(param);
}
// process default
......
......@@ -49,13 +49,14 @@ FuncGraphPtr UnpackCall::GenerateFuncGraph(const AbstractBasePtrList &args_spec_
MS_LOG(EXCEPTION) << op_name << " requires at least two args, but got " << arg_length << ".";
}
(void)abstract::CheckArg<AbstractFunction>(op_name, args_spec_list, 0);
// No need to check, check will be done in infer.
auto ret_graph = std::make_shared<FuncGraph>();
ret_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
ret_graph->debug_info()->set_name("UnpackCall");
AnfNodePtr fnNode = ret_graph->add_parameter();
AnfNodePtr fn_node = ret_graph->add_parameter();
std::vector<AnfNodePtr> elems;
elems.push_back(fnNode);
elems.push_back(fn_node);
for (size_t index = 1; index < arg_length; index++) {
MS_EXCEPTION_IF_NULL(args_spec_list[index]);
if (args_spec_list[index]->isa<AbstractTuple>()) {
......
......@@ -129,16 +129,22 @@ AbstractBasePtr InferImplMakeRefKey(const AnalysisEnginePtr &, const PrimitivePt
AbstractBasePtr InferImplMakeRef(const AnalysisEnginePtr &, const PrimitivePtr &,
const AbstractBasePtrList &args_spec_list) {
// arguments: key, value, original value
// arguments: key, value, target type(None if no target type)
if (args_spec_list.size() != 3) {
MS_LOG(EXCEPTION) << "make_ref evaluator requires 3 parameters, while the input size is " << args_spec_list.size()
<< ".";
}
TypePtr type = args_spec_list[0]->GetTypeTrack();
ValuePtr tensor_target_v = args_spec_list[2]->BuildValue();
if (type->type_id() != kObjectTypeRefKey) {
MS_LOG(EXCEPTION) << "First input of make_ref should be a RefKey but a " << type->ToString();
}
return std::make_shared<AbstractRef>(args_spec_list[0], args_spec_list[1], args_spec_list[2]);
auto need_cast = !tensor_target_v->isa<None>();
if (need_cast && !tensor_target_v->isa<Type>()) {
MS_LOG(EXCEPTION) << "Third input of make_ref should be a Type but a " << tensor_target_v->ToString();
}
TypePtr cast_target = tensor_target_v->cast<TypePtr>();
return std::make_shared<AbstractRef>(args_spec_list[0], args_spec_list[1], need_cast, cast_target);
}
AbstractBasePtr InferImplGetRefKey(const AnalysisEnginePtr &, const PrimitivePtr &,
......@@ -163,25 +169,11 @@ AbstractBasePtr InferImplGetRefValue(const AnalysisEnginePtr &, const PrimitiveP
}
TypePtr type = args_spec_list[0]->GetTypeTrack();
if (type->type_id() != kObjectTypeRef) {
MS_LOG(EXCEPTION) << "First input of get_ref_value should be a Ref but a " << type->ToString();
return args_spec_list[0];
}
return args_spec_list[0]->cast<AbstractRefPtr>()->ref();
}
AbstractBasePtr InferImplGetRefOrigin(const AnalysisEnginePtr &, const PrimitivePtr &,
const AbstractBasePtrList &args_spec_list) {
// arguments: value
if (args_spec_list.size() != 1) {
MS_LOG(EXCEPTION) << "get_ref_origin requires 1 parameters, while the input size is " << args_spec_list.size()
<< ".";
}
TypePtr type = args_spec_list[0]->GetTypeTrack();
if (type->type_id() != kObjectTypeRef) {
MS_LOG(EXCEPTION) << "First input of get_ref_value should be a Ref but a " << type->ToString();
}
return args_spec_list[0]->cast<AbstractRefPtr>()->ref_origin();
}
AbstractBasePtr InferImplStateSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// args: Two objects of a subclass of AbstractBase, key and value.
......
......@@ -95,10 +95,10 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
// Ref eliminate
make_ref_eliminate_ =
MakeSubstitution(std::make_shared<MakeRefEliminater>(), "make_ref_eliminate", prim::kPrimMakeRef);
get_ref_param_eliminate_ = MakeSubstitution(std::make_shared<GetRefParamEliminater>(), "get_ref_param_eliminate",
{prim::kPrimGetRefValue, prim::kPrimGetRefOrigin});
get_ref_param_eliminate_ =
MakeSubstitution(std::make_shared<GetRefParamEliminater>(), "get_ref_param_eliminate", {prim::kPrimGetRefValue});
get_make_ref_eliminate_ = MakeSubstitution(std::make_shared<GetMakeRefEliminater>(), "get_make_ref_eliminate",
{prim::kPrimGetRefKey, prim::kPrimGetRefValue, prim::kPrimGetRefOrigin});
{prim::kPrimGetRefKey, prim::kPrimGetRefValue});
replace_refkey_by_param_ = MakeSubstitution(std::make_shared<ReplaceRefkeyByParam>(), "replace_refkey_by_param",
IsValueNode<RefKey>, opt::FORCE_RENORM);
......
......@@ -37,27 +37,23 @@ class MakeRefEliminater : public OptimizerCaller {
};
// {prim::kPrimGetRefValue, Parameter} -> Parameter
// {prim::kPrimGetRefOrigin, Parameter} -> Parameter
class GetRefParamEliminater : public OptimizerCaller {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
PatternNode<AnfNodePtr> x;
MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefValue, x), x);
MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefOrigin, x), x);
return nullptr;
}
};
// {prim::kPrimGetRefKey, {prim::kPrimMakeRef, X, Y, Z}} -> X
// {prim::kPrimGetRefValue, {prim::kPrimMakeRef, X, Y, Z}} -> Y
// {prim::kPrimGetRefOrigin, {prim::kPrimMakeRef, X, Y, Z}} -> Z
class GetMakeRefEliminater : public OptimizerCaller {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
PatternNode<AnfNodePtr> x, y, z;
MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefKey, PPrimitive(prim::kPrimMakeRef, x, y, z)), x);
MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefValue, PPrimitive(prim::kPrimMakeRef, x, y, z)), y);
MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefOrigin, PPrimitive(prim::kPrimMakeRef, x, y, z)), z);
return nullptr;
}
......
......@@ -60,6 +60,17 @@ FuncGraphPtr ParsePythonCode(const py::object &obj, const std::string &python_mo
return func_graph;
}
ValuePtr GetMixedPrecisionTargetType(const FuncGraphPtr &func_graph, const AnfNodePtr &param) {
TypePtr dst_type;
if (func_graph->has_flag(GRAPH_FLAG_MIX_PRECISION_FP32)) {
return kFloat32;
} else if (func_graph->has_flag(GRAPH_FLAG_MIX_PRECISION_FP16)) {
return kFloat16;
} else {
return kNone;
}
}
// if any mixed precision flag add a cast node after the parameter node.
AnfNodePtr GetMixedPrecisionCastHelp(const FuncGraphPtr &func_graph, const AnfNodePtr &param) {
TypePtr dst_type;
......
......@@ -359,6 +359,7 @@ class ParseAst {
bool UpdateFuncGraphFlags(py::object obj, const FuncGraphPtr &func_graph);
AnfNodePtr GetMixedPrecisionCastHelp(const FuncGraphPtr &func_graph, const AnfNodePtr &param);
ValuePtr GetMixedPrecisionTargetType(const FuncGraphPtr &func_graph, const AnfNodePtr &param);
} // namespace parse
} // namespace mindspore
......
......@@ -70,6 +70,7 @@ bool SymbolResolver::Resolve() {
}
namespace {
// if any mixed precision flag add a cast node after the parameter node.
// argument obj should be python Parameter object
// it will be converted to Parameter node here
AnfNodePtr ResolveParameterObj(const FuncGraphPtr &func_graph, const py::object &obj) {
......@@ -112,11 +113,12 @@ AnfNodePtr ResolveParameterObj(const FuncGraphPtr &func_graph, const py::object
}
auto iter = func_graph->make_ref_params().find(para_node);
if (iter == func_graph->make_ref_params().end()) {
AnfNodePtr value = GetMixedPrecisionCastHelp(func_graph, para_node);
ValuePtr target_type = GetMixedPrecisionTargetType(func_graph, para_node);
AnfNodePtr make_ref = NewValueNode(prim::kPrimMakeRef);
AnfNodePtr ref_key = NewValueNode(std::make_shared<RefKey>(param_name));
AnfNodePtr ref_node = func_graph->NewCNode({make_ref, ref_key, value, para_node});
AnfNodePtr target_type_node = NewValueNode(target_type);
AnfNodePtr ref_node = func_graph->NewCNode({make_ref, ref_key, para_node, target_type_node});
func_graph->make_ref_params()[para_node] = ref_node;
func_graph->add_parameter_obj_node(ref_node);
return ref_node;
......
......@@ -125,7 +125,6 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{prim::kPrimMakeRef, {InferImplMakeRef, true}},
{prim::kPrimGetRefKey, {InferImplGetRefKey, true}},
{prim::kPrimGetRefValue, {InferImplGetRefValue, true}},
{prim::kPrimGetRefOrigin, {InferImplGetRefOrigin, true}},
{prim::kPrimStateSetItem, {InferImplStateSetItem, true}},
{prim::kPrimDepend, {InferImplDepend, true}},
{prim::kPrimBroadcastGradientArgs, {InferImplBroadcastGradientArgs, false}},
......
......@@ -1117,11 +1117,12 @@ std::vector<AnfNodePtr> PynativeExecutor::GetWeightsArgs(const py::object &weigh
free_param->debug_info()->set_name(param_name);
para_node = free_param;
}
AnfNodePtr value = parse::GetMixedPrecisionCastHelp(df_builder_, para_node);
ValuePtr target_type = parse::GetMixedPrecisionTargetType(df_builder_, para_node);
AnfNodePtr make_ref = NewValueNode(prim::kPrimMakeRef);
auto refkey = std::make_shared<RefKey>(para_node->cast<ParameterPtr>()->name());
AnfNodePtr ref_key_node = NewValueNode(refkey);
AnfNodePtr ref_node = df_builder_->NewCNode({make_ref, ref_key_node, value, para_node});
AnfNodePtr target_type_node = NewValueNode(target_type);
AnfNodePtr ref_node = df_builder_->NewCNode({make_ref, ref_key_node, para_node, target_type_node});
w_args.push_back(ref_node);
}
} else {
......
......@@ -808,14 +808,40 @@ std::string AbstractJTagged::ToString() const {
return buffer.str();
}
AbstractRef::AbstractRef(const AbstractBasePtr &ref_key, const AbstractBasePtr &ref_value, bool need_cast,
TypePtr cast_target)
: ref_key_(ref_key), ref_(ref_value), need_cast_(false), target_type_(nullptr), ref_key_value_(nullptr) {
set_type(std::make_shared<RefType>());
auto origin_type = ref_value->BuildType();
if (need_cast && cast_target && origin_type && origin_type->isa<TensorType>()) {
auto tensor_dtype = origin_type->cast<TensorTypePtr>()->element();
if (tensor_dtype && IsSubType(tensor_dtype, kFloat)) {
if (cast_target != tensor_dtype) {
need_cast_ = true;
target_type_ = cast_target;
}
}
}
if (ref_key && ref_key->isa<AbstractRefKey>()) {
ref_key_value_ = ref_key->cast<AbstractRefKeyPtr>()->ref_key_value();
}
}
BaseShapePtr AbstractRef::BuildShape() const { return ref_->BuildShape(); }
TypePtr AbstractRef::BuildType() const {
TypePtr subtype = ref_->BuildType();
TypePtr subtype_origin = ref_origin_->BuildType();
TypePtr subtype_origin = subtype;
if (need_cast_) {
subtype_origin = std::make_shared<TensorType>(target_type_);
}
return std::make_shared<RefType>(subtype, subtype_origin);
}
bool AbstractRef::operator==(const AbstractRef &other) const {
return (*ref_ == *other.ref_) && (*ref_key_ == *other.ref_key_) && (*ref_origin_ == *other.ref_origin_);
return (*ref_ == *other.ref_) && (need_cast_ == other.need_cast_) &&
(!need_cast_ || (*target_type_ == *other.target_type_));
// not compare the key for reuse the graph (*ref_key_ == *other.ref_key_);
}
bool AbstractRef::operator==(const AbstractBase &other) const {
......@@ -826,27 +852,45 @@ bool AbstractRef::operator==(const AbstractBase &other) const {
return false;
}
AbstractBasePtr AbstractRefKey::Join(const AbstractBasePtr &other) {
MS_EXCEPTION_IF_NULL(other);
if (*this == *other) {
auto ret = shared_from_base<AbstractBase>();
return ret;
}
auto value_self = GetValueTrack();
MS_EXCEPTION_IF_NULL(value_self);
ValuePtr res_value = ValueJoin(value_self, other->GetValueTrack());
if (res_value == value_self) {
auto ret = shared_from_base<AbstractBase>();
return ret;
}
auto ret = std::make_shared<AbstractRefKey>();
ret->set_value(res_value);
return ret;
}
AbstractBasePtr AbstractRef::Join(const AbstractBasePtr &other) {
auto other_ref = other->cast<AbstractRefPtr>();
if (other_ref == nullptr) {
auto new_ref = ref_->Join(other);
return std::make_shared<AbstractRef>(ref_key_, new_ref, ref_origin_);
return std::make_shared<AbstractRef>(ref_key_, new_ref);
}
if (*this == *other) {
if ((*this == *other) && (*ref_key_ == *other_ref->ref_key_)) {
return shared_from_base<AbstractBase>();
}
auto ref_key = ref_key_->Join(other_ref->ref_key_);
auto ref = ref_->Join(other_ref->ref());
auto ref_origin = ref_origin_->Join(other_ref->ref_origin_);
return std::make_shared<AbstractRef>(ref_key, ref, ref_origin);
return std::make_shared<AbstractRef>(ref_key, ref);
}
std::string AbstractRef::ToString() const {
std::ostringstream buffer;
buffer << type_name() << "("
<< "key: " << ref_key_->ToString() << " ref_value: " << ref_->ToString()
<< " origin_value: " << ref_origin_->ToString();
<< "key: " << ref_key_->ToString() << " ref_value: " << ref_->ToString();
if (need_cast_) {
buffer << " cast to: " << target_type_->ToString();
}
auto value = GetValueTrack();
if (value) {
buffer << ", value: " << value->ToString();
......@@ -873,6 +917,12 @@ std::string AbstractNone::ToString() const {
ValuePtr AbstractNone::RealBuildValue() const { return kNone; }
AbstractBasePtr AbstractRefKey::Broaden() const {
auto refkey = std::make_shared<AbstractRefKey>();
refkey->set_value(kAnyValue);
return refkey;
}
bool AbstractRefKey::operator==(const AbstractRefKey &other) const {
ValuePtr value_self = GetValueTrack();
ValuePtr value_other = other.GetValueTrack();
......
......@@ -535,50 +535,70 @@ using AbstractEllipsisPtr = std::shared_ptr<AbstractEllipsis>;
class AbstractRefKey : public AbstractBase {
public:
AbstractRefKey() : AbstractBase() { set_type(std::make_shared<RefKeyType>()); }
AbstractRefKey() : AbstractBase(), ref_key_value_(nullptr) { set_type(std::make_shared<RefKeyType>()); }
~AbstractRefKey() override = default;
MS_DECLARE_PARENT(AbstractRefKey, AbstractBase)
TypePtr BuildType() const override { return std::make_shared<RefKeyType>(); }
bool operator==(const AbstractRefKey &other) const;
bool operator==(const AbstractBase &other) const override;
AbstractBasePtr Clone() const override { return std::make_shared<AbstractRefKey>(); }
AbstractBasePtr Clone() const override {
auto cloned = std::make_shared<AbstractRefKey>();
cloned->set_value(GetValueTrack());
return cloned;
}
inline void set_value(const ValuePtr &value) {
AbstractBase::set_value(value);
ref_key_value_ = value->cast<RefKeyPtr>();
}
RefKeyPtr ref_key_value() const { return ref_key_value_; }
AbstractBasePtr Join(const AbstractBasePtr &other) override;
AbstractBasePtr Broaden() const override;
std::string ToString() const override;
private:
// cache for ref_key after build value, when value is null, return nullptr.
RefKeyPtr ref_key_value_{nullptr};
};
using AbstractRefKeyPtr = std::shared_ptr<AbstractRefKey>;
class AbstractRef : public AbstractBase {
public:
AbstractRef(const AbstractBasePtr &ref_key, const AbstractBasePtr &ref_value, const AbstractBasePtr &ref_origin)
: ref_key_(ref_key), ref_(ref_value), ref_origin_(ref_origin) {
set_type(std::make_shared<RefType>());
}
AbstractRef(const AbstractBasePtr &ref_key, const AbstractBasePtr &ref_value, bool need_cast = false,
TypePtr cast_target = nullptr);
~AbstractRef() override = default;
MS_DECLARE_PARENT(AbstractRef, AbstractBase)
TypePtr BuildType() const override;
BaseShapePtr BuildShape() const override;
bool operator==(const AbstractRef &other) const;
bool operator==(const AbstractBase &other) const override;
AbstractBasePtr Clone() const override {
return std::make_shared<AbstractRef>(ref_key_->Clone(), ref_->Clone(), ref_origin_->Clone());
return std::make_shared<AbstractRef>(ref_key_->Clone(), ref_->Clone(), need_cast_, target_type_);
}
std::string ToString() const override;
AbstractBasePtr ref() { return ref_; }
AbstractBasePtr ref_origin() { return ref_origin_; }
AbstractBasePtr ref_key() { return ref_key_; }
inline AbstractBasePtr ref() const { return ref_; }
inline AbstractBasePtr ref_key() const { return ref_key_; }
inline RefKeyPtr ref_key_value() const { return ref_key_value_; }
inline TypePtr target_type() const { return target_type_; }
inline bool need_cast() const { return need_cast_; }
AbstractBasePtr Broaden() const override {
return std::make_shared<AbstractRef>(ref_key_->Broaden(), ref_->Broaden(), ref_origin_->Broaden());
return std::make_shared<AbstractRef>(ref_key_->Broaden(), ref_->Broaden(), need_cast_, target_type_);
}
AbstractBasePtr Join(const AbstractBasePtr &other) override;
std::size_t hash() const override {
return ref_key_->hash() ^ ref_->hash() ^ ref_origin_->hash() ^ (std::hash<uint32_t>{}(this->tid()) << 1);
return ref_->hash() ^ (std::hash<uint32_t>{}(this->tid()) << 1); // ref_key_->hash() ^
}
private:
AbstractBasePtr ref_key_;
AbstractBasePtr ref_;
AbstractBasePtr ref_origin_;
// For mix presicion, only float type need to cast to float16 of float32
bool need_cast_;
TypePtr target_type_;
// cache for ref_key after build value, when value is null, return nullptr.
RefKeyPtr ref_key_value_;
};
using AbstractRefPtr = std::shared_ptr<AbstractRef>;
......
......@@ -171,9 +171,7 @@ AnalysisContextPtr AnalysisContext::SpecializeKey() const {
}
if (arg->isa<AbstractRef>()) {
MS_LOG(DEBUG) << "refkey broaden";
auto arg_spec = dyn_cast<AbstractRef>(arg);
auto ret_spec = arg_spec->Broaden();
return ret_spec;
return arg->Broaden();
}
return arg;
});
......
......@@ -121,7 +121,6 @@ inline const PrimitivePtr kPrimEnvAdd = std::make_shared<Primitive>("env_add");
inline const PrimitivePtr kPrimMakeRefKey = std::make_shared<Primitive>("MakeRefKey");
inline const PrimitivePtr kPrimGetRefKey = std::make_shared<Primitive>("get_ref_key");
inline const PrimitivePtr kPrimGetRefValue = std::make_shared<Primitive>("get_ref_value");
inline const PrimitivePtr kPrimGetRefOrigin = std::make_shared<Primitive>("get_ref_origin");
inline const PrimitivePtr kPrimInsertGradientOf = std::make_shared<Primitive>("InsertGradientOf");
inline const PrimitivePtr kPrimHookBackward = std::make_shared<Primitive>("HookBackward");
inline const PrimitivePtr kPrimPrintShapeType = std::make_shared<Primitive>("PrintShapeType");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册