未验证 提交 8eaff62d 编写于 作者: S Sylwester Fraczek 提交者: GitHub

add function FindInputNameByVarName (#46759)

* Add methods that find input or output name by var name

* kind of bugfix - initialize variables

* ci fix

* review fixed
上级 0ce5554c
......@@ -873,7 +873,7 @@ void CPUQuantizePass::QuantizeElementwise(
auto x_name = elementwise_op->Op()->Input("X");
auto y_name = elementwise_op->Op()->Input("Y");
Node *elementwise_x, *elementwise_y;
Node *elementwise_x{nullptr}, *elementwise_y{nullptr};
for (auto& input : elementwise_op->inputs) {
if (input->Name() == x_name[0]) elementwise_x = input;
......
......@@ -19,6 +19,7 @@
#include <vector>
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/mkldnn_helper.h"
#include "paddle/fluid/string/pretty_log.h"
namespace paddle {
......@@ -262,10 +263,8 @@ void CPUQuantizeSquashPass::OpRequantSquash(Graph* graph) const {
GET_IR_NODE_FROM_SUBGRAPH(requant_out, requant_out, op_requant_pattern);
if (requant_in->outputs.size() == 1) {
std::string any_op_output_name;
for (auto name : any_op->Op()->OutputNames())
for (auto output_name : any_op->Op()->Output(name))
if (output_name == requant_in->Name()) any_op_output_name = name;
std::string any_op_output_name =
FindOutputNameByVarName(any_op->Op(), requant_in->Name());
PADDLE_ENFORCE_NE(
any_op_output_name.empty(),
......@@ -308,10 +307,8 @@ void CPUQuantizeSquashPass::RequantOpSquash(Graph* graph) const {
GET_IR_NODE_FROM_SUBGRAPH(any_op, any_op, requant_op_pattern);
if (requant_out->outputs.size() == 1) {
std::string any_op_input_name;
for (auto name : any_op->Op()->InputNames())
for (auto input_name : any_op->Op()->Input(name))
if (input_name == requant_out->Name()) any_op_input_name = name;
std::string any_op_input_name =
FindInputNameByVarName(any_op->Op(), requant_out->Name());
PADDLE_ENFORCE_NE(any_op_input_name.empty(),
true,
......@@ -374,10 +371,8 @@ void CPUQuantizeSquashPass::OpDequantSquash(Graph* graph) const {
return;
}
// Find the name of the output linking any_op to dequant_in
std::string output_name;
for (auto name : any_op->Op()->OutputNames())
for (auto out_name : any_op->Op()->Output(name))
if (out_name == dequant_in->Name()) output_name = name;
std::string output_name =
FindOutputNameByVarName(any_op->Op(), dequant_in->Name());
if (output_name.empty()) return;
......@@ -430,17 +425,16 @@ void CPUQuantizeSquashPass::MultipleQuantizeSquash(Graph* graph) const {
quant_op->Op()->GetAttrIfExists<float>("Shift") == shift) {
auto quant_out = quant_op->outputs[0];
auto last_op = quant_out->outputs[0];
auto last_op_op = last_op->Op();
std::string last_op_input_name;
for (auto name : last_op->Op()->InputNames())
for (auto input_name : last_op->Op()->Input(name))
if (input_name == quant_out->Name()) last_op_input_name = name;
std::string last_op_input_name =
FindInputNameByVarName(last_op_op, quant_out->Name());
PADDLE_ENFORCE_NE(
last_op_input_name.empty(),
true,
platform::errors::NotFound("Operator after quantize operator(%s) "
"should has quantize output as input.",
"should have quantize output as input.",
quant_out->Name()));
// update the next operator input,
......
......@@ -19,6 +19,7 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/mkldnn_helper.h"
#include "paddle/fluid/string/pretty_log.h"
namespace paddle {
......@@ -110,10 +111,8 @@ void ScaleMatmulFusePass::ApplyImpl(ir::Graph* graph) const {
"Scale(%f) of scale op should have positive value.",
scale_scale));
std::string matmul_op_input_name;
for (auto name : matmul_op->Op()->InputNames())
for (auto input_name : matmul_op->Op()->Input(name))
if (input_name == scale_out->Name()) matmul_op_input_name = name;
std::string matmul_op_input_name =
FindInputNameByVarName(matmul_op->Op(), scale_out->Name());
PADDLE_ENFORCE_NE(
matmul_op_input_name.empty(),
......
......@@ -632,4 +632,22 @@ bool constexpr is_int8() {
}
} // namespace platform
inline std::string FindInputNameByVarName(framework::OpDesc* op,
const std::string& searched_name) {
std::string ret;
for (const auto& name : op->InputNames())
for (const auto& input_name : op->Input(name))
if (input_name == searched_name) ret = name;
return ret;
}
inline std::string FindOutputNameByVarName(framework::OpDesc* op,
const std::string& searched_name) {
std::string ret;
for (const auto& name : op->OutputNames())
for (const auto& output_name : op->Output(name))
if (output_name == searched_name) ret = name;
return ret;
}
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册