提交 22a9d02e 编写于 作者: P panyifeng

switch_layer incorporate env_get and tuple_get

上级 4f46c427
......@@ -95,6 +95,10 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
incorporate_env_getitem_ =
MakeSubstitution(std::make_shared<IncorporateEnvGetitem>(), "incorporate_env_get_item", prim::kPrimEnvGetItem);
incorporate_env_getitem_switch_layer_ =
MakeSubstitution(std::make_shared<IncorporateEnvGetitemSwitchLayer>(), "incorporate_env_getitem_switch_layer",
prim::kPrimEnvGetItem);
// Ref eliminate
make_ref_eliminate_ =
MakeSubstitution(std::make_shared<MakeRefEliminater>(), "make_ref_eliminate", prim::kPrimMakeRef);
......
......@@ -58,6 +58,7 @@ class OptimizeIRPassLib {
SubstitutionPtr incorporate_env_getitem_;
SubstitutionPtr incorporate_env_getitem_bypass_recursive_;
SubstitutionPtr incorporate_env_getitem_switch_;
SubstitutionPtr incorporate_env_getitem_switch_layer_;
// Ref eliminate
SubstitutionPtr make_ref_eliminate_;
......
......@@ -91,6 +91,69 @@ class EnvGetitemTransform {
std::unordered_map<std::pair<SymbolicKeyInstancePtr, AnfNodePtr>, FuncGraphPtr, PairHasher>>
cache_;
};
class EnvGetitemTransformACrossGraph {
public:
EnvGetitemTransformACrossGraph() : cache_() {}
~EnvGetitemTransformACrossGraph() = default;
FuncGraphPtr operator()(const FuncGraphPtr &fg, const SymbolicKeyInstancePtr &key, const AnfNodePtr &default_node) {
if (cache_.find(fg) == cache_.end()) {
cache_[fg] = {};
}
auto &cache = cache_[fg];
auto hash_key = std::make_pair(key, default_node);
if (cache.find(hash_key) == cache.end()) {
std::ostringstream ss("env", std::ostringstream::app);
if (key->node() != nullptr) {
ss << key->node()->ToString();
}
auto new_fg_outer = TransformableClone(fg, std::make_shared<TraceTransform>(ss.str()));
auto output_outer = new_fg_outer->output();
if (!IsValueNode<FuncGraph>(output_outer)) {
MS_LOG(WARNING) << "Output of outer graph should be a func_graph";
return nullptr;
}
auto fg_inner = GetValueNode<FuncGraphPtr>(output_outer);
auto new_fg = TransformableClone(fg_inner, std::make_shared<TraceTransform>(ss.str()));
new_fg_outer->set_output(NewValueNode(new_fg));
auto env = new_fg->output();
while (IsPrimitiveCNode(env, prim::kPrimEnvSetItem)) {
// {prim::kPrimEnvSetItem, env, symbolickey, value}
auto &inputs = env->cast<CNodePtr>()->inputs();
if (inputs.size() != 4) {
MS_LOG(WARNING) << "Input size should be 4";
return nullptr;
}
if (!IsValueNode<SymbolicKeyInstance>(inputs[2])) {
MS_LOG(DEBUG) << "Input 2 is not a SymbolicKeyInstance?";
return nullptr;
}
env = inputs[1];
auto value = inputs[3];
auto key2 = GetValueNode<SymbolicKeyInstancePtr>(inputs[2]);
if (*key2 == *key) {
new_fg->set_output(value);
cache[hash_key] = new_fg_outer;
return new_fg_outer;
}
}
new_fg->set_output(new_fg->NewCNode({NewValueNode(prim::kPrimEnvGetItem), env, NewValueNode(key), default_node}));
cache[hash_key] = new_fg_outer;
}
return cache[hash_key];
}
private:
std::unordered_map<FuncGraphPtr,
std::unordered_map<std::pair<SymbolicKeyInstancePtr, AnfNodePtr>, FuncGraphPtr, PairHasher>>
cache_;
};
} // namespace internal
// {prim::kPrimEnvGetItem, C1, C2, Y} -> Y
......@@ -358,6 +421,78 @@ class IncorporateEnvGetitemSwitch : public AnfVisitor {
bool is_match_{false};
internal::EnvGetitemTransform env_get_item_transform_;
};
// {prim::kPrimEnvGetItem, {{{prim::kPrimSwitchLayer, X, {prim::kPrimMakeTuple, G1, G2...}}, Xs}, Ys}, C, Y}
class IncorporateEnvGetitemSwitchLayer : public AnfVisitor {
public:
IncorporateEnvGetitemSwitchLayer() : env_get_item_transform_() {}
~IncorporateEnvGetitemSwitchLayer() override = default;
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
is_match_ = false;
AnfVisitor::Match(prim::kPrimEnvGetItem, {IsCNode, IsValueNode<SymbolicKeyInstance>, IsNode})(node);
if (!is_match_ || node->func_graph() == nullptr) {
return nullptr;
}
// {prim::kPrimEnvGetItem, {...}, C, Y}
auto cnode = node->cast<CNodePtr>();
auto inp1 = cnode->input(1)->cast<CNodePtr>();
auto key = GetValueNode<SymbolicKeyInstancePtr>(cnode->input(2));
auto default_v = cnode->input(3);
// {{prim::kPrimSwitchLayer, X, {prim::kPrimMakeTuple, G1, G2...}}, Xs}, Ys}
auto &inputs_outer = inp1->inputs();
if (!inputs_outer[0]->isa<CNode>()) {
return nullptr;
}
std::vector<AnfNodePtr> args_outer;
args_outer.insert(args_outer.end(), inputs_outer.begin() + 1, inputs_outer.end());
auto &input_switch_layer = inputs_outer[0]->cast<CNodePtr>()->inputs();
is_match_ = false;
AnfVisitor::Match(prim::kPrimSwitchLayer, {IsNode, IsCNode})(input_switch_layer[0]);
if (!is_match_) {
return nullptr;
}
std::vector<AnfNodePtr> args;
(void)args.insert(args.end(), input_switch_layer.begin() + 1, input_switch_layer.end());
// {prim::kPrimSwitchLayers, X, {prim::kPrimMakeTuple, G1, G2...}}
auto sw = input_switch_layer[0]->cast<CNodePtr>();
std::vector<FuncGraphPtr> graphs{};
auto graphs_cnode = sw->input(2)->cast<CNodePtr>();
auto &graphs_inputs = graphs_cnode->inputs();
if (IsPrimitiveCNode(graphs_cnode, prim::kPrimMakeTuple) && IsValueNode<FuncGraph>(graphs_inputs[1])) {
(void)std::transform(graphs_inputs.begin() + 1, graphs_inputs.end(), std::back_inserter(graphs),
[](const AnfNodePtr &vnode) { return GetValueNode<FuncGraphPtr>(vnode); });
}
if (graphs.empty()) {
return nullptr;
}
auto fg = node->func_graph();
std::vector<AnfNodePtr> layers;
for (auto &graph : graphs) {
auto fg_transform = env_get_item_transform_(graph, key, default_v);
if (fg_transform == nullptr) {
return nullptr;
}
layers.push_back(NewValueNode(fg_transform));
}
auto layers_node = fg->NewCNode(prim::kPrimMakeTuple, layers);
auto new_sw = fg->NewCNode({NewValueNode(prim::kPrimSwitchLayer), sw->input(1), layers_node});
args.insert(args.begin(), new_sw);
auto inner_call = fg->NewCNode(args);
args_outer.insert(args_outer.begin(), inner_call);
return fg->NewCNode(args_outer);
}
void Visit(const AnfNodePtr &) override { is_match_ = true; }
private:
bool is_match_{false};
internal::EnvGetitemTransformACrossGraph env_get_item_transform_;
};
} // namespace irpass
} // namespace opt
} // namespace mindspore
......
......@@ -72,6 +72,52 @@ class GetitemTransform {
private:
std::unordered_map<FuncGraphPtr, std::unordered_map<int, FuncGraphPtr>> cache_;
};
class GetItemTransformACrossGraph {
public:
GetItemTransformACrossGraph() : cache_() {}
~GetItemTransformACrossGraph() = default;
FuncGraphPtr operator()(const FuncGraphPtr &fg, int idx) {
if (cache_.find(fg) == cache_.end()) {
cache_[fg] = {};
}
auto &cache = cache_[fg];
if (cache.find(idx) == cache.end()) {
std::ostringstream ss("tp", std::ostringstream::app);
ss << idx;
auto new_fg_outer = TransformableClone(fg, std::make_shared<TraceTransform>(ss.str()));
auto output_outer = new_fg_outer->output();
if (!IsValueNode<FuncGraph>(output_outer)) {
MS_LOG(WARNING) << "Output of outer graph should be a func_graph";
return nullptr;
}
auto fg_inner = GetValueNode<FuncGraphPtr>(output_outer);
auto new_fg = TransformableClone(fg_inner, std::make_shared<TraceTransform>(ss.str()));
new_fg_outer->set_output(NewValueNode(new_fg));
auto output = new_fg->output();
if (IsPrimitiveCNode(output, prim::kPrimMakeTuple)) {
auto cnode = output->cast<CNodePtr>();
auto ids = IntToSize(idx + 1);
// Inputs should be [make_tuple, item1, item2, ...], so have to offset idx in tuple_getitem by 1.
if (ids >= cnode->size()) {
MS_LOG(EXCEPTION) << "index " << ids << " is out of inputs length " << cnode->size();
}
new_fg->set_output(cnode->input(ids));
} else {
new_fg->set_output(new_fg->NewCNode({NewValueNode(prim::kPrimTupleGetItem), output, NewValueNode(idx)}));
}
cache[idx] = new_fg_outer;
}
return cache[idx];
}
private:
std::unordered_map<FuncGraphPtr, std::unordered_map<int, FuncGraphPtr>> cache_;
};
} // namespace internal
// {prim::kPrimTupleGetItem, {G, Xs}, C}
......@@ -385,13 +431,199 @@ class IncorporateGetitemSwitch : public AnfVisitor {
internal::GetitemTransform getitem_transform_;
};
// {prim::kPrimTupleGetItem, {{prim::kPrimSwitchLayer, X, {prim::kPrimMakeTuple, G1, G2...}}, Xs}, C}
class IncorporateGetitemSwitchLayerA : public AnfVisitor {
public:
IncorporateGetitemSwitchLayerA() : getitem_transform_() {}
~IncorporateGetitemSwitchLayerA() override = default;
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
Reset();
is_in_get_ = true;
AnfVisitor::Match(prim::kPrimTupleGetItem, {IsCNode, IsValueNode<Int32Imm>})(node);
is_in_get_ = false;
auto fg = node->func_graph();
if (idx_ == -1 || switch_layer_ == nullptr || fg == nullptr) {
return nullptr;
}
is_in_switch_ = true;
AnfVisitor::Match(prim::kPrimSwitchLayer, {IsNode, IsCNode})(switch_layer_);
is_in_switch_ = false;
if (graphs_.empty()) {
return nullptr;
}
std::vector<AnfNodePtr> layers;
for (auto &graph : graphs_) {
auto fg_transform = getitem_transform_(graph, idx_);
if (fg_transform == nullptr) {
return nullptr;
}
layers.push_back(NewValueNode(fg_transform));
}
auto layers_node = fg->NewCNode(prim::kPrimMakeTuple, layers);
std::vector<AnfNodePtr> sw_args{NewValueNode(prim::kPrimSwitchLayer), x_, layers_node};
auto sw_node = fg->NewCNode(sw_args);
(void)args_.insert(args_.begin(), sw_node);
return fg->NewCNode(args_);
}
void Visit(const AnfNodePtr &node) override {
if (is_in_switch_ && x_ == nullptr) {
x_ = node;
return;
}
AnfVisitor::Visit(node);
}
void Visit(const CNodePtr &cnode) override {
if (is_in_get_ && cnode->size() != 0) {
auto &inputs = cnode->inputs();
switch_layer_ = inputs[0];
(void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(args_));
}
if (is_in_switch_ && cnode->size() > 2) {
auto &inputs = cnode->inputs();
if (IsPrimitiveCNode(cnode, prim::kPrimMakeTuple) && IsValueNode<FuncGraph>(inputs[1])) {
(void)std::transform(inputs.begin() + 1, inputs.end(), std::back_inserter(graphs_),
[](const AnfNodePtr &vnode) { return GetValueNode<FuncGraphPtr>(vnode); });
}
}
}
void Visit(const ValueNodePtr &vnode) override {
if (is_in_get_) {
idx_ = GetValue<int>(vnode->value());
}
}
void Reset() {
x_ = nullptr;
graphs_.clear();
switch_layer_ = nullptr;
args_.clear();
is_in_get_ = false;
is_in_switch_ = false;
}
private:
int idx_{-1};
AnfNodePtr switch_layer_{nullptr}, x_{nullptr};
std::vector<FuncGraphPtr> graphs_{};
bool is_in_get_{false}, is_in_switch_{false};
std::vector<AnfNodePtr> args_{};
internal::GetitemTransform getitem_transform_;
};
// {prim::kPrimTupleGetItem, {{{prim::kPrimSwitchLayer, X, {prim::kPrimMakeTuple, G1, G2...}}, Xs}, Ys}, C}
class IncorporateGetitemSwitchLayerB : public AnfVisitor {
public:
IncorporateGetitemSwitchLayerB() : getitem_transform_() {}
~IncorporateGetitemSwitchLayerB() override = default;
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
Reset();
is_in_get_ = true;
AnfVisitor::Match(prim::kPrimTupleGetItem, {IsCNode, IsValueNode<Int32Imm>})(node);
is_in_get_ = false;
auto fg = node->func_graph();
if (idx_ == -1 || switch_layer_call_ == nullptr || !switch_layer_call_->isa<CNode>() || fg == nullptr) {
return nullptr;
}
auto &switch_layer_call_inputs = switch_layer_call_->cast<CNodePtr>()->inputs();
(void)std::copy(switch_layer_call_inputs.begin() + 1, switch_layer_call_inputs.end(), std::back_inserter(args_));
is_in_switch_ = true;
AnfVisitor::Match(prim::kPrimSwitchLayer, {IsNode, IsCNode})(switch_layer_call_inputs[0]);
is_in_switch_ = false;
if (graphs_.empty()) {
return nullptr;
}
std::vector<AnfNodePtr> layers;
for (auto &graph : graphs_) {
auto fg_transform = getitem_transform_(graph, idx_);
if (fg_transform == nullptr) {
return nullptr;
}
layers.push_back(NewValueNode(fg_transform));
}
auto layers_node = fg->NewCNode(prim::kPrimMakeTuple, layers);
std::vector<AnfNodePtr> sw_args{NewValueNode(prim::kPrimSwitchLayer), x_, layers_node};
auto sw_node = fg->NewCNode(sw_args);
(void)args_.insert(args_.begin(), sw_node);
auto call_switch_layer = fg->NewCNode(args_);
(void)outer_call_args_.insert(outer_call_args_.begin(), call_switch_layer);
return fg->NewCNode(outer_call_args_);
}
void Visit(const AnfNodePtr &node) override {
if (is_in_switch_ && x_ == nullptr) {
x_ = node;
return;
}
AnfVisitor::Visit(node);
}
void Visit(const CNodePtr &cnode) override {
if (is_in_get_ && cnode->size() != 0) {
auto &inputs = cnode->inputs();
switch_layer_call_ = inputs[0];
(void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(outer_call_args_));
}
if (is_in_switch_ && cnode->size() > 2) {
auto &inputs = cnode->inputs();
if (IsPrimitiveCNode(cnode, prim::kPrimMakeTuple) && IsValueNode<FuncGraph>(inputs[1])) {
(void)std::transform(inputs.begin() + 1, inputs.end(), std::back_inserter(graphs_),
[](const AnfNodePtr &vnode) { return GetValueNode<FuncGraphPtr>(vnode); });
}
}
}
void Visit(const ValueNodePtr &vnode) override {
if (is_in_get_) {
idx_ = GetValue<int>(vnode->value());
}
}
void Reset() {
x_ = nullptr;
graphs_.clear();
switch_layer_call_ = nullptr;
args_.clear();
outer_call_args_.clear();
is_in_get_ = false;
is_in_switch_ = false;
}
private:
int idx_{-1};
AnfNodePtr switch_layer_call_{nullptr}, x_{nullptr};
std::vector<FuncGraphPtr> graphs_{};
bool is_in_get_{false}, is_in_switch_{false};
std::vector<AnfNodePtr> args_{};
std::vector<AnfNodePtr> outer_call_args_{};
internal::GetItemTransformACrossGraph getitem_transform_;
};
class IncorporateGetitemSet : public OptimizerCaller {
public:
IncorporateGetitemSet()
: incorporate_getitem_(std::make_shared<IncorporateGetitem>()),
incorporate_getitem_switch_(std::make_shared<IncorporateGetitemSwitch>()) {
incorporate_getitem_switch_(std::make_shared<IncorporateGetitemSwitch>()),
incorporate_getitem_switch_layer_a_(std::make_shared<IncorporateGetitemSwitchLayerA>()),
incorporate_getitem_switch_layer_b_(std::make_shared<IncorporateGetitemSwitchLayerB>()) {
eliminaters_.emplace_back(incorporate_getitem_);
eliminaters_.emplace_back(incorporate_getitem_switch_);
eliminaters_.emplace_back(incorporate_getitem_switch_layer_a_);
eliminaters_.emplace_back(incorporate_getitem_switch_layer_b_);
}
~IncorporateGetitemSet() = default;
......@@ -407,7 +639,8 @@ class IncorporateGetitemSet : public OptimizerCaller {
}
private:
OptimizerCallerPtr incorporate_getitem_, incorporate_getitem_switch_;
OptimizerCallerPtr incorporate_getitem_, incorporate_getitem_switch_, incorporate_getitem_switch_layer_a_,
incorporate_getitem_switch_layer_b_;
std::vector<OptimizerCallerPtr> eliminaters_{};
};
} // namespace irpass
......
......@@ -180,7 +180,7 @@ OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) {
{irpass.zero_like_fill_zero_, irpass.item_tuple_eliminate_, irpass.float_tuple_getitem_switch_,
irpass.reset_defer_inline_, irpass.inline_, irpass.special_op_eliminate_, irpass.get_make_ref_eliminate_,
irpass.incorporate_env_getitem_, irpass.incorporate_env_getitem_switch_, irpass.env_get_item_eliminate_,
irpass.value_based_eliminate_});
irpass.incorporate_env_getitem_switch_layer_, irpass.value_based_eliminate_});
opt::OptPassConfig b_2 = opt::OptPassConfig({
irpass.replace_refkey_by_param_,
irpass.make_ref_eliminate_,
......
......@@ -464,6 +464,36 @@ def test_switch_layer_with_single_prim():
C.grad_all(net)(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
def test_switch_layer_env_eliminate():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.conv = nn.Conv2d(1, 1, 3, pad_mode='same')
self.conv2 = nn.Conv2d(1, 1, 5, pad_mode='same')
self.funs = (self.conv, self.conv2)
def construct(self, x, index):
x = self.funs[index](x)
return x
class NetGrad(nn.Cell):
def __init__(self, net):
super(NetGrad, self).__init__()
self.grad_op = C.GradOperation('grad', get_by_list=True, sens_param=False)
self.net = net
self.weights = ParameterTuple(self.net.trainable_params())
def construct(self, x, index):
weights = self.weights
grad = self.grad_op(self.net, weights)(x, index)
return grad
net = Net()
net2 = NetGrad(net)
x = Tensor(np.ones((3, 1, 12, 12)), ms.float32)
i = Tensor(1, ms.int32)
net2(x, i)
def test_control_depend_check():
with pytest.raises(TypeError) as e:
P.ControlDepend(0.0)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册