提交 2a6d346d 编写于 作者: H huangdongrun

support if by if grad parameter

add join for ref

adjust env eliminate to eliminate all env ops

add partial app cache

resolve while endless

fix env eliminate

support for "for while" cases

fix join shape error
上级 eb474964
......@@ -330,7 +330,8 @@ std::string AnfExporter::GetSymbolicKeyInstanceText(const FuncGraphPtr &func_gra
}
oss << "SymInst(%para" << idx << ")";
} else {
MS_LOG(EXCEPTION) << "SymbolicKeyInstance does not embed a parameter: " << sym_node->ToString();
MS_LOG(WARNING) << "SymbolicKeyInstance does not embed a parameter: " << sym_node->ToString();
oss << "SymInst(cnode_" << sym_node->ToString() << ")";
}
return oss.str();
......
......@@ -189,6 +189,11 @@ AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) {
if (!morph->isa<CNode>()) {
return nullptr;
}
// for free variable, which may be handled in MapValueObject, just return it
auto node_adjoint_found = anfnode_to_adjoin_.find(morph);
if (node_adjoint_found != anfnode_to_adjoin_.end()) {
return node_adjoint_found->second;
}
ScopeGuard scope_guard(morph->scope());
auto cnode_morph = morph->cast<CNodePtr>();
......@@ -502,7 +507,7 @@ void DFunctor::MapFvObject() {
if (parent_adjoint != nullptr) {
adjoint = std::make_shared<Adjoint>(node, parent_adjoint->k(), tape_);
} else {
if (is_top_ || node->isa<Parameter>() || !IsInScope(node)) {
if (is_top_ || node->isa<Parameter>()) {
// Out of ad scope, add adjoint for free variables.
adjoint = std::make_shared<Adjoint>(node, node, tape_);
UpdateAdjoint(adjoint);
......
......@@ -87,10 +87,12 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
env_get_item_eliminate_ =
MakeSubstitution(std::make_shared<EnvGetItemEliminater>(), "env_get_item_eliminate", prim::kPrimEnvGetItem);
new_env_get_item_ = MakeSubstitution(std::make_shared<NewEnvGetItem>(), "new_env_get_item", prim::kPrimEnvGetItem);
incorporate_env_getitem_ =
MakeSubstitution(std::make_shared<IncorporateEnvGetitem>(), "incorporate_env_get_item", prim::kPrimEnvGetItem);
incorporate_env_getitem_bypass_recursive_ =
MakeSubstitution(std::make_shared<IncorporateEnvGetitem>(true), "incorporate_env_get_item", prim::kPrimEnvGetItem);
incorporate_env_getitem_switch_ = MakeSubstitution(std::make_shared<IncorporateEnvGetitemSwitch>(),
"incorporate_env_getitem_switch", prim::kPrimEnvGetItem);
incorporate_env_getitem_ =
MakeSubstitution(std::make_shared<IncorporateEnvGetitem>(), "incorporate_env_get_item", prim::kPrimEnvGetItem);
// Ref eliminate
make_ref_eliminate_ =
......@@ -122,6 +124,7 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
// inline
inline_ = MakeSubstitution(std::make_shared<Inliner>(), "inline", IsCNodeGraph);
inline_without_move_ = MakeSubstitution(std::make_shared<DirectInliner>(false), "inline", IsCNodeGraph);
replace_applicator_ =
MakeSubstitution(std::make_shared<ReplaceApplicator>(), "replace_applicator", IsValueNode<FuncGraph>);
specialize_transform_ =
......
......@@ -55,6 +55,7 @@ class OptimizeIRPassLib {
SubstitutionPtr env_get_item_eliminate_;
SubstitutionPtr new_env_get_item_;
SubstitutionPtr incorporate_env_getitem_;
SubstitutionPtr incorporate_env_getitem_bypass_recursive_;
SubstitutionPtr incorporate_env_getitem_switch_;
// Ref eliminate
......@@ -80,6 +81,7 @@ class OptimizeIRPassLib {
// inline
SubstitutionPtr inline_;
SubstitutionPtr inline_without_move_;
SubstitutionPtr replace_applicator_;
SubstitutionPtr specialize_transform_;
......@@ -193,6 +195,16 @@ inline bool IsCNodeDup(const AnfNodePtr &node) {
auto inp0 = node->cast<CNodePtr>()->input(0);
return (inp0 != nullptr) && inp0->isa<CNode>();
}
// check if the cnode is a switch cnode
inline bool IsCNodeSwitch(const AnfNodePtr &node) {
if (node != nullptr) {
if (node->isa<CNode>()) {
return IsPrimitiveCNode(node, prim::kPrimSwitch);
}
}
return false;
}
} // namespace irpass
} // namespace opt
} // namespace mindspore
......
......@@ -29,6 +29,7 @@
#include "frontend/optimizer/anf_visitor.h"
#include "frontend/operator/ops.h"
#include "frontend/optimizer/irpass.h"
#include "frontend/optimizer/irpass/inline.h"
#include "frontend/optimizer/optimizer.h"
#include "utils/symbolic.h"
......@@ -59,8 +60,13 @@ class EnvGetitemTransform {
while (IsPrimitiveCNode(env, prim::kPrimEnvSetItem)) {
// {prim::kPrimEnvSetItem, env, symbolickey, value}
auto &inputs = env->cast<CNodePtr>()->inputs();
if (inputs.size() != 4 || !IsValueNode<SymbolicKeyInstance>(inputs[2])) {
MS_LOG(EXCEPTION) << "It should be SymbolicKeyInstance.";
if (inputs.size() != 4) {
MS_LOG(WARNING) << "Input size should be 4";
return nullptr;
}
if (!IsValueNode<SymbolicKeyInstance>(inputs[2])) {
MS_LOG(DEBUG) << "Input 2 is not a SymbolicKeyInstance?";
return nullptr;
}
env = inputs[1];
......@@ -91,33 +97,12 @@ class EnvGetitemTransform {
class NewEnvGetItem : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
Reset();
auto gety = [this](const AnfNodePtr &node) -> bool {
this->y_ = node;
return true;
};
AnfVisitor::Match(prim::kPrimEnvGetItem, {IsValueNode<EnvInstance>, IsVNode, gety})(node);
if (env_ != nullptr && env_->Len() == 0) {
return y_;
}
PatternNode c1, c2, y;
MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimEnvGetItem, c1, c2, y), y,
(IsValueNode<EnvInstance>(c1.GetNode(node)) && IsVNode(c2.GetNode(node)) &&
(GetValueNode<EnvInstancePtr>(c1.GetNode(node)))->Len() == 0));
return nullptr;
}
void Visit(const ValueNodePtr &vnode) override {
if (env_ == nullptr) {
env_ = GetValueNode<EnvInstancePtr>(vnode);
}
}
void Reset() {
y_ = nullptr;
env_ = nullptr;
}
private:
AnfNodePtr y_{nullptr};
EnvInstancePtr env_{nullptr};
};
// {prim::kPrimEnvGetItem, {prim::kPrimEnvAdd, X, Y}, C, Z} ->
......@@ -205,8 +190,13 @@ class EnvGetSetItem : public AnfVisitor {
while (IsPrimitiveCNode(env, prim::kPrimEnvSetItem)) {
// {prim::kPrimEnvSetItem, env, symbolickey, value}
auto &inputs = env->cast<CNodePtr>()->inputs();
if (inputs.size() != 4 || !IsValueNode<SymbolicKeyInstance>(inputs[2])) {
MS_LOG(EXCEPTION) << "Input 2 should be a SymbolicKeyInstance.";
if (inputs.size() != 4) {
MS_LOG(WARNING) << "Input size should be 4";
return nullptr;
}
if (!IsValueNode<SymbolicKeyInstance>(inputs[2])) {
MS_LOG(DEBUG) << "Input 2 is not a SymbolicKeyInstance?";
return nullptr;
}
env = inputs[1];
......@@ -257,7 +247,8 @@ class EnvGetItemEliminater : public OptimizerCaller {
// {prim::kPrimEnvGetItem, {G, Xs}, C, Y}
class IncorporateEnvGetitem : public AnfVisitor {
public:
IncorporateEnvGetitem() : env_get_item_transform_() {}
explicit IncorporateEnvGetitem(bool bypass_recursive = false)
: env_get_item_transform_(), bypass_recursive_(bypass_recursive) {}
~IncorporateEnvGetitem() override = default;
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
......@@ -285,7 +276,13 @@ class IncorporateEnvGetitem : public AnfVisitor {
auto inputs = inp1->inputs();
auto fg = GetValueNode<FuncGraphPtr>(inputs[0]);
auto new_fg = env_get_item_transform_(fg, key, default_v);
if (fg->recursive() && bypass_recursive_) {
MS_LOG(DEBUG) << "Bypass env_get_item transform for recursive fg=" << fg->ToString();
return nullptr;
}
if (new_fg == nullptr) {
return nullptr;
}
std::vector<AnfNodePtr> args;
args.push_back(NewValueNode(new_fg));
(void)args.insert(args.end(), inputs.begin() + 1, inputs.end());
......@@ -298,6 +295,7 @@ class IncorporateEnvGetitem : public AnfVisitor {
private:
bool is_match_{false};
internal::EnvGetitemTransform env_get_item_transform_;
bool bypass_recursive_;
};
// {prim::kPrimEnvGetItem, {{prim::kPrimSwitch, X, G1, G2}, Xs}, C, Y}
......@@ -342,7 +340,9 @@ class IncorporateEnvGetitemSwitch : public AnfVisitor {
auto g2 = GetValueNode<FuncGraphPtr>(sw->input(3));
auto new_g1 = env_get_item_transform_(g1, key, default_v);
auto new_g2 = env_get_item_transform_(g2, key, default_v);
if (new_g1 == nullptr || new_g2 == nullptr) {
return nullptr;
}
auto fg = node->func_graph();
auto new_sw = fg->NewCNode({NewValueNode(prim::kPrimSwitch), x, NewValueNode(new_g1), NewValueNode(new_g2)});
......
......@@ -93,10 +93,22 @@ bool IsCore(const FuncGraphPtr &fg, AnfNodePtr) { return fg->has_flag("core"); }
bool NoCriterion(FuncGraphPtr, AnfNodePtr) { return true; }
bool IsDirectParentCall(FuncGraphPtr fg, AnfNodePtr node) {
bool unique_use = IsUniqueUse(fg, nullptr);
bool is_recursive = fg->recursive();
if (fg->parent() != nullptr && is_recursive) {
if (fg->parent() == node->func_graph() && unique_use) {
return true;
}
}
return false;
}
// {G, Xs}
class InlinerBase : public AnfVisitor {
public:
explicit InlinerBase(std::vector<std::pair<CriterionFuncType, bool>> criterions) : criterions_(criterions) {}
explicit InlinerBase(std::vector<std::pair<CriterionFuncType, bool>> criterions, bool use_move = true)
: use_move_(use_move), criterions_(criterions) {}
~InlinerBase() override = default;
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
if (!node->isa<CNode>()) {
......@@ -113,6 +125,7 @@ class InlinerBase : public AnfVisitor {
if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE) || fg->stub()) {
return nullptr;
}
// Do not inline GraphKernel to Cell.
if (fg->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL) && !node->func_graph()->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
// If the GraphKernel only contains a return node, we make it inlined.
......@@ -142,8 +155,12 @@ class InlinerBase : public AnfVisitor {
std::vector<AnfNodePtr> params;
(void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(params));
if (IsUniqueUse(fg, nullptr)) {
// compare size to avoid the case that the function has default value after grad.
// for which after renormalize, the function default value will be an input
if (fg->parameters().size() != params.size()) {
return nullptr;
}
if (use_move_ && IsUniqueUse(fg, nullptr)) {
auto mng = fg->manager();
MS_EXCEPTION_IF_NULL(mng);
ReplaceParams(mng, params, fg);
......@@ -183,21 +200,36 @@ class InlinerBase : public AnfVisitor {
private:
bool is_checked_{false}, is_recursive_{false};
bool use_move_;
std::vector<std::pair<CriterionFuncType, bool>> criterions_;
};
class Inliner : public InlinerBase {
public:
Inliner()
: InlinerBase({
{IsUniqueUse, true},
{IsTrivial, false},
{IsInside, false},
{IsCore, false},
{NoCriterion, true},
}) {}
explicit Inliner(bool use_move = true)
: InlinerBase(
{
{IsUniqueUse, true},
{IsTrivial, false},
{IsInside, false},
{IsCore, false},
{IsDirectParentCall, false},
{NoCriterion, true},
},
use_move) {}
~Inliner() override = default;
};
class DirectInliner : public InlinerBase {
public:
explicit DirectInliner(bool use_move = true)
: InlinerBase(
{
{IsDirectParentCall, false},
},
use_move) {}
~DirectInliner() override = default;
};
} // namespace irpass
} // namespace opt
} // namespace mindspore
......
......@@ -26,6 +26,30 @@
namespace mindspore {
namespace opt {
namespace irpass {
namespace internal {
class GetRefValueTransform {
public:
GetRefValueTransform() {}
~GetRefValueTransform() = default;
AnfNodePtr operator()(const AnfNodePtr &node) {
CNodePtr cnode = node->cast<CNodePtr>();
auto inputs = cnode->inputs();
auto fg = GetValueNode(inputs[0])->cast<FuncGraphPtr>();
if (fg->recursive()) {
MS_LOG(DEBUG) << "Get refvalue by pass recursive:" << fg->ToString();
return node;
}
auto new_fg = TransformableClone(fg, std::make_shared<TraceTransform>("GetRefValue"));
auto output = new_fg->output();
new_fg->set_output(new_fg->NewCNode({NewValueNode(prim::kPrimGetRefValue), output}));
inputs[0] = NewValueNode(new_fg);
auto ret_node = cnode->func_graph()->NewCNode(inputs);
return ret_node;
}
};
} // namespace internal
// {prim::kPrimMakeRef, X, Y, Z} -> Y
class MakeRefEliminater : public OptimizerCaller {
public:
......@@ -48,13 +72,23 @@ class GetRefParamEliminater : public OptimizerCaller {
// {prim::kPrimGetRefKey, {prim::kPrimMakeRef, X, Y, Z}} -> X
// {prim::kPrimGetRefValue, {prim::kPrimMakeRef, X, Y, Z}} -> Y
// {prim::kPrimGetRefValue, {prim::switch, cond, t, f}} -> {prim::switch, cond, t, f}
class GetMakeRefEliminater : public OptimizerCaller {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
PatternNode<AnfNodePtr> x, y, z;
MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefKey, PPrimitive(prim::kPrimMakeRef, x, y, z)), x);
MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefValue, PPrimitive(prim::kPrimMakeRef, x, y, z)), y);
MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimGetRefValue, x), x, x.CheckFunc(IsCNodeSwitch, node));
internal::GetRefValueTransform trans;
auto GetRefLambda = [&trans, &x, &node]() -> AnfNodePtr {
auto rep = trans(x.GetNode(node));
if (rep != nullptr) {
return rep;
}
return nullptr;
};
MATCH_REPLACE_LAMBDA_IF(node, PPrimitive(prim::kPrimGetRefValue, x), GetRefLambda, x.CheckFunc(IsCNodeGraph, node));
return nullptr;
}
};
......
......@@ -314,6 +314,7 @@ bool TaskEmitAction(const ResourcePtr &res) {
FuncGraphPtr func_graph = res->func_graph();
auto bc_ptr = res->results()[kBackend].cast<compile::BackendPtr>();
auto context_ptr = MsContext::GetInstance();
std::string backend = MsContext::GetInstance()->backend_policy();
MS_EXCEPTION_IF_NULL(context_ptr);
if (CompileGraphs::ContainMixedTarget(func_graph)) {
bc_ptr->set_is_multi_graph_sink(false);
......@@ -321,13 +322,13 @@ bool TaskEmitAction(const ResourcePtr &res) {
context_ptr->set_loop_sink_flag(false);
} else if (context_ptr->execution_mode() != kPynativeMode) {
std::string device_target = context_ptr->device_target();
if (device_target == kAscendDevice) {
if (device_target == kAscendDevice && backend != kMsVm) {
bc_ptr->set_is_multi_graph_sink(true);
context_ptr->set_is_multi_graph_sink(true);
}
}
if (IsCtrlSink()) {
if (IsCtrlSink() && backend == kMsConvert) {
res->results()[kOutput] = bc_ptr->CompileGraph(NOT_NULL(func_graph));
return true;
}
......@@ -344,8 +345,8 @@ bool ExecuteAction(const ResourcePtr &res) {
if (res->results().count(kOutput) == 0) {
MS_LOG(EXCEPTION) << "Execute args error";
}
if (IsCtrlSink()) {
std::string backend = MsContext::GetInstance()->backend_policy();
if (IsCtrlSink() && backend == kMsConvert) {
if (!res->results()[kOutput].is<GraphId>()) {
MS_LOG(EXCEPTION) << "Execute args error";
}
......
......@@ -30,6 +30,7 @@
#include "pipeline/jit/parse/data_converter.h"
#include "pipeline/jit/resource.h"
#include "pipeline/jit/validator.h"
#include "pipeline/jit/remove_value_node_dup.h"
#include "frontend/optimizer/optimizer.h"
#include "frontend/optimizer/cse.h"
#include "frontend/optimizer/graph_kernel_reuse.h"
......@@ -127,11 +128,14 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
irpass.incorporate_getitem_set_,
irpass.incorporate_call_,
irpass.incorporate_call_switch_,
irpass.incorporate_env_getitem_,
irpass.incorporate_env_getitem_bypass_recursive_,
irpass.incorporate_env_getitem_switch_,
irpass.new_env_get_item_,
irpass.depend_value_elim_,
});
opt::OptPassConfig a_after_grad = opt::OptPassConfig({
irpass.inline_without_move_,
});
opt::OptPassConfig a_3 = opt::OptPassConfig({
irpass.arithmetic_simplify2_,
irpass.same_eliminate_,
......@@ -154,6 +158,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
{"virtual_dataset", virtual_dataset},
{"grad", grad},
{"resolve", resolve_pass},
{"a_after_grad", a_after_grad},
{"renormalize", opt::OptPassConfig::Renormalize()},
{"cse", opt::OptPassConfig(opt::CSE(false))},
{"a_3", a_3}});
......@@ -161,11 +166,24 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
return map_a;
}
OptPassGroupMap GetOptPassesAfterCconv(const opt::irpass::OptimizeIRPassLib &irpass) {
opt::OptPassConfig c_1 = opt::OptPassConfig({
// Safe inlining
irpass.inline_,
irpass.partial_eliminate_,
});
OptPassGroupMap map_a({{"c_1", c_1}, {"renormalize", opt::OptPassConfig::Renormalize()}});
return map_a;
}
OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) {
opt::OptPassConfig b_1 =
opt::OptPassConfig({irpass.zero_like_fill_zero_, irpass.item_tuple_eliminate_, irpass.float_tuple_getitem_switch_,
irpass.reset_defer_inline_, irpass.inline_, irpass.special_op_eliminate_,
irpass.get_make_ref_eliminate_, irpass.value_based_eliminate_});
opt::OptPassConfig b_1 = opt::OptPassConfig(
{irpass.zero_like_fill_zero_, irpass.item_tuple_eliminate_, irpass.float_tuple_getitem_switch_,
irpass.reset_defer_inline_, irpass.inline_, irpass.special_op_eliminate_, irpass.get_make_ref_eliminate_,
irpass.incorporate_env_getitem_, irpass.incorporate_env_getitem_switch_, irpass.env_get_item_eliminate_,
irpass.value_based_eliminate_});
opt::OptPassConfig b_2 = opt::OptPassConfig({
irpass.replace_refkey_by_param_,
irpass.make_ref_eliminate_,
......@@ -244,6 +262,8 @@ void InitOpt(const ResourcePtr &res) {
opt::irpass::OptimizeIRPassLib irpass;
g_pass_opts["opt_a"] = Optimizer::MakeOptimizer("opt_a", res, GetOptPassesA(irpass));
g_pass_opts["opt_b"] = Optimizer::MakeOptimizer("opt_b", res, GetOptPassesB(irpass), false, true);
g_pass_opts["opt_after_cconv"] =
Optimizer::MakeOptimizer("opt_after_cconv", res, GetOptPassesAfterCconv(irpass), false, true);
g_pass_opts["opt_graph_kernel_a"] =
Optimizer::MakeOptimizer("opt_graph_kernel_a", res, GetOptPassesGraphKernelA(irpass), true);
g_pass_opts["opt_graph_kernel_b"] =
......@@ -288,6 +308,7 @@ bool OptPassGroup(const ResourcePtr &res, const std::string &name) {
bool OptPassAGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_a"); }
bool OptPassBGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_b"); }
bool OptPassAfterCconvGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_after_cconv"); }
bool OptPassGraphKernelGroupA(const ResourcePtr &res) { return OptPassGroup(res, "opt_graph_kernel_a"); }
bool OptPassGraphKernelGroupB(const ResourcePtr &res) { return OptPassGroup(res, "opt_graph_kernel_b"); }
bool ControlGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_control"); }
......@@ -311,6 +332,33 @@ bool AddControlDependPass(const ResourcePtr &res) {
return true;
}
bool MergeDupGraphPass(const ResourcePtr &res) {
FuncGraphPtr func_graph = res->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(res->manager());
if (res->manager()->func_graphs().size() <= 1) {
return true;
}
return MergeDuplicateGraphs(res->manager());
}
bool RemoveValueNodeDuplicationsPass(const ResourcePtr &res) {
if (res->func_graph() == nullptr) {
MS_LOG(EXCEPTION) << "Remove value node duplications error.";
}
auto manager = res->manager();
HashCache hash_cache;
HashValue hashes;
// Remove duplicated value nodes across all graphs in manager
for (auto &fg : manager->func_graphs()) {
auto value_nodes = fg->value_nodes();
for (const auto &value_pair : value_nodes) {
TryToDoReplace(manager.get(), value_pair.first, &hash_cache, &hashes);
}
}
return true;
}
bool CconvPass(const ResourcePtr &res) {
MS_EXCEPTION_IF_NULL(res->func_graph());
FuncGraphPtr func_graph = res->func_graph();
......@@ -340,6 +388,8 @@ std::vector<PassItem> kVmPasses = {{"simplify_data_structures", SimplifyDataStru
{"clean_after_opta", CleanAfterOptAPass},
{"opt_b", OptPassBGroup},
{"cconv", CconvPass},
{"opt_after_cconv", OptPassAfterCconvGroup},
{"remove_dup_value", RemoveValueNodeDuplicationsPass},
{"opt_graph_kernel_a", OptPassGraphKernelGroupA},
{"opt_graph_kernel_b", OptPassGraphKernelGroupB},
{"add_control_depend", AddControlDependPass}};
......
......@@ -13,6 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <string>
#include "pipeline/jit/remove_value_node_dup.h"
#include "ir/anf.h"
......@@ -70,5 +71,108 @@ void TryToDoReplace(FuncGraphManager *const manager, const AnfNodePtr &node, Has
// Meet for the first time, append node to bucket.
bucket.emplace_back(node);
}
size_t HashOfGraph(const FuncGraphPtr &fg) {
std::vector<AnfNodePtr> toposet = TopoSort(fg->get_return());
MS_LOG(DEBUG) << "TopSort for:" << fg->ToString();
std::unordered_map<AnfNodePtr, std::size_t> hashes;
auto &params = fg->parameters();
for (size_t i = 0; i < params.size(); i++) {
hashes[params[i]] = std::hash<std::string>{}("param" + std::to_string(i));
}
for (auto node : toposet) {
MS_EXCEPTION_IF_NULL(node);
if (hashes.find(node) != hashes.end()) {
continue;
}
std::size_t h = 0;
if (node->isa<ValueNode>()) {
ValueNodePtr value_node = node->cast<ValueNodePtr>();
auto value = value_node->value();
MS_EXCEPTION_IF_NULL(value);
if (IsValueNode<FuncGraph>(value_node)) {
auto v_fg = value->cast<FuncGraphPtr>();
h = value->hash();
} else if (IsValueNode<tensor::Tensor>(value_node)) {
// the tensor has same value has been replaced in duplicate value pass,
// so we use the value pointer here as an identifier
h = hash_combine(value->hash(), std::hash<Value *>{}(value.get()));
} else {
h = hash_combine(value->hash(), (opt::AbsOf(value_node)->hash()));
}
} else if (node->isa<CNode>()) {
auto cnode = node->cast<CNodePtr>();
auto &inputs = cnode->inputs();
size_t init = 0;
h = std::accumulate(inputs.begin(), inputs.end(), init, [&hashes](std::size_t hash, const AnfNodePtr &node_in) {
return hash_combine(hash, hashes[node_in]);
});
} else if (node->isa<Parameter>()) {
h = node->hash();
} else {
MS_LOG(ERROR) << "Unknow node type";
}
hashes[node] = h;
}
return hashes[fg->get_return()];
}
bool IsCNodeGraph(const AnfNodePtr &node) {
if (node == nullptr || !node->isa<CNode>()) {
return false;
}
auto inp0 = node->cast<CNodePtr>()->input(0);
return IsValueNode<FuncGraph>(inp0);
}
bool MergeDuplicateGraphs(const FuncGraphManagerPtr manager) {
std::unordered_map<size_t, std::vector<FuncGraphPtr>> hash_graphs;
std::unordered_map<FuncGraphPtr, size_t> graph_hash;
for (auto fg : manager->func_graphs()) {
size_t h = HashOfGraph(fg);
graph_hash[fg] = h;
if (hash_graphs.find(h) == hash_graphs.end()) {
hash_graphs[h] = {fg};
} else {
hash_graphs[h].push_back(fg);
}
}
FuncGraphPairMapEquiv equiv_graph;
NodeMapEquiv equiv_node;
for (auto &fg : manager->func_graphs()) {
MS_LOG(DEBUG) << "Try Merge Graph:" << fg->ToString();
for (auto &item : fg->nodes()) {
if (!item->isa<CNode>()) {
continue;
}
auto &inputs = item->cast<CNodePtr>()->inputs();
for (size_t i = 0; i < inputs.size(); i++) {
if (!inputs[i]->isa<ValueNode>()) {
continue;
}
auto value_ptr = GetValueNode(inputs[i]);
auto v_fg = value_ptr->cast<FuncGraphPtr>();
if (v_fg == nullptr) {
continue;
}
auto &fg_vec = hash_graphs[graph_hash[v_fg]];
if (fg_vec.size() > 1) {
if (v_fg != fg_vec[0]) {
bool is_morphic = Isomorphic(v_fg, fg_vec[0], &equiv_graph, &equiv_node);
if (is_morphic) {
auto new_node = NewValueNode(fg_vec[0]);
MS_LOG(DEBUG) << "Replace graph node :" << inputs[i]->ToString() << " with:" << new_node->ToString();
manager->Replace(inputs[i], new_node);
}
}
}
}
}
}
return true;
}
} // namespace pipeline
} // namespace mindspore
......@@ -28,6 +28,10 @@ using HashCache = std::unordered_map<std::size_t, std::vector<AnfNodePtr>>;
using HashValue = std::unordered_map<AnfNodePtr, std::size_t>;
void TryToDoReplace(FuncGraphManager *manager, const AnfNodePtr &node, HashCache *hash_cache, HashValue *hash_value);
size_t HashOfGraph(const FuncGraphPtr &fg);
bool IsCNodeGraph(const AnfNodePtr &node);
bool MergeDuplicateGraphs(const FuncGraphManagerPtr manager);
} // namespace pipeline
} // namespace mindspore
......
......@@ -113,17 +113,18 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr
}
const AnfNodePtr &func_node = fg->get_return();
MS_LOG(DEBUG) << "Analysis FuncGraph begin, func graph: " << fg->ToString()
MS_LOG(DEBUG) << "Analysis FuncGraph begin, func graph: " << fg.get() << fg->ToString()
<< ", context: " << graph_context_->ToString() << ", return node: " << func_node->DebugString();
AbstractBasePtr ret_base = nullptr;
std::vector<AnfNodePtr> nodes = FastShadowSort(func_node);
for (auto it = nodes.crbegin(); it != nodes.crend(); it++) {
const auto &node = *it;
AnfNodeConfigPtr node_conf = engine->MakeConfig(node, graph_context_);
MS_LOG(DEBUG) << "Analysis node begin, func graph: " << fg->ToString() << ", node_conf: " << node_conf->ToString();
MS_LOG(DEBUG) << "Analysis node begin, func graph: " << fg.get() << fg->ToString()
<< ", node_conf: " << node_conf->ToString();
ret_base = engine->GetEvaluatedValue(node_conf)->abstract();
MS_LOG(DEBUG) << "Analysis node end, func graph: " << fg->ToString() << ", node_conf: " << node_conf->ToString()
<< ", abstract: " << ret_base->ToString();
MS_LOG(DEBUG) << "Analysis node end, func graph: " << fg.get() << fg->ToString()
<< ", node_conf: " << node_conf->ToString() << ", abstract: " << ret_base->ToString();
}
MS_EXCEPTION_IF_NULL(ret_base);
......@@ -142,16 +143,17 @@ AbstractBasePtrList FuncGraphEvaluator::NormalizeArgs(const AbstractBasePtrList
(void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(broaded_list),
[](const AbstractBasePtr &arg) -> AbstractBasePtr {
MS_EXCEPTION_IF_NULL(arg);
return arg->Broaden();
if (arg->GetValueTrack() != kAnyValue) {
return arg->Broaden();
}
return arg;
});
if (func_graph_->joined_shapes_.size() != broaded_list.size()) {
MS_EXCEPTION(ValueError) << "Number of input arguments " << broaded_list.size()
<< " does not equal to number of original buffer arguments "
<< func_graph_->joined_shapes_.size();
}
for (size_t i = 0; i < broaded_list.size(); ++i) {
broaded_list[i]->set_shape(func_graph_->joined_shapes_[i]);
if (func_graph_->joined_shapes_.size() == broaded_list.size()) {
for (size_t i = 0; i < broaded_list.size(); ++i) {
broaded_list[i]->set_shape(func_graph_->joined_shapes_[i]);
}
}
MS_LOG(DEBUG) << func_graph_->ToString() << " original: " << mindspore::ToString(args_spec_list)
<< ", broaded: " << mindspore::ToString(broaded_list);
return broaded_list;
......@@ -181,8 +183,13 @@ AbstractBasePtrList FuncGraphEvaluator::BroadenUndeterminedArgs(const AbstractBa
func_graph_->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true);
func_graph_->joined_shapes_.clear();
std::transform(joined_args_spec_list.begin(), joined_args_spec_list.end(),
std::back_inserter(func_graph_->joined_shapes_),
[](const AbstractBasePtr &arg_spec) { return arg_spec->GetShapeTrack(); });
std::back_inserter(func_graph_->joined_shapes_), [](const AbstractBasePtr &arg_spec) {
if (arg_spec->isa<AbstractRef>()) {
return arg_spec->cast<AbstractRefPtr>()->ref()->GetShapeTrack();
}
return arg_spec->GetShapeTrack();
});
joined_args_spec_list = NormalizeArgs(joined_args_spec_list);
MS_LOG(DEBUG) << "Set " << func_graph_->ToString() << " with IGNORE_VALUES flag.";
}
return joined_args_spec_list;
......@@ -199,8 +206,13 @@ AbstractBasePtrList FuncGraphEvaluator::BroadenUndeterminedArgs(const AbstractBa
func_graph_->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true);
func_graph_->joined_shapes_.clear();
std::transform(joined_args_spec_list.begin(), joined_args_spec_list.end(),
std::back_inserter(func_graph_->joined_shapes_),
[](const AbstractBasePtr &arg_spec) { return arg_spec->GetShapeTrack(); });
std::back_inserter(func_graph_->joined_shapes_), [](const AbstractBasePtr &arg_spec) {
if (arg_spec->isa<AbstractRef>()) {
return arg_spec->cast<AbstractRefPtr>()->ref()->GetShapeTrack();
}
return arg_spec->GetShapeTrack();
});
joined_args_spec_list = NormalizeArgs(joined_args_spec_list);
MS_LOG(DEBUG) << "Set " << func_graph_->ToString() << " with IGNORE_VALUES flag.";
}
MS_LOG(DEBUG) << "Joined eval args: " << ::mindspore::ToString(joined_args_spec_list);
......
......@@ -188,6 +188,7 @@ EvalResultPtr AnalysisEngine::Eval(const AnfNodeConfigPtr &conf) {
trace::TraceEvalCNodeLeave();
} else {
MS_LOG(EXCEPTION) << "Illegal AnfNode for evaluating, " << node->DebugString()
<< (node->func_graph() != nullptr ? node->func_graph()->ToString() : "nullgraph")
<< ". NodeInfo: " << trace::GetDebugInfo(node->debug_info());
}
......@@ -301,6 +302,8 @@ void AnalysisEngine::Clear() {
anfnode_config_map_.clear();
eval_trace_.clear();
constructors_.clear();
constructors_app_.clear();
continued_evals_.clear();
}
namespace {
......@@ -426,8 +429,14 @@ EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<PartialAbstr
MS_EXCEPTION_IF_NULL(func);
AbstractFunctionPtr func_orig = func->fn();
EvaluatorPtr evaluator_orig = GetEvaluatorFor(func_orig);
auto part_pair = std::make_pair(func_orig, func->args());
auto itr = constructors_app_.find(part_pair);
if (itr != constructors_app_.end()) {
return itr->second;
}
std::shared_ptr<PartialAppEvaluator> partial_evaluator =
std::make_shared<PartialAppEvaluator>(evaluator_orig, func->args());
constructors_app_[part_pair] = partial_evaluator;
return partial_evaluator;
}
......@@ -504,9 +513,10 @@ void AnalysisEngine::SetUndeterminedFlag(const EvaluatorPtr &evaluator) {
if (fg_eval == nullptr) {
return;
}
auto fg = fg_eval->func_graph();
MS_EXCEPTION_IF_NULL(fg);
auto undetermined_fgs = fg->recursive_graphs();
auto undetermined_fgs = fg->recursive();
if (undetermined_fgs) {
auto fg_parent = fg->parent();
MS_EXCEPTION_IF_NULL(fg_parent);
......@@ -546,15 +556,19 @@ EvaluatorPtr AnalysisEngine::HandleNestedRecursion(const std::vector<EvaluatorPt
MS_LOG(DEBUG) << "undetermined_evals size(): " << undetermined_evals.size();
for (auto u_eval : undetermined_evals) {
MS_LOG(DEBUG) << u_eval.first->ToString() << " check undetermined.";
if (!undetermined_evals.count(std::make_pair(multi_poss_[u_eval.first], args_spec_list))) {
MS_LOG(DEBUG) << u_eval.first->ToString() << " has undetermined.";
MS_LOG(DEBUG) << u_eval.first->ToString() << "check undetermined.";
auto &alternate_evaluator = multi_poss_[u_eval.first];
auto &eval_cache = alternate_evaluator->cache();
if ((!undetermined_evals.count(std::make_pair(alternate_evaluator, args_spec_list))) &&
(((!continued_evals_.count(u_eval)) && (eval_cache->find(args_spec_list) != eval_cache->end())) ||
(eval_cache->find(args_spec_list) == eval_cache->end()))) {
MS_LOG(DEBUG) << u_eval.first->ToString() << "has undetermined.";
has_undetermined = true;
break;
}
}
if (has_undetermined == false) {
MS_LOG(DEBUG) << eval->ToString() << " has no undetermined.";
MS_LOG(DEBUG) << eval->ToString() << "has no undetermined.";
*continue_flag = true;
return latest_entry;
}
......@@ -597,34 +611,33 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Evalua
auto current_inf = std::make_pair(eval, args_spec_list);
MS_LOG(DEBUG) << "Check Evaluator " << eval->ToString();
// If current evaluator is under tracing, then skip current evaluator to avoid recursively evaluating.
auto it = std::find(eval_trace_.rbegin(), eval_trace_.rend(), current_inf);
if (it == eval_trace_.rend()) {
eval_trace_.push_back(current_inf);
MS_LOG(DEBUG) << "Trace Evaluator " << eval->ToString() << " ptr: " << eval.get();
MS_EXCEPTION_IF_NULL(eval);
auto eval_result = eval->Run(shared_from_this(), args_conf_list, out_conf);
MS_EXCEPTION_IF_NULL(eval_result->abstract());
MS_LOG(DEBUG) << "Evaluator " << eval->ToString() << " return out_spec: " << eval_result->abstract()->ToString();
out_specs.push_back(eval_result->abstract());
eval_trace_.pop_back();
if (eval_trace_.empty()) {
multi_poss_.clear();
}
} else if (it != eval_trace_.rbegin()) {
} else {
bool continue_flag = false;
auto latest_entry = HandleNestedRecursion(evaluators, eval, args_spec_list, it, &continue_flag);
if (continue_flag) {
MS_LOG(DEBUG) << "continued_evals_ add " << current_inf.first.get() << current_inf.first->ToString();
continued_evals_.insert(current_inf);
continue;
}
// Try to travel the latest undetermined.
if (latest_entry != eval_trace_.rbegin()->first) {
MS_LOG(DEBUG) << "Direct Run Evaluator " << eval->ToString();
MS_LOG(DEBUG) << "Direct Run Evaluator " << eval.get() << "----" << eval->ToString();
auto eval_result = latest_entry->Run(shared_from_this(), args_conf_list, out_conf);
MS_EXCEPTION_IF_NULL(eval_result->abstract());
MS_LOG(DEBUG) << "Evaluator " << latest_entry->ToString()
MS_LOG(DEBUG) << "end Direct Evaluator " << latest_entry->ToString()
<< " return out_spec: " << eval_result->abstract()->ToString();
return eval_result;
}
......
......@@ -26,6 +26,7 @@
#include <vector>
#include <utility>
#include <map>
#include <set>
#ifdef DEBUG
#include <stack>
......@@ -113,7 +114,8 @@ class AnfNodeConfig : public Config {
std::string ToString() const override {
std::ostringstream buffer;
buffer << "Node: " << node_->DebugString() << ", Context: " << context_->ToString();
buffer << "Node: " << node_->DebugString() << "-uid(" << node_->UniqueId()
<< "), Context: " << context_->ToString();
return buffer.str();
}
......@@ -173,7 +175,13 @@ struct AnalysisResult {
};
using EvalTraceRevIter = std::list<std::pair<EvaluatorPtr, AbstractBasePtrList>>::reverse_iterator;
struct PartialAppHasher {
std::size_t operator()(const std::pair<AbstractFunctionPtr, AbstractBasePtrList> &p) const {
auto h1 = std::hash<AbstractFunctionPtr>{}(p.first);
auto h2 = AbstractBasePtrListHash(p.second);
return h1 ^ h2;
}
};
class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
public:
AnalysisEngine(const PrimEvaluatorMap &prim_evaluator_map, const FuncGraphManagerPtr &func_graph_manager)
......@@ -233,10 +241,13 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
const PrimEvaluatorMap &prim_constructors_;
FuncGraphManagerPtr func_graph_manager_;
std::unordered_map<AbstractFunctionPtr, EvaluatorPtr, AbstractFunctionHasher, AbstractFunctionEqual> constructors_;
std::unordered_map<std::pair<AbstractFunctionPtr, AbstractBasePtrList>, EvaluatorPtr, PartialAppHasher>
constructors_app_;
AnfNodeConfigMap anfnode_config_map_;
// Use a list to trace multiple evaluators.
std::list<std::pair<EvaluatorPtr, AbstractBasePtrList>> eval_trace_;
std::map<EvaluatorPtr, EvaluatorPtr> multi_poss_;
std::set<std::pair<EvaluatorPtr, AbstractBasePtrList>> continued_evals_;
AnalysisContextPtr Run(const FuncGraphPtr &func_graph, const AnalysisContextPtr &context,
const ConfigPtrList &args_conf_list);
......
......@@ -34,6 +34,7 @@ using mindspore::abstract::AbstractError;
using mindspore::abstract::AbstractFunction;
using mindspore::abstract::AbstractJTagged;
using mindspore::abstract::AbstractList;
using mindspore::abstract::AbstractRef;
using mindspore::abstract::AbstractRowTensor;
using mindspore::abstract::AbstractScalar;
using mindspore::abstract::AbstractSparseTensor;
......@@ -83,7 +84,8 @@ void ValidateAbstract(const AnfNodePtr &node) {
// only send string in external
if (!IsValueNode<StringImm>(node)) {
// Validate a type.
MS_LOG(EXCEPTION) << "Illegal type in the graph: " << ptrBase->ToString();
MS_LOG(EXCEPTION) << "Illegal type in the graph: " << ptrBase->ToString()
<< " for node=" << node->DebugString();
}
}
return;
......@@ -96,7 +98,7 @@ void ValidateAbstract(const AnfNodePtr &node) {
if (ptrBase->isa<AbstractType>() || ptrBase->isa<AbstractFunction>() || ptrBase->isa<AbstractTuple>() ||
ptrBase->isa<AbstractList>() || ptrBase->isa<AbstractTensor>() || ptrBase->isa<AbstractRowTensor>() ||
ptrBase->isa<AbstractSparseTensor>() || ptrBase->isa<abstract::AbstractRefKey>()) {
ptrBase->isa<AbstractSparseTensor>() || ptrBase->isa<abstract::AbstractRefKey>() || ptrBase->isa<AbstractRef>()) {
return;
}
......
......@@ -481,8 +481,10 @@ bool IsGraphOutputValueNodeOrParameter(const AnfNodePtr &output, const py::tuple
}
// Isomorphism
static bool SameNodeShallow(const AnfNodePtr &node1, const AnfNodePtr &node2, FuncGraphPairMapEquiv *equiv_func_graph,
NodeMapEquiv *const equiv_node) {
static bool SameNode(const AnfNodePtr &node1, const AnfNodePtr &node2, FuncGraphPairMapEquiv *equiv_func_graph,
NodeMapEquiv *const equiv_node);
bool SameNodeShallow(const AnfNodePtr &node1, const AnfNodePtr &node2, FuncGraphPairMapEquiv *equiv_func_graph,
NodeMapEquiv *const equiv_node) {
if (equiv_node == nullptr) {
MS_LOG(ERROR) << "Invalid equiv_node";
return false;
......@@ -514,6 +516,9 @@ static bool SameNodeShallow(const AnfNodePtr &node1, const AnfNodePtr &node2, Fu
MS_LOG(DEBUG) << "two parameters are not equal.";
return false;
}
if (node1->isa<CNode>() && node2->isa<CNode>()) {
return SameNode(node1, node2, equiv_func_graph, equiv_node);
}
MS_LOG(ERROR) << "type error";
return false;
}
......
......@@ -116,12 +116,15 @@ AnfNodePtr RefSubGraphNode(const FuncGraphPtr &fg, const AnfNodePtr &node, AnfNo
} // namespace
std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> TransformSegmentToAnfGraph(const AnfNodePtrList &lst) {
auto fg = std::make_shared<FuncGraph>();
AnfNodePtrList inputs;
AnfNodePtrToAnfNodePtrMap eqv;
if (lst.empty()) {
MS_LOG(EXCEPTION) << "Input anf node list is empty";
}
TraceManager::DebugTrace(
std::make_shared<TraceSegmentTransform>(lst[0]->cast<CNodePtr>()->func_graph()->debug_info()));
auto fg = std::make_shared<FuncGraph>();
TraceManager::EndTrace();
AnfNodePtrList inputs;
AnfNodePtrToAnfNodePtrMap eqv;
// Merge CNodes into a AnfGraph that represents a linear instruction segment
for (auto n : lst) {
if (!n->isa<CNode>()) {
......@@ -154,7 +157,9 @@ std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> TransformSegmentToAnfGr
(void)std::transform(std::begin(inps) + 1, std::end(inps), std::back_inserter(args),
[&fg, &inputs, &eqv](const AnfNodePtr &a) { return RefSubGraphNode(fg, a, &inputs, &eqv); });
}
TraceManager::DebugTrace(std::make_shared<TraceGetEnv>(n->debug_info()));
eqv[n] = fg->NewCNode(args);
TraceManager::EndTrace();
eqv[n]->set_abstract(n->abstract());
eqv[n]->set_kernel_info(n->kernel_info_ptr());
}
......
......@@ -452,6 +452,10 @@ AbstractBasePtr AbstractTensor::Join(const AbstractBasePtr &other) {
}
auto other_tensor = dyn_cast<AbstractTensor>(other);
if (other_tensor == nullptr) {
auto ref_tensor = dyn_cast<AbstractRef>(other);
if (ref_tensor != nullptr) {
return this->Join(ref_tensor->ref());
}
MS_LOG(EXCEPTION) << "Join failed as type mismatch, this: " << ToString() << ", other: " << other->ToString();
}
if (*this == *other) {
......
......@@ -48,7 +48,7 @@ std::vector<AnfNodePtr> TopoSort(const AnfNodePtr &root, const SuccFunc &succ, c
continue;
}
if (rank.find(node) != rank.end() && rank[node] != todo.size()) {
MS_LOG(EXCEPTION) << "Graph exists cycle, node " << node->DebugString();
MS_LOG(EXCEPTION) << "Graph exists cycle, node " << node->DebugString(2);
}
rank[node] = todo.size();
bool cont = false;
......
......@@ -30,6 +30,7 @@
#include "base/base.h"
#include "ir/dtype.h"
#include "ir/dtype/number.h"
#include "utils/hashing.h"
using std::fabs;
......@@ -51,7 +52,7 @@ using ScalarPtr = std::shared_ptr<Scalar>;
class BoolImm : public Scalar {
public:
explicit BoolImm(bool b) : Scalar(kBool), v_(b) { hash_ = std::hash<bool>{}(v_); }
explicit BoolImm(bool b) : Scalar(kBool), v_(b) { hash_ = hash_combine({tid(), std::hash<bool>{}(v_)}); }
~BoolImm() override = default;
MS_DECLARE_PARENT(BoolImm, Scalar)
std::size_t hash() const override { return hash_; }
......@@ -91,7 +92,7 @@ class IntergerImm : public Scalar {
class Int8Imm : public IntergerImm {
public:
Int8Imm() : IntergerImm(kInt8), v_(0) {}
explicit Int8Imm(int8_t v) : IntergerImm(kInt8), v_(v) { hash_ = std::hash<int>{}(v_); }
explicit Int8Imm(int8_t v) : IntergerImm(kInt8), v_(v) { hash_ = hash_combine({tid(), std::hash<int>{}(v_)}); }
~Int8Imm() override = default;
MS_DECLARE_PARENT(Int8Imm, IntergerImm)
std::size_t hash() const override { return hash_; }
......@@ -117,7 +118,7 @@ IMM_TRAITS(Int8ImmPtr, int8_t)
class Int16Imm : public IntergerImm {
public:
Int16Imm() : IntergerImm(kInt16), v_(0) {}
explicit Int16Imm(int16_t v) : IntergerImm(kInt16), v_(v) { hash_ = std::hash<int>{}(v_); }
explicit Int16Imm(int16_t v) : IntergerImm(kInt16), v_(v) { hash_ = hash_combine({tid(), std::hash<int>{}(v_)}); }
~Int16Imm() override = default;
MS_DECLARE_PARENT(Int16Imm, IntergerImm)
std::size_t hash() const override { return hash_; }
......@@ -143,7 +144,7 @@ IMM_TRAITS(Int16ImmPtr, int16_t)
class Int32Imm : public IntergerImm {
public:
Int32Imm() : IntergerImm(kInt32), v_(0) {}
explicit Int32Imm(int v) : IntergerImm(kInt32), v_(v) { hash_ = std::hash<int>{}(v_); }
explicit Int32Imm(int v) : IntergerImm(kInt32), v_(v) { hash_ = hash_combine({tid(), std::hash<int>{}(v_)}); }
~Int32Imm() override = default;
MS_DECLARE_PARENT(Int32Imm, IntergerImm)
std::size_t hash() const override { return hash_; }
......@@ -169,7 +170,7 @@ IMM_TRAITS(Int32ImmPtr, int32_t)
class Int64Imm : public IntergerImm {
public:
Int64Imm() : IntergerImm(kInt64), v_(0) {}
explicit Int64Imm(int64_t v) : IntergerImm(kInt64), v_(v) { hash_ = std::hash<int64_t>{}(v_); }
explicit Int64Imm(int64_t v) : IntergerImm(kInt64), v_(v) { hash_ = hash_combine({tid(), std::hash<int64_t>{}(v_)}); }
~Int64Imm() override = default;
MS_DECLARE_PARENT(Int64Imm, IntergerImm)
std::size_t hash() const override { return hash_; }
......@@ -195,7 +196,9 @@ IMM_TRAITS(Int64ImmPtr, int64_t)
class UInt8Imm : public IntergerImm {
public:
UInt8Imm() : IntergerImm(kUInt8), v_(0) {}
explicit UInt8Imm(uint8_t v) : IntergerImm(kUInt8), v_(v) { hash_ = std::hash<unsigned int>{}(v_); }
explicit UInt8Imm(uint8_t v) : IntergerImm(kUInt8), v_(v) {
hash_ = hash_combine({tid(), std::hash<unsigned int>{}(v_)});
}
~UInt8Imm() override = default;
MS_DECLARE_PARENT(UInt8Imm, IntergerImm)
std::size_t hash() const override { return hash_; }
......@@ -221,7 +224,9 @@ IMM_TRAITS(UInt8ImmPtr, uint8_t);
class UInt16Imm : public IntergerImm {
public:
UInt16Imm() : IntergerImm(kUInt16), v_(0) {}
explicit UInt16Imm(uint16_t v) : IntergerImm(kUInt16), v_(v) { hash_ = std::hash<unsigned int>{}(v_); }
explicit UInt16Imm(uint16_t v) : IntergerImm(kUInt16), v_(v) {
hash_ = hash_combine({tid(), std::hash<unsigned int>{}(v_)});
}
~UInt16Imm() override = default;
MS_DECLARE_PARENT(UInt16Imm, IntergerImm)
std::size_t hash() const override { return hash_; }
......@@ -247,7 +252,9 @@ IMM_TRAITS(UInt16ImmPtr, uint16_t);
class UInt32Imm : public IntergerImm {
public:
UInt32Imm() : IntergerImm(kUInt32), v_(0) {}
explicit UInt32Imm(uint32_t v) : IntergerImm(kUInt32), v_(v) { hash_ = std::hash<unsigned int>{}(v_); }
explicit UInt32Imm(uint32_t v) : IntergerImm(kUInt32), v_(v) {
hash_ = hash_combine({tid(), std::hash<unsigned int>{}(v_)});
}
~UInt32Imm() override = default;
MS_DECLARE_PARENT(UInt32Imm, IntergerImm)
std::size_t hash() const override { return hash_; }
......@@ -273,7 +280,9 @@ IMM_TRAITS(UInt32ImmPtr, uint32_t);
class UInt64Imm : public IntergerImm {
public:
UInt64Imm() : IntergerImm(kUInt64), v_(0) {}
explicit UInt64Imm(uint64_t v) : IntergerImm(kUInt64), v_(v) { hash_ = std::hash<uint64_t>{}(v); }
explicit UInt64Imm(uint64_t v) : IntergerImm(kUInt64), v_(v) {
hash_ = hash_combine({tid(), std::hash<uint64_t>{}(v)});
}
~UInt64Imm() override = default;
MS_DECLARE_PARENT(UInt64Imm, IntergerImm)
std::size_t hash() const override { return hash_; }
......@@ -308,7 +317,7 @@ using FloatImmPtr = std::shared_ptr<FloatImm>;
class FP32Imm : public FloatImm {
public:
FP32Imm() : FloatImm(kFloat32), v_(0.0) {}
explicit FP32Imm(float v) : FloatImm(kFloat32), v_(v) { hash_ = std::hash<float>{}(v_); }
explicit FP32Imm(float v) : FloatImm(kFloat32), v_(v) { hash_ = hash_combine({tid(), std::hash<float>{}(v_)}); }
~FP32Imm() override = default;
MS_DECLARE_PARENT(FP32Imm, FloatImm)
std::size_t hash() const override { return hash_; }
......@@ -334,7 +343,7 @@ IMM_TRAITS(FP32ImmPtr, float)
class FP64Imm : public FloatImm {
public:
FP64Imm() : FloatImm(kFloat64), v_(0.0) {}
explicit FP64Imm(double v) : FloatImm(kFloat64), v_(v) { hash_ = std::hash<double>{}(v_); }
explicit FP64Imm(double v) : FloatImm(kFloat64), v_(v) { hash_ = hash_combine({tid(), std::hash<double>{}(v_)}); }
~FP64Imm() override = default;
MS_DECLARE_PARENT(FP64Imm, FloatImm)
std::size_t hash() const override { return hash_; }
......
......@@ -412,6 +412,16 @@ class TraceCombileLikeGraphs : public TraceInfo {
return std::make_shared<TraceCombileLikeGraphs>(*shared_from_base<TraceCombileLikeGraphs>());
}
};
class TraceSegmentTransform : public TraceInfo {
public:
explicit TraceSegmentTransform(const DebugInfoPtr &info) : TraceInfo(info, "segment_transform", "") {}
MS_DECLARE_PARENT(TraceGetEnv, TraceInfo);
~TraceSegmentTransform() override = default;
TraceInfoPtr clone() override {
return std::make_shared<TraceSegmentTransform>(*shared_from_base<TraceSegmentTransform>());
}
};
} // namespace mindspore
#endif // MINDSPORE_CORE_UTILS_TRACE_INFO_H_
此差异已折叠。
......@@ -58,6 +58,7 @@ add_subdirectory(serving)
file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
"../../../mindspore/core/base/*.cc"
"../../../mindspore/core/gvar/*.cc"
"../../../mindspore/core/abstract/*.cc"
"../../../mindspore/core/ir/*.cc"
"../../../mindspore/core/utils/*.cc"
......
......@@ -34,7 +34,6 @@ from ....mindspore_test_framework.pipeline.forward.compile_forward \
from ....mindspore_test_framework.pipeline.gradient.compile_gradient \
import pipeline_for_compile_grad_ge_graph_for_case_by_case_config
class InputBackward(nn.Cell):
def __init__(self, network):
super(InputBackward, self).__init__()
......
......@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
CURRPATH=$(cd $(dirname $0); pwd)
IGNORE_EXEC="--ignore=$CURRPATH/exec"
PROJECT_PATH=$(cd ${CURRPATH}/../../..; pwd)
......
......@@ -14,7 +14,6 @@
# ============================================================================
"""Generate vm_impl function for array ops"""
import numpy as np
import mindspore.common.dtype as mstype
from mindspore.common.tensor import Tensor
from mindspore.ops import operations as P
......@@ -22,7 +21,6 @@ from mindspore.ops.operations import _grad_ops as G
from mindspore.ops.vm_impl_registry import vm_impl_registry as vm_impl_getters
from .vm_interface import vm
# pylint: disable=unused-argument
......@@ -181,8 +179,7 @@ def vm_impl_tile(self):
def vm_impl(x, multiples):
x = x.asnumpy()
multiples = multiples.asnumpy()
out = vm.Tile(x, multiples)
out = np.tile(x, multiples)
return Tensor(out)
return vm_impl
......@@ -255,7 +252,10 @@ def vm_impl_sum(self):
def vm_impl(x, axis):
x = x.asnumpy()
out = vm.sum(x, axis)
if axis == ():
out = np.sum(x)
else:
out = np.sum(x, axis=axis)
return Tensor(np.array(out))
return vm_impl
......@@ -291,12 +291,14 @@ def vm_impl_square(self):
return vm_impl
@vm_impl_getters.register(P.ZerosLike)
def vm_impl_zeros_like(self):
"""Generate vm_impl function for ZerosLike"""
def vm_impl(x):
return Tensor(np.zeros_like(x.asnumpy()))
@vm_impl_getters.register(P.Partial)
def vm_impl_partial(self):
"""Generate vm_impl function for Partial"""
......@@ -307,6 +309,7 @@ def vm_impl_partial(self):
return vm_impl
@vm_impl_getters.register(P.Depend)
def vm_impl_depend(self):
"""Generate vm_impl function for Depend"""
......
......@@ -196,6 +196,18 @@ def vm_impl_reduce_mean(self):
return vm_impl
@vm_impl_getters.register(P.ReduceMax)
def vm_impl_reduce_max(self):
"""Generate vm_impl function for ReduceMean."""
def vm_impl(x, axis):
x = x.asnumpy()
if axis == ():
axis = None
out = np.amax(x, axis)
return Tensor(out)
return vm_impl
@vm_impl_getters.register(P.Equal)
def vm_impl_equal(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册