未验证 提交 8eaff62d 编写于 作者: S Sylwester Fraczek 提交者: GitHub

add function FindInputNameByVarName (#46759)

* Add methods that find input or output name by var name

* kind of bugfix - initialize variables

* ci fix

* review fixed
上级 0ce5554c
...@@ -873,7 +873,7 @@ void CPUQuantizePass::QuantizeElementwise( ...@@ -873,7 +873,7 @@ void CPUQuantizePass::QuantizeElementwise(
auto x_name = elementwise_op->Op()->Input("X"); auto x_name = elementwise_op->Op()->Input("X");
auto y_name = elementwise_op->Op()->Input("Y"); auto y_name = elementwise_op->Op()->Input("Y");
Node *elementwise_x, *elementwise_y; Node *elementwise_x{nullptr}, *elementwise_y{nullptr};
for (auto& input : elementwise_op->inputs) { for (auto& input : elementwise_op->inputs) {
if (input->Name() == x_name[0]) elementwise_x = input; if (input->Name() == x_name[0]) elementwise_x = input;
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include <vector> #include <vector>
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/mkldnn_helper.h"
#include "paddle/fluid/string/pretty_log.h" #include "paddle/fluid/string/pretty_log.h"
namespace paddle { namespace paddle {
...@@ -262,10 +263,8 @@ void CPUQuantizeSquashPass::OpRequantSquash(Graph* graph) const { ...@@ -262,10 +263,8 @@ void CPUQuantizeSquashPass::OpRequantSquash(Graph* graph) const {
GET_IR_NODE_FROM_SUBGRAPH(requant_out, requant_out, op_requant_pattern); GET_IR_NODE_FROM_SUBGRAPH(requant_out, requant_out, op_requant_pattern);
if (requant_in->outputs.size() == 1) { if (requant_in->outputs.size() == 1) {
std::string any_op_output_name; std::string any_op_output_name =
for (auto name : any_op->Op()->OutputNames()) FindOutputNameByVarName(any_op->Op(), requant_in->Name());
for (auto output_name : any_op->Op()->Output(name))
if (output_name == requant_in->Name()) any_op_output_name = name;
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
any_op_output_name.empty(), any_op_output_name.empty(),
...@@ -308,10 +307,8 @@ void CPUQuantizeSquashPass::RequantOpSquash(Graph* graph) const { ...@@ -308,10 +307,8 @@ void CPUQuantizeSquashPass::RequantOpSquash(Graph* graph) const {
GET_IR_NODE_FROM_SUBGRAPH(any_op, any_op, requant_op_pattern); GET_IR_NODE_FROM_SUBGRAPH(any_op, any_op, requant_op_pattern);
if (requant_out->outputs.size() == 1) { if (requant_out->outputs.size() == 1) {
std::string any_op_input_name; std::string any_op_input_name =
for (auto name : any_op->Op()->InputNames()) FindInputNameByVarName(any_op->Op(), requant_out->Name());
for (auto input_name : any_op->Op()->Input(name))
if (input_name == requant_out->Name()) any_op_input_name = name;
PADDLE_ENFORCE_NE(any_op_input_name.empty(), PADDLE_ENFORCE_NE(any_op_input_name.empty(),
true, true,
...@@ -374,10 +371,8 @@ void CPUQuantizeSquashPass::OpDequantSquash(Graph* graph) const { ...@@ -374,10 +371,8 @@ void CPUQuantizeSquashPass::OpDequantSquash(Graph* graph) const {
return; return;
} }
// Find the name of the output linking any_op to dequant_in // Find the name of the output linking any_op to dequant_in
std::string output_name; std::string output_name =
for (auto name : any_op->Op()->OutputNames()) FindOutputNameByVarName(any_op->Op(), dequant_in->Name());
for (auto out_name : any_op->Op()->Output(name))
if (out_name == dequant_in->Name()) output_name = name;
if (output_name.empty()) return; if (output_name.empty()) return;
...@@ -430,17 +425,16 @@ void CPUQuantizeSquashPass::MultipleQuantizeSquash(Graph* graph) const { ...@@ -430,17 +425,16 @@ void CPUQuantizeSquashPass::MultipleQuantizeSquash(Graph* graph) const {
quant_op->Op()->GetAttrIfExists<float>("Shift") == shift) { quant_op->Op()->GetAttrIfExists<float>("Shift") == shift) {
auto quant_out = quant_op->outputs[0]; auto quant_out = quant_op->outputs[0];
auto last_op = quant_out->outputs[0]; auto last_op = quant_out->outputs[0];
auto last_op_op = last_op->Op();
std::string last_op_input_name; std::string last_op_input_name =
for (auto name : last_op->Op()->InputNames()) FindInputNameByVarName(last_op_op, quant_out->Name());
for (auto input_name : last_op->Op()->Input(name))
if (input_name == quant_out->Name()) last_op_input_name = name;
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
last_op_input_name.empty(), last_op_input_name.empty(),
true, true,
platform::errors::NotFound("Operator after quantize operator(%s) " platform::errors::NotFound("Operator after quantize operator(%s) "
"should has quantize output as input.", "should have quantize output as input.",
quant_out->Name())); quant_out->Name()));
// update the next operator input, // update the next operator input,
......
...@@ -19,6 +19,7 @@ limitations under the License. */ ...@@ -19,6 +19,7 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/mkldnn_helper.h"
#include "paddle/fluid/string/pretty_log.h" #include "paddle/fluid/string/pretty_log.h"
namespace paddle { namespace paddle {
...@@ -110,10 +111,8 @@ void ScaleMatmulFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -110,10 +111,8 @@ void ScaleMatmulFusePass::ApplyImpl(ir::Graph* graph) const {
"Scale(%f) of scale op should have positive value.", "Scale(%f) of scale op should have positive value.",
scale_scale)); scale_scale));
std::string matmul_op_input_name; std::string matmul_op_input_name =
for (auto name : matmul_op->Op()->InputNames()) FindInputNameByVarName(matmul_op->Op(), scale_out->Name());
for (auto input_name : matmul_op->Op()->Input(name))
if (input_name == scale_out->Name()) matmul_op_input_name = name;
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
matmul_op_input_name.empty(), matmul_op_input_name.empty(),
......
...@@ -632,4 +632,22 @@ bool constexpr is_int8() { ...@@ -632,4 +632,22 @@ bool constexpr is_int8() {
} }
} // namespace platform } // namespace platform
inline std::string FindInputNameByVarName(framework::OpDesc* op,
const std::string& searched_name) {
std::string ret;
for (const auto& name : op->InputNames())
for (const auto& input_name : op->Input(name))
if (input_name == searched_name) ret = name;
return ret;
}
inline std::string FindOutputNameByVarName(framework::OpDesc* op,
const std::string& searched_name) {
std::string ret;
for (const auto& name : op->OutputNames())
for (const auto& output_name : op->Output(name))
if (output_name == searched_name) ret = name;
return ret;
}
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册