未验证 提交 c29dc34e 编写于 作者: J Jiabin Yang 提交者: GitHub

optimzie reshape related fusion (#53066)

上级 e669528a
...@@ -510,8 +510,16 @@ void AnalyseClusterVariables( ...@@ -510,8 +510,16 @@ void AnalyseClusterVariables(
bool is_inference_stage, bool is_inference_stage,
const std::unordered_set<std::string>& skip_gc_var_names) { const std::unordered_set<std::string>& skip_gc_var_names) {
// collecting all input and output of op // collecting all input and output of op
std::unordered_set<std::string> unused_outputs;
std::unordered_set<std::string> legacy_ops{"reshape2", "transpose2"};
for (auto* op_node : cluster) { for (auto* op_node : cluster) {
const auto& op_name = op_node->Name(); const auto& op_name = op_node->Name();
if (legacy_ops.count(op_name) && op_node->Op()->HasOutput("XShape")) {
for (const auto& var_name :
(*(op_node->Op()->MutableOutputs()))["XShape"]) {
unused_outputs.insert(var_name);
}
}
for (auto* input_var_node : op_node->inputs) { for (auto* input_var_node : op_node->inputs) {
if (!deny_var_set.count(input_var_node->Name())) { if (!deny_var_set.count(input_var_node->Name())) {
// ignore deny var node // ignore deny var node
...@@ -527,9 +535,12 @@ void AnalyseClusterVariables( ...@@ -527,9 +535,12 @@ void AnalyseClusterVariables(
// remove output node from cluster_inputs, // remove output node from cluster_inputs,
// and add cluster_internals node // and add cluster_internals node
for (auto* var_node : *cluster_outputs) { for (auto* var_node : *cluster_outputs) {
if (cluster_inputs->count(var_node) > 0) { if ((cluster_inputs->count(var_node) > 0) ||
(unused_outputs.count(var_node->Name()))) {
// if a input node also exists in output list, remove // if a input node also exists in output list, remove
cluster_inputs->erase(var_node); if (cluster_inputs->count(var_node) > 0) {
cluster_inputs->erase(var_node);
}
// the internal node is must an output node of sub-graph, // the internal node is must an output node of sub-graph,
// but not any input node of out-graph. // but not any input node of out-graph.
...@@ -538,8 +549,12 @@ void AnalyseClusterVariables( ...@@ -538,8 +549,12 @@ void AnalyseClusterVariables(
for (size_t i = 0; i < var_node->outputs.size() && is_only_used_internal; for (size_t i = 0; i < var_node->outputs.size() && is_only_used_internal;
++i) { ++i) {
is_only_used_internal &= (cluster.count(var_node->outputs[i]) > 0); is_only_used_internal &= (cluster.count(var_node->outputs[i]) > 0);
VLOG(3) << "var_node->outputs[" << i << "]: " << var_node->Name()
<< ", is_only_used_internal: " << is_only_used_internal;
} }
if (is_only_used_internal) { if (is_only_used_internal) {
VLOG(3) << "insert internal var: " << var_node->Name();
cluster_internals->insert(var_node); cluster_internals->insert(var_node);
} }
} }
......
...@@ -164,10 +164,12 @@ def layernorm_composite(x, scale, bias, epsilon, begin_norm_axis): ...@@ -164,10 +164,12 @@ def layernorm_composite(x, scale, bias, epsilon, begin_norm_axis):
out = difference * rsqrt_var out = difference * rsqrt_var
if scale is not None: if scale is not None:
scale = reshape(scale, x.shape[begin_norm_axis:]) if x.shape[begin_norm_axis:] is not scale.shape:
scale = reshape(scale, x.shape[begin_norm_axis:])
out = out * scale out = out * scale
if bias is not None: if bias is not None:
bias = reshape(bias, x.shape[begin_norm_axis:]) if x.shape[begin_norm_axis:] is not bias.shape:
bias = reshape(bias, x.shape[begin_norm_axis:])
out = out + bias out = out + bias
mean_ = reshape(mean_, [-1]) mean_ = reshape(mean_, [-1])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册