提交 5b9899eb 编写于 作者: Z Zichun Ye

add support to resolve attr for a class

update logic for cache

clean up code

add test cases for new api

update include of code

specify the resolve attr to ClassMember

avoid resolve attr of self
上级 d8fcf269
...@@ -168,6 +168,8 @@ OptimizeIRPassLib::OptimizeIRPassLib() { ...@@ -168,6 +168,8 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
} }
ResolveIRPassLib::ResolveIRPassLib() { ResolveIRPassLib::ResolveIRPassLib() {
resolver_resolve_attr_ =
MakeSubstitution(std::make_shared<ResolveAttr>(), "resolver_resolve_attr", prim::kPrimGetAttr);
resolver_resolve_ = MakeSubstitution(std::make_shared<ResolverResolve>(), "resolver_resolve", prim::kPrimResolve); resolver_resolve_ = MakeSubstitution(std::make_shared<ResolverResolve>(), "resolver_resolve", prim::kPrimResolve);
resolver_getattr_ = MakeSubstitution(std::make_shared<ResolverGetattr>(), "resolver_getattr", prim::kPrimGetAttr); resolver_getattr_ = MakeSubstitution(std::make_shared<ResolverGetattr>(), "resolver_getattr", prim::kPrimGetAttr);
} }
......
...@@ -118,6 +118,7 @@ class ResolveIRPassLib { ...@@ -118,6 +118,7 @@ class ResolveIRPassLib {
ResolveIRPassLib(); ResolveIRPassLib();
~ResolveIRPassLib() = default; ~ResolveIRPassLib() = default;
SubstitutionPtr resolver_resolve_attr_;
SubstitutionPtr resolver_resolve_; SubstitutionPtr resolver_resolve_;
SubstitutionPtr resolver_getattr_; SubstitutionPtr resolver_getattr_;
}; };
......
...@@ -21,15 +21,21 @@ ...@@ -21,15 +21,21 @@
#include <memory> #include <memory>
#include "frontend/optimizer/optimizer.h" #include "frontend/optimizer/optimizer.h"
#include "frontend/optimizer/optimizer_caller.h"
#include "frontend/optimizer/irpass.h" #include "frontend/optimizer/irpass.h"
#include "frontend/optimizer/anf_visitor.h" #include "frontend/optimizer/anf_visitor.h"
#include "frontend/operator/ops.h" #include "frontend/operator/ops.h"
#include "ir/pattern_matcher.h"
#include "pipeline/jit/parse/data_converter.h" #include "pipeline/jit/parse/data_converter.h"
#include "pipeline/jit/parse/python_adapter.h" #include "pipeline/jit/parse/python_adapter.h"
#include "pipeline/jit/parse/parse_base.h"
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
namespace irpass { namespace irpass {
const char PARSE_SUPER_NAME[] = "namespace";
// {prim::kPrimResolve, Ns, Sym} // {prim::kPrimResolve, Ns, Sym}
class ResolverResolve : public AnfVisitor { class ResolverResolve : public AnfVisitor {
public: public:
...@@ -90,6 +96,34 @@ class ResolverGetattr : public AnfVisitor { ...@@ -90,6 +96,34 @@ class ResolverGetattr : public AnfVisitor {
parse::NameSpacePtr ns_{nullptr}; parse::NameSpacePtr ns_{nullptr};
parse::SymbolPtr sym_{nullptr}; parse::SymbolPtr sym_{nullptr};
}; };
// {prim::kPrimGetAttr, {prim::kPrimResolve, ns_node, sym_node}, attr_node}
class ResolveAttr : public OptimizerCaller {
public:
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
PatternNode<AnfNodePtr> ns_node, sym_node, attr_node;
auto ResolveAttrLambda = [&node, &ns_node, &sym_node, &attr_node, &optimizer]() -> AnfNodePtr {
auto node_to_getattr = node->cast<CNodePtr>()->input(1);
std::string attr_as_string = GetValueNode<StringImmPtr>(attr_node.GetNode(node))->value();
auto ns_ = GetValueNode<parse::NameSpacePtr>(ns_node.GetNode(node));
auto sym_ = GetValueNode<parse::SymbolPtr>(sym_node.GetNode(node));
if (ns_->module() == parse::RESOLVE_NAMESPACE_NAME_CLASS_MEMBER && sym_->symbol() != PARSE_SUPER_NAME) {
// deal with the case of getting attr from a class member
// and avoid the case of getting attr from self (the result of ParseSuper)
auto result = parse::ResolveCellwithAttr(optimizer->manager(), ns_, sym_, node_to_getattr, attr_as_string);
return result;
}
return nullptr;
};
MATCH_REPLACE_LAMBDA_IF(
node, PPrimitive(prim::kPrimGetAttr, PPrimitive(prim::kPrimResolve, ns_node, sym_node), attr_node),
ResolveAttrLambda, attr_node.CheckFunc(IsValueNode<StringImm>, node));
return nullptr;
}
};
} // namespace irpass } // namespace irpass
} // namespace opt } // namespace opt
} // namespace mindspore } // namespace mindspore
......
...@@ -228,19 +228,10 @@ bool TransformVectorFuncValueNode(const FuncGraphManagerPtr &manager, const Func ...@@ -228,19 +228,10 @@ bool TransformVectorFuncValueNode(const FuncGraphManagerPtr &manager, const Func
return true; return true;
} }
} // namespace
AnfNodePtr ResolveSymbol(const FuncGraphManagerPtr &manager, const NameSpacePtr &name_space, const SymbolPtr &symbol, // resolve the python obj, and if the resovled node is valuenode with graphs, add the graphs to manager
AnfNodePtr ResolveObjectAndAddToManager(const FuncGraphManagerPtr &manager, const py::object &obj,
const AnfNodePtr &node) { const AnfNodePtr &node) {
if (node->func_graph() == nullptr || manager == nullptr) {
MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " graph or manager is nullptr";
}
SymbolResolver symbol_resolver(name_space, symbol, node);
if (!symbol_resolver.Resolve()) {
MS_LOG(EXCEPTION) << "Parse Resolve node failed NodeInfo: " << trace::GetDebugInfo(node->debug_info());
}
py::object obj = symbol_resolver.result();
ScopeGuard scope_guard(node->scope()); ScopeGuard scope_guard(node->scope());
AnfNodePtr resolved_node = nullptr; AnfNodePtr resolved_node = nullptr;
TraceManager::DebugTrace(std::make_shared<TraceResolve>(node->debug_info())); TraceManager::DebugTrace(std::make_shared<TraceResolve>(node->debug_info()));
...@@ -262,10 +253,54 @@ AnfNodePtr ResolveSymbol(const FuncGraphManagerPtr &manager, const NameSpacePtr ...@@ -262,10 +253,54 @@ AnfNodePtr ResolveSymbol(const FuncGraphManagerPtr &manager, const NameSpacePtr
TraceManager::EndTrace(); TraceManager::EndTrace();
return resolved_node; return resolved_node;
} }
} // namespace
AnfNodePtr ResolveSymbol(const FuncGraphManagerPtr &manager, const NameSpacePtr &name_space, const SymbolPtr &symbol,
const AnfNodePtr &node) {
if (node->func_graph() == nullptr || manager == nullptr) {
MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " graph or manager is nullptr";
}
SymbolResolver symbol_resolver(name_space, symbol, node);
if (!symbol_resolver.Resolve()) {
MS_LOG(EXCEPTION) << "Parse Resolve node failed NodeInfo: " << trace::GetDebugInfo(node->debug_info());
}
py::object obj = symbol_resolver.result();
AnfNodePtr resolved_node = ResolveObjectAndAddToManager(manager, obj, node);
return resolved_node;
}
AnfNodePtr ResolveCellwithAttr(const FuncGraphManagerPtr &manager, const NameSpacePtr &name_space,
const SymbolPtr &symbol, const AnfNodePtr &node, const std::string &attr) {
if (node->func_graph() == nullptr || manager == nullptr) {
MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " graph or manager is nullptr";
}
SymbolResolver symbol_resolver(name_space, symbol, node);
if (!symbol_resolver.Resolve()) {
MS_LOG(EXCEPTION) << "Parse Resolve node failed NodeInfo: " << trace::GetDebugInfo(node->debug_info());
}
py::object obj = symbol_resolver.result();
if (!data_converter::IsCellInstance(obj)) {
return nullptr;
}
py::object obj_attr = obj.attr(attr.c_str());
AnfNodePtr resolved_node = ResolveObjectAndAddToManager(manager, obj_attr, node);
return resolved_node;
}
namespace { namespace {
opt::OptPassGroupMap GetOptResolvePasses(const opt::irpass::ResolveIRPassLib &irpass) { opt::OptPassGroupMap GetOptResolvePasses(const opt::irpass::ResolveIRPassLib &irpass) {
opt::OptPassGroupMap map({ opt::OptPassGroupMap map({
{"resolve_attr",
{
// for resolve primitive;
irpass.resolver_resolve_attr_,
}},
{"resolve", {"resolve",
{ {
// for resolve and getattr primitive; // for resolve and getattr primitive;
......
...@@ -145,6 +145,10 @@ using SymbolResolverPtr = std::shared_ptr<SymbolResolver>; ...@@ -145,6 +145,10 @@ using SymbolResolverPtr = std::shared_ptr<SymbolResolver>;
AnfNodePtr ResolveSymbol(const FuncGraphManagerPtr &manager, const NameSpacePtr &name_space, const SymbolPtr &symbol, AnfNodePtr ResolveSymbol(const FuncGraphManagerPtr &manager, const NameSpacePtr &name_space, const SymbolPtr &symbol,
const AnfNodePtr &node); const AnfNodePtr &node);
// Resolve Cell with attr name.
AnfNodePtr ResolveCellwithAttr(const FuncGraphManagerPtr &manager, const NameSpacePtr &name_space,
const SymbolPtr &symbol, const AnfNodePtr &node, const std::string &attr);
// Resolve one graph which normally is the root graph. FuncGraph shall be managed by res->manager(). // Resolve one graph which normally is the root graph. FuncGraph shall be managed by res->manager().
bool ResolveFuncGraph(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePtr &res, bool use_profile = true); bool ResolveFuncGraph(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePtr &res, bool use_profile = true);
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test cases for new api of normal distribution"""
import numpy as np
from scipy import stats
import mindspore.nn as nn
from mindspore import dtype
from mindspore import Tensor
import mindspore.context as context
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
class Net(nn.Cell):
"""
Test class: new api of normal distribution.
"""
def __init__(self):
super(Net, self).__init__()
self.normal = nn.Normal(0., 1., dtype=dtype.float32)
def construct(self, x_, y_):
kl = self.normal.kl_loss('kl_loss', 'Normal', x_, y_)
prob = self.normal.prob('prob', kl)
return prob
def test_new_api():
"""
Test new api of normal distribution.
"""
prob = Net()
mean_a = np.array([0.0]).astype(np.float32)
sd_a = np.array([1.0]).astype(np.float32)
mean_b = np.array([1.0]).astype(np.float32)
sd_b = np.array([1.0]).astype(np.float32)
ans = prob(Tensor(mean_b), Tensor(sd_b))
diff_log_scale = np.log(sd_a) - np.log(sd_b)
squared_diff = np.square(mean_a / sd_b - mean_b / sd_b)
expect_kl_loss = 0.5 * squared_diff + 0.5 * \
np.expm1(2 * diff_log_scale) - diff_log_scale
norm_benchmark = stats.norm(np.array([0.0]), np.array([1.0]))
expect_prob = norm_benchmark.pdf(expect_kl_loss).astype(np.float32)
tol = 1e-6
assert (np.abs(ans.asnumpy() - expect_prob) < tol).all()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册