未验证 提交 4a1b7fec 编写于 作者: Z Zhen Wang 提交者: GitHub

Add setting Scope function for the graph class (#17417)

* add set_not_owned function for graph

* add scope set. test=develop

* add scope_ptr enforce not null before setting.test=develop
上级 58d5c61a
......@@ -136,22 +136,22 @@ void PrepareLSTMBias(const LoDTensor& B_forget, const LoDTensor& B_input,
void PrepareParameters(Graph* graph, const Param& param) {
// Check parameters
PADDLE_ENFORCE(graph->Has(kParamScopeAttr));
auto* scope = graph->Get<Scope*>(kParamScopeAttr);
auto& scope = graph->Get<Scope>(kParamScopeAttr);
// Create new parameters.
scope->Var(param.LSTMWeight)->GetMutable<LoDTensor>();
scope->Var(param.LSTMBias)->GetMutable<LoDTensor>();
scope->Var(param.Hidden)->GetMutable<LoDTensor>();
scope->Var(param.Cell)->GetMutable<LoDTensor>();
scope->Var(param.AttentionedX)->GetMutable<LoDTensor>();
scope->Var(param.AttentionFCOut)->GetMutable<LoDTensor>();
scope->Var(param.LSTMX)->GetMutable<LoDTensor>();
scope->Var(param.LSTMOUT)->GetMutable<LoDTensor>();
scope.Var(param.LSTMWeight)->GetMutable<LoDTensor>();
scope.Var(param.LSTMBias)->GetMutable<LoDTensor>();
scope.Var(param.Hidden)->GetMutable<LoDTensor>();
scope.Var(param.Cell)->GetMutable<LoDTensor>();
scope.Var(param.AttentionedX)->GetMutable<LoDTensor>();
scope.Var(param.AttentionFCOut)->GetMutable<LoDTensor>();
scope.Var(param.LSTMX)->GetMutable<LoDTensor>();
scope.Var(param.LSTMOUT)->GetMutable<LoDTensor>();
#define GATE_W(name__) \
auto* W_##name__##_w0 = scope->FindVar(#name__ ".w_0"); \
auto* W_##name__##_w1 = scope->FindVar(#name__ ".w_1"); \
auto* W_##name__##_b0 = scope->FindVar(#name__ ".b_0"); \
auto* W_##name__##_w0 = scope.FindVar(#name__ ".w_0"); \
auto* W_##name__##_w1 = scope.FindVar(#name__ ".w_1"); \
auto* W_##name__##_b0 = scope.FindVar(#name__ ".b_0"); \
CHECK_P3(W_##name__##_w0, W_##name__##_w1, W_##name__##_b0); \
VLOG(4) << #name__ "_w0" \
<< " shape: " << W_##name__##_w0->Get<LoDTensor>().dims(); \
......@@ -169,26 +169,26 @@ void PrepareParameters(Graph* graph, const Param& param) {
GATE_W(c);
#undef GATE_W
auto* attention_fc_w = scope->FindVar("attention_fc.w_0");
auto* attention_fc_b = scope->FindVar("attention_fc.b_0");
auto* attention_output_w = scope->FindVar("attention_output.w_0");
auto* attention_output_b = scope->FindVar("attention_output.b_0");
auto* attention_fc_w = scope.FindVar("attention_fc.w_0");
auto* attention_fc_b = scope.FindVar("attention_fc.b_0");
auto* attention_output_w = scope.FindVar("attention_output.w_0");
auto* attention_output_b = scope.FindVar("attention_output.b_0");
CHECK_P4(attention_fc_w, attention_fc_b, attention_output_w,
attention_output_b);
auto* lstm_weight = scope->Var(param.LSTMWeight);
auto* lstm_weight = scope.Var(param.LSTMWeight);
auto* lstm_weight_t = lstm_weight->GetMutable<LoDTensor>();
auto* lstm_bias = scope->Var(param.LSTMBias);
auto* lstm_bias = scope.Var(param.LSTMBias);
auto* lstm_bias_t = lstm_bias->GetMutable<LoDTensor>();
// reshape attention_bias
auto* attention_bias_t =
scope->FindVar(param.AttentionBias)->GetMutable<LoDTensor>();
scope.FindVar(param.AttentionBias)->GetMutable<LoDTensor>();
PADDLE_ENFORCE_EQ(attention_bias_t->dims().size(), 1);
attention_bias_t->Resize(make_ddim({1, attention_bias_t->dims()[0]}));
auto* attention_scalar_bias_t =
scope->FindVar(param.AttentionScalarBias)->GetMutable<LoDTensor>();
scope.FindVar(param.AttentionScalarBias)->GetMutable<LoDTensor>();
attention_scalar_bias_t->Resize(
make_ddim({1, attention_scalar_bias_t->dims()[0]}));
......
......@@ -151,11 +151,11 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
op_desc.SetAttr("use_seq", true);
PADDLE_ENFORCE(graph->Has(kParamScopeAttr));
auto* scope = graph->Get<Scope*>(kParamScopeAttr);
auto& scope = graph->Get<Scope>(kParamScopeAttr);
#define OP_SET_OUT(x) \
const std::string x = patterns::UniqueKey(#x); \
op_desc.SetOutput(#x, {x}); \
scope->Var(x)->GetMutable<LoDTensor>()
scope.Var(x)->GetMutable<LoDTensor>()
OP_SET_OUT(BatchedCell);
OP_SET_OUT(BatchedHidden);
OP_SET_OUT(ReorderedH0);
......
......@@ -69,16 +69,15 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
auto* op = graph->CreateOpNode(&op_desc);
PADDLE_ENFORCE(graph->Has(kParamScopeAttr));
auto* scope = graph->Get<Scope*>(kParamScopeAttr);
PADDLE_ENFORCE(scope);
auto& scope = graph->Get<Scope>(kParamScopeAttr);
if (with_fc_bias) {
// Fusion GRU bias = fcbias + grubias
auto* fusion_bias_var = scope->Var(NEW_NAME(bias) + bias->Name());
auto* fusion_bias_var = scope.Var(NEW_NAME(bias) + bias->Name());
auto* out_bias_tensor =
fusion_bias_var->GetMutable<framework::LoDTensor>();
PADDLE_ENFORCE(fusion_bias_var);
auto* gru_bias_var = scope->FindVar(bias->Name());
auto* fc_bias_var = scope->FindVar(fc_bias->Name());
auto* gru_bias_var = scope.FindVar(bias->Name());
auto* fc_bias_var = scope.FindVar(fc_bias->Name());
PADDLE_ENFORCE(gru_bias_var);
PADDLE_ENFORCE(fc_bias_var);
const auto& gru_bias_tenosr = gru_bias_var->Get<framework::LoDTensor>();
......@@ -94,7 +93,7 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
#undef GET_NODE
#define NEW_IMTERMEDIATE_OUT(key) \
scope->Var(NEW_NAME(key))->GetMutable<framework::LoDTensor>()
scope.Var(NEW_NAME(key))->GetMutable<framework::LoDTensor>()
NEW_IMTERMEDIATE_OUT(ReorderedH0);
NEW_IMTERMEDIATE_OUT(XX);
NEW_IMTERMEDIATE_OUT(BatchedInput);
......
......@@ -100,11 +100,11 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
op_desc.SetAttr("use_seq", true);
PADDLE_ENFORCE(graph->Has(kParamScopeAttr));
auto* scope = graph->Get<Scope*>(kParamScopeAttr);
auto& scope = graph->Get<Scope>(kParamScopeAttr);
#define OP_SET_OUT(x) \
const std::string x = patterns::UniqueKey(#x); \
op_desc.SetOutput(#x, {x}); \
scope->Var(x)->GetMutable<LoDTensor>()
scope.Var(x)->GetMutable<LoDTensor>()
OP_SET_OUT(BatchedCell);
OP_SET_OUT(BatchedHidden);
OP_SET_OUT(ReorderedH0);
......
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include <unordered_map>
namespace paddle {
namespace framework {
......@@ -25,7 +26,8 @@ void FusePassBase::Init(const std::string& repr, Graph* graph) const {
Scope* FusePassBase::param_scope() const {
PADDLE_ENFORCE(graph_->Has(kParamScopeAttr));
return graph_->Get<framework::Scope*>(kParamScopeAttr);
auto& scope = graph_->Get<framework::Scope>(kParamScopeAttr);
return &scope;
}
void FusePassBase::AddStatis(int count_of_fused) const {
......@@ -55,7 +57,7 @@ FuseOptions FusePassBase::FindFuseOption(const Node& node1,
#else
return FUSE_NATIVE;
#endif
};
}
} // namespace ir
} // namespace framework
......
......@@ -97,7 +97,7 @@ void MainTest(bool convWithExistingBias) {
InitTensorHolder(&scope, place, "conv_bias");
InitTensorHolder(&scope, place, "eltwise_bias");
}
graph->Set(kParamScopeAttr, new framework::Scope*(&scope));
graph->SetNotOwned(kParamScopeAttr, &scope);
auto pass = PassRegistry::Instance().Get("conv_bias_mkldnn_fuse_pass");
......
......@@ -132,7 +132,7 @@ void MainTest(const ProgramDesc& prog, int conv_count, int pool_count,
(*scales)[v] = std::make_pair(false, std::move(tensor));
}
graph->Set(kParamScopeAttr, new framework::Scope*(&scope));
graph->SetNotOwned(kParamScopeAttr, &scope);
auto pass = PassRegistry::Instance().Get("cpu_quantize_pass");
pass->Set("quant_var_scales", scales);
......
......@@ -119,7 +119,7 @@ void MainTest(const ProgramDesc& prog, int removed_nodes_num) {
InitTensorHolder(&scope, place, v.c_str());
}
graph->Set(kParamScopeAttr, new framework::Scope*(&scope));
graph->SetNotOwned(kParamScopeAttr, &scope);
auto pass = PassRegistry::Instance().Get("cpu_quantize_squash_pass");
......
......@@ -43,11 +43,11 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope) {
op_desc.SetAttr("contextStart", seqconv->Op()->GetAttr("contextStart"));
op_desc.SetAttr("contextStride", seqconv->Op()->GetAttr("contextStride"));
PADDLE_ENFORCE(graph->Has(kParamScopeAttr));
auto* scope = graph->Get<Scope*>(kParamScopeAttr);
auto& scope = graph->Get<Scope>(kParamScopeAttr);
const std::string ColMat = patterns::UniqueKey("SeqConvColMat");
op_desc.SetOutput("ColMat", {ColMat});
op_desc.SetOutput("Out", {relu_out->Name()});
scope->Var(ColMat)->GetMutable<LoDTensor>();
scope.Var(ColMat)->GetMutable<LoDTensor>();
auto* op = graph->CreateOpNode(&op_desc);
IR_NODE_LINK_TO(input, op);
......
......@@ -38,9 +38,9 @@ IRPassManager::IRPassManager(Argument *argument) {
ARGUMENT_CHECK_FIELD(argument, main_program);
graph_ = std::unique_ptr<Graph>(new Graph(argument->main_program()));
if (argument->Has("scope")) {
graph_->Set(framework::ir::kParamScopeAttr,
new framework::Scope *(
const_cast<framework::Scope *>(&argument->scope())));
auto *scope_ptr = argument->scope_ptr();
PADDLE_ENFORCE(scope_ptr);
graph_->SetNotOwned(framework::ir::kParamScopeAttr, scope_ptr);
}
ARGUMENT_CHECK_FIELD(argument, ir_analysis_passes);
......
......@@ -13,9 +13,10 @@
// limitations under the License.
#include "paddle/fluid/inference/analysis/passes/ir_graph_build_pass.h"
#include <paddle/fluid/framework/ir/fuse_pass_base.h>
#include <memory>
#include <string>
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/inference/io.h"
#include "paddle/fluid/platform/enforce.h"
......@@ -56,8 +57,9 @@ void IrGraphBuildPass::RunImpl(Argument *argument) {
auto graph = std::unique_ptr<Graph>(new Graph(argument->main_program()));
argument->SetMainGraph(graph.release());
argument->main_graph().Set(framework::ir::kParamScopeAttr,
new framework::Scope *(argument->scope_ptr()));
auto *scope_ptr = argument->scope_ptr();
PADDLE_ENFORCE(scope_ptr);
argument->main_graph().SetNotOwned(framework::ir::kParamScopeAttr, scope_ptr);
}
std::unique_ptr<framework::ProgramDesc> IrGraphBuildPass::LoadModel(
......
......@@ -353,8 +353,9 @@ void AnalysisPredictor::MkldnnQuantizer::PrepareArgument() const {
arg.SetMainProgramNotOwned(predictor_.inference_program_.get());
auto graph = std::unique_ptr<Graph>(new Graph(arg.main_program()));
arg.SetMainGraph(graph.release());
arg.main_graph().Set(framework::ir::kParamScopeAttr,
new framework::Scope*(arg.scope_ptr()));
auto* scope_ptr = arg.scope_ptr();
PADDLE_ENFORCE(scope_ptr);
arg.main_graph().SetNotOwned(framework::ir::kParamScopeAttr, scope_ptr);
auto* builder = predictor_.config_.pass_builder();
builder->SetPasses({
......
......@@ -24,6 +24,7 @@
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/var_desc.h"
#include "pybind11/stl.h"
......@@ -37,6 +38,7 @@ using paddle::framework::ir::TopologySortOperations;
using paddle::framework::ir::BuildOperationAdjList;
using paddle::framework::OpDesc;
using paddle::framework::ProgramDesc;
using paddle::framework::Scope;
using paddle::framework::VarDesc;
using pybind11::return_value_policy;
......@@ -57,12 +59,15 @@ void BindGraph(py::module *m) {
.def(py::init<const ProgramDesc &>())
.def("clone", &Graph::Clone)
.def("has", &Graph::Has)
.def("get_bool", &Graph::Get<bool>)
.def("get_int", &Graph::Get<int>)
.def("get_float", &Graph::Get<float>)
.def("get_double", &Graph::Get<double>)
.def("get_string", &Graph::Get<std::string>)
.def("get_marked_nodes", &Graph::Get<std::unordered_set<const Node *>>,
return_value_policy::reference)
.def("set", [](Graph &self, const std::string &attr_name,
bool attr) { return self.Set(attr_name, new bool(attr)); })
.def("set", [](Graph &self, const std::string &attr_name,
int attr) { return self.Set(attr_name, new int(attr)); })
.def("set",
......@@ -90,6 +95,10 @@ void BindGraph(py::module *m) {
return self.Set(attr_name,
new std::unordered_set<std::string>(attr));
})
.def("set_not_owned",
[](Graph &self, const std::string &attr_name, Scope &attr) {
self.SetNotOwned<Scope>(attr_name, &attr);
})
.def("erase", &Graph::Erase)
.def("nodes", &Graph::Nodes, return_value_policy::reference)
.def("create_var_node",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册