未验证 提交 259b0aad 编写于 作者: W wz1qqx 提交者: GitHub

[XPU] fix error pattern and rename max name (#52726)

上级 327c0e4d
...@@ -99,13 +99,15 @@ Conv2dXPUPattern::Conv2dXPUPattern(PDPattern* pattern, ...@@ -99,13 +99,15 @@ Conv2dXPUPattern::Conv2dXPUPattern(PDPattern* pattern,
auto conv = pattern->NewNode(conv_repr())->assert_is_op(conv_type_); auto conv = pattern->NewNode(conv_repr())->assert_is_op(conv_type_);
auto input = pattern->NewNode(input_repr()) auto input = pattern->NewNode(input_repr())
->assert_is_op_input(conv_type_, "Input") ->assert_is_op_input(conv_type_, "Input")
->AsInput(); ->AsInput()
->assert_more([](Node* node) {
return node->Var()->GetShape().size() == 4;
});
auto conv_filter = pattern->NewNode(conv_filter_repr()) auto conv_filter = pattern->NewNode(conv_filter_repr())
->assert_is_op_input(conv_type_, "Filter") ->assert_is_op_input(conv_type_, "Filter")
->AsInput(); ->AsInput();
auto conv_out = pattern->NewNode(conv_out_repr()) auto conv_out = pattern->NewNode(conv_out_repr())
->assert_is_op_output(conv_type_, "Output") ->assert_is_op_output(conv_type_, "Output");
->assert_var_not_persistable();
conv->LinksFrom({input, conv_filter}).LinksTo({conv_out}); conv->LinksFrom({input, conv_filter}).LinksTo({conv_out});
// ew_bias_add op // ew_bias_add op
PDNode* ew_bias_add = nullptr; PDNode* ew_bias_add = nullptr;
...@@ -116,11 +118,17 @@ Conv2dXPUPattern::Conv2dXPUPattern(PDPattern* pattern, ...@@ -116,11 +118,17 @@ Conv2dXPUPattern::Conv2dXPUPattern(PDPattern* pattern,
ew_bias_add_y = pattern->NewNode(ew_bias_add_y_repr()) ew_bias_add_y = pattern->NewNode(ew_bias_add_y_repr())
->assert_is_op_input("elementwise_add", "Y") ->assert_is_op_input("elementwise_add", "Y")
->assert_is_persistable_var() ->assert_is_persistable_var()
->assert_has_n_outputs(1); ->assert_has_n_outputs(1)
->assert_more([](Node* node) {
return node->Var()->GetShape().size() == 1;
});
ew_bias_add = ew_bias_add =
pattern->NewNode(ew_bias_add_repr())->assert_is_op("elementwise_add"); pattern->NewNode(ew_bias_add_repr())->assert_is_op("elementwise_add");
ew_bias_add_out = pattern->NewNode(ew_bias_add_out_repr()) ew_bias_add_out = pattern->NewNode(ew_bias_add_out_repr())
->assert_is_op_output("elementwise_add", "Out"); ->assert_is_op_output("elementwise_add", "Out");
if (with_bn_ || with_branch_ || !act_type_.empty()) {
ew_bias_add_out->assert_has_n_outputs(1);
}
ew_bias_add->LinksFrom({conv_out, ew_bias_add_y}) ew_bias_add->LinksFrom({conv_out, ew_bias_add_y})
.LinksTo({ew_bias_add_out}); .LinksTo({ew_bias_add_out});
} else { } else {
...@@ -159,6 +167,9 @@ Conv2dXPUPattern::Conv2dXPUPattern(PDPattern* pattern, ...@@ -159,6 +167,9 @@ Conv2dXPUPattern::Conv2dXPUPattern(PDPattern* pattern,
bn = pattern->NewNode(bn_repr())->assert_is_op("batch_norm"); bn = pattern->NewNode(bn_repr())->assert_is_op("batch_norm");
bn_out = bn_out =
pattern->NewNode(bn_out_repr())->assert_is_op_output("batch_norm", "Y"); pattern->NewNode(bn_out_repr())->assert_is_op_output("batch_norm", "Y");
if (with_branch_ || !act_type_.empty()) {
bn_out->assert_has_n_outputs(1);
}
bn_mean_out = pattern->NewNode(bn_mean_out_repr()) bn_mean_out = pattern->NewNode(bn_mean_out_repr())
->assert_is_op_output("batch_norm", "MeanOut"); ->assert_is_op_output("batch_norm", "MeanOut");
bn_saved_mean = pattern->NewNode(bn_saved_mean_repr()) bn_saved_mean = pattern->NewNode(bn_saved_mean_repr())
...@@ -179,23 +190,27 @@ Conv2dXPUPattern::Conv2dXPUPattern(PDPattern* pattern, ...@@ -179,23 +190,27 @@ Conv2dXPUPattern::Conv2dXPUPattern(PDPattern* pattern,
bn_out->assert_is_op_input("elementwise_add", "Y")->AsIntermediate(); bn_out->assert_is_op_input("elementwise_add", "Y")->AsIntermediate();
ew_branch_add_in = pattern->NewNode(ew_branch_add_in_repr()) ew_branch_add_in = pattern->NewNode(ew_branch_add_in_repr())
->assert_is_op_input("elementwise_add", "X") ->assert_is_op_input("elementwise_add", "X")
->AsInput() ->AsInput();
->assert_more([](Node* node) {
return node->Var()->GetShape().size() == 4;
});
} else if (with_branch_y_) { } else if (with_branch_y_) {
bn_out->assert_is_op_input("elementwise_add", "X")->AsIntermediate(); bn_out->assert_is_op_input("elementwise_add", "X")->AsIntermediate();
ew_branch_add_in = pattern->NewNode(ew_branch_add_in_repr()) ew_branch_add_in = pattern->NewNode(ew_branch_add_in_repr())
->assert_is_op_input("elementwise_add", "Y") ->assert_is_op_input("elementwise_add", "Y")
->AsInput() ->AsInput();
->assert_more([](Node* node) {
return node->Var()->GetShape().size() == 4;
});
} }
ew_branch_add = ew_branch_add = pattern->NewNode(ew_branch_add_repr())
pattern->NewNode(ew_branch_add_repr())->assert_is_op("elementwise_add"); ->assert_is_op("elementwise_add")
->assert_more([](Node* node) {
if (node->inputs.size() != 2) {
return false;
}
return node->inputs[0]->Var()->GetShape() ==
node->inputs[1]->Var()->GetShape();
});
ew_branch_add_out = pattern->NewNode(ew_branch_add_out_repr()) ew_branch_add_out = pattern->NewNode(ew_branch_add_out_repr())
->assert_is_op_output("elementwise_add", "Out"); ->assert_is_op_output("elementwise_add", "Out");
if (!act_type_.empty()) {
ew_branch_add_out->assert_has_n_outputs(1);
}
ew_branch_add->LinksFrom({bn_out, ew_branch_add_in}) ew_branch_add->LinksFrom({bn_out, ew_branch_add_in})
.LinksTo({ew_branch_add_out}); .LinksTo({ew_branch_add_out});
} else { } else {
...@@ -401,6 +416,7 @@ int Conv2dXPUFusePass::ApplyImpl(ir::Graph* graph, ...@@ -401,6 +416,7 @@ int Conv2dXPUFusePass::ApplyImpl(ir::Graph* graph,
scope->FindVar(conv_filter->Name())->GetMutable<phi::DenseTensor>(); scope->FindVar(conv_filter->Name())->GetMutable<phi::DenseTensor>();
auto filter_dims = filter_t->dims(); auto filter_dims = filter_t->dims();
bool has_bias = with_bn || with_conv_bias; bool has_bias = with_bn || with_conv_bias;
bool has_branch = with_branch_x || with_branch_y;
// Create conv_fusion_bias (conv bias) variable // Create conv_fusion_bias (conv bias) variable
Node* fusion_bias_node = nullptr; Node* fusion_bias_node = nullptr;
if (has_bias) { if (has_bias) {
...@@ -501,18 +517,17 @@ int Conv2dXPUFusePass::ApplyImpl(ir::Graph* graph, ...@@ -501,18 +517,17 @@ int Conv2dXPUFusePass::ApplyImpl(ir::Graph* graph,
framework::OpDesc conv2d_xpu_op_desc(block); framework::OpDesc conv2d_xpu_op_desc(block);
// set input&output var // set input&output var
conv2d_xpu_op_desc.SetType("conv2d_xpu"); conv2d_xpu_op_desc.SetType("conv2d_xpu");
conv2d_xpu_op_desc.SetInput("input", {input->Name()}); conv2d_xpu_op_desc.SetInput("x", {input->Name()});
conv2d_xpu_op_desc.SetInput("filter", {filter_int16->Name()}); conv2d_xpu_op_desc.SetInput("filter", {filter_int16->Name()});
conv2d_xpu_op_desc.SetInput("filter_max", {filter_max->Name()}); conv2d_xpu_op_desc.SetInput("filter_max", {filter_max->Name()});
conv2d_xpu_op_desc.SetOutput("output", {conv2d_xpu_out_name}); conv2d_xpu_op_desc.SetOutput("out", {conv2d_xpu_out_name});
conv2d_xpu_op_desc.SetOutput("output_max", {conv_out_max_name}); conv2d_xpu_op_desc.SetOutput("out_max", {conv_out_max_name});
// set fusion_bias input node // set fusion_bias input node
if (has_bias) { if (has_bias) {
conv2d_xpu_op_desc.SetInput("bias", {fusion_bias_node->Name()}); conv2d_xpu_op_desc.SetInput("bias", {fusion_bias_node->Name()});
conv2d_xpu_op_desc.SetAttr("has_bias", has_bias);
} }
// set ew_branch_add input node // set ew_branch_add input node
if (ew_branch_add_in != nullptr) { if (ew_branch_add != nullptr) {
conv2d_xpu_op_desc.SetInput("branch", {ew_branch_add_in->Name()}); conv2d_xpu_op_desc.SetInput("branch", {ew_branch_add_in->Name()});
} }
// set attrs of conv2d_xpu // set attrs of conv2d_xpu
...@@ -566,7 +581,8 @@ int Conv2dXPUFusePass::ApplyImpl(ir::Graph* graph, ...@@ -566,7 +581,8 @@ int Conv2dXPUFusePass::ApplyImpl(ir::Graph* graph,
conv2d_xpu_op_desc.SetAttr("place_z", std::vector<int>{10}); conv2d_xpu_op_desc.SetAttr("place_z", std::vector<int>{10});
conv2d_xpu_op_desc.SetAttr("paddings", conv_paddings); conv2d_xpu_op_desc.SetAttr("paddings", conv_paddings);
conv2d_xpu_op_desc.SetAttr("block_lod", std::vector<int>{1}); conv2d_xpu_op_desc.SetAttr("block_lod", std::vector<int>{1});
conv2d_xpu_op_desc.SetAttr("has_branch", with_branch_x || with_branch_y); conv2d_xpu_op_desc.SetAttr("has_branch", has_branch);
conv2d_xpu_op_desc.SetAttr("has_bias", has_bias);
auto* conv2d_xpu = graph->CreateOpNode(&conv2d_xpu_op_desc); auto* conv2d_xpu = graph->CreateOpNode(&conv2d_xpu_op_desc);
IR_NODE_LINK_TO(input, conv2d_xpu); IR_NODE_LINK_TO(input, conv2d_xpu);
......
...@@ -5,14 +5,14 @@ ...@@ -5,14 +5,14 @@
# otherwise the operator only could be used in static mode. # otherwise the operator only could be used in static mode.
- op : conv2d_xpu - op : conv2d_xpu
args : (Tensor input, Tensor input_max, Tensor filter, Tensor filter_max, Tensor bias, Tensor branch, int[] paddings, int[] dilations, int[] strides, str padding_algorithm, int groups, bool has_bias, bool has_branch, int act_type, float act_param) args : (Tensor x, Tensor x_max, Tensor filter, Tensor filter_max, Tensor bias, Tensor branch, int[] paddings, int[] dilations, int[] strides, str padding_algorithm, int groups, bool has_bias, bool has_branch, int act_type, float act_param)
output : Tensor(output), Tensor(output_max) output : Tensor(out), Tensor(out_max)
infer_meta : infer_meta :
func : Conv2dXPUInferMeta func : Conv2dXPUInferMeta
kernel : kernel :
func : conv2d_xpu func : conv2d_xpu
data_type : input data_type : x
optional : bias, branch, input_max optional : bias, branch, x_max
- op : embedding_with_eltwise_add_xpu - op : embedding_with_eltwise_add_xpu
args : (Tensor[] ids, Tensor[] tables, int64_t padding_idx) args : (Tensor[] ids, Tensor[] tables, int64_t padding_idx)
......
...@@ -35,8 +35,8 @@ inline int ConvOutSize(int input_size, ...@@ -35,8 +35,8 @@ inline int ConvOutSize(int input_size,
return output_size; return output_size;
} }
void Conv2dXPUInferMeta(const MetaTensor& input, void Conv2dXPUInferMeta(const MetaTensor& x,
const MetaTensor& input_max, const MetaTensor& x_max,
const MetaTensor& filter, const MetaTensor& filter,
const MetaTensor& filter_max, const MetaTensor& filter_max,
const MetaTensor& bias, const MetaTensor& bias,
...@@ -50,9 +50,9 @@ void Conv2dXPUInferMeta(const MetaTensor& input, ...@@ -50,9 +50,9 @@ void Conv2dXPUInferMeta(const MetaTensor& input,
bool has_branch, bool has_branch,
int act_type, int act_type,
float act_param, float act_param,
MetaTensor* output, MetaTensor* out,
MetaTensor* output_max) { MetaTensor* out_max) {
auto in_dims = input.dims(); auto in_dims = x.dims();
auto filter_dims = filter.dims(); auto filter_dims = filter.dims();
// do some checks // do some checks
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
...@@ -157,8 +157,8 @@ void Conv2dXPUInferMeta(const MetaTensor& input, ...@@ -157,8 +157,8 @@ void Conv2dXPUInferMeta(const MetaTensor& input,
strides[i])); strides[i]));
} }
// set output and output max dims // set output and output max dims
output->set_dims(DDim(out_shape.data(), out_shape.size())); out->set_dims(DDim(out_shape.data(), out_shape.size()));
output_max->set_dims(phi::make_ddim({4})); out_max->set_dims(phi::make_ddim({4}));
} }
void EmbeddingWithEltwiseAddXPUInferMeta( void EmbeddingWithEltwiseAddXPUInferMeta(
......
...@@ -22,8 +22,8 @@ namespace phi { ...@@ -22,8 +22,8 @@ namespace phi {
// Common InferMeta Functions for fusion operators. // Common InferMeta Functions for fusion operators.
// NOTE: The InferMeta Functions in this file are arranged in alphabetic order. // NOTE: The InferMeta Functions in this file are arranged in alphabetic order.
void Conv2dXPUInferMeta(const MetaTensor& input, void Conv2dXPUInferMeta(const MetaTensor& x,
const MetaTensor& input_max, const MetaTensor& x_max,
const MetaTensor& filter, const MetaTensor& filter,
const MetaTensor& filter_max, const MetaTensor& filter_max,
const MetaTensor& bias, const MetaTensor& bias,
...@@ -37,8 +37,8 @@ void Conv2dXPUInferMeta(const MetaTensor& input, ...@@ -37,8 +37,8 @@ void Conv2dXPUInferMeta(const MetaTensor& input,
bool has_branch, bool has_branch,
int act_type, int act_type,
float act_param, float act_param,
MetaTensor* output, MetaTensor* out,
MetaTensor* output_max); MetaTensor* out_max);
void EmbeddingWithEltwiseAddXPUInferMeta( void EmbeddingWithEltwiseAddXPUInferMeta(
const std::vector<const MetaTensor*>& ids, const std::vector<const MetaTensor*>& ids,
......
...@@ -21,8 +21,8 @@ namespace fusion { ...@@ -21,8 +21,8 @@ namespace fusion {
template <typename T, typename Context> template <typename T, typename Context>
void Conv2dXPUKernel(const Context& ctx, void Conv2dXPUKernel(const Context& ctx,
const DenseTensor& input, const DenseTensor& x,
const paddle::optional<DenseTensor>& input_max, const paddle::optional<DenseTensor>& x_max,
const DenseTensor& filter, const DenseTensor& filter,
const DenseTensor& filter_max, const DenseTensor& filter_max,
const paddle::optional<DenseTensor>& bias, const paddle::optional<DenseTensor>& bias,
...@@ -36,10 +36,10 @@ void Conv2dXPUKernel(const Context& ctx, ...@@ -36,10 +36,10 @@ void Conv2dXPUKernel(const Context& ctx,
bool has_branch, bool has_branch,
int act_type, int act_type,
float act_param, float act_param,
DenseTensor* output, DenseTensor* out,
DenseTensor* output_max) { DenseTensor* out_max) {
using XPUType = typename XPUTypeTrait<T>::Type; using XPUType = typename XPUTypeTrait<T>::Type;
auto input_dims = input.dims(); auto input_dims = x.dims();
auto filter_dims = filter.dims(); auto filter_dims = filter.dims();
// update paddings and dilations accoring to padding_algorithm // update paddings and dilations accoring to padding_algorithm
std::vector<int> paddings_vec = paddings; std::vector<int> paddings_vec = paddings;
...@@ -62,17 +62,16 @@ void Conv2dXPUKernel(const Context& ctx, ...@@ -62,17 +62,16 @@ void Conv2dXPUKernel(const Context& ctx,
int win_h = static_cast<int>(filter_dims[2]); int win_h = static_cast<int>(filter_dims[2]);
int win_w = static_cast<int>(filter_dims[3]); int win_w = static_cast<int>(filter_dims[3]);
auto* input_data = reinterpret_cast<const XPUType*>(input.data<T>()); auto* input_data = reinterpret_cast<const XPUType*>(x.data<T>());
const float* input_max_data = input_max.get_ptr() == nullptr const float* input_max_data =
? nullptr x_max.get_ptr() == nullptr ? nullptr : x_max.get_ptr()->data<float>();
: input_max.get_ptr()->data<float>();
auto* branch_data = auto* branch_data =
branch.get_ptr() == nullptr branch.get_ptr() == nullptr
? nullptr ? nullptr
: reinterpret_cast<const XPUType*>(branch.get_ptr()->data<T>()); : reinterpret_cast<const XPUType*>(branch.get_ptr()->data<T>());
const float* bias_data = const float* bias_data =
bias.get_ptr() == nullptr ? nullptr : bias.get_ptr()->data<float>(); bias.get_ptr() == nullptr ? nullptr : bias.get_ptr()->data<float>();
auto* out_data = reinterpret_cast<XPUType*>(ctx.template Alloc<T>(output)); auto* out_data = reinterpret_cast<XPUType*>(ctx.template Alloc<T>(out));
xpu::Activation_t act(static_cast<xpu::Activation_t::act_enum>(act_type)); xpu::Activation_t act(static_cast<xpu::Activation_t::act_enum>(act_type));
if (act_type == xpu::Activation_t::LEAKY_RELU) { if (act_type == xpu::Activation_t::LEAKY_RELU) {
...@@ -98,13 +97,13 @@ void Conv2dXPUKernel(const Context& ctx, ...@@ -98,13 +97,13 @@ void Conv2dXPUKernel(const Context& ctx,
/* int64_t groups */ groups, /* int64_t groups */ groups,
/* const float* in_maxptr */ input_max_data, /* const float* in_maxptr */ input_max_data,
/* const float* filter_maxptr */ filter_max.data<float>(), /* const float* filter_maxptr */ filter_max.data<float>(),
/* float* out_maxptr */ ctx.template Alloc<float>(output_max), /* float* out_maxptr */ ctx.template Alloc<float>(out_max),
/* bool is_nchw */ true, /* bool is_nchw */ true,
/* const float* bias */ bias_data, /* const float* bias */ bias_data,
/* const TY* branch */ branch_data, /* const TY* branch */ branch_data,
/* const baidu::xpu::api::Activation_t& act */ act, /* const baidu::xpu::api::Activation_t& act */ act,
/* const float* branch_maxptr */ nullptr); /* const float* branch_maxptr */ nullptr,
// /* const float* scale */ nullptr); /* const float* scale */ nullptr);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv2d_xpu"); PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv2d_xpu");
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册