未验证 提交 8adaa0f0 编写于 作者: J jianghaicheng 提交者: GitHub

add popart_canonicalization p1 (#37964)

上级 89069af5
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/ipu/avg_shard_pass.h"
#include "paddle/fluid/platform/device/ipu/ipu_backend.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
namespace paddle {
namespace framework {
namespace ir {
void AvgShardPass::ApplyImpl(ir::Graph* graph) const {
VLOG(10) << "enter AvgShardPass::ApplyImpl";
std::shared_ptr<platform::ipu::IpuBackend> ipu_backend =
platform::ipu::IpuBackend::GetInstance();
if (ipu_backend->GetIpuStrategy()->need_avg_shard) {
VLOG(10) << "start AvgShardPass";
auto nodes = ir::TopologySortOperations(*graph);
auto num_ipus = ipu_backend->GetIpuStrategy()->num_ipus;
int shard_position = nodes.size() / num_ipus;
int index_and_stage = -1;
for (int i = 0; i < nodes.size(); i++) {
if ((i % shard_position) == 0 && index_and_stage < num_ipus - 1) {
index_and_stage++;
}
nodes[i]->Op()->SetAttr("ipu_index", index_and_stage);
nodes[i]->Op()->SetAttr("ipu_stage", index_and_stage);
}
VLOG(10) << "end AvgShardPass";
}
VLOG(10) << "leave AvgShardPass::ApplyImpl";
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(avg_shard_pass, paddle::framework::ir::AvgShardPass);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/ipu/ipu_pass_base.h"
namespace paddle {
namespace framework {
namespace ir {
class AvgShardPass : public IPUPassBase {
protected:
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/ipu/forward_graph_extract_pass.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
namespace paddle {
namespace framework {
namespace ir {
void ForwardGraphExtractPass::ApplyImpl(ir::Graph* graph) const {
VLOG(10) << "enter ForwardGraphExtractPass::ApplyImpl";
std::unordered_map<OpRole, std::unordered_set<ir::Node*>> all_ops{
{OpRole::kForward, {}}, {OpRole::kBackward, {}},
{OpRole::kOptimize, {}}, {OpRole::kRPC, {}},
{OpRole::kDist, {}}, {OpRole::kLRSched, {}},
{OpRole::kLoss, {}}, {OpRole::kNotSpecified, {}}};
for (auto* node : graph->Nodes()) {
if (!node->IsOp()) {
continue;
}
auto op_role = BOOST_GET_MUTABLE(int, node->Op()->GetAttr("op_role"));
if (op_role == static_cast<int>(OpRole::kForward)) {
all_ops[OpRole::kForward].insert(node);
} else if (op_role == static_cast<int>(OpRole::kBackward)) {
all_ops[OpRole::kBackward].insert(node);
} else if (op_role == static_cast<int>(OpRole::kOptimize)) {
all_ops[OpRole::kOptimize].insert(node);
} else if (op_role == static_cast<int>(OpRole::kRPC)) {
} else if (op_role == static_cast<int>(OpRole::kDist)) {
} else if (op_role == static_cast<int>(OpRole::kLRSched)) {
} else if (op_role == static_cast<int>(OpRole::kLoss)) {
all_ops[OpRole::kLoss].insert(node);
} else if (op_role == static_cast<int>(OpRole::kNotSpecified)) {
LOG(WARNING) << "Op: " << node->Name() << " OpRole is NotSpecified ";
}
}
std::unordered_set<ir::Node*> forward_vars;
std::unordered_set<ir::Node*> backward_vars;
std::unordered_set<ir::Node*> control_vars;
// forward_vars
for (auto& nodes : std::array<std::unordered_set<ir::Node*>, 2>{
all_ops[OpRole::kForward], all_ops[OpRole::kLoss]}) {
for (auto* node : nodes) {
for (auto* in_node : node->inputs) {
forward_vars.insert(in_node);
}
for (auto* out_node : node->outputs) {
forward_vars.insert(out_node);
}
}
}
// control_vars & backward_vars
for (auto* node : graph->Nodes()) {
if (!node->IsVar()) {
continue;
}
if (node->IsCtrlVar()) {
control_vars.insert(node);
}
for (auto* in_node : node->inputs) {
if (all_ops[OpRole::kOptimize].count(in_node)) {
backward_vars.insert(node);
}
}
}
// all removed node
std::unordered_set<ir::Node*> rm_nodes;
for (auto* node : graph->Nodes()) {
if (backward_vars.count(node)) {
rm_nodes.insert(node);
} else if (control_vars.count(node)) {
rm_nodes.insert(node);
} else if (all_ops[OpRole::kBackward].count(node)) {
rm_nodes.insert(node);
} else if (all_ops[OpRole::kForward].count(node) == 0 &&
all_ops[OpRole::kLoss].count(node) == 0 &&
forward_vars.count(node) == 0) {
rm_nodes.insert(node);
} else if (node->Name() == "feed" || node->Name() == "fetch") {
rm_nodes.insert(node);
}
}
VLOG(10) << "Remove Node: ";
for (auto* node : rm_nodes) {
// rm node releations
for (auto* node_in : node->inputs) {
for (size_t i = 0; i < node_in->outputs.size(); ++i) {
if (node_in->outputs[i] == node) {
node_in->outputs.erase(node_in->outputs.begin() + i);
break;
}
}
}
for (auto* node_out : node->outputs) {
for (size_t i = 0; i < node_out->inputs.size(); ++i) {
if (node_out->inputs[i] == node) {
node_out->inputs.erase(node_out->inputs.begin() + i);
break;
}
}
}
VLOG(10) << "\t" << node->Name();
graph->RemoveNode(node);
}
VLOG(10) << "Post Graph: ";
VLOG(10) << DebugString(graph);
VLOG(10) << "leave ForwardGraphExtractPass::ApplyImpl";
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(forward_graph_extract_pass,
paddle::framework::ir::ForwardGraphExtractPass);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/ipu/ipu_pass_base.h"
namespace paddle {
namespace framework {
namespace ir {
class ForwardGraphExtractPass : public IPUPassBase {
protected:
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/ipu/infer_shape_pass.h"
#include "paddle/fluid/platform/device/ipu/ipu_backend.h"
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/variable_helper.h"
namespace paddle {
namespace framework {
namespace ir {
void InferShapePass::ApplyImpl(ir::Graph* graph) const {
VLOG(10) << "enter InferShapePass::ApplyImpl";
VLOG(10) << "Raw Graph: ";
VLOG(10) << DebugString(graph);
std::shared_ptr<platform::ipu::IpuBackend> ipu_backend =
platform::ipu::IpuBackend::GetInstance();
auto batch_size = ipu_backend->GetIpuStrategy()->batch_size;
auto feed_list = Get<std::vector<std::string>>("feed_list");
for (auto node : graph->Nodes()) {
if (!node->IsVar()) {
continue;
}
bool is_feed = std::find(feed_list.begin(), feed_list.end(),
node->Name()) != feed_list.end();
if (is_feed) {
auto input_shape = node->Var()->GetShape();
if (input_shape[0] <= -1) {
input_shape[0] = batch_size;
node->Var()->SetShape(input_shape);
}
// int64->int32
if (node->Var()->GetDataType() == proto::VarType::INT64) {
node->Var()->SetDataType(proto::VarType::INT32);
}
}
}
// temp scope for shape inference
std::shared_ptr<paddle::framework::Scope> scope(
new paddle::framework::Scope());
for (auto node : graph->Nodes()) {
if (!node->IsVar()) {
continue;
}
auto var_desc = node->Var();
auto* ptr = scope->Var(var_desc->Name());
paddle::framework::InitializeVariable(ptr, var_desc->GetType());
auto tensor = ptr->GetMutable<paddle::framework::LoDTensor>();
tensor->Resize(paddle::framework::make_ddim(var_desc->GetShape()));
}
// infer shape
auto nodes = ir::TopologySortOperations(*graph);
for (auto node : nodes) {
auto op_desc = node->Op();
auto op = paddle::framework::OpRegistry::CreateOp(*op_desc);
paddle::framework::RuntimeContext ctx(op->Inputs(), op->Outputs(), *scope);
op->RuntimeInferShape(*scope, paddle::platform::CPUPlace(), ctx);
for (auto it = ctx.outputs.begin(); it != ctx.outputs.end(); it++) {
for (int i = 0; i < it->second.size(); i++) {
auto output_name = op_desc->Output(it->first)[i];
auto dim =
it->second[i]->GetMutable<paddle::framework::LoDTensor>()->dims();
auto new_shape = paddle::framework::vectorize(dim);
for (auto output_node : node->outputs) {
if (output_node->Name() == output_name) {
output_node->Var()->SetShape(new_shape);
}
}
}
}
}
// release the temp scope
scope.reset();
VLOG(10) << "Post Graph: ";
VLOG(10) << DebugString(graph);
VLOG(10) << "leave InferShapePass::ApplyImpl";
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(infer_shape_pass, paddle::framework::ir::InferShapePass)
.RequirePassAttr("feed_list");
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/ipu/ipu_pass_base.h"
namespace paddle {
namespace framework {
namespace ir {
class InferShapePass : public IPUPassBase {
protected:
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/ipu/inference_postprocess_pass.h"
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/platform/device/ipu/ipu_backend.h"
#include "paddle/fluid/platform/device/ipu/ipu_strategy.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace framework {
namespace ir {
void InferencePostprocessPass::ApplyImpl(ir::Graph *graph) const {
VLOG(10) << "enter InferencePostprocessPass::ApplyImpl";
std::vector<std::string> feed_list;
feed_list = Get<std::vector<std::string>>("feed_list");
std::vector<std::string> fetch_list;
fetch_list = Get<std::vector<std::string>>("fetch_list");
auto *feed_var = new paddle::framework::VarDesc("feed");
feed_var->SetType(proto::VarType::FEED_MINIBATCH);
auto *feed_var_node = graph->CreateVarNode(feed_var);
auto *fetch_var = new paddle::framework::VarDesc("fetch");
fetch_var->SetType(proto::VarType::FETCH_LIST);
auto *fetch_var_node = graph->CreateVarNode(fetch_var);
for (int i = 0; i < feed_list.size(); i++) {
for (auto node : graph->Nodes()) {
if (node->Name() == feed_list[i]) {
auto *op = new paddle::framework::OpDesc();
op->SetType("feed");
op->SetInput("X", {"feed"});
op->SetOutput("Out", {node->Name()});
op->SetAttr("col", i);
auto *op_node = graph->CreateOpNode(op);
node->inputs.push_back(op_node);
op_node->outputs.push_back(node);
feed_var_node->outputs.push_back(op_node);
op_node->inputs.push_back(feed_var_node);
break;
}
}
}
for (int i = 0; i < fetch_list.size(); i++) {
for (auto node : graph->Nodes()) {
if (node->Name() == fetch_list[i]) {
auto *op = new paddle::framework::OpDesc();
op->SetType("fetch");
op->SetInput("X", {node->Name()});
op->SetOutput("Out", {"fetch"});
op->SetAttr("col", i);
auto *op_node = graph->CreateOpNode(op);
node->outputs.push_back(op_node);
op_node->inputs.push_back(node);
fetch_var_node->inputs.push_back(op_node);
op_node->outputs.push_back(fetch_var_node);
break;
}
}
}
VLOG(10) << "leave InferencePostprocessPass::ApplyImpl";
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(inference_postprocess_pass,
paddle::framework::ir::InferencePostprocessPass)
.RequirePassAttr("feed_list")
.RequirePassAttr("fetch_list");
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/ipu/ipu_pass_base.h"
namespace paddle {
namespace framework {
namespace ir {
class InferencePostprocessPass : public IPUPassBase {
protected:
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/ipu/inference_process_pass.h"
#include "paddle/fluid/platform/device/ipu/ipu_backend.h"
#include "paddle/fluid/platform/device/ipu/ipu_strategy.h"
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace framework {
namespace ir {
void InferenceProcessPass::ApplyImpl(ir::Graph* graph) const {
VLOG(10) << "enter InferenceProcessPass::ApplyImpl";
// Get a new instance of ipu_backend
std::shared_ptr<platform::ipu::IpuBackend> ipu_backend =
platform::ipu::IpuBackend::GetNewInstance();
// Set scope
auto& scope = graph->Get<Scope>(kParamScopeAttr);
ipu_backend->SetScope(scope);
// Set ipu_strategy
static std::shared_ptr<platform::ipu::IpuStrategy> ipu_strategy_instance_(
new platform::ipu::IpuStrategy());
ipu_strategy_instance_->is_training = false;
auto num_ipus = graph->Get<int>("num_ipus");
ipu_strategy_instance_->num_ipus = num_ipus;
if (num_ipus > 1) {
ipu_strategy_instance_->popart_options_.virtualGraphMode =
platform::ipu::VirtualGraphMode::Manual;
} else {
ipu_strategy_instance_->popart_options_.virtualGraphMode =
platform::ipu::VirtualGraphMode::Off;
}
auto enable_pipelining = graph->Get<bool>("enable_pipelining");
ipu_strategy_instance_->popart_options_.enablePipelining = enable_pipelining;
if (enable_pipelining) {
auto batches_per_step = graph->Get<int>("batches_per_step");
PADDLE_ENFORCE_GE(
batches_per_step, num_ipus,
platform::errors::InvalidArgument("Batched per step should be equal or "
"greater than the number of IPUs"));
ipu_strategy_instance_->batches_per_step = batches_per_step;
}
ipu_strategy_instance_->batch_size = graph->Get<int>("batch_size");
ipu_strategy_instance_->need_avg_shard = graph->Get<bool>("need_avg_shard");
ipu_backend->SetIpuStrategy(*(ipu_strategy_instance_.get()));
// Get feed_list and fetch list
std::vector<std::string> feed_list = {};
std::vector<std::string> fetch_list = {};
for (auto node : graph->Nodes()) {
if (node->Name() == "feed") {
if (node->IsOp()) {
feed_list.push_back("");
}
} else if (node->Name() == "fetch") {
if (node->IsOp()) {
fetch_list.push_back("");
}
}
}
for (auto node : graph->Nodes()) {
if (node->Name() == "feed") {
if (node->IsOp()) {
feed_list[BOOST_GET_CONST(int, node->Op()->GetAttr("col"))] =
node->outputs[0]->Name();
}
} else if (node->Name() == "fetch") {
if (node->IsOp()) {
fetch_list[BOOST_GET_CONST(int, node->Op()->GetAttr("col"))] =
node->inputs[0]->Name();
}
}
}
// Run passes
std::vector<std::string> graph_pass = {"forward_graph_extract_pass",
"infer_shape_pass", "avg_shard_pass",
"popart_canonicalization_pass"};
std::vector<std::string> compile_pass = {
"ipu_inplace_pass", "ipu_graph_builder_pass", "ipu_runtime_replacer_pass",
"inference_postprocess_pass"};
for (auto pass_name : graph_pass) {
auto pass = PassRegistry::Instance().Get(pass_name);
if (pass_name == "infer_shape_pass") {
pass->Set("feed_list", new std::vector<std::string>(feed_list.begin(),
feed_list.end()));
}
pass->Apply(graph);
}
for (auto pass_name : compile_pass) {
auto pass = PassRegistry::Instance().Get(pass_name);
pass->Set("feed_list",
new std::vector<std::string>(feed_list.begin(), feed_list.end()));
pass->Set("fetch_list", new std::vector<std::string>(fetch_list.begin(),
fetch_list.end()));
pass->Apply(graph);
}
VLOG(10) << "leave InferenceProcessPass::ApplyImpl";
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(inference_process_pass,
paddle::framework::ir::InferenceProcessPass);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/ipu/ipu_pass_base.h"
namespace paddle {
namespace framework {
namespace ir {
class InferenceProcessPass : public IPUPassBase {
protected:
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/ipu/ipu_graph_builder_pass.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/platform/device/ipu/ipu_backend.h"
namespace paddle {
namespace framework {
namespace ir {
void IpuGraphBuilderPass::ApplyImpl(ir::Graph* graph) const {
VLOG(10) << "enter IpuGraphBuilderPass::ApplyImpl";
VLOG(10) << "Raw Graph: ";
VLOG(10) << DebugString(graph);
std::vector<std::string> feed_list;
feed_list = Get<std::vector<std::string>>("feed_list");
std::vector<std::string> fetch_list;
fetch_list = Get<std::vector<std::string>>("fetch_list");
std::shared_ptr<platform::ipu::IpuBackend> ipu_backend =
platform::ipu::IpuBackend::GetInstance();
ipu_backend->Compile(graph, feed_list, fetch_list);
VLOG(10) << "Post Graph: ";
VLOG(10) << DebugString(graph);
VLOG(10) << "leave IpuGraphBuilderPass::ApplyImpl";
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(ipu_graph_builder_pass,
paddle::framework::ir::IpuGraphBuilderPass)
.RequirePassAttr("feed_list")
.RequirePassAttr("fetch_list");
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/ipu/ipu_pass_base.h"
namespace paddle {
namespace framework {
namespace ir {
class IpuGraphBuilderPass : public IPUPassBase {
protected:
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/ipu/ipu_inplace_pass.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
namespace paddle {
namespace framework {
namespace ir {
std::string GenerateVarName(Node *node) {
return node->Name() + "_" + std::to_string(node->id());
}
void IpuInplacePass::ApplyImpl(ir::Graph *graph) const {
// use this pass after forward_graph_extract_pass
// raise error if the inplaced var both in feed_list & fetch_list
VLOG(10) << "enter IpuInplacePass::ApplyImpl";
VLOG(10) << "Raw Graph: ";
VLOG(10) << DebugString(graph);
std::vector<std::string> feed_list;
feed_list = Get<std::vector<std::string>>("feed_list");
std::vector<std::string> fetch_list;
fetch_list = Get<std::vector<std::string>>("fetch_list");
std::map<std::string, int> var_name;
for (auto *node : graph->Nodes()) {
if (node->IsVar()) {
if (var_name.find(node->Name()) == var_name.end()) {
var_name.emplace(node->Name(), 1);
} else {
var_name[node->Name()]++;
}
}
}
for (auto *node : graph->Nodes()) {
if (node->IsVar()) {
if (var_name[node->Name()] > 1) {
auto is_feed = (std::find(feed_list.begin(), feed_list.end(),
node->Name()) != feed_list.end()) &&
(node->inputs.size() == 0);
auto is_fetch = (std::find(fetch_list.begin(), fetch_list.end(),
node->Name()) != fetch_list.end()) &&
(node->outputs.size() == 0);
if (!is_feed && !is_fetch && !node->Var()->Persistable()) {
auto old_name = node->Name();
auto new_name = GenerateVarName(node);
node->RenameVar(new_name);
for (auto *op_in : node->inputs) {
op_in->Op()->RenameOutput(old_name, new_name);
}
for (auto *op_out : node->outputs) {
op_out->Op()->RenameInput(old_name, new_name);
}
}
}
}
}
VLOG(10) << "Post Graph: ";
VLOG(10) << DebugString(graph);
VLOG(10) << "leave IpuInplacePass::ApplyImpl";
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(ipu_inplace_pass, paddle::framework::ir::IpuInplacePass)
.RequirePassAttr("feed_list")
.RequirePassAttr("fetch_list");
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/ipu/ipu_pass_base.h"
namespace paddle {
namespace framework {
namespace ir {
class IpuInplacePass : public IPUPassBase {
protected:
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/ipu/ipu_pass_base.h"
namespace paddle {
namespace framework {
namespace ir {
void IPUPassBase::Init(const std::string& repr, Graph* graph) const {
repr_ = repr;
graph_ = graph;
}
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/scope.h"
namespace paddle {
namespace framework {
namespace ir {
class IPUPassBase : public Pass {
public:
void Init(const std::string& repr, Graph* graph) const;
virtual ~IPUPassBase() {}
protected:
mutable Graph* graph_;
mutable std::string repr_;
};
} // namespace ir
} // namespace framework
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册