未验证 提交 5a2e5179 编写于 作者: K Kaipeng Deng 提交者: GitHub

Add FusedMultiTransformer fuse pass for GPT3 (#45907)


* add fused_multi_transformer_encoder/decoder pass, run GPT-3 success
上级 4dc4d5fc
...@@ -105,6 +105,8 @@ pass_library(simplify_with_basic_ops_pass base) ...@@ -105,6 +105,8 @@ pass_library(simplify_with_basic_ops_pass base)
pass_library(fc_elementwise_layernorm_fuse_pass base) pass_library(fc_elementwise_layernorm_fuse_pass base)
pass_library(skip_layernorm_fuse_pass base) pass_library(skip_layernorm_fuse_pass base)
pass_library(multihead_matmul_fuse_pass inference) pass_library(multihead_matmul_fuse_pass inference)
pass_library(fused_multi_transformer_encoder_pass inference)
pass_library(fused_multi_transformer_decoder_pass inference)
pass_library(adaptive_pool2d_convert_global_pass inference) pass_library(adaptive_pool2d_convert_global_pass inference)
pass_library(unsqueeze2_eltwise_fuse_pass inference) pass_library(unsqueeze2_eltwise_fuse_pass inference)
pass_library(yolo_box_fuse_pass inference) pass_library(yolo_box_fuse_pass inference)
...@@ -311,6 +313,14 @@ cc_test( ...@@ -311,6 +313,14 @@ cc_test(
test_multihead_matmul_fuse_pass test_multihead_matmul_fuse_pass
SRCS multihead_matmul_fuse_pass_tester.cc SRCS multihead_matmul_fuse_pass_tester.cc
DEPS multihead_matmul_fuse_pass) DEPS multihead_matmul_fuse_pass)
cc_test(
test_fused_multi_transformer_encoder_pass
SRCS fused_multi_transformer_encoder_pass_tester.cc
DEPS fused_multi_transformer_encoder_pass)
cc_test(
test_fused_multi_transformer_decoder_pass
SRCS fused_multi_transformer_decoder_pass_tester.cc
DEPS fused_multi_transformer_decoder_pass)
cc_test( cc_test(
test_conv_bn_fuse_pass_cc test_conv_bn_fuse_pass_cc
SRCS conv_bn_fuse_pass_tester.cc SRCS conv_bn_fuse_pass_tester.cc
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
struct FusedMultiTransformerDecoderPattern : public PatternBase {
FusedMultiTransformerDecoderPattern(PDPattern* pattern,
const std::string& name_scope)
: PatternBase(pattern, name_scope, "fused_multi_transformer_decoder") {}
PDNode* operator()();
// Q, K, V path
PATTERN_DECL_NODE(input0);
PATTERN_DECL_NODE(layer_norm);
PATTERN_DECL_NODE(layer_norm_scale);
PATTERN_DECL_NODE(layer_norm_bias);
PATTERN_DECL_NODE(layer_norm_mean);
PATTERN_DECL_NODE(layer_norm_variance);
PATTERN_DECL_NODE(layer_norm_out);
PATTERN_DECL_NODE(matmul0);
PATTERN_DECL_NODE(matmul1);
PATTERN_DECL_NODE(matmul2);
PATTERN_DECL_NODE(matmul0_w);
PATTERN_DECL_NODE(matmul1_w);
PATTERN_DECL_NODE(matmul2_w);
PATTERN_DECL_NODE(matmul0_out);
PATTERN_DECL_NODE(matmul1_out);
PATTERN_DECL_NODE(matmul2_out);
PATTERN_DECL_NODE(eltadd0); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd1); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd2); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd0_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd1_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd2_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd0_out);
PATTERN_DECL_NODE(eltadd1_out);
PATTERN_DECL_NODE(eltadd2_out);
PATTERN_DECL_NODE(reshape2_0);
PATTERN_DECL_NODE(reshape2_1);
PATTERN_DECL_NODE(reshape2_2);
PATTERN_DECL_NODE(reshape2_0_out);
PATTERN_DECL_NODE(reshape2_1_out);
PATTERN_DECL_NODE(reshape2_2_out);
PATTERN_DECL_NODE(transpose2_0);
PATTERN_DECL_NODE(transpose2_1);
PATTERN_DECL_NODE(transpose2_2);
PATTERN_DECL_NODE(transpose2_0_out);
PATTERN_DECL_NODE(transpose2_1_out);
PATTERN_DECL_NODE(transpose2_2_out);
PATTERN_DECL_NODE(concat_0_in);
PATTERN_DECL_NODE(concat_0);
PATTERN_DECL_NODE(concat_0_out);
PATTERN_DECL_NODE(assign_0);
PATTERN_DECL_NODE(concat_1_in);
PATTERN_DECL_NODE(concat_1);
PATTERN_DECL_NODE(concat_1_out);
PATTERN_DECL_NODE(assign_1);
// Q, K matmul
PATTERN_DECL_NODE(matmul_qk);
PATTERN_DECL_NODE(matmul_qk_out);
PATTERN_DECL_NODE(eltadd_qk);
PATTERN_DECL_NODE(eltadd_qk_b);
PATTERN_DECL_NODE(eltadd_qk_out);
PATTERN_DECL_NODE(softmax_qk);
PATTERN_DECL_NODE(softmax_qk_out);
PATTERN_DECL_NODE(dropout_qk);
PATTERN_DECL_NODE(dropout_qk_out);
// QK, V matmul
PATTERN_DECL_NODE(matmul_qkv);
PATTERN_DECL_NODE(matmul_qkv_out);
PATTERN_DECL_NODE(reshape2_qkv);
PATTERN_DECL_NODE(reshape2_qkv_out);
PATTERN_DECL_NODE(transpose2_qkv);
PATTERN_DECL_NODE(transpose2_qkv_out);
// out linear
PATTERN_DECL_NODE(matmul_linear);
PATTERN_DECL_NODE(matmul_linear_w);
PATTERN_DECL_NODE(matmul_linear_out);
PATTERN_DECL_NODE(eltadd_linear);
PATTERN_DECL_NODE(eltadd_linear_b);
PATTERN_DECL_NODE(eltadd_linear_out);
PATTERN_DECL_NODE(dropout_linear);
PATTERN_DECL_NODE(dropout_linear_out);
// output elementwise_add
PATTERN_DECL_NODE(eltadd_out)
PATTERN_DECL_NODE(attention_output);
// while loop
PATTERN_DECL_NODE(while0);
// Feed Forward nodes
PATTERN_DECL_NODE(ffn_layer_norm);
PATTERN_DECL_NODE(ffn_layer_norm_scale);
PATTERN_DECL_NODE(ffn_layer_norm_bias);
PATTERN_DECL_NODE(ffn_layer_norm_mean);
PATTERN_DECL_NODE(ffn_layer_norm_variance);
PATTERN_DECL_NODE(ffn_layer_norm_out);
PATTERN_DECL_NODE(ffn_matmul0);
PATTERN_DECL_NODE(ffn_matmul0_w);
PATTERN_DECL_NODE(ffn_matmul0_out);
PATTERN_DECL_NODE(ffn_eltadd0); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd0_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd0_out);
PATTERN_DECL_NODE(ffn_gelu);
PATTERN_DECL_NODE(ffn_gelu_out);
PATTERN_DECL_NODE(ffn_matmul1);
PATTERN_DECL_NODE(ffn_matmul1_w);
PATTERN_DECL_NODE(ffn_matmul1_out);
PATTERN_DECL_NODE(ffn_eltadd1); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd1_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd1_out);
PATTERN_DECL_NODE(ffn_dropout);
PATTERN_DECL_NODE(ffn_dropout_out);
// output elementwise_add
PATTERN_DECL_NODE(ffn_eltadd_out)
PATTERN_DECL_NODE(ffn_output);
};
struct FusedMultiTransformerDecoderFuseQKVPattern : public PatternBase {
FusedMultiTransformerDecoderFuseQKVPattern(PDPattern* pattern,
const std::string& name_scope)
: PatternBase(
pattern, name_scope, "fused_multi_transformer_decoder_fuse_qkv") {}
PDNode* operator()();
// Q, K, V path
PATTERN_DECL_NODE(input0);
PATTERN_DECL_NODE(layer_norm);
PATTERN_DECL_NODE(layer_norm_scale);
PATTERN_DECL_NODE(layer_norm_bias);
PATTERN_DECL_NODE(layer_norm_mean);
PATTERN_DECL_NODE(layer_norm_variance);
PATTERN_DECL_NODE(layer_norm_out);
PATTERN_DECL_NODE(matmul0);
PATTERN_DECL_NODE(matmul0_w);
PATTERN_DECL_NODE(matmul0_out);
PATTERN_DECL_NODE(eltadd0); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd0_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd0_out);
PATTERN_DECL_NODE(reshape2_0);
PATTERN_DECL_NODE(reshape2_0_out);
PATTERN_DECL_NODE(transpose2_0);
PATTERN_DECL_NODE(transpose2_0_out);
PATTERN_DECL_NODE(split0)
PATTERN_DECL_NODE(split0_q_out)
PATTERN_DECL_NODE(split0_k_out)
PATTERN_DECL_NODE(split0_v_out)
PATTERN_DECL_NODE(concat_k_in)
PATTERN_DECL_NODE(concat_v_in)
PATTERN_DECL_NODE(concat_k)
PATTERN_DECL_NODE(concat_v)
PATTERN_DECL_NODE(concat_k_out)
PATTERN_DECL_NODE(concat_v_out)
PATTERN_DECL_NODE(assign_k)
PATTERN_DECL_NODE(assign_v)
// Q, K matmul
PATTERN_DECL_NODE(matmul_qk);
PATTERN_DECL_NODE(matmul_qk_out);
PATTERN_DECL_NODE(eltadd_qk);
PATTERN_DECL_NODE(eltadd_qk_b);
PATTERN_DECL_NODE(eltadd_qk_out);
PATTERN_DECL_NODE(softmax_qk);
PATTERN_DECL_NODE(softmax_qk_out);
PATTERN_DECL_NODE(dropout_qk);
PATTERN_DECL_NODE(dropout_qk_out);
// QK, V matmul
PATTERN_DECL_NODE(matmul_qkv);
PATTERN_DECL_NODE(matmul_qkv_out);
PATTERN_DECL_NODE(reshape2_qkv);
PATTERN_DECL_NODE(reshape2_qkv_out);
PATTERN_DECL_NODE(transpose2_qkv);
PATTERN_DECL_NODE(transpose2_qkv_out);
// out linear
PATTERN_DECL_NODE(matmul_linear);
PATTERN_DECL_NODE(matmul_linear_w);
PATTERN_DECL_NODE(matmul_linear_out);
PATTERN_DECL_NODE(eltadd_linear);
PATTERN_DECL_NODE(eltadd_linear_b);
PATTERN_DECL_NODE(eltadd_linear_out);
PATTERN_DECL_NODE(dropout_linear);
PATTERN_DECL_NODE(dropout_linear_out);
// output elementwise_add
PATTERN_DECL_NODE(eltadd_out)
PATTERN_DECL_NODE(attention_output);
// Feed Forward nodes
PATTERN_DECL_NODE(ffn_layer_norm);
PATTERN_DECL_NODE(ffn_layer_norm_scale);
PATTERN_DECL_NODE(ffn_layer_norm_bias);
PATTERN_DECL_NODE(ffn_layer_norm_mean);
PATTERN_DECL_NODE(ffn_layer_norm_variance);
PATTERN_DECL_NODE(ffn_layer_norm_out);
PATTERN_DECL_NODE(ffn_matmul0);
PATTERN_DECL_NODE(ffn_matmul0_w);
PATTERN_DECL_NODE(ffn_matmul0_out);
PATTERN_DECL_NODE(ffn_eltadd0); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd0_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd0_out);
PATTERN_DECL_NODE(ffn_gelu);
PATTERN_DECL_NODE(ffn_gelu_out);
PATTERN_DECL_NODE(ffn_matmul1);
PATTERN_DECL_NODE(ffn_matmul1_w);
PATTERN_DECL_NODE(ffn_matmul1_out);
PATTERN_DECL_NODE(ffn_eltadd1); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd1_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd1_out);
PATTERN_DECL_NODE(ffn_dropout);
PATTERN_DECL_NODE(ffn_dropout_out);
// output elementwise_add
PATTERN_DECL_NODE(ffn_eltadd_out)
PATTERN_DECL_NODE(ffn_output);
};
struct MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern
: public PatternBase {
MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern(
PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern,
name_scope,
"multi_devices_fused_multi_transformer_decoder_fuse_qkv") {}
PDNode* operator()();
// Q, K, V path
PATTERN_DECL_NODE(input0);
PATTERN_DECL_NODE(layer_norm);
PATTERN_DECL_NODE(layer_norm_scale);
PATTERN_DECL_NODE(layer_norm_bias);
PATTERN_DECL_NODE(layer_norm_mean);
PATTERN_DECL_NODE(layer_norm_variance);
PATTERN_DECL_NODE(layer_norm_out);
PATTERN_DECL_NODE(c_identity);
PATTERN_DECL_NODE(c_identity_out);
PATTERN_DECL_NODE(matmul0);
PATTERN_DECL_NODE(matmul0_w);
PATTERN_DECL_NODE(matmul0_out);
PATTERN_DECL_NODE(eltadd0); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd0_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd0_out);
PATTERN_DECL_NODE(reshape2_0);
PATTERN_DECL_NODE(reshape2_0_out);
PATTERN_DECL_NODE(transpose2_0);
PATTERN_DECL_NODE(transpose2_0_out);
PATTERN_DECL_NODE(split0)
PATTERN_DECL_NODE(split0_q_out)
PATTERN_DECL_NODE(split0_k_out)
PATTERN_DECL_NODE(split0_v_out)
PATTERN_DECL_NODE(concat_k_in)
PATTERN_DECL_NODE(concat_v_in)
PATTERN_DECL_NODE(concat_k)
PATTERN_DECL_NODE(concat_v)
PATTERN_DECL_NODE(concat_k_out)
PATTERN_DECL_NODE(concat_v_out)
PATTERN_DECL_NODE(assign_k)
PATTERN_DECL_NODE(assign_v)
// Q, K matmul
PATTERN_DECL_NODE(matmul_qk);
PATTERN_DECL_NODE(matmul_qk_out);
PATTERN_DECL_NODE(eltadd_qk);
PATTERN_DECL_NODE(eltadd_qk_b);
PATTERN_DECL_NODE(eltadd_qk_out);
PATTERN_DECL_NODE(softmax_qk);
PATTERN_DECL_NODE(softmax_qk_out);
PATTERN_DECL_NODE(dropout_qk);
PATTERN_DECL_NODE(dropout_qk_out);
// QK, V matmul
PATTERN_DECL_NODE(matmul_qkv);
PATTERN_DECL_NODE(matmul_qkv_out);
PATTERN_DECL_NODE(reshape2_qkv);
PATTERN_DECL_NODE(reshape2_qkv_out);
PATTERN_DECL_NODE(transpose2_qkv);
PATTERN_DECL_NODE(transpose2_qkv_out);
// out linear
PATTERN_DECL_NODE(matmul_linear);
PATTERN_DECL_NODE(matmul_linear_w);
PATTERN_DECL_NODE(matmul_linear_out);
PATTERN_DECL_NODE(c_allreduce_sum);
PATTERN_DECL_NODE(c_allreduce_sum_out);
PATTERN_DECL_NODE(eltadd_linear);
PATTERN_DECL_NODE(eltadd_linear_b);
PATTERN_DECL_NODE(eltadd_linear_out);
PATTERN_DECL_NODE(dropout_linear);
PATTERN_DECL_NODE(dropout_linear_out);
// output elementwise_add
PATTERN_DECL_NODE(eltadd_out)
PATTERN_DECL_NODE(attention_output);
// Feed Forward nodes
PATTERN_DECL_NODE(ffn_layer_norm);
PATTERN_DECL_NODE(ffn_layer_norm_scale);
PATTERN_DECL_NODE(ffn_layer_norm_bias);
PATTERN_DECL_NODE(ffn_layer_norm_mean);
PATTERN_DECL_NODE(ffn_layer_norm_variance);
PATTERN_DECL_NODE(ffn_layer_norm_out);
PATTERN_DECL_NODE(ffn_c_identity);
PATTERN_DECL_NODE(ffn_c_identity_out);
PATTERN_DECL_NODE(ffn_matmul0);
PATTERN_DECL_NODE(ffn_matmul0_w);
PATTERN_DECL_NODE(ffn_matmul0_out);
PATTERN_DECL_NODE(ffn_eltadd0); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd0_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd0_out);
PATTERN_DECL_NODE(ffn_gelu);
PATTERN_DECL_NODE(ffn_gelu_out);
PATTERN_DECL_NODE(ffn_matmul1);
PATTERN_DECL_NODE(ffn_matmul1_w);
PATTERN_DECL_NODE(ffn_matmul1_out);
PATTERN_DECL_NODE(ffn_c_allreduce_sum);
PATTERN_DECL_NODE(ffn_c_allreduce_sum_out);
PATTERN_DECL_NODE(ffn_eltadd1); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd1_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd1_out);
PATTERN_DECL_NODE(ffn_dropout);
PATTERN_DECL_NODE(ffn_dropout_out);
// output elementwise_add
PATTERN_DECL_NODE(ffn_eltadd_out)
PATTERN_DECL_NODE(ffn_output);
};
} // namespace patterns
class FusedMultiTransformerDecoderPass : public FusePassBase {
public:
FusedMultiTransformerDecoderPass();
virtual ~FusedMultiTransformerDecoderPass() {}
protected:
void ApplyImpl(Graph* graph) const;
const std::string name_scope_{"fused_multi_transformer_decoder"};
private:
int BuildFusion(Graph* graph,
const std::string& name_scope,
Scope* scope) const;
};
class FusedMultiTransformerDecoderFuseQKVPass : public FusePassBase {
public:
FusedMultiTransformerDecoderFuseQKVPass();
virtual ~FusedMultiTransformerDecoderFuseQKVPass() {}
protected:
void ApplyImpl(Graph* graph) const;
const std::string name_scope_{"fused_multi_transformer_decoder_fuse_qkv"};
private:
int BuildFusion(Graph* graph,
const std::string& name_scope,
Scope* scope) const;
};
class MultiDevicesFusedMultiTransformerDecoderFuseQKVPass
: public FusePassBase {
public:
MultiDevicesFusedMultiTransformerDecoderFuseQKVPass();
virtual ~MultiDevicesFusedMultiTransformerDecoderFuseQKVPass() {}
protected:
void ApplyImpl(Graph* graph) const;
const std::string name_scope_{
"multi_devices_fused_multi_transformer_decoder_fuse_qkv"};
private:
int BuildFusion(Graph* graph,
const std::string& name_scope,
Scope* scope) const;
};
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
struct FusedMultiTransformerEncoderPattern : public PatternBase {
FusedMultiTransformerEncoderPattern(PDPattern* pattern,
const std::string& name_scope)
: PatternBase(pattern, name_scope, "fused_multi_transformer_encoder") {}
PDNode* operator()();
// Q, K, V path
PATTERN_DECL_NODE(input0);
PATTERN_DECL_NODE(layer_norm);
PATTERN_DECL_NODE(layer_norm_scale);
PATTERN_DECL_NODE(layer_norm_bias);
PATTERN_DECL_NODE(layer_norm_mean);
PATTERN_DECL_NODE(layer_norm_variance);
PATTERN_DECL_NODE(layer_norm_out);
PATTERN_DECL_NODE(matmul0);
PATTERN_DECL_NODE(matmul1);
PATTERN_DECL_NODE(matmul2);
PATTERN_DECL_NODE(matmul0_w);
PATTERN_DECL_NODE(matmul1_w);
PATTERN_DECL_NODE(matmul2_w);
PATTERN_DECL_NODE(matmul0_out);
PATTERN_DECL_NODE(matmul1_out);
PATTERN_DECL_NODE(matmul2_out);
PATTERN_DECL_NODE(eltadd0); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd1); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd2); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd0_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd1_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd2_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd0_out);
PATTERN_DECL_NODE(eltadd1_out);
PATTERN_DECL_NODE(eltadd2_out);
PATTERN_DECL_NODE(reshape2_0);
PATTERN_DECL_NODE(reshape2_1);
PATTERN_DECL_NODE(reshape2_2);
PATTERN_DECL_NODE(reshape2_0_out);
PATTERN_DECL_NODE(reshape2_1_out);
PATTERN_DECL_NODE(reshape2_2_out);
PATTERN_DECL_NODE(transpose2_0);
PATTERN_DECL_NODE(transpose2_1);
PATTERN_DECL_NODE(transpose2_2);
PATTERN_DECL_NODE(transpose2_0_out);
PATTERN_DECL_NODE(transpose2_1_out);
PATTERN_DECL_NODE(transpose2_2_out);
// Q, K matmul
PATTERN_DECL_NODE(matmul_qk);
PATTERN_DECL_NODE(matmul_qk_out);
PATTERN_DECL_NODE(eltadd_qk);
PATTERN_DECL_NODE(eltadd_qk_b);
PATTERN_DECL_NODE(eltadd_qk_out);
PATTERN_DECL_NODE(softmax_qk);
PATTERN_DECL_NODE(softmax_qk_out);
PATTERN_DECL_NODE(dropout_qk);
PATTERN_DECL_NODE(dropout_qk_out);
// QK, V matmul
PATTERN_DECL_NODE(matmul_qkv);
PATTERN_DECL_NODE(matmul_qkv_out);
PATTERN_DECL_NODE(reshape2_qkv);
PATTERN_DECL_NODE(reshape2_qkv_out);
PATTERN_DECL_NODE(transpose2_qkv);
PATTERN_DECL_NODE(transpose2_qkv_out);
// out linear
PATTERN_DECL_NODE(matmul_linear);
PATTERN_DECL_NODE(matmul_linear_w);
PATTERN_DECL_NODE(matmul_linear_out);
PATTERN_DECL_NODE(eltadd_linear);
PATTERN_DECL_NODE(eltadd_linear_b);
PATTERN_DECL_NODE(eltadd_linear_out);
PATTERN_DECL_NODE(dropout_linear);
PATTERN_DECL_NODE(dropout_linear_out);
// output elementwise_add
PATTERN_DECL_NODE(eltadd_out)
PATTERN_DECL_NODE(attention_output);
// while loop
PATTERN_DECL_NODE(while0);
// Feed Forward nodes
PATTERN_DECL_NODE(ffn_layer_norm);
PATTERN_DECL_NODE(ffn_layer_norm_scale);
PATTERN_DECL_NODE(ffn_layer_norm_bias);
PATTERN_DECL_NODE(ffn_layer_norm_mean);
PATTERN_DECL_NODE(ffn_layer_norm_variance);
PATTERN_DECL_NODE(ffn_layer_norm_out);
PATTERN_DECL_NODE(ffn_matmul0);
PATTERN_DECL_NODE(ffn_matmul0_w);
PATTERN_DECL_NODE(ffn_matmul0_out);
PATTERN_DECL_NODE(ffn_eltadd0); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd0_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd0_out);
PATTERN_DECL_NODE(ffn_gelu);
PATTERN_DECL_NODE(ffn_gelu_out);
PATTERN_DECL_NODE(ffn_matmul1);
PATTERN_DECL_NODE(ffn_matmul1_w);
PATTERN_DECL_NODE(ffn_matmul1_out);
PATTERN_DECL_NODE(ffn_eltadd1); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd1_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd1_out);
PATTERN_DECL_NODE(ffn_dropout);
PATTERN_DECL_NODE(ffn_dropout_out);
// output elementwise_add
PATTERN_DECL_NODE(ffn_eltadd_out)
PATTERN_DECL_NODE(ffn_output);
};
struct FusedMultiTransformerEncoderFuseQKVPattern : public PatternBase {
FusedMultiTransformerEncoderFuseQKVPattern(PDPattern* pattern,
const std::string& name_scope)
: PatternBase(
pattern, name_scope, "fused_multi_transformer_encoder_fuse_qkv") {}
PDNode* operator()();
// Q, K, V path
PATTERN_DECL_NODE(input0);
PATTERN_DECL_NODE(layer_norm);
PATTERN_DECL_NODE(layer_norm_scale);
PATTERN_DECL_NODE(layer_norm_bias);
PATTERN_DECL_NODE(layer_norm_mean);
PATTERN_DECL_NODE(layer_norm_variance);
PATTERN_DECL_NODE(layer_norm_out);
PATTERN_DECL_NODE(matmul0);
PATTERN_DECL_NODE(matmul0_w);
PATTERN_DECL_NODE(matmul0_out);
PATTERN_DECL_NODE(eltadd0); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd0_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd0_out);
PATTERN_DECL_NODE(reshape2_0);
PATTERN_DECL_NODE(reshape2_0_out);
PATTERN_DECL_NODE(transpose2_0);
PATTERN_DECL_NODE(transpose2_0_out);
PATTERN_DECL_NODE(split0)
PATTERN_DECL_NODE(split0_q_out)
PATTERN_DECL_NODE(split0_k_out)
PATTERN_DECL_NODE(split0_v_out)
// Q, K matmul
PATTERN_DECL_NODE(matmul_qk);
PATTERN_DECL_NODE(matmul_qk_out);
PATTERN_DECL_NODE(eltadd_qk);
PATTERN_DECL_NODE(eltadd_qk_b);
PATTERN_DECL_NODE(eltadd_qk_out);
PATTERN_DECL_NODE(softmax_qk);
PATTERN_DECL_NODE(softmax_qk_out);
PATTERN_DECL_NODE(dropout_qk);
PATTERN_DECL_NODE(dropout_qk_out);
// QK, V matmul
PATTERN_DECL_NODE(matmul_qkv);
PATTERN_DECL_NODE(matmul_qkv_out);
PATTERN_DECL_NODE(reshape2_qkv);
PATTERN_DECL_NODE(reshape2_qkv_out);
PATTERN_DECL_NODE(transpose2_qkv);
PATTERN_DECL_NODE(transpose2_qkv_out);
// while loop
PATTERN_DECL_NODE(while0);
// out linear
PATTERN_DECL_NODE(matmul_linear);
PATTERN_DECL_NODE(matmul_linear_w);
PATTERN_DECL_NODE(matmul_linear_out);
PATTERN_DECL_NODE(eltadd_linear);
PATTERN_DECL_NODE(eltadd_linear_b);
PATTERN_DECL_NODE(eltadd_linear_out);
PATTERN_DECL_NODE(dropout_linear);
PATTERN_DECL_NODE(dropout_linear_out);
// output elementwise_add
PATTERN_DECL_NODE(eltadd_out)
PATTERN_DECL_NODE(attention_output);
// Feed Forward nodes
PATTERN_DECL_NODE(ffn_layer_norm);
PATTERN_DECL_NODE(ffn_layer_norm_scale);
PATTERN_DECL_NODE(ffn_layer_norm_bias);
PATTERN_DECL_NODE(ffn_layer_norm_mean);
PATTERN_DECL_NODE(ffn_layer_norm_variance);
PATTERN_DECL_NODE(ffn_layer_norm_out);
PATTERN_DECL_NODE(ffn_matmul0);
PATTERN_DECL_NODE(ffn_matmul0_w);
PATTERN_DECL_NODE(ffn_matmul0_out);
PATTERN_DECL_NODE(ffn_eltadd0); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd0_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd0_out);
PATTERN_DECL_NODE(ffn_gelu);
PATTERN_DECL_NODE(ffn_gelu_out);
PATTERN_DECL_NODE(ffn_matmul1);
PATTERN_DECL_NODE(ffn_matmul1_w);
PATTERN_DECL_NODE(ffn_matmul1_out);
PATTERN_DECL_NODE(ffn_eltadd1); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd1_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd1_out);
PATTERN_DECL_NODE(ffn_dropout);
PATTERN_DECL_NODE(ffn_dropout_out);
// output elementwise_add
PATTERN_DECL_NODE(ffn_eltadd_out)
PATTERN_DECL_NODE(ffn_output);
};
struct MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern
: public PatternBase {
MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern(
PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern,
name_scope,
"multi_devices_fused_multi_transformer_encoder_fuse_qkv") {}
PDNode* operator()();
// Q, K, V path
PATTERN_DECL_NODE(input0);
PATTERN_DECL_NODE(layer_norm);
PATTERN_DECL_NODE(layer_norm_scale);
PATTERN_DECL_NODE(layer_norm_bias);
PATTERN_DECL_NODE(layer_norm_mean);
PATTERN_DECL_NODE(layer_norm_variance);
PATTERN_DECL_NODE(layer_norm_out);
PATTERN_DECL_NODE(c_identity);
PATTERN_DECL_NODE(c_identity_out);
PATTERN_DECL_NODE(matmul0);
PATTERN_DECL_NODE(matmul0_w);
PATTERN_DECL_NODE(matmul0_out);
PATTERN_DECL_NODE(eltadd0); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd0_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd0_out);
PATTERN_DECL_NODE(reshape2_0);
PATTERN_DECL_NODE(reshape2_0_out);
PATTERN_DECL_NODE(transpose2_0);
PATTERN_DECL_NODE(transpose2_0_out);
PATTERN_DECL_NODE(split0)
PATTERN_DECL_NODE(split0_q_out)
PATTERN_DECL_NODE(split0_k_out)
PATTERN_DECL_NODE(split0_v_out)
// Q, K matmul
PATTERN_DECL_NODE(matmul_qk);
PATTERN_DECL_NODE(matmul_qk_out);
PATTERN_DECL_NODE(eltadd_qk);
PATTERN_DECL_NODE(eltadd_qk_b);
PATTERN_DECL_NODE(eltadd_qk_out);
PATTERN_DECL_NODE(softmax_qk);
PATTERN_DECL_NODE(softmax_qk_out);
PATTERN_DECL_NODE(dropout_qk);
PATTERN_DECL_NODE(dropout_qk_out);
// QK, V matmul
PATTERN_DECL_NODE(matmul_qkv);
PATTERN_DECL_NODE(matmul_qkv_out);
PATTERN_DECL_NODE(reshape2_qkv);
PATTERN_DECL_NODE(reshape2_qkv_out);
PATTERN_DECL_NODE(transpose2_qkv);
PATTERN_DECL_NODE(transpose2_qkv_out);
// while loop
PATTERN_DECL_NODE(while0);
// out linear
PATTERN_DECL_NODE(matmul_linear);
PATTERN_DECL_NODE(matmul_linear_w);
PATTERN_DECL_NODE(matmul_linear_out);
PATTERN_DECL_NODE(c_allreduce_sum);
PATTERN_DECL_NODE(c_allreduce_sum_out);
PATTERN_DECL_NODE(eltadd_linear);
PATTERN_DECL_NODE(eltadd_linear_b);
PATTERN_DECL_NODE(eltadd_linear_out);
PATTERN_DECL_NODE(dropout_linear);
PATTERN_DECL_NODE(dropout_linear_out);
// output elementwise_add
PATTERN_DECL_NODE(eltadd_out)
PATTERN_DECL_NODE(attention_output);
// Feed Forward nodes
PATTERN_DECL_NODE(ffn_layer_norm);
PATTERN_DECL_NODE(ffn_layer_norm_scale);
PATTERN_DECL_NODE(ffn_layer_norm_bias);
PATTERN_DECL_NODE(ffn_layer_norm_mean);
PATTERN_DECL_NODE(ffn_layer_norm_variance);
PATTERN_DECL_NODE(ffn_layer_norm_out);
PATTERN_DECL_NODE(ffn_c_identity);
PATTERN_DECL_NODE(ffn_c_identity_out);
PATTERN_DECL_NODE(ffn_matmul0);
PATTERN_DECL_NODE(ffn_matmul0_w);
PATTERN_DECL_NODE(ffn_matmul0_out);
PATTERN_DECL_NODE(ffn_eltadd0); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd0_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd0_out);
PATTERN_DECL_NODE(ffn_gelu);
PATTERN_DECL_NODE(ffn_gelu_out);
PATTERN_DECL_NODE(ffn_matmul1);
PATTERN_DECL_NODE(ffn_matmul1_w);
PATTERN_DECL_NODE(ffn_matmul1_out);
PATTERN_DECL_NODE(ffn_c_allreduce_sum);
PATTERN_DECL_NODE(ffn_c_allreduce_sum_out);
PATTERN_DECL_NODE(ffn_eltadd1); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd1_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd1_out);
PATTERN_DECL_NODE(ffn_dropout);
PATTERN_DECL_NODE(ffn_dropout_out);
// output elementwise_add
PATTERN_DECL_NODE(ffn_eltadd_out)
PATTERN_DECL_NODE(ffn_output);
};
} // namespace patterns
class FusedMultiTransformerEncoderPass : public FusePassBase {
public:
FusedMultiTransformerEncoderPass();
virtual ~FusedMultiTransformerEncoderPass() {}
protected:
void ApplyImpl(Graph* graph) const;
const std::string name_scope_{"fused_multi_transformer_encoder"};
private:
int BuildFusion(Graph* graph,
const std::string& name_scope,
Scope* scope) const;
};
class FusedMultiTransformerEncoderFuseQKVPass : public FusePassBase {
public:
FusedMultiTransformerEncoderFuseQKVPass();
virtual ~FusedMultiTransformerEncoderFuseQKVPass() {}
protected:
void ApplyImpl(Graph* graph) const;
const std::string name_scope_{"fused_multi_transformer_encoder_fuse_qkv"};
private:
int BuildFusion(Graph* graph,
const std::string& name_scope,
Scope* scope) const;
};
class MultiDevicesFusedMultiTransformerEncoderFuseQKVPass
: public FusePassBase {
public:
MultiDevicesFusedMultiTransformerEncoderFuseQKVPass();
virtual ~MultiDevicesFusedMultiTransformerEncoderFuseQKVPass() {}
protected:
void ApplyImpl(Graph* graph) const;
const std::string name_scope_{
"multi_devices_fused_multi_transformer_encoder_fuse_qkv"};
private:
int BuildFusion(Graph* graph,
const std::string& name_scope,
Scope* scope) const;
};
} // namespace ir
} // namespace framework
} // namespace paddle
...@@ -815,9 +815,14 @@ void GraphToProgram(const Graph &graph, ...@@ -815,9 +815,14 @@ void GraphToProgram(const Graph &graph,
// avoid kRootBlockIndex not 0 // avoid kRootBlockIndex not 0
if (idx == kRootBlockIndex) continue; if (idx == kRootBlockIndex) continue;
if (static_cast<int>(idx) < program_pb.blocks_size()) {
block = program_pb.mutable_blocks(idx);
} else {
block = program_pb.add_blocks(); block = program_pb.add_blocks();
block->set_idx(idx); block->set_idx(idx);
block->set_parent_idx(kRootBlockIndex); block->set_parent_idx(kRootBlockIndex);
}
GraphToBlock(*graph.GetSubGraph(idx), GraphToBlock(*graph.GetSubGraph(idx),
block, block,
sort_kind, sort_kind,
......
...@@ -112,6 +112,7 @@ bool GraphPatternDetector::MarkPDNodesInGraph(const ir::Graph &graph) { ...@@ -112,6 +112,7 @@ bool GraphPatternDetector::MarkPDNodesInGraph(const ir::Graph &graph) {
if (graph.Nodes().empty()) return false; if (graph.Nodes().empty()) return false;
for (auto &node : GraphTraits::DFS(graph)) { for (auto &node : GraphTraits::DFS(graph)) {
if (node.Name().rfind("__control_var") == 0) continue;
for (const auto &pdnode : pattern_.nodes()) { for (const auto &pdnode : pattern_.nodes()) {
if (pdnode->Tell(&node)) { if (pdnode->Tell(&node)) {
VLOG(4) << "Node " << node.Name() << " marked as " << pdnode->name(); VLOG(4) << "Node " << node.Name() << " marked as " << pdnode->name();
...@@ -383,7 +384,6 @@ std::string PDPattern::DotString() const { ...@@ -383,7 +384,6 @@ std::string PDPattern::DotString() const {
// Create Edges // Create Edges
for (const auto &edge : edges()) { for (const auto &edge : edges()) {
if (!node2dot.count(edge.first) || !node2dot.count(edge.second)) { if (!node2dot.count(edge.first) || !node2dot.count(edge.second)) {
LOG(ERROR) << "no node " << edge.first << " " << edge.second;
continue; continue;
} }
auto &src = node2dot.at(edge.first); auto &src = node2dot.at(edge.first);
...@@ -453,7 +453,8 @@ PDNode *PDNode::assert_var_not_persistable() { ...@@ -453,7 +453,8 @@ PDNode *PDNode::assert_var_not_persistable() {
PDNode *PDNode::assert_is_persistable_var() { PDNode *PDNode::assert_is_persistable_var() {
assert_is_var(); assert_is_var();
asserts_.emplace_back([=](Node *x) { return x->Var()->Persistable(); }); asserts_.emplace_back(
[=](Node *x) { return x->Var() && x->Var()->Persistable(); });
return this; return this;
} }
......
...@@ -1990,6 +1990,14 @@ struct AddSupportInt8 : public PatternBase { ...@@ -1990,6 +1990,14 @@ struct AddSupportInt8 : public PatternBase {
a->outputs.push_back(b); \ a->outputs.push_back(b); \
b->inputs.push_back(a); b->inputs.push_back(a);
// UnLink 2 ir::Nodes from each other.
#define IR_NODE_UNLINK(a, b) \
a->outputs.erase( \
std::remove(std::begin(a->outputs), std::end(a->outputs), b), \
std::end(a->outputs)); \
b->inputs.erase(std::remove(std::begin(b->inputs), std::end(b->inputs), a), \
std::end(b->inputs));
// Set the out_var as the output of the op // Set the out_var as the output of the op
#define IR_OP_VAR_LINK(op, out_var) \ #define IR_OP_VAR_LINK(op, out_var) \
op->outputs.push_back(out_var); \ op->outputs.push_back(out_var); \
......
...@@ -22,6 +22,7 @@ limitations under the License. */ ...@@ -22,6 +22,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class Scope;
namespace ir { namespace ir {
class Graph; class Graph;
} // namespace ir } // namespace ir
...@@ -35,6 +36,17 @@ namespace paddle { ...@@ -35,6 +36,17 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
static const char kParamScopeAttr[] = "__param_scope__";
static const std::vector<std::string> support_subgraph_passes = {
"fused_multi_transformer_encoder_pass",
"fused_multi_transformer_decoder_pass",
"fused_multi_transformer_encoder_fuse_qkv_pass",
"fused_multi_transformer_decoder_fuse_qkv_pass",
"multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass",
"multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass",
};
Graph *Pass::Apply(Graph *graph) const { Graph *Pass::Apply(Graph *graph) const {
VLOG(10) << "start to apply pass " << Type() << " to graph"; VLOG(10) << "start to apply pass " << Type() << " to graph";
CheckPrevPass(); CheckPrevPass();
...@@ -65,11 +77,41 @@ Graph *Pass::Apply(Graph *graph) const { ...@@ -65,11 +77,41 @@ Graph *Pass::Apply(Graph *graph) const {
true, true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The VarDescs of persistable variable are not consistency.")); "The VarDescs of persistable variable are not consistency."));
applied_ = true;
if (!graph->Has(kPassRecorder)) { if (!graph->Has(kPassRecorder)) {
graph->Set<PassRecorder>(kPassRecorder, new PassRecorder); graph->Set<PassRecorder>(kPassRecorder, new PassRecorder);
} }
graph->Get<PassRecorder>(kPassRecorder).insert(Type()); graph->Get<PassRecorder>(kPassRecorder).insert(Type());
if (graph->IsMainGraph() && std::count(support_subgraph_passes.begin(),
support_subgraph_passes.end(),
Type())) {
for (size_t i = 1; i < graph->SubGraphsSize(); i++) {
auto *sub_graph = graph->GetSubGraph(i);
if (!sub_graph->Has(framework::ir::kParamScopeAttr)) {
sub_graph->SetNotOwned<Scope>(
framework::ir::kParamScopeAttr,
&graph->Get<Scope>(framework::ir::kParamScopeAttr));
}
ApplyImpl(sub_graph);
PADDLE_ENFORCE_EQ(
HasCircle(*sub_graph),
false,
platform::errors::InvalidArgument(
"Illegal pass %s. Generated graph shouldn't contain cycle.",
Type()));
PADDLE_ENFORCE_EQ(
VarDescIsConsistency(*sub_graph),
true,
platform::errors::InvalidArgument(
"The VarDescs of persistable variable are not consistency."));
if (!sub_graph->Has(kPassRecorder)) {
sub_graph->Set<PassRecorder>(kPassRecorder, new PassRecorder);
}
sub_graph->Get<PassRecorder>(kPassRecorder).insert(Type());
}
}
applied_ = true;
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
// Clear mkl-dnn cache, // Clear mkl-dnn cache,
// Passes can change params, tensors, so caching need to be discarded // Passes can change params, tensors, so caching need to be discarded
......
...@@ -47,6 +47,18 @@ constexpr char kPassRecorder[] = "pass_recorder"; ...@@ -47,6 +47,18 @@ constexpr char kPassRecorder[] = "pass_recorder";
constexpr char kEmbEltwiseLayernormPass[] = constexpr char kEmbEltwiseLayernormPass[] =
"embedding_eltwise_layernorm_fuse_pass_flag"; "embedding_eltwise_layernorm_fuse_pass_flag";
constexpr char kMultiheadMatmulPass[] = "multihead_matmul_fuse_pass_flag"; constexpr char kMultiheadMatmulPass[] = "multihead_matmul_fuse_pass_flag";
constexpr char kFusedMultiTransformerEncoderPass[] =
"fused_multi_transformer_encoder_pass_flag";
constexpr char kFusedMultiTransformerDecoderPass[] =
"fused_multi_transformer_decoder_pass_flag";
constexpr char kFusedMultiTransformerEncoderFuseQKVPass[] =
"fused_multi_transformer_encoder_fuse_qkv_pass_flag";
constexpr char kFusedMultiTransformerDecoderFuseQKVPass[] =
"fused_multi_transformer_decoder_fuse_qkv_pass_flag";
constexpr char kMultiDevicesFusedMultiTransformerEncoderFuseQKVPass[] =
"multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass_flag";
constexpr char kMultiDevicesFusedMultiTransformerDecoderFuseQKVPass[] =
"multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass_flag";
constexpr char kPrelnEmbEltwiseLayernormPass[] = constexpr char kPrelnEmbEltwiseLayernormPass[] =
"preln_embedding_eltwise_layernorm_fuse_pass_flag"; "preln_embedding_eltwise_layernorm_fuse_pass_flag";
......
...@@ -146,6 +146,12 @@ struct Layers { ...@@ -146,6 +146,12 @@ struct Layers {
return unary_op("relu", x, out); return unary_op("relu", x, out);
} }
VarDesc* gelu(VarDesc* x, VarDesc* out = nullptr, bool approximate = true) {
AttributeMap attrs;
attrs["approximate"] = approximate;
return unary_op("gelu", x, out, &attrs);
}
VarDesc* sigmoid(VarDesc* x, VarDesc* out = nullptr) { VarDesc* sigmoid(VarDesc* x, VarDesc* out = nullptr) {
return unary_op("sigmoid", x, out); return unary_op("sigmoid", x, out);
} }
...@@ -154,6 +160,20 @@ struct Layers { ...@@ -154,6 +160,20 @@ struct Layers {
return unary_op("tanh", x, out); return unary_op("tanh", x, out);
} }
VarDesc* c_identity(VarDesc* x, VarDesc* out = nullptr, int ring_id = -1) {
AttributeMap attrs;
attrs["ring_id"] = ring_id;
return unary_op("c_identity", x, out, &attrs);
}
VarDesc* c_allreduce_sum(VarDesc* x,
VarDesc* out = nullptr,
int ring_id = -1) {
AttributeMap attrs;
attrs["ring_id"] = ring_id;
return unary_op("c_allreduce_sum", x, out, &attrs);
}
VarDesc* fc(VarDesc* input, VarDesc* fc(VarDesc* input,
VarDesc* w, VarDesc* w,
VarDesc* bias, VarDesc* bias,
...@@ -332,6 +352,37 @@ struct Layers { ...@@ -332,6 +352,37 @@ struct Layers {
return outs; return outs;
} }
std::vector<VarDesc*> split(VarDesc* x, int num_or_section, int axis = 0) {
std::vector<VarDesc*> outs(num_or_section);
for (int i = 0; i < num_or_section; i++) {
outs[i] = lod_tensor(unique_name());
}
std::vector<std::string> out_names(num_or_section);
for (int i = 0; i < num_or_section; i++) {
out_names[i] = outs[i]->Name();
}
OpDesc* op = program_.MutableBlock(0)->AppendOp();
op->SetType("split");
op->SetInput("X", {x->Name()});
op->SetOutput("Out", out_names);
op->SetAttr("num_or_section", num_or_section);
op->SetAttr("axis", axis);
op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
static_cast<int>(OpRole::kForward));
return outs;
}
VarDesc* assign(VarDesc* x) {
VarDesc* out = lod_tensor(unique_name());
OpDesc* op = program_.MutableBlock(0)->AppendOp();
op->SetType("assign");
op->SetInput("X", {x->Name()});
op->SetOutput("Out", {out->Name()});
op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
static_cast<int>(OpRole::kForward));
return out;
}
VarDesc* matmul(VarDesc* x, VarDesc* matmul(VarDesc* x,
VarDesc* y, VarDesc* y,
VarDesc* alpha = nullptr, VarDesc* alpha = nullptr,
...@@ -459,6 +510,24 @@ struct Layers { ...@@ -459,6 +510,24 @@ struct Layers {
return out; return out;
} }
VarDesc* while_loop(std::vector<VarDesc*> xs, VarDesc* cond = nullptr) {
VarDesc* out = lod_tensor(unique_name());
VarDesc* step_scopes = lod_tensor(unique_name());
if (cond == nullptr) cond = lod_tensor(unique_name());
OpDesc* op = program_.MutableBlock(0)->AppendOp();
op->SetType("while");
std::vector<std::string> xs_names;
for (auto& x : xs) xs_names.emplace_back(x->Name());
op->SetInput("X", xs_names);
op->SetInput("Condition", {cond->Name()});
op->SetOutput("Out", {out->Name()});
op->SetOutput("StepScopes", {step_scopes->Name()});
op->SetAttr("sub_block", {program_.MutableBlock(0)});
op->SetAttr("is_test", true);
return out;
}
void backward(std::vector<VarDesc*> targets) { void backward(std::vector<VarDesc*> targets) {
// This function is designed to simulate the structure of training program, // This function is designed to simulate the structure of training program,
// but is constructed differently as the actual program. // but is constructed differently as the actual program.
...@@ -523,7 +592,10 @@ struct Layers { ...@@ -523,7 +592,10 @@ struct Layers {
return var; return var;
} }
VarDesc* unary_op(std::string type, VarDesc* x, VarDesc* out = nullptr) { VarDesc* unary_op(std::string type,
VarDesc* x,
VarDesc* out = nullptr,
const AttributeMap* attrs = nullptr) {
if (!out) { if (!out) {
out = lod_tensor(unique_name()); out = lod_tensor(unique_name());
} }
...@@ -531,6 +603,11 @@ struct Layers { ...@@ -531,6 +603,11 @@ struct Layers {
op->SetType(type); op->SetType(type);
op->SetInput("X", {x->Name()}); op->SetInput("X", {x->Name()});
op->SetOutput("Out", {out->Name()}); op->SetOutput("Out", {out->Name()});
if (attrs) {
for (auto& iter : *attrs) {
op->SetAttr(iter.first, iter.second);
}
}
op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
static_cast<int>(OpRole::kForward)); static_cast<int>(OpRole::kForward));
return out; return out;
......
...@@ -76,6 +76,7 @@ void MemoryOptimizePass::CollectLifeCycle( ...@@ -76,6 +76,7 @@ void MemoryOptimizePass::CollectLifeCycle(
} else { } else {
// Normal operators. // Normal operators.
for (const Node* node : requires) { for (const Node* node : requires) {
if (!node->Var()) continue;
if (node->Var()->Persistable()) continue; if (node->Var()->Persistable()) continue;
std::string var = node->Name(); std::string var = node->Name();
if (!lifecycles->count(var)) { if (!lifecycles->count(var)) {
...@@ -133,7 +134,7 @@ void MemoryOptimizePass::CollectVarMemorySize( ...@@ -133,7 +134,7 @@ void MemoryOptimizePass::CollectVarMemorySize(
// between performance and underlying principle. // between performance and underlying principle.
std::unordered_set<std::string> black_list; std::unordered_set<std::string> black_list;
for (auto* node : graph->Nodes()) { for (auto* node : graph->Nodes()) {
if (node->IsVar() && if (node->IsVar() && node->Var() &&
node->Var()->GetType() == node->Var()->GetType() ==
framework::proto::VarType::Type::VarType_Type_LOD_TENSOR) { framework::proto::VarType::Type::VarType_Type_LOD_TENSOR) {
if (!valid_var(node)) { if (!valid_var(node)) {
...@@ -144,7 +145,7 @@ void MemoryOptimizePass::CollectVarMemorySize( ...@@ -144,7 +145,7 @@ void MemoryOptimizePass::CollectVarMemorySize(
// Collect tensors from graph. // Collect tensors from graph.
for (auto* node : graph->Nodes()) { for (auto* node : graph->Nodes()) {
if (node->IsVar() && if (node->IsVar() && node->Var() &&
node->Var()->GetType() == node->Var()->GetType() ==
framework::proto::VarType::Type::VarType_Type_LOD_TENSOR && framework::proto::VarType::Type::VarType_Type_LOD_TENSOR &&
!black_list.count(node->Var()->Name())) { !black_list.count(node->Var()->Name())) {
......
...@@ -199,6 +199,12 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) { ...@@ -199,6 +199,12 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
"conv_eltwiseadd_bn_fuse_pass", // "conv_eltwiseadd_bn_fuse_pass", //
"embedding_eltwise_layernorm_fuse_pass", // "embedding_eltwise_layernorm_fuse_pass", //
"multihead_matmul_fuse_pass_v2", // "multihead_matmul_fuse_pass_v2", //
"fused_multi_transformer_encoder_pass", //
"fused_multi_transformer_decoder_pass", //
"fused_multi_transformer_encoder_fuse_qkv_pass", //
"fused_multi_transformer_decoder_fuse_qkv_pass", //
"multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass", //
"multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass", //
"gpu_cpu_squeeze2_matmul_fuse_pass", // "gpu_cpu_squeeze2_matmul_fuse_pass", //
"gpu_cpu_reshape2_matmul_fuse_pass", // "gpu_cpu_reshape2_matmul_fuse_pass", //
"gpu_cpu_flatten2_matmul_fuse_pass", // "gpu_cpu_flatten2_matmul_fuse_pass", //
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册