提交 ef8cdbd8 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!2639 [ad] improve the performance of ad

Merge pull request !2639 from Kang/master
......@@ -424,7 +424,6 @@ AnfNodePtr DFunctor::MapToK(const AnfNodePtr &primal) {
}
auto k_prim = g_k_prims.KPrimitive(value_node, resources_);
if (k_prim != nullptr) {
k_prim = BasicClone(k_prim);
return NewValueNode(k_prim);
}
// When failed to find k_prim, try k_meta.
......
......@@ -47,12 +47,12 @@ struct PrimitiveTotalEqual {
return false;
}
for (auto &attr : attrs1) {
if (!t2->HasAttr(attr.first)) {
for (auto &attr1 : attrs1) {
if (!t2->HasAttr(attr1.first)) {
return false;
}
if (!(*(attr.second) == *(t2->GetAttr(attr.first)))) {
if (!(*(attr1.second) == *(t2->GetAttr(attr1.first)))) {
return false;
}
}
......@@ -61,7 +61,7 @@ struct PrimitiveTotalEqual {
}
};
using Registry = std::unordered_map<PrimitivePtr, FuncGraphPtr, PrimitiveHasher>;
using Registry = std::unordered_map<PrimitivePtr, FuncGraphPtr, PrimitiveHasher, PrimitiveTotalEqual>;
class KPrim;
extern KPrim g_k_prims;
class DFunctor;
......
......@@ -96,34 +96,37 @@ FuncGraphPtr KPrim::KPrimitive(const ValueNodePtr &value_node, const pipeline::R
MS_LOG(EXCEPTION) << "Primitive node is not valid.";
}
auto prim = value_node->value()->cast<PrimitivePtr>();
MS_EXCEPTION_IF_NULL(prim);
auto iter = bprop_registry_.find(prim);
if (iter != bprop_registry_.end()) {
return iter->second;
}
if (prim->Hash() == prim::kPrimSwitchLayer->Hash() && prim->name() == "switch_layer") {
auto prim = GetValueNode<PrimitivePtr>(value_node);
if (prim->Hash() == prim::kPrimSwitchLayer->Hash() && prim->name() == prim::kPrimSwitchLayer->name()) {
auto fprop = GetFprop(prim);
fprop->transforms().emplace("primal", FuncGraphTransform(prim::kPrimSwitchLayer));
bprop_registry_[prim::kPrimSwitchLayer] = fprop;
return fprop;
}
if (prim->Hash() == prim::kPrimMakeTuple->Hash() && prim->name() == prim::kPrimMakeTuple->name()) {
} else if (prim->Hash() == prim::kPrimMakeTuple->Hash() && prim->name() == prim::kPrimMakeTuple->name()) {
return nullptr;
}
bool is_faked_bprop = false;
FuncGraphPtr bprop_fg = nullptr;
if (prim->Hash() == prim::kPrimHookBackward->Hash() && prim->name() == "HookBackward") {
bprop_fg = BpropCut(value_node, resources);
} else {
bprop_fg = GetBprop(prim);
if (bprop_fg == nullptr) {
bprop_fg = FakeBprop(value_node, resources);
is_faked_bprop = true;
auto iter = bprop_registry_.find(prim);
if (iter != bprop_registry_.end()) {
bprop_fg = iter->second;
}
if (bprop_fg == nullptr) {
bool is_faked_bprop = false;
if (prim->Hash() == prim::kPrimHookBackward->Hash() && prim->name() == prim::kPrimHookBackward->name()) {
bprop_fg = BpropCut(value_node, resources);
} else {
bprop_fg = GetBprop(prim);
if (bprop_fg == nullptr) {
bprop_fg = FakeBprop(value_node, resources);
is_faked_bprop = true;
}
}
// To support primitives with variable params, do not cache faked bprop
if (!is_faked_bprop) {
// Set bprop_g graph cache
bprop_registry_[prim] = bprop_fg;
}
}
......@@ -134,11 +137,6 @@ FuncGraphPtr KPrim::KPrimitive(const ValueNodePtr &value_node, const pipeline::R
<< trace::GetDebugInfo(bprop_fg->debug_info());
}
// To support primitives with variable params, do not cache faked bprop
if (!is_faked_bprop) {
// Set bprop_g graph cache
bprop_registry_[prim] = expanded_fg;
}
return expanded_fg;
}
......
......@@ -38,7 +38,10 @@ TEST_F(TestGradImplementations, TestGetAugmentedGraph) {
draw::Draw("gradImpl_TestGetAugmentedFuncGraph.dot", fg);
auto fg1 = ad::g_k_prims.KPrimitive(NewValueNode(kPrimScalarMul), nullptr);
ASSERT_TRUE(fg == fg1);
FuncGraphPairMapEquiv equiv_graph;
NodeMapEquiv equiv_node;
ASSERT_TRUE(Isomorphic(fg, fg1, &equiv_graph, &equiv_node));
}
} // namespace prim
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册