未验证 提交 6c0982b9 编写于 作者: W wanghuancoder 提交者: GitHub

fix some errmsg report, in framework/ir/mkldnn (#25467)

* fix paddle/fluid/framework/ir/mkldnn/ error msg reoprt, test=develop

* modify error msg reoprt, about errortype, grammar, supplementary infor, test=develop

* modified some error descriptions, test=develop
上级 fce64662
...@@ -22,7 +22,8 @@ namespace framework { ...@@ -22,7 +22,8 @@ namespace framework {
namespace ir { namespace ir {
void ConvActivationFusePass::ApplyImpl(ir::Graph* graph) const { void ConvActivationFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(graph, "graph cannot be nullptr."); PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init("conv_activation_mkldnn_fuse", graph); FusePassBase::Init("conv_activation_mkldnn_fuse", graph);
GraphPatternDetector gpd; GraphPatternDetector gpd;
...@@ -75,7 +76,8 @@ void ConvActivationFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -75,7 +76,8 @@ void ConvActivationFusePass::ApplyImpl(ir::Graph* graph) const {
GraphSafeRemoveNodes(graph, {activation, conv_out}); GraphSafeRemoveNodes(graph, {activation, conv_out});
PADDLE_ENFORCE_GT(subgraph.count(conv_input), 0UL, PADDLE_ENFORCE_GT(subgraph.count(conv_input), 0UL,
"subgraph has to contain conv_input node."); platform::errors::InvalidArgument(
"Subgraph has to contain conv input node."));
IR_NODE_LINK_TO(conv, activation_out); IR_NODE_LINK_TO(conv, activation_out);
found_conv_activation_count++; found_conv_activation_count++;
}; };
......
...@@ -26,7 +26,11 @@ namespace ir { ...@@ -26,7 +26,11 @@ namespace ir {
template <typename BinaryOperation> template <typename BinaryOperation>
LoDTensor tensor_apply_eltwise(const LoDTensor& vec_a, const LoDTensor& vec_b, LoDTensor tensor_apply_eltwise(const LoDTensor& vec_a, const LoDTensor& vec_b,
BinaryOperation f) { BinaryOperation f) {
PADDLE_ENFORCE_EQ(vec_a.dims(), vec_b.dims()); PADDLE_ENFORCE_EQ(vec_a.dims(), vec_b.dims(),
platform::errors::InvalidArgument(
"Input two tensors must have same shape, but they are "
"different: %s, %s.",
vec_a.dims(), vec_b.dims()));
LoDTensor vec_y; LoDTensor vec_y;
vec_y.Resize(vec_a.dims()); vec_y.Resize(vec_a.dims());
const float* a = vec_a.data<float>(); const float* a = vec_a.data<float>();
...@@ -39,11 +43,13 @@ LoDTensor tensor_apply_eltwise(const LoDTensor& vec_a, const LoDTensor& vec_b, ...@@ -39,11 +43,13 @@ LoDTensor tensor_apply_eltwise(const LoDTensor& vec_a, const LoDTensor& vec_b,
} }
void ConvBiasFusePass::ApplyImpl(ir::Graph* graph) const { void ConvBiasFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE(graph); PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init(name_scope_, graph); FusePassBase::Init(name_scope_, graph);
auto* scope = param_scope(); auto* scope = param_scope();
PADDLE_ENFORCE(scope); PADDLE_ENFORCE_NOT_NULL(
scope, platform::errors::InvalidArgument("Scope cannot be nullptr."));
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto* conv_input = auto* conv_input =
...@@ -68,7 +74,9 @@ void ConvBiasFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -68,7 +74,9 @@ void ConvBiasFusePass::ApplyImpl(ir::Graph* graph) const {
// elementwise_add op // elementwise_add op
GET_IR_NODE_FROM_SUBGRAPH(eltwise, eltwise, conv_bias_pattern); GET_IR_NODE_FROM_SUBGRAPH(eltwise, eltwise, conv_bias_pattern);
PADDLE_ENFORCE(subgraph.count(conv_input)); PADDLE_ENFORCE_NE(
subgraph.count(conv_input), 0,
platform::errors::NotFound("Detector did not find conv input."));
// check if fuse can be done and if MKL-DNN should be used // check if fuse can be done and if MKL-DNN should be used
FuseOptions fuse_option = FindFuseOption(*conv, *eltwise); FuseOptions fuse_option = FindFuseOption(*conv, *eltwise);
...@@ -86,10 +94,16 @@ void ConvBiasFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -86,10 +94,16 @@ void ConvBiasFusePass::ApplyImpl(ir::Graph* graph) const {
if (has_bias && conv->Op()->Input("Bias").size() > 0) { if (has_bias && conv->Op()->Input("Bias").size() > 0) {
auto conv_bias_names = conv->Op()->Input("Bias"); auto conv_bias_names = conv->Op()->Input("Bias");
// add eltwise bias to existing conv bias // add eltwise bias to existing conv bias
PADDLE_ENFORCE_EQ(conv_bias_names.size(), 1); PADDLE_ENFORCE_EQ(conv_bias_names.size(), 1,
platform::errors::NotFound("Can not find var Bias."));
auto* conv_bias_var = scope->FindVar(conv_bias_names[0]); auto* conv_bias_var = scope->FindVar(conv_bias_names[0]);
auto* conv_bias_tensor = conv_bias_var->GetMutable<LoDTensor>(); auto* conv_bias_tensor = conv_bias_var->GetMutable<LoDTensor>();
PADDLE_ENFORCE_EQ(conv_bias_tensor->dims(), eltwise_bias_tensor->dims()); PADDLE_ENFORCE_EQ(
conv_bias_tensor->dims(), eltwise_bias_tensor->dims(),
platform::errors::InvalidArgument(
"Conv bias tensor and eltwise bias tensor "
"must have same shape, but they are different: %s, %s.",
conv_bias_tensor->dims(), eltwise_bias_tensor->dims()));
*conv_bias_tensor = tensor_apply_eltwise( *conv_bias_tensor = tensor_apply_eltwise(
*conv_bias_tensor, *eltwise_bias_tensor, std::plus<float>()); *conv_bias_tensor, *eltwise_bias_tensor, std::plus<float>());
......
...@@ -39,7 +39,10 @@ void ConvConcatReLUFusePass::FindConcatWithConvs( ...@@ -39,7 +39,10 @@ void ConvConcatReLUFusePass::FindConcatWithConvs(
for (auto node : concat_inputs) { for (auto node : concat_inputs) {
auto prev_op_node = node->inputs; auto prev_op_node = node->inputs;
PADDLE_ENFORCE_EQ(prev_op_node.size(), 1); PADDLE_ENFORCE_EQ(prev_op_node.size(), 1,
platform::errors::InvalidArgument(
"Node(%s) input size(%d) must be 1.", node->Name(),
prev_op_node.size()));
auto* conv_op = prev_op_node[0]; auto* conv_op = prev_op_node[0];
if (conv_op->Op()->Type() != "conv2d") return; if (conv_op->Op()->Type() != "conv2d") return;
...@@ -103,7 +106,8 @@ void ConvConcatReLUFusePass::FuseConvConcatReLU( ...@@ -103,7 +106,8 @@ void ConvConcatReLUFusePass::FuseConvConcatReLU(
} }
void ConvConcatReLUFusePass::ApplyImpl(ir::Graph* graph) const { void ConvConcatReLUFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE(graph); PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init(name_scope_, graph); FusePassBase::Init(name_scope_, graph);
std::unordered_map<const Node*, int> concat_with_convs_counter; std::unordered_map<const Node*, int> concat_with_convs_counter;
......
...@@ -68,10 +68,10 @@ void CPUQuantizePass::QuantizeInput(Graph* g, Node* op, Node* input, ...@@ -68,10 +68,10 @@ void CPUQuantizePass::QuantizeInput(Graph* g, Node* op, Node* input,
auto inputs = op->Op()->InputNames(); auto inputs = op->Op()->InputNames();
bool name_found = bool name_found =
std::find(inputs.begin(), inputs.end(), input_name) != inputs.end(); std::find(inputs.begin(), inputs.end(), input_name) != inputs.end();
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(name_found, true,
name_found, true, platform::errors::InvalidArgument(
platform::errors::InvalidArgument("%s isn't the input of the %s operator", "Var(%s) isn't the input of the %s operator.",
input_name, op->Op()->Type())); input_name, op->Op()->Type()));
unsigned max = is_unsigned ? U8_MAX : S8_MAX; unsigned max = is_unsigned ? U8_MAX : S8_MAX;
float scale = scale_to_one * max; float scale = scale_to_one * max;
...@@ -110,8 +110,14 @@ void CPUQuantizePass::QuantizeInputs(Graph* g, Node* op, std::string input_name, ...@@ -110,8 +110,14 @@ void CPUQuantizePass::QuantizeInputs(Graph* g, Node* op, std::string input_name,
std::string scale_attr_name) const { std::string scale_attr_name) const {
auto inputs = op->inputs; auto inputs = op->inputs;
auto output = op->outputs[0]; auto output = op->outputs[0];
PADDLE_ENFORCE_GE(inputs.size(), 1); PADDLE_ENFORCE_GE(inputs.size(), 1,
PADDLE_ENFORCE_EQ(op->outputs.size(), 1); platform::errors::InvalidArgument(
"OP(%s)'s inputs(%d) must be equal or greater than 1.",
op->Name(), inputs.size()));
PADDLE_ENFORCE_EQ(op->outputs.size(), 1,
platform::errors::InvalidArgument(
"OP(%s)'s outputs(%d) must be equal to 1.", op->Name(),
op->outputs.size()));
// create a quantize op desc prototype // create a quantize op desc prototype
OpDesc q_desc; OpDesc q_desc;
...@@ -159,8 +165,8 @@ void CPUQuantizePass::DequantizeOutput(Graph* g, Node* op, Node* output, ...@@ -159,8 +165,8 @@ void CPUQuantizePass::DequantizeOutput(Graph* g, Node* op, Node* output,
std::find(outputs.begin(), outputs.end(), output_name) != outputs.end(); std::find(outputs.begin(), outputs.end(), output_name) != outputs.end();
PADDLE_ENFORCE_EQ(name_found, true, PADDLE_ENFORCE_EQ(name_found, true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"%s isn't the output of the %s operator", output_name, "Var(%s) isn't the output of the %s operator.",
op->Op()->Type())); output_name, op->Op()->Type()));
unsigned max = is_unsigned ? U8_MAX : S8_MAX; unsigned max = is_unsigned ? U8_MAX : S8_MAX;
float scale = scale_to_one * max; float scale = scale_to_one * max;
...@@ -682,10 +688,12 @@ void CPUQuantizePass::QuantizeMatmul(Graph* graph) const { ...@@ -682,10 +688,12 @@ void CPUQuantizePass::QuantizeMatmul(Graph* graph) const {
bool is_x_unsigned{false}, is_y_unsigned{false}; bool is_x_unsigned{false}, is_y_unsigned{false};
auto input_x_scale = GetScaleValueForNode(matmul_in_x, &is_x_unsigned); auto input_x_scale = GetScaleValueForNode(matmul_in_x, &is_x_unsigned);
auto input_y_scale = GetScaleValueForNode(matmul_in_y, &is_y_unsigned); auto input_y_scale = GetScaleValueForNode(matmul_in_y, &is_y_unsigned);
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(is_x_unsigned, is_y_unsigned,
is_x_unsigned, is_y_unsigned, platform::errors::InvalidArgument(
platform::errors::InvalidArgument( "Matmul inputs should have the same "
"Matmul inputs should have the same value of is_unsigned")); "attribute of signed/unsigned, but they "
"are different: x(%d), y(%d).",
is_x_unsigned, is_y_unsigned));
QuantizeInput(g, matmul_op, matmul_in_x, "X", input_x_scale, is_x_unsigned, QuantizeInput(g, matmul_op, matmul_in_x, "X", input_x_scale, is_x_unsigned,
"Scale_x"); "Scale_x");
QuantizeInput(g, matmul_op, matmul_in_y, "Y", input_y_scale, is_y_unsigned, QuantizeInput(g, matmul_op, matmul_in_y, "Y", input_y_scale, is_y_unsigned,
...@@ -785,10 +793,12 @@ void CPUQuantizePass::QuantizeElementwiseAdd(Graph* graph) const { ...@@ -785,10 +793,12 @@ void CPUQuantizePass::QuantizeElementwiseAdd(Graph* graph) const {
void CPUQuantizePass::ApplyImpl(ir::Graph* graph) const { void CPUQuantizePass::ApplyImpl(ir::Graph* graph) const {
VLOG(3) << "Quantizing the graph."; VLOG(3) << "Quantizing the graph.";
PADDLE_ENFORCE(graph); PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init(name_scope_, graph); FusePassBase::Init(name_scope_, graph);
PADDLE_ENFORCE(param_scope()); PADDLE_ENFORCE_NOT_NULL(param_scope(), platform::errors::InvalidArgument(
"Scope cannot be nullptr."));
QuantizeConv(graph, false /* with_residual_data */); QuantizeConv(graph, false /* with_residual_data */);
QuantizeConv(graph, true /* with_residual_data */); QuantizeConv(graph, true /* with_residual_data */);
......
...@@ -75,7 +75,7 @@ void CPUQuantizeSquashPass::DequantQuantSquash( ...@@ -75,7 +75,7 @@ void CPUQuantizeSquashPass::DequantQuantSquash(
BOOST_GET_CONST(float, quant_op->Op()->GetAttr("Scale")); BOOST_GET_CONST(float, quant_op->Op()->GetAttr("Scale"));
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
nodes_keep_counter->find(dequant_out), nodes_keep_counter->end(), nodes_keep_counter->find(dequant_out), nodes_keep_counter->end(),
platform::errors::NotFound("The dequant output node is not found")); platform::errors::NotFound("The dequant output node is not found."));
// check if dequantize op should be kept or removed, decrease the counter // check if dequantize op should be kept or removed, decrease the counter
bool keep_dequant = (*nodes_keep_counter)[dequant_out]-- > 1; bool keep_dequant = (*nodes_keep_counter)[dequant_out]-- > 1;
...@@ -153,8 +153,9 @@ void CPUQuantizeSquashPass::OpRequantSquash(Graph* graph) const { ...@@ -153,8 +153,9 @@ void CPUQuantizeSquashPass::OpRequantSquash(Graph* graph) const {
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
any_op_output_name.empty(), true, any_op_output_name.empty(), true,
platform::errors::NotFound("Operator before requantize operator " platform::errors::NotFound("Operator before requantize operator(%s) "
"should have requantize input as output")); "should have requantize input as output.",
requant_in->Name()));
float requant_scale_out = float requant_scale_out =
BOOST_GET_CONST(float, requant_op->Op()->GetAttr("Scale_out")); BOOST_GET_CONST(float, requant_op->Op()->GetAttr("Scale_out"));
...@@ -195,10 +196,11 @@ void CPUQuantizeSquashPass::RequantOpSquash(Graph* graph) const { ...@@ -195,10 +196,11 @@ void CPUQuantizeSquashPass::RequantOpSquash(Graph* graph) const {
for (auto input_name : any_op->Op()->Input(name)) for (auto input_name : any_op->Op()->Input(name))
if (input_name == requant_out->Name()) any_op_input_name = name; if (input_name == requant_out->Name()) any_op_input_name = name;
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(any_op_input_name.empty(), true,
any_op_input_name.empty(), true, platform::errors::NotFound(
platform::errors::NotFound("The operator after requantize operator " "The operator after requantize operator(%s) "
"should have requantize output as input")); "should have requantize output as input.",
requant_out->Name()));
float requant_scale_in = float requant_scale_in =
boost::get<float>(requant_op->Op()->GetAttr("Scale_in")); boost::get<float>(requant_op->Op()->GetAttr("Scale_in"));
...@@ -206,11 +208,14 @@ void CPUQuantizeSquashPass::RequantOpSquash(Graph* graph) const { ...@@ -206,11 +208,14 @@ void CPUQuantizeSquashPass::RequantOpSquash(Graph* graph) const {
if (any_op->Op()->Type() == "matmul") if (any_op->Op()->Type() == "matmul")
scale_name = any_op_input_name == "X" ? "Scale_x" : "Scale_y"; scale_name = any_op_input_name == "X" ? "Scale_x" : "Scale_y";
PADDLE_ENFORCE_EQ(requant_op->Op()->GetAttrIfExists<float>("Scale_out"), PADDLE_ENFORCE_EQ(
any_op->Op()->GetAttrIfExists<float>(scale_name), requant_op->Op()->GetAttrIfExists<float>("Scale_out"),
platform::errors::InvalidArgument( any_op->Op()->GetAttrIfExists<float>(scale_name),
"The operator after requantize should have input " platform::errors::InvalidArgument(
"scale equal to requantize output scale")); "The operator after requantize should have input "
"scale(%f) equal to requantize output scale(%f).",
any_op->Op()->GetAttrIfExists<float>(scale_name),
requant_op->Op()->GetAttrIfExists<float>("Scale_out")));
any_op->Op()->SetAttr(scale_name, requant_scale_in); any_op->Op()->SetAttr(scale_name, requant_scale_in);
any_op->Op()->SetInput(any_op_input_name, any_op->Op()->SetInput(any_op_input_name,
std::vector<std::string>({requant_in->Name()})); std::vector<std::string>({requant_in->Name()}));
...@@ -286,8 +291,9 @@ void CPUQuantizeSquashPass::MultipleQuantizeSquash(Graph* graph) const { ...@@ -286,8 +291,9 @@ void CPUQuantizeSquashPass::MultipleQuantizeSquash(Graph* graph) const {
auto* first_quant_out = first_quant_op->outputs[0]; auto* first_quant_out = first_quant_op->outputs[0];
float scale = first_quant_op->Op()->GetAttrIfExists<float>("Scale"); float scale = first_quant_op->Op()->GetAttrIfExists<float>("Scale");
PADDLE_ENFORCE_NE(scale, 0, platform::errors::InvalidArgument( PADDLE_ENFORCE_NE(scale, 0,
"Quantize scale should not be equal 0")); platform::errors::InvalidArgument(
"Quantize scale(%f) should not be equal 0.", scale));
for (int iter = prev_out->outputs.size() - 1; iter >= 0; iter--) { for (int iter = prev_out->outputs.size() - 1; iter >= 0; iter--) {
auto quant_op = prev_out->outputs[iter]; auto quant_op = prev_out->outputs[iter];
...@@ -304,8 +310,9 @@ void CPUQuantizeSquashPass::MultipleQuantizeSquash(Graph* graph) const { ...@@ -304,8 +310,9 @@ void CPUQuantizeSquashPass::MultipleQuantizeSquash(Graph* graph) const {
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
last_op_input_name.empty(), true, last_op_input_name.empty(), true,
platform::errors::NotFound("Operator after quantize operator " platform::errors::NotFound("Operator after quantize operator(%s) "
"should has quantize output as input")); "should has quantize output as input.",
quant_out->Name()));
last_op->Op()->SetInput( last_op->Op()->SetInput(
last_op_input_name, last_op_input_name,
std::vector<std::string>({first_quant_out->Name()})); std::vector<std::string>({first_quant_out->Name()}));
...@@ -345,10 +352,12 @@ void CPUQuantizeSquashPass::DequantScaleSquash(Graph* graph) const { ...@@ -345,10 +352,12 @@ void CPUQuantizeSquashPass::DequantScaleSquash(Graph* graph) const {
PADDLE_ENFORCE_GT(dequant_scale, 0.0f, PADDLE_ENFORCE_GT(dequant_scale, 0.0f,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Dequantize scale should have positive value")); "Dequantize scale(%f) should have positive value.",
dequant_scale));
PADDLE_ENFORCE_GT(scale_scale, 0.0f, PADDLE_ENFORCE_GT(scale_scale, 0.0f,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Scale of scale op should have positive value")); "Scale(%f) of scale op should have positive value.",
scale_scale));
dequant_op->Op()->SetAttr("Scale", dequant_scale / scale_scale); dequant_op->Op()->SetAttr("Scale", dequant_scale / scale_scale);
dequant_op->Op()->SetOutput( dequant_op->Op()->SetOutput(
...@@ -367,8 +376,8 @@ void CPUQuantizeSquashPass::DequantScaleSquash(Graph* graph) const { ...@@ -367,8 +376,8 @@ void CPUQuantizeSquashPass::DequantScaleSquash(Graph* graph) const {
void CPUQuantizeSquashPass::ApplyImpl(ir::Graph* graph) const { void CPUQuantizeSquashPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
graph, graph,
platform::errors::NotFound( platform::errors::InvalidArgument(
"The graph in function CPUQuantizeSquashPass::ApplyImpl is null")); "The graph in function CPUQuantizeSquashPass::ApplyImpl is null."));
FusePassBase::Init("cpu_quantize_squash_pass", graph); FusePassBase::Init("cpu_quantize_squash_pass", graph);
std::unordered_map<const Node*, int> nodes_keep_counter; std::unordered_map<const Node*, int> nodes_keep_counter;
......
...@@ -57,7 +57,7 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, ...@@ -57,7 +57,7 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
PADDLE_ENFORCE_EQ(inputs.size(), 2UL, PADDLE_ENFORCE_EQ(inputs.size(), 2UL,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The fc inputs should contain input and weights, but " "The fc inputs should contain input and weights, but "
"now the size of inputs is %d", "now the size of inputs is %d.",
inputs.size())); inputs.size()));
op->SetInput("W", {inputs[1]}); op->SetInput("W", {inputs[1]});
op->SetOutput("Out", outputs); op->SetOutput("Out", outputs);
......
...@@ -19,14 +19,17 @@ namespace paddle { ...@@ -19,14 +19,17 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
#define GET_NODE(id, pattern) \ #define GET_NODE(id, pattern) \
PADDLE_ENFORCE(subgraph.count(pattern.RetrieveNode(#id)), \ PADDLE_ENFORCE_NE(subgraph.count(pattern.RetrieveNode(#id)), 0, \
"pattern has no Node called %s", #id); \ platform::errors::InvalidArgument( \
auto* id = subgraph.at(pattern.RetrieveNode(#id)); \ "Pattern has no Node called %s.", #id)); \
PADDLE_ENFORCE_NOT_NULL(id, "subgraph has no node %s", #id); auto* id = subgraph.at(pattern.RetrieveNode(#id)); \
PADDLE_ENFORCE_NOT_NULL( \
id, platform::errors::InvalidArgument("Subgraph has no node %s.", #id));
void DepthwiseConvMKLDNNPass::ApplyImpl(ir::Graph* graph) const { void DepthwiseConvMKLDNNPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE(graph); PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init("depthwise_conv_mkldnn_pass", graph); FusePassBase::Init("depthwise_conv_mkldnn_pass", graph);
GraphPatternDetector gpd; GraphPatternDetector gpd;
......
...@@ -46,12 +46,15 @@ void ScaleMatmulFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -46,12 +46,15 @@ void ScaleMatmulFusePass::ApplyImpl(ir::Graph* graph) const {
if (scale_op->Op()->GetAttrIfExists<float>("bias") == 0.0) { if (scale_op->Op()->GetAttrIfExists<float>("bias") == 0.0) {
auto matmul_alpha = matmul_op->Op()->GetAttrIfExists<float>("alpha"); auto matmul_alpha = matmul_op->Op()->GetAttrIfExists<float>("alpha");
auto scale_scale = scale_op->Op()->GetAttrIfExists<float>("scale"); auto scale_scale = scale_op->Op()->GetAttrIfExists<float>("scale");
PADDLE_ENFORCE_GT(matmul_alpha, 0.0f, PADDLE_ENFORCE_GT(
platform::errors::InvalidArgument( matmul_alpha, 0.0f,
"Alpha of matmul op should have positive value")); platform::errors::InvalidArgument(
"Alpha(%f) of matmul op should have positive value.",
matmul_alpha));
PADDLE_ENFORCE_GT(scale_scale, 0.0f, PADDLE_ENFORCE_GT(scale_scale, 0.0f,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Scale of scale op should have positive value")); "Scale(%f) of scale op should have positive value.",
scale_scale));
std::string matmul_op_input_name; std::string matmul_op_input_name;
for (auto name : matmul_op->Op()->InputNames()) for (auto name : matmul_op->Op()->InputNames())
...@@ -60,8 +63,9 @@ void ScaleMatmulFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -60,8 +63,9 @@ void ScaleMatmulFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
matmul_op_input_name.empty(), true, matmul_op_input_name.empty(), true,
platform::errors::NotFound("Operator after scale operator " platform::errors::NotFound("Operator after scale operator(%s) "
"should have scale output as input")); "should have scale output as input.",
scale_out->Name()));
matmul_op->Op()->SetAttr("alpha", matmul_alpha * scale_scale); matmul_op->Op()->SetAttr("alpha", matmul_alpha * scale_scale);
matmul_op->Op()->SetInput(matmul_op_input_name, matmul_op->Op()->SetInput(matmul_op_input_name,
std::vector<std::string>({scale_in->Name()})); std::vector<std::string>({scale_in->Name()}));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册