未验证 提交 4a1b7fec 编写于 作者: Z Zhen Wang 提交者: GitHub

Add setting Scope function for the graph class (#17417)

* add set_not_owned function for graph

* add scope set. test=develop

* add scope_ptr enforce not null before setting.test=develop
上级 58d5c61a
...@@ -136,22 +136,22 @@ void PrepareLSTMBias(const LoDTensor& B_forget, const LoDTensor& B_input, ...@@ -136,22 +136,22 @@ void PrepareLSTMBias(const LoDTensor& B_forget, const LoDTensor& B_input,
void PrepareParameters(Graph* graph, const Param& param) { void PrepareParameters(Graph* graph, const Param& param) {
// Check parameters // Check parameters
PADDLE_ENFORCE(graph->Has(kParamScopeAttr)); PADDLE_ENFORCE(graph->Has(kParamScopeAttr));
auto* scope = graph->Get<Scope*>(kParamScopeAttr); auto& scope = graph->Get<Scope>(kParamScopeAttr);
// Create new parameters. // Create new parameters.
scope->Var(param.LSTMWeight)->GetMutable<LoDTensor>(); scope.Var(param.LSTMWeight)->GetMutable<LoDTensor>();
scope->Var(param.LSTMBias)->GetMutable<LoDTensor>(); scope.Var(param.LSTMBias)->GetMutable<LoDTensor>();
scope->Var(param.Hidden)->GetMutable<LoDTensor>(); scope.Var(param.Hidden)->GetMutable<LoDTensor>();
scope->Var(param.Cell)->GetMutable<LoDTensor>(); scope.Var(param.Cell)->GetMutable<LoDTensor>();
scope->Var(param.AttentionedX)->GetMutable<LoDTensor>(); scope.Var(param.AttentionedX)->GetMutable<LoDTensor>();
scope->Var(param.AttentionFCOut)->GetMutable<LoDTensor>(); scope.Var(param.AttentionFCOut)->GetMutable<LoDTensor>();
scope->Var(param.LSTMX)->GetMutable<LoDTensor>(); scope.Var(param.LSTMX)->GetMutable<LoDTensor>();
scope->Var(param.LSTMOUT)->GetMutable<LoDTensor>(); scope.Var(param.LSTMOUT)->GetMutable<LoDTensor>();
#define GATE_W(name__) \ #define GATE_W(name__) \
auto* W_##name__##_w0 = scope->FindVar(#name__ ".w_0"); \ auto* W_##name__##_w0 = scope.FindVar(#name__ ".w_0"); \
auto* W_##name__##_w1 = scope->FindVar(#name__ ".w_1"); \ auto* W_##name__##_w1 = scope.FindVar(#name__ ".w_1"); \
auto* W_##name__##_b0 = scope->FindVar(#name__ ".b_0"); \ auto* W_##name__##_b0 = scope.FindVar(#name__ ".b_0"); \
CHECK_P3(W_##name__##_w0, W_##name__##_w1, W_##name__##_b0); \ CHECK_P3(W_##name__##_w0, W_##name__##_w1, W_##name__##_b0); \
VLOG(4) << #name__ "_w0" \ VLOG(4) << #name__ "_w0" \
<< " shape: " << W_##name__##_w0->Get<LoDTensor>().dims(); \ << " shape: " << W_##name__##_w0->Get<LoDTensor>().dims(); \
...@@ -169,26 +169,26 @@ void PrepareParameters(Graph* graph, const Param& param) { ...@@ -169,26 +169,26 @@ void PrepareParameters(Graph* graph, const Param& param) {
GATE_W(c); GATE_W(c);
#undef GATE_W #undef GATE_W
auto* attention_fc_w = scope->FindVar("attention_fc.w_0"); auto* attention_fc_w = scope.FindVar("attention_fc.w_0");
auto* attention_fc_b = scope->FindVar("attention_fc.b_0"); auto* attention_fc_b = scope.FindVar("attention_fc.b_0");
auto* attention_output_w = scope->FindVar("attention_output.w_0"); auto* attention_output_w = scope.FindVar("attention_output.w_0");
auto* attention_output_b = scope->FindVar("attention_output.b_0"); auto* attention_output_b = scope.FindVar("attention_output.b_0");
CHECK_P4(attention_fc_w, attention_fc_b, attention_output_w, CHECK_P4(attention_fc_w, attention_fc_b, attention_output_w,
attention_output_b); attention_output_b);
auto* lstm_weight = scope->Var(param.LSTMWeight); auto* lstm_weight = scope.Var(param.LSTMWeight);
auto* lstm_weight_t = lstm_weight->GetMutable<LoDTensor>(); auto* lstm_weight_t = lstm_weight->GetMutable<LoDTensor>();
auto* lstm_bias = scope->Var(param.LSTMBias); auto* lstm_bias = scope.Var(param.LSTMBias);
auto* lstm_bias_t = lstm_bias->GetMutable<LoDTensor>(); auto* lstm_bias_t = lstm_bias->GetMutable<LoDTensor>();
// reshape attention_bias // reshape attention_bias
auto* attention_bias_t = auto* attention_bias_t =
scope->FindVar(param.AttentionBias)->GetMutable<LoDTensor>(); scope.FindVar(param.AttentionBias)->GetMutable<LoDTensor>();
PADDLE_ENFORCE_EQ(attention_bias_t->dims().size(), 1); PADDLE_ENFORCE_EQ(attention_bias_t->dims().size(), 1);
attention_bias_t->Resize(make_ddim({1, attention_bias_t->dims()[0]})); attention_bias_t->Resize(make_ddim({1, attention_bias_t->dims()[0]}));
auto* attention_scalar_bias_t = auto* attention_scalar_bias_t =
scope->FindVar(param.AttentionScalarBias)->GetMutable<LoDTensor>(); scope.FindVar(param.AttentionScalarBias)->GetMutable<LoDTensor>();
attention_scalar_bias_t->Resize( attention_scalar_bias_t->Resize(
make_ddim({1, attention_scalar_bias_t->dims()[0]})); make_ddim({1, attention_scalar_bias_t->dims()[0]}));
......
...@@ -151,11 +151,11 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, ...@@ -151,11 +151,11 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
op_desc.SetAttr("use_seq", true); op_desc.SetAttr("use_seq", true);
PADDLE_ENFORCE(graph->Has(kParamScopeAttr)); PADDLE_ENFORCE(graph->Has(kParamScopeAttr));
auto* scope = graph->Get<Scope*>(kParamScopeAttr); auto& scope = graph->Get<Scope>(kParamScopeAttr);
#define OP_SET_OUT(x) \ #define OP_SET_OUT(x) \
const std::string x = patterns::UniqueKey(#x); \ const std::string x = patterns::UniqueKey(#x); \
op_desc.SetOutput(#x, {x}); \ op_desc.SetOutput(#x, {x}); \
scope->Var(x)->GetMutable<LoDTensor>() scope.Var(x)->GetMutable<LoDTensor>()
OP_SET_OUT(BatchedCell); OP_SET_OUT(BatchedCell);
OP_SET_OUT(BatchedHidden); OP_SET_OUT(BatchedHidden);
OP_SET_OUT(ReorderedH0); OP_SET_OUT(ReorderedH0);
......
...@@ -69,16 +69,15 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, ...@@ -69,16 +69,15 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
auto* op = graph->CreateOpNode(&op_desc); auto* op = graph->CreateOpNode(&op_desc);
PADDLE_ENFORCE(graph->Has(kParamScopeAttr)); PADDLE_ENFORCE(graph->Has(kParamScopeAttr));
auto* scope = graph->Get<Scope*>(kParamScopeAttr); auto& scope = graph->Get<Scope>(kParamScopeAttr);
PADDLE_ENFORCE(scope);
if (with_fc_bias) { if (with_fc_bias) {
// Fusion GRU bias = fcbias + grubias // Fusion GRU bias = fcbias + grubias
auto* fusion_bias_var = scope->Var(NEW_NAME(bias) + bias->Name()); auto* fusion_bias_var = scope.Var(NEW_NAME(bias) + bias->Name());
auto* out_bias_tensor = auto* out_bias_tensor =
fusion_bias_var->GetMutable<framework::LoDTensor>(); fusion_bias_var->GetMutable<framework::LoDTensor>();
PADDLE_ENFORCE(fusion_bias_var); PADDLE_ENFORCE(fusion_bias_var);
auto* gru_bias_var = scope->FindVar(bias->Name()); auto* gru_bias_var = scope.FindVar(bias->Name());
auto* fc_bias_var = scope->FindVar(fc_bias->Name()); auto* fc_bias_var = scope.FindVar(fc_bias->Name());
PADDLE_ENFORCE(gru_bias_var); PADDLE_ENFORCE(gru_bias_var);
PADDLE_ENFORCE(fc_bias_var); PADDLE_ENFORCE(fc_bias_var);
const auto& gru_bias_tenosr = gru_bias_var->Get<framework::LoDTensor>(); const auto& gru_bias_tenosr = gru_bias_var->Get<framework::LoDTensor>();
...@@ -94,7 +93,7 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, ...@@ -94,7 +93,7 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
#undef GET_NODE #undef GET_NODE
#define NEW_IMTERMEDIATE_OUT(key) \ #define NEW_IMTERMEDIATE_OUT(key) \
scope->Var(NEW_NAME(key))->GetMutable<framework::LoDTensor>() scope.Var(NEW_NAME(key))->GetMutable<framework::LoDTensor>()
NEW_IMTERMEDIATE_OUT(ReorderedH0); NEW_IMTERMEDIATE_OUT(ReorderedH0);
NEW_IMTERMEDIATE_OUT(XX); NEW_IMTERMEDIATE_OUT(XX);
NEW_IMTERMEDIATE_OUT(BatchedInput); NEW_IMTERMEDIATE_OUT(BatchedInput);
......
...@@ -100,11 +100,11 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope, ...@@ -100,11 +100,11 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
op_desc.SetAttr("use_seq", true); op_desc.SetAttr("use_seq", true);
PADDLE_ENFORCE(graph->Has(kParamScopeAttr)); PADDLE_ENFORCE(graph->Has(kParamScopeAttr));
auto* scope = graph->Get<Scope*>(kParamScopeAttr); auto& scope = graph->Get<Scope>(kParamScopeAttr);
#define OP_SET_OUT(x) \ #define OP_SET_OUT(x) \
const std::string x = patterns::UniqueKey(#x); \ const std::string x = patterns::UniqueKey(#x); \
op_desc.SetOutput(#x, {x}); \ op_desc.SetOutput(#x, {x}); \
scope->Var(x)->GetMutable<LoDTensor>() scope.Var(x)->GetMutable<LoDTensor>()
OP_SET_OUT(BatchedCell); OP_SET_OUT(BatchedCell);
OP_SET_OUT(BatchedHidden); OP_SET_OUT(BatchedHidden);
OP_SET_OUT(ReorderedH0); OP_SET_OUT(ReorderedH0);
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/ir/fuse_pass_base.h" #include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include <unordered_map>
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -25,7 +26,8 @@ void FusePassBase::Init(const std::string& repr, Graph* graph) const { ...@@ -25,7 +26,8 @@ void FusePassBase::Init(const std::string& repr, Graph* graph) const {
Scope* FusePassBase::param_scope() const { Scope* FusePassBase::param_scope() const {
PADDLE_ENFORCE(graph_->Has(kParamScopeAttr)); PADDLE_ENFORCE(graph_->Has(kParamScopeAttr));
return graph_->Get<framework::Scope*>(kParamScopeAttr); auto& scope = graph_->Get<framework::Scope>(kParamScopeAttr);
return &scope;
} }
void FusePassBase::AddStatis(int count_of_fused) const { void FusePassBase::AddStatis(int count_of_fused) const {
...@@ -55,7 +57,7 @@ FuseOptions FusePassBase::FindFuseOption(const Node& node1, ...@@ -55,7 +57,7 @@ FuseOptions FusePassBase::FindFuseOption(const Node& node1,
#else #else
return FUSE_NATIVE; return FUSE_NATIVE;
#endif #endif
}; }
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
......
...@@ -97,7 +97,7 @@ void MainTest(bool convWithExistingBias) { ...@@ -97,7 +97,7 @@ void MainTest(bool convWithExistingBias) {
InitTensorHolder(&scope, place, "conv_bias"); InitTensorHolder(&scope, place, "conv_bias");
InitTensorHolder(&scope, place, "eltwise_bias"); InitTensorHolder(&scope, place, "eltwise_bias");
} }
graph->Set(kParamScopeAttr, new framework::Scope*(&scope)); graph->SetNotOwned(kParamScopeAttr, &scope);
auto pass = PassRegistry::Instance().Get("conv_bias_mkldnn_fuse_pass"); auto pass = PassRegistry::Instance().Get("conv_bias_mkldnn_fuse_pass");
......
...@@ -132,7 +132,7 @@ void MainTest(const ProgramDesc& prog, int conv_count, int pool_count, ...@@ -132,7 +132,7 @@ void MainTest(const ProgramDesc& prog, int conv_count, int pool_count,
(*scales)[v] = std::make_pair(false, std::move(tensor)); (*scales)[v] = std::make_pair(false, std::move(tensor));
} }
graph->Set(kParamScopeAttr, new framework::Scope*(&scope)); graph->SetNotOwned(kParamScopeAttr, &scope);
auto pass = PassRegistry::Instance().Get("cpu_quantize_pass"); auto pass = PassRegistry::Instance().Get("cpu_quantize_pass");
pass->Set("quant_var_scales", scales); pass->Set("quant_var_scales", scales);
......
...@@ -119,7 +119,7 @@ void MainTest(const ProgramDesc& prog, int removed_nodes_num) { ...@@ -119,7 +119,7 @@ void MainTest(const ProgramDesc& prog, int removed_nodes_num) {
InitTensorHolder(&scope, place, v.c_str()); InitTensorHolder(&scope, place, v.c_str());
} }
graph->Set(kParamScopeAttr, new framework::Scope*(&scope)); graph->SetNotOwned(kParamScopeAttr, &scope);
auto pass = PassRegistry::Instance().Get("cpu_quantize_squash_pass"); auto pass = PassRegistry::Instance().Get("cpu_quantize_squash_pass");
......
...@@ -43,11 +43,11 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope) { ...@@ -43,11 +43,11 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope) {
op_desc.SetAttr("contextStart", seqconv->Op()->GetAttr("contextStart")); op_desc.SetAttr("contextStart", seqconv->Op()->GetAttr("contextStart"));
op_desc.SetAttr("contextStride", seqconv->Op()->GetAttr("contextStride")); op_desc.SetAttr("contextStride", seqconv->Op()->GetAttr("contextStride"));
PADDLE_ENFORCE(graph->Has(kParamScopeAttr)); PADDLE_ENFORCE(graph->Has(kParamScopeAttr));
auto* scope = graph->Get<Scope*>(kParamScopeAttr); auto& scope = graph->Get<Scope>(kParamScopeAttr);
const std::string ColMat = patterns::UniqueKey("SeqConvColMat"); const std::string ColMat = patterns::UniqueKey("SeqConvColMat");
op_desc.SetOutput("ColMat", {ColMat}); op_desc.SetOutput("ColMat", {ColMat});
op_desc.SetOutput("Out", {relu_out->Name()}); op_desc.SetOutput("Out", {relu_out->Name()});
scope->Var(ColMat)->GetMutable<LoDTensor>(); scope.Var(ColMat)->GetMutable<LoDTensor>();
auto* op = graph->CreateOpNode(&op_desc); auto* op = graph->CreateOpNode(&op_desc);
IR_NODE_LINK_TO(input, op); IR_NODE_LINK_TO(input, op);
......
...@@ -38,9 +38,9 @@ IRPassManager::IRPassManager(Argument *argument) { ...@@ -38,9 +38,9 @@ IRPassManager::IRPassManager(Argument *argument) {
ARGUMENT_CHECK_FIELD(argument, main_program); ARGUMENT_CHECK_FIELD(argument, main_program);
graph_ = std::unique_ptr<Graph>(new Graph(argument->main_program())); graph_ = std::unique_ptr<Graph>(new Graph(argument->main_program()));
if (argument->Has("scope")) { if (argument->Has("scope")) {
graph_->Set(framework::ir::kParamScopeAttr, auto *scope_ptr = argument->scope_ptr();
new framework::Scope *( PADDLE_ENFORCE(scope_ptr);
const_cast<framework::Scope *>(&argument->scope()))); graph_->SetNotOwned(framework::ir::kParamScopeAttr, scope_ptr);
} }
ARGUMENT_CHECK_FIELD(argument, ir_analysis_passes); ARGUMENT_CHECK_FIELD(argument, ir_analysis_passes);
......
...@@ -13,9 +13,10 @@ ...@@ -13,9 +13,10 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/inference/analysis/passes/ir_graph_build_pass.h" #include "paddle/fluid/inference/analysis/passes/ir_graph_build_pass.h"
#include <paddle/fluid/framework/ir/fuse_pass_base.h> #include <memory>
#include <string> #include <string>
#include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/inference/io.h" #include "paddle/fluid/inference/io.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
...@@ -56,8 +57,9 @@ void IrGraphBuildPass::RunImpl(Argument *argument) { ...@@ -56,8 +57,9 @@ void IrGraphBuildPass::RunImpl(Argument *argument) {
auto graph = std::unique_ptr<Graph>(new Graph(argument->main_program())); auto graph = std::unique_ptr<Graph>(new Graph(argument->main_program()));
argument->SetMainGraph(graph.release()); argument->SetMainGraph(graph.release());
argument->main_graph().Set(framework::ir::kParamScopeAttr, auto *scope_ptr = argument->scope_ptr();
new framework::Scope *(argument->scope_ptr())); PADDLE_ENFORCE(scope_ptr);
argument->main_graph().SetNotOwned(framework::ir::kParamScopeAttr, scope_ptr);
} }
std::unique_ptr<framework::ProgramDesc> IrGraphBuildPass::LoadModel( std::unique_ptr<framework::ProgramDesc> IrGraphBuildPass::LoadModel(
......
...@@ -353,8 +353,9 @@ void AnalysisPredictor::MkldnnQuantizer::PrepareArgument() const { ...@@ -353,8 +353,9 @@ void AnalysisPredictor::MkldnnQuantizer::PrepareArgument() const {
arg.SetMainProgramNotOwned(predictor_.inference_program_.get()); arg.SetMainProgramNotOwned(predictor_.inference_program_.get());
auto graph = std::unique_ptr<Graph>(new Graph(arg.main_program())); auto graph = std::unique_ptr<Graph>(new Graph(arg.main_program()));
arg.SetMainGraph(graph.release()); arg.SetMainGraph(graph.release());
arg.main_graph().Set(framework::ir::kParamScopeAttr, auto* scope_ptr = arg.scope_ptr();
new framework::Scope*(arg.scope_ptr())); PADDLE_ENFORCE(scope_ptr);
arg.main_graph().SetNotOwned(framework::ir::kParamScopeAttr, scope_ptr);
auto* builder = predictor_.config_.pass_builder(); auto* builder = predictor_.config_.pass_builder();
builder->SetPasses({ builder->SetPasses({
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/node.h" #include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/var_desc.h" #include "paddle/fluid/framework/var_desc.h"
#include "pybind11/stl.h" #include "pybind11/stl.h"
...@@ -37,6 +38,7 @@ using paddle::framework::ir::TopologySortOperations; ...@@ -37,6 +38,7 @@ using paddle::framework::ir::TopologySortOperations;
using paddle::framework::ir::BuildOperationAdjList; using paddle::framework::ir::BuildOperationAdjList;
using paddle::framework::OpDesc; using paddle::framework::OpDesc;
using paddle::framework::ProgramDesc; using paddle::framework::ProgramDesc;
using paddle::framework::Scope;
using paddle::framework::VarDesc; using paddle::framework::VarDesc;
using pybind11::return_value_policy; using pybind11::return_value_policy;
...@@ -57,12 +59,15 @@ void BindGraph(py::module *m) { ...@@ -57,12 +59,15 @@ void BindGraph(py::module *m) {
.def(py::init<const ProgramDesc &>()) .def(py::init<const ProgramDesc &>())
.def("clone", &Graph::Clone) .def("clone", &Graph::Clone)
.def("has", &Graph::Has) .def("has", &Graph::Has)
.def("get_bool", &Graph::Get<bool>)
.def("get_int", &Graph::Get<int>) .def("get_int", &Graph::Get<int>)
.def("get_float", &Graph::Get<float>) .def("get_float", &Graph::Get<float>)
.def("get_double", &Graph::Get<double>) .def("get_double", &Graph::Get<double>)
.def("get_string", &Graph::Get<std::string>) .def("get_string", &Graph::Get<std::string>)
.def("get_marked_nodes", &Graph::Get<std::unordered_set<const Node *>>, .def("get_marked_nodes", &Graph::Get<std::unordered_set<const Node *>>,
return_value_policy::reference) return_value_policy::reference)
.def("set", [](Graph &self, const std::string &attr_name,
bool attr) { return self.Set(attr_name, new bool(attr)); })
.def("set", [](Graph &self, const std::string &attr_name, .def("set", [](Graph &self, const std::string &attr_name,
int attr) { return self.Set(attr_name, new int(attr)); }) int attr) { return self.Set(attr_name, new int(attr)); })
.def("set", .def("set",
...@@ -90,6 +95,10 @@ void BindGraph(py::module *m) { ...@@ -90,6 +95,10 @@ void BindGraph(py::module *m) {
return self.Set(attr_name, return self.Set(attr_name,
new std::unordered_set<std::string>(attr)); new std::unordered_set<std::string>(attr));
}) })
.def("set_not_owned",
[](Graph &self, const std::string &attr_name, Scope &attr) {
self.SetNotOwned<Scope>(attr_name, &attr);
})
.def("erase", &Graph::Erase) .def("erase", &Graph::Erase)
.def("nodes", &Graph::Nodes, return_value_policy::reference) .def("nodes", &Graph::Nodes, return_value_policy::reference)
.def("create_var_node", .def("create_var_node",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册