提交 cd6e8d65 编写于 作者: H huanghui

fix ReluV2's mask shape in derelu fusion pass

上级 b48d663c
......@@ -46,6 +46,8 @@
#include "pre_activate/ascend/ir_fusion/mul_addn_fusion.h"
#include "pre_activate/ascend/ir_fusion/matmul_biasadd_fusion.h"
#include "pre_activate/ascend/ir_fusion/remove_reshape_pair.h"
#include "pre_activate/ascend/ir_fusion/confusion_mul_grad_fusion.h"
#include "pre_activate/ascend/ir_fusion/derelu_fusion.h"
#include "pre_activate/ascend/format_type/insert_trans_op.h"
#include "pre_activate/pass/getitem_tuple.h"
#include "pre_activate/pass/optimize_dependence.h"
......@@ -94,8 +96,10 @@ void AddAscendBackendOptionalIRFusion(PassManager *ir_fusion_pm) {
ir_fusion_pm->AddPass(std::make_shared<MulAddNFusion>());
ir_fusion_pm->AddPass(std::make_shared<MatmulBiasaddFusion>());
ir_fusion_pm->AddPass(std::make_shared<AddnFission>());
ir_fusion_pm->AddPass(std::make_shared<GetitemTuple>());
ir_fusion_pm->AddPass(std::make_shared<DereluFusion>());
ir_fusion_pm->AddPass(std::make_shared<ConfusionMulGradFusion>());
ir_fusion_pm->AddPass(std::make_shared<TransposeTransDataFusion>());
ir_fusion_pm->AddPass(std::make_shared<GetitemTuple>());
}
} // namespace
......
......@@ -18,6 +18,7 @@
#include <memory>
#include <vector>
#include <algorithm>
#include <string>
#include "session/anf_runtime_algorithm.h"
#include "ir/primitive.h"
#include "utils/utils.h"
......@@ -89,6 +90,9 @@ const AnfNodePtr ConfusionMulGradFusion::Process(const FuncGraphPtr &graph, cons
auto reduce_sum = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(reduce_sum);
auto mul1 = reduce_sum->input(1);
if (mul1->fullname_with_scope().find("bert/encoder") == std::string::npos) {
return nullptr;
}
if (IsUsedByOthers(graph, mul1)) {
MS_LOG(INFO) << "Mul1 is used by others, quit fusion!";
return nullptr;
......
......@@ -50,9 +50,22 @@ CNodePtr CreateReluV2(const FuncGraphPtr &graph, const CNodePtr &relu) {
MS_EXCEPTION_IF_NULL(new_node);
new_node->set_scope(relu->scope());
// ReluV2's 2rd output is mask whose data type is uint8 and value is 0 or 1, so shape is an empty vector
// ReluV2's 2rd output is mask whose data type is uint8
TypeId mask_dtype = kNumberTypeUInt8;
std::vector<size_t> mask_shape;
std::vector<size_t> mask_shape = AnfAlgo::GetOutputInferShape(relu, 0);
if (mask_shape.size() != 4) {
MS_LOG(WARNING) << "relu's infer shape size not equal 4";
return nullptr;
}
auto input_dtype = AnfAlgo::GetPrevNodeOutputInferDataType(relu, 0);
if (input_dtype == kNumberTypeUInt8 || input_dtype == kNumberTypeInt8) {
mask_shape[1] = (mask_shape[1] + 31) / 32;
mask_shape.push_back(4);
} else {
mask_shape[1] = (mask_shape[1] + 15) / 16;
mask_shape.push_back(2);
}
auto types = {AnfAlgo::GetOutputInferDataType(relu, 0), mask_dtype};
auto shapes = {AnfAlgo::GetOutputInferShape(relu, 0), mask_shape};
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, new_node.get());
......@@ -91,6 +104,9 @@ const AnfNodePtr DereluFusion::Process(const FuncGraphPtr &graph, const AnfNodeP
MS_EXCEPTION_IF_NULL(relu);
auto relu_v2 = CreateReluV2(graph, relu);
if (relu_v2 == nullptr) {
return nullptr;
}
std::vector<AnfNodePtr> relu_v2_node_outputs;
CreateMultipleOutputsOfAnfNode(graph, relu_v2, kReluV2OutputNum, &relu_v2_node_outputs);
......
......@@ -120,7 +120,7 @@ constexpr auto kStreamActiveOpName = "StreamActive";
constexpr auto kAssignAddOpName = "AssignAdd";
constexpr auto kSendOpName = "Send";
constexpr auto kRecvOpName = "Recv";
constexpr auto kReluV2OpName = "ReluV2";
constexpr auto kReluV2OpName = "ReLUV2";
constexpr auto kReluGradV2OpName = "ReluGradV2";
// attr key name
......
......@@ -32,6 +32,11 @@ class TestHWOptimizeConfusionMulGradFusion : public BackendCommon {
TEST_F(TestHWOptimizeConfusionMulGradFusion, test_fusion) {
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_confusion_mul_grad_fusion", "before");
EXPECT_NE(g, nullptr);
auto bert_scope = std::make_shared<Scope>("bert/encoder");
for (auto node : TopoSort(g->get_return())) {
node->set_scope(bert_scope);
}
std::vector<int> shp{1, 1, 1, 1};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
AbstractBasePtrList args_spec_list;
......
......@@ -17,7 +17,7 @@ from mindspore.ops import Primitive
relu = P.ReLU()
relu_grad = Primitive('ReluGrad')
relu_v2 = Primitive('ReluV2')
relu_v2 = Primitive('ReLUV2')
relu_grad_v2 = Primitive('ReluGradV2')
make_tuple = Primitive('make_tuple')
tuple_getitem = Primitive('tuple_getitem')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册