未验证 提交 84f257bd 编写于 作者: A Allen Guo 提交者: GitHub

[IPU] update ipu releated passes p0 (#38846)

* update ipu releated passes
Co-authored-by: NXiaobing Wang <xiaobingw@graphcore.ai>
Co-authored-by: NAllen Guo <alleng@graphcore.ai>
Co-authored-by: NZhixin Yao <zhixiny@graphcore.ai>
Co-authored-by: NHaicheng Jiang <haichengj@graphcore.ai>
Co-authored-by: NHan Zhao <hanzhao@graphcore.ai>

* remove ipu_pass_base

* update error msg

* update error msg 02

* split pr 01

* restore ipu_pass_base
Co-authored-by: NXiaobing Wang <xiaobingw@graphcore.ai>
Co-authored-by: NZhixin Yao <zhixiny@graphcore.ai>
Co-authored-by: NHaicheng Jiang <haichengj@graphcore.ai>
Co-authored-by: NHan Zhao <hanzhao@graphcore.ai>
上级 e50d883e
...@@ -26,13 +26,15 @@ namespace ir { ...@@ -26,13 +26,15 @@ namespace ir {
void AvgShardPass::ApplyImpl(ir::Graph* graph) const { void AvgShardPass::ApplyImpl(ir::Graph* graph) const {
VLOG(10) << "enter AvgShardPass::ApplyImpl"; VLOG(10) << "enter AvgShardPass::ApplyImpl";
std::shared_ptr<platform::ipu::IpuBackend> ipu_backend = auto ipu_backend = platform::ipu::IpuBackend::GetInstance();
platform::ipu::IpuBackend::GetInstance();
if (ipu_backend->GetIpuStrategy()->need_avg_shard) { if (ipu_backend->GetIpuStrategy()->need_avg_shard) {
VLOG(10) << "start AvgShardPass"; VLOG(10) << "start AvgShardPass";
auto nodes = ir::TopologySortOperations(*graph); auto nodes = ir::TopologySortOperations(*graph);
auto num_ipus = ipu_backend->GetIpuStrategy()->num_ipus; auto num_ipus = ipu_backend->GetIpuStrategy()->num_ipus;
auto replica_factor =
ipu_backend->GetIpuStrategy()->popart_options.replicatedGraphCount;
num_ipus = num_ipus / replica_factor;
int shard_position = nodes.size() / num_ipus; int shard_position = nodes.size() / num_ipus;
int index_and_stage = -1; int index_and_stage = -1;
......
...@@ -14,13 +14,13 @@ ...@@ -14,13 +14,13 @@
#pragma once #pragma once
#include "paddle/fluid/framework/ir/ipu/ipu_pass_base.h" #include "paddle/fluid/framework/ir/pass.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
class AvgShardPass : public IPUPassBase { class AvgShardPass : public Pass {
protected: protected:
void ApplyImpl(ir::Graph* graph) const override; void ApplyImpl(ir::Graph* graph) const override;
}; };
......
...@@ -15,13 +15,13 @@ ...@@ -15,13 +15,13 @@
#pragma once #pragma once
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/ipu/ipu_pass_base.h" #include "paddle/fluid/framework/ir/pass.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
class ForwardGraphExtractPass : public IPUPassBase { class ForwardGraphExtractPass : public Pass {
protected: protected:
void ApplyImpl(ir::Graph* graph) const override; void ApplyImpl(ir::Graph* graph) const override;
}; };
......
...@@ -29,10 +29,10 @@ void InferShapePass::ApplyImpl(ir::Graph* graph) const { ...@@ -29,10 +29,10 @@ void InferShapePass::ApplyImpl(ir::Graph* graph) const {
VLOG(10) << "Raw Graph: "; VLOG(10) << "Raw Graph: ";
VLOG(10) << DebugString(graph); VLOG(10) << DebugString(graph);
std::shared_ptr<platform::ipu::IpuBackend> ipu_backend = // Make batch_size fixed
platform::ipu::IpuBackend::GetInstance(); bool need_infer_shape = false;
auto batch_size = ipu_backend->GetIpuStrategy()->batch_size; auto ipu_backend = platform::ipu::IpuBackend::GetInstance();
auto micro_batch_size = ipu_backend->GetIpuStrategy()->micro_batch_size;
auto feed_list = Get<std::vector<std::string>>("feed_list"); auto feed_list = Get<std::vector<std::string>>("feed_list");
for (auto node : graph->Nodes()) { for (auto node : graph->Nodes()) {
if (!node->IsVar()) { if (!node->IsVar()) {
...@@ -43,8 +43,9 @@ void InferShapePass::ApplyImpl(ir::Graph* graph) const { ...@@ -43,8 +43,9 @@ void InferShapePass::ApplyImpl(ir::Graph* graph) const {
if (is_feed) { if (is_feed) {
auto input_shape = node->Var()->GetShape(); auto input_shape = node->Var()->GetShape();
if (input_shape[0] <= -1) { if (input_shape[0] <= -1) {
input_shape[0] = batch_size; input_shape[0] = micro_batch_size;
node->Var()->SetShape(input_shape); node->Var()->SetShape(input_shape);
need_infer_shape = true;
} }
// int64->int32 // int64->int32
if (node->Var()->GetDataType() == proto::VarType::INT64) { if (node->Var()->GetDataType() == proto::VarType::INT64) {
...@@ -54,44 +55,63 @@ void InferShapePass::ApplyImpl(ir::Graph* graph) const { ...@@ -54,44 +55,63 @@ void InferShapePass::ApplyImpl(ir::Graph* graph) const {
} }
// temp scope for shape inference // temp scope for shape inference
std::shared_ptr<paddle::framework::Scope> scope( if (need_infer_shape) {
new paddle::framework::Scope()); std::shared_ptr<paddle::framework::Scope> scope(
for (auto node : graph->Nodes()) { new paddle::framework::Scope());
if (!node->IsVar()) { for (auto node : graph->Nodes()) {
continue; if (!node->IsVar()) {
} continue;
auto var_desc = node->Var(); }
auto* ptr = scope->Var(var_desc->Name()); auto var_desc = node->Var();
paddle::framework::InitializeVariable(ptr, var_desc->GetType()); auto* ptr = scope->Var(var_desc->Name());
paddle::framework::InitializeVariable(ptr, var_desc->GetType());
auto tensor = ptr->GetMutable<paddle::framework::LoDTensor>(); auto tensor = ptr->GetMutable<paddle::framework::LoDTensor>();
tensor->Resize(paddle::framework::make_ddim(var_desc->GetShape())); tensor->Resize(paddle::framework::make_ddim(var_desc->GetShape()));
} }
// infer shape // infer shape
auto nodes = ir::TopologySortOperations(*graph); auto nodes = ir::TopologySortOperations(*graph);
for (auto node : nodes) { for (auto node : nodes) {
auto op_desc = node->Op(); VLOG(10) << "InferShapePass: Infer shape for Op (" << node->Name() << ")";
auto op = paddle::framework::OpRegistry::CreateOp(*op_desc); auto op_desc = node->Op();
paddle::framework::RuntimeContext ctx(op->Inputs(), op->Outputs(), *scope); if (op_desc->Type() == "popart_optimizer") {
op->RuntimeInferShape(*scope, paddle::platform::CPUPlace(), ctx); continue;
}
auto op = paddle::framework::OpRegistry::CreateOp(*op_desc);
paddle::framework::RuntimeContext ctx(op->Inputs(), op->Outputs(),
*scope);
op->RuntimeInferShape(*scope, paddle::platform::CPUPlace(), ctx);
for (auto it = ctx.outputs.begin(); it != ctx.outputs.end(); it++) { for (auto it = ctx.outputs.begin(); it != ctx.outputs.end(); it++) {
for (int i = 0; i < it->second.size(); i++) { for (int i = 0; i < it->second.size(); i++) {
auto output_name = op_desc->Output(it->first)[i]; auto output_name = op_desc->Output(it->first)[i];
auto dim = auto dim =
it->second[i]->GetMutable<paddle::framework::LoDTensor>()->dims(); it->second[i]->GetMutable<paddle::framework::LoDTensor>()->dims();
auto new_shape = paddle::framework::vectorize(dim); auto new_shape = paddle::framework::vectorize(dim);
for (auto output_node : node->outputs) { for (auto output_node : node->outputs) {
if (output_node->Name() == output_name) { if (output_node->Name() == output_name) {
output_node->Var()->SetShape(new_shape); output_node->Var()->SetShape(new_shape);
if (VLOG_IS_ON(10)) {
std::ostringstream sout;
sout << "InferShapePass: output[" << output_node->Name()
<< "], infer shape:[";
for (auto s : new_shape) {
sout << std::to_string(s) << ", ";
}
sout << "]";
VLOG(10) << sout.str();
}
}
} }
} }
} }
VLOG(10) << "InferShapePass: Infer shape for Op (" << node->Name()
<< ") finished";
} }
// release the temp scope
scope.reset();
} }
// release the temp scope
scope.reset();
VLOG(10) << "Post Graph: "; VLOG(10) << "Post Graph: ";
VLOG(10) << DebugString(graph); VLOG(10) << DebugString(graph);
......
...@@ -14,13 +14,13 @@ ...@@ -14,13 +14,13 @@
#pragma once #pragma once
#include "paddle/fluid/framework/ir/ipu/ipu_pass_base.h" #include "paddle/fluid/framework/ir/pass.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
class InferShapePass : public IPUPassBase { class InferShapePass : public Pass {
protected: protected:
void ApplyImpl(ir::Graph* graph) const override; void ApplyImpl(ir::Graph* graph) const override;
}; };
......
...@@ -14,13 +14,13 @@ ...@@ -14,13 +14,13 @@
#pragma once #pragma once
#include "paddle/fluid/framework/ir/ipu/ipu_pass_base.h" #include "paddle/fluid/framework/ir/pass.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
class InferencePostprocessPass : public IPUPassBase { class InferencePostprocessPass : public Pass {
protected: protected:
void ApplyImpl(ir::Graph* graph) const override; void ApplyImpl(ir::Graph* graph) const override;
}; };
......
...@@ -29,8 +29,7 @@ void InferenceProcessPass::ApplyImpl(ir::Graph* graph) const { ...@@ -29,8 +29,7 @@ void InferenceProcessPass::ApplyImpl(ir::Graph* graph) const {
VLOG(10) << "enter InferenceProcessPass::ApplyImpl"; VLOG(10) << "enter InferenceProcessPass::ApplyImpl";
// Get a new instance of ipu_backend // Get a new instance of ipu_backend
std::shared_ptr<platform::ipu::IpuBackend> ipu_backend = auto ipu_backend = platform::ipu::IpuBackend::GetInstance();
platform::ipu::IpuBackend::GetNewInstance();
// Set scope // Set scope
auto& scope = graph->Get<Scope>(kParamScopeAttr); auto& scope = graph->Get<Scope>(kParamScopeAttr);
...@@ -40,18 +39,34 @@ void InferenceProcessPass::ApplyImpl(ir::Graph* graph) const { ...@@ -40,18 +39,34 @@ void InferenceProcessPass::ApplyImpl(ir::Graph* graph) const {
static std::shared_ptr<platform::ipu::IpuStrategy> ipu_strategy_instance_( static std::shared_ptr<platform::ipu::IpuStrategy> ipu_strategy_instance_(
new platform::ipu::IpuStrategy()); new platform::ipu::IpuStrategy());
ipu_strategy_instance_->is_training = false; ipu_strategy_instance_->is_training = false;
// Set graph replication
auto replica_num = graph->Get<int>("replica_num");
if (replica_num > 1) {
ipu_strategy_instance_->popart_options.enableReplicatedGraphs = true;
ipu_strategy_instance_->popart_options.replicatedGraphCount = replica_num;
}
// Set the num of IPUs
auto num_ipus = graph->Get<int>("num_ipus"); auto num_ipus = graph->Get<int>("num_ipus");
ipu_strategy_instance_->num_ipus = num_ipus; // Set sharding
if (num_ipus > 1) { if (num_ipus > 1) {
ipu_strategy_instance_->popart_options_.virtualGraphMode = ipu_strategy_instance_->need_avg_shard = true;
ipu_strategy_instance_->popart_options.virtualGraphMode =
platform::ipu::VirtualGraphMode::Manual; platform::ipu::VirtualGraphMode::Manual;
} else { } else {
ipu_strategy_instance_->popart_options_.virtualGraphMode = ipu_strategy_instance_->need_avg_shard = false;
ipu_strategy_instance_->popart_options.virtualGraphMode =
platform::ipu::VirtualGraphMode::Off; platform::ipu::VirtualGraphMode::Off;
} }
// total num IPUs = num_ipus * replica_num
ipu_strategy_instance_->num_ipus = num_ipus * replica_num;
// Set micro_batch_size for shape inference
ipu_strategy_instance_->micro_batch_size =
graph->Get<int>("micro_batch_size");
// Set pipelining
auto enable_pipelining = graph->Get<bool>("enable_pipelining"); auto enable_pipelining = graph->Get<bool>("enable_pipelining");
ipu_strategy_instance_->popart_options_.enablePipelining = enable_pipelining; ipu_strategy_instance_->popart_options.enablePipelining = enable_pipelining;
if (enable_pipelining) { if (enable_pipelining) {
auto batches_per_step = graph->Get<int>("batches_per_step"); auto batches_per_step = graph->Get<int>("batches_per_step");
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(
...@@ -60,8 +75,20 @@ void InferenceProcessPass::ApplyImpl(ir::Graph* graph) const { ...@@ -60,8 +75,20 @@ void InferenceProcessPass::ApplyImpl(ir::Graph* graph) const {
"greater than the number of IPUs")); "greater than the number of IPUs"));
ipu_strategy_instance_->batches_per_step = batches_per_step; ipu_strategy_instance_->batches_per_step = batches_per_step;
} }
ipu_strategy_instance_->batch_size = graph->Get<int>("batch_size");
ipu_strategy_instance_->need_avg_shard = graph->Get<bool>("need_avg_shard"); // Set FP16
auto enable_fp16 = graph->Get<bool>("enable_fp16");
ipu_strategy_instance_->enable_fp16 = enable_fp16;
if (enable_fp16) {
auto enable_half_partial = graph->Get<bool>("enable_half_partial");
if (enable_half_partial) {
ipu_strategy_instance_->popart_options.partialsTypeMatMuls = "half";
}
}
// Set available memory proportion for matmul/conv
ipu_strategy_instance_->available_memory_proportion =
graph->Get<float>("available_memory_proportion");
ipu_backend->SetIpuStrategy(*(ipu_strategy_instance_.get())); ipu_backend->SetIpuStrategy(*(ipu_strategy_instance_.get()));
...@@ -94,9 +121,9 @@ void InferenceProcessPass::ApplyImpl(ir::Graph* graph) const { ...@@ -94,9 +121,9 @@ void InferenceProcessPass::ApplyImpl(ir::Graph* graph) const {
} }
// Run passes // Run passes
std::vector<std::string> graph_pass = {"forward_graph_extract_pass", std::vector<std::string> graph_pass = {
"infer_shape_pass", "avg_shard_pass", "forward_graph_extract_pass", "infer_shape_pass", "avg_shard_pass",
"popart_canonicalization_pass"}; "popart_canonicalization_pass", "transfer_cast_op_pass"};
std::vector<std::string> compile_pass = { std::vector<std::string> compile_pass = {
"ipu_inplace_pass", "ipu_graph_builder_pass", "ipu_runtime_replacer_pass", "ipu_inplace_pass", "ipu_graph_builder_pass", "ipu_runtime_replacer_pass",
"inference_postprocess_pass"}; "inference_postprocess_pass"};
......
...@@ -14,13 +14,13 @@ ...@@ -14,13 +14,13 @@
#pragma once #pragma once
#include "paddle/fluid/framework/ir/ipu/ipu_pass_base.h" #include "paddle/fluid/framework/ir/pass.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
class InferenceProcessPass : public IPUPassBase { class InferenceProcessPass : public Pass {
protected: protected:
void ApplyImpl(ir::Graph* graph) const override; void ApplyImpl(ir::Graph* graph) const override;
}; };
......
...@@ -32,8 +32,7 @@ void IpuGraphBuilderPass::ApplyImpl(ir::Graph* graph) const { ...@@ -32,8 +32,7 @@ void IpuGraphBuilderPass::ApplyImpl(ir::Graph* graph) const {
std::vector<std::string> fetch_list; std::vector<std::string> fetch_list;
fetch_list = Get<std::vector<std::string>>("fetch_list"); fetch_list = Get<std::vector<std::string>>("fetch_list");
std::shared_ptr<platform::ipu::IpuBackend> ipu_backend = auto ipu_backend = platform::ipu::IpuBackend::GetInstance();
platform::ipu::IpuBackend::GetInstance();
ipu_backend->Compile(graph, feed_list, fetch_list); ipu_backend->Compile(graph, feed_list, fetch_list);
......
...@@ -15,13 +15,13 @@ ...@@ -15,13 +15,13 @@
#pragma once #pragma once
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/ipu/ipu_pass_base.h" #include "paddle/fluid/framework/ir/pass.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
class IpuGraphBuilderPass : public IPUPassBase { class IpuGraphBuilderPass : public Pass {
protected: protected:
void ApplyImpl(ir::Graph* graph) const override; void ApplyImpl(ir::Graph* graph) const override;
}; };
......
...@@ -14,13 +14,13 @@ ...@@ -14,13 +14,13 @@
#pragma once #pragma once
#include "paddle/fluid/framework/ir/ipu/ipu_pass_base.h" #include "paddle/fluid/framework/ir/pass.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
class IpuInplacePass : public IPUPassBase { class IpuInplacePass : public Pass {
protected: protected:
void ApplyImpl(ir::Graph* graph) const override; void ApplyImpl(ir::Graph* graph) const override;
}; };
......
...@@ -56,19 +56,6 @@ void IpuRuntimeReplacerPass::ApplyImpl(ir::Graph* graph) const { ...@@ -56,19 +56,6 @@ void IpuRuntimeReplacerPass::ApplyImpl(ir::Graph* graph) const {
} }
} }
// set ipu_runtime_op dtype attr
if (fetch_list.size() == 1) {
for (auto* node : graph->Nodes()) {
if (node->IsVar()) {
for (auto fetch : fetch_list) {
if (node->Name() == fetch) {
ipu_rt_node->Op()->SetAttr("dtype", node->Var()->GetDataType());
}
}
}
}
}
// Remove unneeded nodes. // Remove unneeded nodes.
std::unordered_set<const Node*> marked_nodes; std::unordered_set<const Node*> marked_nodes;
for (auto* node : graph->Nodes()) { for (auto* node : graph->Nodes()) {
......
...@@ -15,13 +15,13 @@ ...@@ -15,13 +15,13 @@
#pragma once #pragma once
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/ipu/ipu_pass_base.h" #include "paddle/fluid/framework/ir/pass.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
class IpuRuntimeReplacerPass : public IPUPassBase { class IpuRuntimeReplacerPass : public Pass {
protected: protected:
void ApplyImpl(ir::Graph* graph) const override; void ApplyImpl(ir::Graph* graph) const override;
}; };
......
...@@ -15,72 +15,303 @@ ...@@ -15,72 +15,303 @@
#include "paddle/fluid/framework/ir/ipu/optimizer_extract_pass.h" #include "paddle/fluid/framework/ir/ipu/optimizer_extract_pass.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h" #include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/platform/device/ipu/ipu_backend.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
void IpuOptimizerExtractPass::ApplyImpl(ir::Graph* graph) const { std::set<std::string> ignored_ops = {
VLOG(10) << "enter IpuOptimizerExtractPass::ApplyImpl"; "sign",
VLOG(10) << "Raw Graph: "; "sum",
VLOG(10) << DebugString(graph); "clip",
"clip_by_norm",
"square",
"reduce_sum",
"sqrt",
"elementwise_max",
"elementwise_div",
"elementwise_mul",
"scale", // adamax
"assign", // adamw
};
const bool startswith(const std::string& str, const std::string& pre) {
if (str.rfind(pre, 0) == 0) {
return true;
} else {
return false;
}
}
const bool is_grad_clip_op(const std::string& op_namescope) {
return startswith(op_namescope, "/gradient_clip");
}
auto ipu_backend = paddle::platform::ipu::IpuBackend::GetInstance(); const bool is_optimizer_op(const std::string& op_namescope) {
return startswith(op_namescope, "/optimizer");
}
const bool is_regularization_op(const std::string& op_namescope) {
return startswith(op_namescope, "/regularization");
}
void IpuOptimizerExtractPass::ApplyImpl(ir::Graph* graph) const {
// 这里构建的 op 符合 popart 的定义, 涉及到的一些值需要在 LowerOptimier 时获得
OpDesc new_op("popart_optimizer", {}, {}, {});
new_op.SetAttr("op_role", 0);
new_op.SetAttr("with_lr_sched", false);
std::set<std::string> set_ops{};
// use map store <op_type, op_ptr> ?
for (auto* node : graph->Nodes()) { for (auto* node : graph->Nodes()) {
if (node->IsOp() && node->Op()) { if (!node->IsOp()) {
int op_role = BOOST_GET_CONST( continue;
int, node->Op()->GetAttr( }
framework::OpProtoAndCheckerMaker::OpRoleAttrName()));
// graph usually have multiple optimizer node for different parameter,
// and these node have the same type and attr value usually
if ((op_role == static_cast<int>(framework::OpRole::kOptimize))) {
ipu_backend->GetExecutor().SetOptimizerType(node->Op()->Type());
VLOG(10) << "found optimizer type: " << node->Op()->Type();
for (const std::string& attr_name : node->Op()->AttrNames()) {
auto attr_type = node->Op()->GetAttrType(attr_name);
// with adam, attr are float
if (attr_type == proto::AttrType::FLOAT) {
auto attr_value =
BOOST_GET_CONST(float, node->Op()->GetAttr(attr_name));
ipu_backend->GetExecutor().SetOptimizerAttr(attr_name, attr_value);
} else {
VLOG(10) << "Skip " << attr_type;
}
}
auto lr_var_name = node->Op()->Input("LearningRate"); auto op = node->Op();
PADDLE_ENFORCE_EQ(lr_var_name.size(), 1u, auto op_type = op->Type();
platform::errors::InvalidArgument( int op_role_ = BOOST_GET_CONST(
"In op(%s), find input(LearningRate) failed.", int, op->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName()));
node->Op()->Type())); auto op_role = static_cast<OpRole>(op_role_);
ipu_backend->GetExecutor().SetLRVarName(lr_var_name[0]); if (op_role == OpRole::kOptimize) {
if (set_ops.count(op_type)) {
continue;
} }
if ((op_role == static_cast<int>(framework::OpRole::kLoss))) { auto op_namescope =
VLOG(10) << "found loss op type: " << node->Op()->Type(); BOOST_GET_CONST(std::string, op->GetAttr("op_namescope"));
auto outputs = node->Op()->Outputs(); bool is_grad_clip = is_grad_clip_op(op_namescope);
PADDLE_ENFORCE_EQ( // bool is_optimizer = is_optimizer_op(op_namescope);
outputs.size(), 1, bool is_regularization = is_regularization_op(op_namescope);
platform::errors::InvalidArgument("Can only support one loss key"));
auto losses_name = outputs.begin()->second;
PADDLE_ENFORCE_EQ(losses_name.size(), 1,
platform::errors::InvalidArgument(
"Can only support one loss name"));
ipu_backend->GetExecutor().SetLoss(losses_name[0]); VLOG(10) << "found optimizer releated op: " << op_type;
// initial larning_rate will be set in LowerOptimier
set_ops.insert(op_type);
if (op_type == "sgd") {
auto type = std::string{"sgd"};
auto lr_var = op->Input("LearningRate").front();
new_op.SetAttr("type", type);
new_op.SetAttr("lr_var", lr_var);
new_op.SetAttr("weight_decay", 0.0f);
new_op.SetAttr("momentum", 0.0f);
new_op.SetAttr("raw_type", op_type);
} else if (op_type == "momentum") {
auto type = std::string{"sgd"};
// auto LearningRate = op->Input("LearningRate");
auto use_nesterov = BOOST_GET_CONST(bool, op->GetAttr("use_nesterov"));
PADDLE_ENFORCE_EQ(use_nesterov, false,
platform::errors::Unimplemented(
"ipu does not support nesterov mode."));
auto regularization_method =
BOOST_GET_CONST(std::string, op->GetAttr("regularization_method"));
PADDLE_ENFORCE_NE(regularization_method, "l1_decay",
platform::errors::Unimplemented(
"ipu does not support l1_decay mode."));
auto multi_precision =
BOOST_GET_CONST(bool, op->GetAttr("multi_precision"));
PADDLE_ENFORCE_EQ(multi_precision, false,
platform::errors::Unimplemented(
"ipu does not support multi_precision mode."));
auto rescale_grad = BOOST_GET_CONST(float, op->GetAttr("rescale_grad"));
PADDLE_ENFORCE_EQ(rescale_grad, 1.0,
platform::errors::Unimplemented(
"ipu does not support rescale_grad mode."));
auto regularization_coeff =
BOOST_GET_CONST(float, op->GetAttr("regularization_coeff"));
auto lr_var = op->Input("LearningRate").front();
auto momentum = BOOST_GET_CONST(float, op->GetAttr("mu"));
new_op.SetAttr("type", type);
new_op.SetAttr("lr_var", lr_var);
new_op.SetAttr("momentum", momentum);
new_op.SetAttr("weight_decay", regularization_coeff);
new_op.SetAttr("raw_type", op_type);
} else if (op_type == "adam" || op_type == "adamw") {
auto type = std::string{"adam"};
auto lr_var = op->Input("LearningRate").front();
auto beta1 = BOOST_GET_CONST(float, op->GetAttr("beta1"));
auto beta2 = BOOST_GET_CONST(float, op->GetAttr("beta2"));
auto epsilon = BOOST_GET_CONST(float, op->GetAttr("epsilon"));
auto lazy_mode = BOOST_GET_CONST(bool, op->GetAttr("lazy_mode"));
auto multi_precision =
BOOST_GET_CONST(bool, op->GetAttr("multi_precision"));
PADDLE_ENFORCE_EQ(lazy_mode, false,
platform::errors::Unimplemented(
"ipu does not support lazy_mode mode."));
PADDLE_ENFORCE_EQ(multi_precision, false,
platform::errors::Unimplemented(
"ipu does not support multi_precision mode."));
new_op.SetAttr("type", type);
new_op.SetAttr("lr_var", lr_var);
new_op.SetAttr("weight_decay", 0.0f);
new_op.SetAttr("beta1", beta1);
new_op.SetAttr("beta2", beta2);
new_op.SetAttr("eps", epsilon);
new_op.SetAttr("adam_mode", std::string{"adam"});
// adam or adamw
if (op_type == "adam") {
new_op.SetAttr("weight_decay_mode", std::string{"l2_regularization"});
new_op.SetAttr("raw_type", std::string{"adam"});
} else {
new_op.SetAttr("weight_decay_mode", std::string{"decay"});
new_op.SetAttr("raw_type", std::string{"adamw"});
}
} else if (op_type == "adamax") {
auto type = std::string{"adam"};
auto lr_var = op->Input("LearningRate").front();
auto beta1 = BOOST_GET_CONST(float, op->GetAttr("beta1"));
auto beta2 = BOOST_GET_CONST(float, op->GetAttr("beta2"));
auto epsilon = BOOST_GET_CONST(float, op->GetAttr("epsilon"));
new_op.SetAttr("type", type);
new_op.SetAttr("lr_var", lr_var);
new_op.SetAttr("weight_decay", 0.0f);
new_op.SetAttr("beta1", beta1);
new_op.SetAttr("beta2", beta2);
new_op.SetAttr("eps", epsilon);
new_op.SetAttr("adam_mode", std::string{"adamax"});
new_op.SetAttr("weight_decay_mode", std::string{"l2_regularization"});
new_op.SetAttr("raw_type", op_type);
} else if (op_type == "lamb") {
// use decay mode
auto type = std::string{"adam"};
auto lr_var = op->Input("LearningRate").front();
auto weight_decay = BOOST_GET_CONST(float, op->GetAttr("weight_decay"));
auto beta1 = BOOST_GET_CONST(float, op->GetAttr("beta1"));
auto beta2 = BOOST_GET_CONST(float, op->GetAttr("beta2"));
auto epsilon = BOOST_GET_CONST(float, op->GetAttr("epsilon"));
new_op.SetAttr("type", type);
new_op.SetAttr("lr_var", lr_var);
new_op.SetAttr("weight_decay", weight_decay);
new_op.SetAttr("beta1", beta1);
new_op.SetAttr("beta2", beta2);
new_op.SetAttr("eps", epsilon);
new_op.SetAttr("adam_mode", std::string{"lamb"});
new_op.SetAttr("weight_decay_mode", std::string{"decay"});
new_op.SetAttr("raw_type", op_type);
} else if (op_type == "adadelta") {
// NO LearningRate
auto type = std::string{"adaptive"};
auto rho = BOOST_GET_CONST(float, op->GetAttr("rho"));
auto epsilon = BOOST_GET_CONST(float, op->GetAttr("epsilon"));
new_op.SetAttr("type", type);
new_op.SetAttr("weight_decay", 0.0f);
new_op.SetAttr("alpha", rho);
new_op.SetAttr("eps", epsilon);
new_op.SetAttr("momentum", 0.0f);
new_op.SetAttr("adaptive_mode", std::string{"adadelta"});
new_op.SetAttr("weight_decay_mode", std::string{"l2_regularization"});
new_op.SetAttr("raw_type", op_type);
} else if (op_type == "adagrad") {
auto type = std::string{"adaptive"};
auto lr_var = op->Input("LearningRate").front();
auto epsilon = BOOST_GET_CONST(float, op->GetAttr("epsilon"));
new_op.SetAttr("type", type);
new_op.SetAttr("lr_var", lr_var);
new_op.SetAttr("weight_decay", 0.0f);
// `alpha` use default
new_op.SetAttr("alpha", 0.99f);
new_op.SetAttr("eps", epsilon);
new_op.SetAttr("momentum", 0.0f);
new_op.SetAttr("adaptive_mode", std::string{"adagrad"});
new_op.SetAttr("weight_decay_mode", std::string{"l2_regularization"});
new_op.SetAttr("raw_type", op_type);
} else if (op_type == "rmsprop") {
auto type = std::string{"adaptive"};
auto lr_var = op->Input("LearningRate").front();
auto epsilon = BOOST_GET_CONST(float, op->GetAttr("epsilon"));
auto decay = BOOST_GET_CONST(float, op->GetAttr("decay"));
auto momentum = BOOST_GET_CONST(float, op->GetAttr("momentum"));
auto centered = BOOST_GET_CONST(bool, op->GetAttr("centered"));
new_op.SetAttr("type", type);
new_op.SetAttr("weight_decay", 0.0f);
new_op.SetAttr("alpha", decay);
new_op.SetAttr("eps", epsilon);
new_op.SetAttr("momentum", momentum);
new_op.SetAttr("weight_decay_mode", std::string{"l2_regularization"});
if (centered) {
new_op.SetAttr("adaptive_mode", std::string{"centered_rmsprop"});
new_op.SetAttr("raw_type", op_type);
} else {
new_op.SetAttr("adaptive_mode", std::string{"rmsprop"});
new_op.SetAttr("raw_type", op_type);
}
} else if (is_regularization && op_type == "scale") {
// set weight_decay for L2Decay
auto scale = BOOST_GET_CONST(float, op->GetAttr("scale"));
new_op.SetAttr("weight_decay", scale);
} else if (is_grad_clip && op_type == "fill_constant") {
// set clip_norm for ClipGradByGlobalNorm
auto value = BOOST_GET_CONST(float, op->GetAttr("value"));
new_op.SetAttr("clip_norm", value);
} else if (ignored_ops.count(op_type)) {
VLOG(10) << "Ignore optimizer releated op: " << op_type;
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Unknown optimizer releated op_type: %s", op_type));
} }
} else if (op_role == OpRole::kLoss) {
VLOG(10) << "found loss op type: " << op->Type();
auto outputs = op->Outputs();
PADDLE_ENFORCE_EQ(
outputs.size(), 1,
platform::errors::InvalidArgument("Can only support one loss key"));
auto losses = outputs.begin()->second;
PADDLE_ENFORCE_EQ(
losses.size(), 1,
platform::errors::InvalidArgument("Can only support one loss name"));
auto loss_var = losses.front();
new_op.SetAttr("loss_var", loss_var);
} else if (op_role == OpRole::kLRSched) {
// op_role == OpRole::kLRSched | OpRole::kOptimize
new_op.SetAttr("with_lr_sched", true);
}
}
// seems with_lr_sched is always true
new_op.SetAttr("with_lr_sched", true);
// setup weight deacy
// weight_decay/coeff is "scale" attr of scale_op
if (set_ops.count("scale") && set_ops.count("sum")) {
if (set_ops.count("sign")) {
// L1Decay
// sign + scale + sum
PADDLE_THROW(
platform::errors::Unimplemented("Unsupported L1Decay regularizer"));
} else {
// L2Decay
// scale + sum
new_op.SetAttr("weight_decay_mode", std::string{"l2_regularization"});
} }
} else {
VLOG(10) << "No weight deacy setting found";
}
// setup grad clip
if (set_ops.count("clip")) {
// ClipGradByValue
PADDLE_THROW(
platform::errors::Unimplemented("Unsupported ClipGradByValue"));
} else if (set_ops.count("clip_by_norm")) {
// ClipGradByNorm
PADDLE_THROW(platform::errors::Unimplemented("Unsupported ClipGradByNorm"));
} }
VLOG(10) << "Post Graph: "; // ClipGradByGlobalNorm
VLOG(10) << DebugString(graph); // use graph pattern match ClipGradByGlobalNorm
VLOG(10) << "leave IpuOptimizerExtractPass::ApplyImpl"; // square + reduce_sum + sum + sqrt + fill_constant
// + elementwise_max + elementwise_div + elementwise_mul
// clip_norm from fill_constant`s attr `value` dtype float
if (new_op.HasAttr("type")) {
auto new_node = graph->CreateOpNode(&new_op);
VLOG(10) << "New Optimizer Node:";
VLOG(10) << DebugString(new_node);
} else {
PADDLE_THROW(platform::errors::NotFound(
"No optimizer found, optimizer must be one of these types: sgd, "
"momentum, adam, adamw, adamax, lamb, adadelta, adagrad or rmsprop"));
}
} }
} // namespace ir } // namespace ir
......
...@@ -14,14 +14,13 @@ ...@@ -14,14 +14,13 @@
#pragma once #pragma once
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/ipu/ipu_pass_base.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
class IpuOptimizerExtractPass : public IPUPassBase { class IpuOptimizerExtractPass : public Pass {
protected: protected:
void ApplyImpl(ir::Graph* graph) const override; void ApplyImpl(ir::Graph* graph) const override;
}; };
......
...@@ -14,23 +14,19 @@ ...@@ -14,23 +14,19 @@
#include "paddle/fluid/framework/ir/ipu/optimizer_state_align_pass.h" #include "paddle/fluid/framework/ir/ipu/optimizer_state_align_pass.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h" #include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/platform/device/ipu/common.h"
#include "paddle/fluid/platform/device/ipu/ipu_backend.h" #include "paddle/fluid/platform/device/ipu/ipu_backend.h"
#include "paddle/fluid/platform/device/ipu/ipu_names.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
using paddle::platform::ipu::IpuBackend;
using framework::ir::Graph;
using framework::ir::Node;
void IpuOptimizerStateAlignPass::ApplyImpl(ir::Graph* graph) const { void IpuOptimizerStateAlignPass::ApplyImpl(ir::Graph* graph) const {
VLOG(10) << "enter IpuOptimizerStateAlignPass::ApplyImpl"; VLOG(10) << "enter IpuOptimizerStateAlignPass::ApplyImpl";
VLOG(10) << "Raw Graph: "; VLOG(10) << "Raw Graph: ";
VLOG(10) << DebugString(graph); VLOG(10) << DebugString(graph);
auto ipu_backend = IpuBackend::GetInstance(); auto ipu_backend = platform::ipu::IpuBackend::GetInstance();
const auto* scope_ = ipu_backend->GetScope(); const auto* scope_ = ipu_backend->GetScope();
for (auto* node : graph->Nodes()) { for (auto* node : graph->Nodes()) {
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
#pragma once #pragma once
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/ipu/ipu_pass_base.h" #include "paddle/fluid/framework/ir/pass.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -26,7 +26,7 @@ namespace ir { ...@@ -26,7 +26,7 @@ namespace ir {
* include Adam/Lamb. * include Adam/Lamb.
*/ */
class IpuOptimizerStateAlignPass : public IPUPassBase { class IpuOptimizerStateAlignPass : public Pass {
protected: protected:
void ApplyImpl(ir::Graph* graph) const override; void ApplyImpl(ir::Graph* graph) const override;
}; };
......
...@@ -21,15 +21,13 @@ namespace paddle { ...@@ -21,15 +21,13 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
using framework::ir::Graph;
using framework::ir::Node;
using platform::ipu::SymbolHandler;
void PopartCanonicalizationPass::ApplyImpl(ir::Graph* graph) const { void PopartCanonicalizationPass::ApplyImpl(ir::Graph* graph) const {
VLOG(10) << "enter PopartCanonicalizationPass::ApplyImpl"; VLOG(10) << "enter PopartCanonicalizationPass::ApplyImpl";
VLOG(10) << "Raw Graph: "; VLOG(10) << "Raw Graph: ";
VLOG(10) << DebugString(graph); VLOG(10) << DebugString(graph);
auto custom_ops = Get<std::unordered_set<std::string>>("custom_ops");
std::vector<std::string> missing_ops;
auto nodes = graph->Nodes(); auto nodes = graph->Nodes();
for (auto* node : nodes) { for (auto* node : nodes) {
if (!node->IsOp()) { if (!node->IsOp()) {
...@@ -39,21 +37,40 @@ void PopartCanonicalizationPass::ApplyImpl(ir::Graph* graph) const { ...@@ -39,21 +37,40 @@ void PopartCanonicalizationPass::ApplyImpl(ir::Graph* graph) const {
auto op_type = op->Type(); auto op_type = op->Type();
ir::Node* new_node = nullptr; ir::Node* new_node = nullptr;
SymbolHandler handler = platform::ipu::GetHandler(op_type); platform::ipu::SymbolHandler handler = platform::ipu::GetHandler(op_type);
if (!handler && !custom_ops.empty()) {
if (custom_ops.count(op_type)) {
VLOG(10) << "Found custom op: " << op_type;
handler = platform::ipu::GetHandler("custom_op");
}
}
if (handler) { if (handler) {
VLOG(11) << "Raw Paddle Node:"; VLOG(11) << "Raw Paddle Node:";
VLOG(11) << node->Op()->Proto()->DebugString(); VLOG(11) << node->Op()->Proto()->DebugString();
new_node = handler(graph, node); new_node = handler(graph, node);
VLOG(11) << "Post Popart Node:"; if (new_node) {
VLOG(11) << new_node->Op()->Proto()->DebugString(); VLOG(11) << "Post Popart Node:";
VLOG(11) << new_node->Op()->Proto()->DebugString();
platform::ipu::ClearNode(node); platform::ipu::ClearNode(node);
graph->RemoveNode(node); graph->RemoveNode(node);
}
} else { } else {
LOG(ERROR) << "Can not find OpHandler for op_type: " << op_type; missing_ops.push_back(op_type);
} }
} }
if (!missing_ops.empty()) {
LOG(ERROR) << "Can not find OpHandler for op_type: ";
for (auto& op_type : missing_ops) {
LOG(ERROR) << op_type;
}
PADDLE_THROW(platform::errors::Unimplemented(
"Found unimplemented op_handler(s) for IPU"));
}
// post popart_canonicalization
VLOG(10) << "Post Graph: "; VLOG(10) << "Post Graph: ";
VLOG(10) << DebugString(graph); VLOG(10) << DebugString(graph);
VLOG(10) << "leave PopartCanonicalizationPass::ApplyImpl"; VLOG(10) << "leave PopartCanonicalizationPass::ApplyImpl";
...@@ -64,4 +81,5 @@ void PopartCanonicalizationPass::ApplyImpl(ir::Graph* graph) const { ...@@ -64,4 +81,5 @@ void PopartCanonicalizationPass::ApplyImpl(ir::Graph* graph) const {
} // namespace paddle } // namespace paddle
REGISTER_PASS(popart_canonicalization_pass, REGISTER_PASS(popart_canonicalization_pass,
paddle::framework::ir::PopartCanonicalizationPass); paddle::framework::ir::PopartCanonicalizationPass)
.DefaultPassAttr("custom_ops", new std::unordered_set<std::string>{});
...@@ -14,13 +14,13 @@ ...@@ -14,13 +14,13 @@
#pragma once #pragma once
#include "paddle/fluid/framework/ir/ipu/ipu_pass_base.h" #include "paddle/fluid/framework/ir/pass.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
class PopartCanonicalizationPass : public IPUPassBase { class PopartCanonicalizationPass : public Pass {
protected: protected:
void ApplyImpl(ir::Graph* graph) const override; void ApplyImpl(ir::Graph* graph) const override;
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册