未验证 提交 6ec89eeb 编写于 作者: A Allen Guo 提交者: GitHub

[IPU] merge recent changes (#42078)

* merge recent changes

* fix setting pipline
上级 20286ae7
...@@ -185,12 +185,9 @@ void Compiler::RegisterOpFunc() { ...@@ -185,12 +185,9 @@ void Compiler::RegisterOpFunc() {
auto debug_context = BuildDebugContext(op_desc); \ auto debug_context = BuildDebugContext(op_desc); \
auto aiGraphcoreOpset = builder_->aiGraphcoreOpset1(); \ auto aiGraphcoreOpset = builder_->aiGraphcoreOpset1(); \
auto aiOnnxOpset = builder_->aiOnnxOpset11(); \ auto aiOnnxOpset = builder_->aiOnnxOpset11(); \
PushNameScope(op_desc); \ NameScopeHelper ns_helper(op_desc, builder_.get()); \
auto output_ids = OnnxImpl(inputs Args, debug_context); \ auto output_ids = OnnxImpl(inputs Args, debug_context); \
PopNameScope(op_desc); \ PostLower(output_ids, op_desc); \
SetIpuIndexStage(output_ids, op_desc); \
SetAMPAttributes(output_ids, op_desc); \
SetSerializeAttributes(output_ids, op_desc); \
InsertTensors(output_names, output_ids); \ InsertTensors(output_names, output_ids); \
}}, // NOLINT }}, // NOLINT
#include "paddle/fluid/platform/device/ipu/supported_ops_autogen.h" #include "paddle/fluid/platform/device/ipu/supported_ops_autogen.h"
...@@ -273,10 +270,9 @@ void Compiler::LowerConstants(const Scope* scope) { ...@@ -273,10 +270,9 @@ void Compiler::LowerConstants(const Scope* scope) {
popart::TensorInfo tensor_info(PdDataType2PopartType(tensor->dtype()), popart::TensorInfo tensor_info(PdDataType2PopartType(tensor->dtype()),
shape); shape);
const_data.reset(new popart::ConstVoidData(tensor->data(), tensor_info)); const_data.reset(new popart::ConstVoidData(tensor->data(), tensor_info));
PushNameScope(op_desc); NameScopeHelper ns_helper(op_desc, builder_.get());
popart::TensorId result = builder_->aiOnnxOpset11().constant(*const_data); popart::TensorId result = builder_->aiOnnxOpset11().constant(*const_data);
PopNameScope(op_desc); PostLower(result, op_desc);
SetIpuIndexStage(result, op_desc);
resources_->tensors.emplace(tensor_name, result); resources_->tensors.emplace(tensor_name, result);
} }
} }
...@@ -285,42 +281,42 @@ void Compiler::LowerConstants(const Scope* scope) { ...@@ -285,42 +281,42 @@ void Compiler::LowerConstants(const Scope* scope) {
void Compiler::LowerWeights(const Scope* scope) { void Compiler::LowerWeights(const Scope* scope) {
VLOG(10) << "enter Compiler::LowerWeights"; VLOG(10) << "enter Compiler::LowerWeights";
// at this step, the graph doesn't contains optimizer related states // At this step, the graph doesn't contains optimizer related states
for (auto id : graph_helper_->sorted_vars_id) { for (auto id : graph_helper_->sorted_vars_id) {
auto* node = graph_helper_->nodes_id_map[id]; auto* node = graph_helper_->nodes_id_map[id];
if (node->IsVar() && !node->IsCtrlVar() && node->Var()) { // Weights are var node and Persistable
if (node->Var()->Persistable() && node->inputs.empty()) { if (node->IsVar() && !node->IsCtrlVar() && node->Var() &&
auto var_name = node->Var()->Name(); node->Var()->Persistable()) {
if (resources_->tensors.count(var_name) != 0) { // Weights are Parameter in training mode
VLOG(10) << "found existed one, skip lowering Weight: " << var_name; if (ipu_strategy_->is_training && !node->Var()->IsParameter()) {
continue; continue;
} }
if (var_name.rfind("learning_rate", 0) == 0) { auto var_name = node->Var()->Name();
VLOG(10) << "skip learning_rate_var: " << var_name; // Some op has same input and output tensor, like batchnorm
continue; if (resources_->tensors.count(var_name) != 0) {
} VLOG(10) << "found existed one, skip lowering Weight: " << var_name;
VLOG(10) << "lowering weight: " << var_name; continue;
}
auto var = scope->FindVar(var_name); VLOG(10) << "lowering weight: " << var_name;
if (var) { auto var = scope->FindVar(var_name);
auto tensor = var->Get<framework::LoDTensor>(); PADDLE_ENFORCE_NOT_NULL(
auto dtype = PdDataType2PopartType(tensor.dtype()); var, platform::errors::NotFound("Tensor %s is not found in the scope",
auto shape = std::vector<int64_t>(); var_name));
for (size_t i = 0; i < tensor.dims().size(); ++i) { auto tensor = var->Get<framework::LoDTensor>();
shape.push_back(tensor.dims().at(i)); auto dtype = PdDataType2PopartType(tensor.dtype());
} auto shape = std::vector<int64_t>();
popart::TensorInfo tensor_info(dtype, shape); for (size_t i = 0; i < tensor.dims().size(); ++i) {
popart::ConstVoidData const_data{tensor.data(), tensor_info}; shape.push_back(tensor.dims().at(i));
if (!node->outputs.empty()) { }
auto op_node = node->outputs[0]; popart::TensorInfo tensor_info(dtype, shape);
PushNameScope(op_node->Op()); popart::ConstVoidData const_data{tensor.data(), tensor_info};
popart::TensorId result = if (!node->outputs.empty()) {
builder_->addInitializedInputTensor(const_data, var_name); auto op_node = node->outputs[0];
PopNameScope(op_node->Op()); NameScopeHelper ns_helper(op_node->Op(), builder_.get());
resources_->tensors.emplace(var_name, result); popart::TensorId result =
resources_->weights.push_back(var_name); builder_->addInitializedInputTensor(const_data, var_name);
} resources_->tensors.emplace(var_name, result);
} resources_->weights.push_back(var_name);
} }
} }
} }
...@@ -341,10 +337,9 @@ void Compiler::LowerBody() { ...@@ -341,10 +337,9 @@ void Compiler::LowerBody() {
} else if (op_type == "popart_checkpointoutput") { } else if (op_type == "popart_checkpointoutput") {
auto inputs = GetOpInputs(op_desc); auto inputs = GetOpInputs(op_desc);
auto outputs = GetOpOutputs(op_desc); auto outputs = GetOpOutputs(op_desc);
PushNameScope(op_desc); NameScopeHelper ns_helper(op_desc, builder_.get());
auto output_ids = builder_->checkpointOutput(inputs); auto output_ids = builder_->checkpointOutput(inputs);
PopNameScope(op_desc); PostLower(output_ids, op_desc);
SetIpuIndexStage(output_ids, op_desc);
InsertTensors(outputs, output_ids); InsertTensors(outputs, output_ids);
} else if (op_type == "popart_custom_op") { } else if (op_type == "popart_custom_op") {
auto inputs = GetOpInputs(op_desc); auto inputs = GetOpInputs(op_desc);
...@@ -359,12 +354,11 @@ void Compiler::LowerBody() { ...@@ -359,12 +354,11 @@ void Compiler::LowerBody() {
BOOST_GET_CONST(std::string, op_desc->GetAttr("__op_type")); BOOST_GET_CONST(std::string, op_desc->GetAttr("__op_type"));
VLOG(10) << "Build graph from custom op: " << __op_type; VLOG(10) << "Build graph from custom op: " << __op_type;
auto it = custom_ops_.find(__op_type); auto it = custom_ops_.find(__op_type);
PushNameScope(op_desc); NameScopeHelper ns_helper(op_desc, builder_.get());
auto output_ids = auto output_ids =
builder_->customOp(it->second.popart_op, it->second.popart_op.version, builder_->customOp(it->second.popart_op, it->second.popart_op.version,
inputs, outputs.size(), attributes, debug_context); inputs, outputs.size(), attributes, debug_context);
PopNameScope(op_desc); PostLower(output_ids, op_desc);
SetIpuIndexStage(output_ids, op_desc);
InsertTensors(outputs, output_ids); InsertTensors(outputs, output_ids);
} else if (op_type == "popart_printtensor") { } else if (op_type == "popart_printtensor") {
auto inputs = GetOpInputs(op_desc); auto inputs = GetOpInputs(op_desc);
...@@ -373,11 +367,10 @@ void Compiler::LowerBody() { ...@@ -373,11 +367,10 @@ void Compiler::LowerBody() {
auto print_gradient = auto print_gradient =
BOOST_GET_CONST(int64_t, op_desc->GetAttr("print_gradient")); BOOST_GET_CONST(int64_t, op_desc->GetAttr("print_gradient"));
auto title = BOOST_GET_CONST(std::string, op_desc->GetAttr("title")); auto title = BOOST_GET_CONST(std::string, op_desc->GetAttr("title"));
PushNameScope(op_desc); NameScopeHelper ns_helper(op_desc, builder_.get());
auto output_ids = builder_->aiGraphcoreOpset1().printtensor( auto output_ids = builder_->aiGraphcoreOpset1().printtensor(
inputs, print_gradient, debug_context, title); inputs, print_gradient, debug_context, title);
PopNameScope(op_desc); PostLower(output_ids, op_desc);
SetIpuIndexStage(output_ids, op_desc);
InsertTensors(outputs, output_ids); InsertTensors(outputs, output_ids);
} else { } else {
auto itr = name_function_.find(op_type); auto itr = name_function_.find(op_type);
...@@ -625,12 +618,13 @@ void Compiler::InsertTensors(const std::vector<std::string>& output_names, ...@@ -625,12 +618,13 @@ void Compiler::InsertTensors(const std::vector<std::string>& output_names,
resources_->tensors.emplace(output_names[0], tensor_id); resources_->tensors.emplace(output_names[0], tensor_id);
} }
void Compiler::SetIpuIndexStage(const std::vector<std::string>& tensor_ids, void Compiler::PostLower(const std::vector<std::string>& tensor_ids,
const OpDesc* op_desc) { const OpDesc* op_desc) {
VLOG(10) << "enter Compiler::SetIpuIndexStage"; // Set pipline
// Due to the limitation of popart, if an op has multiple outputs,
// pipline settings needs to be set at the same time
auto tensor_ids_set = auto tensor_ids_set =
std::set<std::string>(tensor_ids.begin(), tensor_ids.end()); std::set<std::string>(tensor_ids.begin(), tensor_ids.end());
if (op_desc->HasAttr(sIpuIndexAttr)) { if (op_desc->HasAttr(sIpuIndexAttr)) {
auto ipu_index = BOOST_GET_CONST(int, op_desc->GetAttr(sIpuIndexAttr)); auto ipu_index = BOOST_GET_CONST(int, op_desc->GetAttr(sIpuIndexAttr));
builder_->virtualGraph(tensor_ids_set, ipu_index); builder_->virtualGraph(tensor_ids_set, ipu_index);
...@@ -639,18 +633,24 @@ void Compiler::SetIpuIndexStage(const std::vector<std::string>& tensor_ids, ...@@ -639,18 +633,24 @@ void Compiler::SetIpuIndexStage(const std::vector<std::string>& tensor_ids,
if (op_desc->HasAttr(sIpuStageAttr)) { if (op_desc->HasAttr(sIpuStageAttr)) {
auto ipu_stage = BOOST_GET_CONST(int, op_desc->GetAttr(sIpuStageAttr)); auto ipu_stage = BOOST_GET_CONST(int, op_desc->GetAttr(sIpuStageAttr));
builder_->pipelineStage(tensor_ids_set, ipu_stage); builder_->pipelineStage(tensor_ids_set, ipu_stage);
VLOG(10) << "set " << sIpuStageAttr << "= " << ipu_stage VLOG(10) << "set " << sIpuStageAttr << " = " << ipu_stage
<< " for op: " << op_desc->Type(); << " for op: " << op_desc->Type();
} }
} }
VLOG(10) << "leave Compiler::SetIpuIndexStage";
for (auto& tensor_id : tensor_ids) {
PostLower(tensor_id, op_desc, true);
}
} }
void Compiler::SetIpuIndexStage(const std::string& tensor_id, void Compiler::PostLower(const std::string& tensor_id, const OpDesc* op_desc) {
const OpDesc* op_desc) { PostLower(tensor_id, op_desc, false);
VLOG(10) << "enter Compiler::SetIpuIndexStage"; }
if (op_desc->HasAttr(sIpuIndexAttr)) { void Compiler::PostLower(const std::string& tensor_id, const OpDesc* op_desc,
bool skip_pipline) {
// Set pipline
if (!skip_pipline && op_desc->HasAttr(sIpuIndexAttr)) {
auto ipu_index = BOOST_GET_CONST(int, op_desc->GetAttr(sIpuIndexAttr)); auto ipu_index = BOOST_GET_CONST(int, op_desc->GetAttr(sIpuIndexAttr));
builder_->virtualGraph(tensor_id, ipu_index); builder_->virtualGraph(tensor_id, ipu_index);
VLOG(10) << "set " << sIpuIndexAttr << " = " << ipu_index VLOG(10) << "set " << sIpuIndexAttr << " = " << ipu_index
...@@ -658,32 +658,18 @@ void Compiler::SetIpuIndexStage(const std::string& tensor_id, ...@@ -658,32 +658,18 @@ void Compiler::SetIpuIndexStage(const std::string& tensor_id,
if (op_desc->HasAttr(sIpuStageAttr)) { if (op_desc->HasAttr(sIpuStageAttr)) {
auto ipu_stage = BOOST_GET_CONST(int, op_desc->GetAttr(sIpuStageAttr)); auto ipu_stage = BOOST_GET_CONST(int, op_desc->GetAttr(sIpuStageAttr));
builder_->pipelineStage(tensor_id, ipu_stage); builder_->pipelineStage(tensor_id, ipu_stage);
VLOG(10) << "set " << sIpuStageAttr << "= " << ipu_stage VLOG(10) << "set " << sIpuStageAttr << " = " << ipu_stage
<< " for op: " << op_desc->Type(); << " for op: " << op_desc->Type();
} }
} }
VLOG(10) << "leave Compiler::SetIpuIndexStage"; // Set amp
}
void Compiler::SetAMPAttributes(const std::vector<std::string>& tensor_ids,
const OpDesc* op_desc) {
if (op_desc->Type() == "popart_matmul") {
for (const auto& tensor_id : tensor_ids) {
SetAMPAttributes(tensor_id, op_desc);
}
}
}
void Compiler::SetAMPAttributes(const std::string& tensor_id,
const OpDesc* op_desc) {
VLOG(10) << "enter Compiler::SetAMPAttributes";
if (op_desc->Type() == "popart_matmul") { if (op_desc->Type() == "popart_matmul") {
if (set_amp_for_all_) { if (set_amp_for_all_) {
auto amp = ipu_strategy_->available_memory_proportion; auto amp = ipu_strategy_->available_memory_proportion;
if (amp < 0.0f || amp > 1.0) { if (amp < 0.0f || amp > 1.0) {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"AvailableMemoryProportion %f is invalid, which should be set 0 <= " "AvailableMemoryProportion %f is invalid, which should be in "
"amp <= 1", "range [0.0, 1.0]",
amp)); amp));
} }
if (amp > 0.0f) { if (amp > 0.0f) {
...@@ -694,8 +680,8 @@ void Compiler::SetAMPAttributes(const std::string& tensor_id, ...@@ -694,8 +680,8 @@ void Compiler::SetAMPAttributes(const std::string& tensor_id,
auto amp = BOOST_GET_CONST(float, op_desc->GetAttr(sAvailMemAttribute)); auto amp = BOOST_GET_CONST(float, op_desc->GetAttr(sAvailMemAttribute));
if (amp < 0.0f || amp > 1.0) { if (amp < 0.0f || amp > 1.0) {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"AvailableMemoryProportion %f is invalid, which should be set 0 " "AvailableMemoryProportion %f is invalid, which should be in "
"<= amp <= 1", "range [0.0, 1.0]",
amp)); amp));
} }
if (amp > 0.0f) { if (amp > 0.0f) {
...@@ -705,17 +691,7 @@ void Compiler::SetAMPAttributes(const std::string& tensor_id, ...@@ -705,17 +691,7 @@ void Compiler::SetAMPAttributes(const std::string& tensor_id,
} }
} }
} }
} // Set serialize matmul
VLOG(10) << "leave Compiler::SetAMPAttributes";
}
void Compiler::SetSerializeAttributes(
const std::vector<std::string>& tensor_ids, const OpDesc* op_desc) {
VLOG(10) << "enter Compiler::SetSerializeAttributes";
auto tensor_ids_set =
std::set<std::string>(tensor_ids.begin(), tensor_ids.end());
if (op_desc->Type() == "popart_matmul") {
if (op_desc->HasAttr(sMatmulSerializeFactor)) { if (op_desc->HasAttr(sMatmulSerializeFactor)) {
auto factor = auto factor =
BOOST_GET_CONST(int, op_desc->GetAttr(sMatmulSerializeFactor)); BOOST_GET_CONST(int, op_desc->GetAttr(sMatmulSerializeFactor));
...@@ -724,16 +700,9 @@ void Compiler::SetSerializeAttributes( ...@@ -724,16 +700,9 @@ void Compiler::SetSerializeAttributes(
mode = BOOST_GET_CONST(std::string, mode = BOOST_GET_CONST(std::string,
op_desc->GetAttr(sMatmulSerializeMode)); op_desc->GetAttr(sMatmulSerializeMode));
} }
builder_->setSerializeMatMul(tensor_ids_set, mode, (int64_t)factor, true); builder_->setSerializeMatMul({tensor_id}, mode, factor, true);
} }
} }
VLOG(10) << "leave Compiler::SetSerializeAttributes";
}
void Compiler::SetSerializeAttributes(const std::string& tensor_id,
const OpDesc* op_desc) {
std::vector<std::string> tensor_ids = {tensor_id};
SetSerializeAttributes(tensor_ids, op_desc);
} }
void Compiler::SetCustomOps( void Compiler::SetCustomOps(
...@@ -793,29 +762,6 @@ popart::DebugContext Compiler::BuildDebugContext(const OpDesc* op) { ...@@ -793,29 +762,6 @@ popart::DebugContext Compiler::BuildDebugContext(const OpDesc* op) {
return popart::DebugContext(op_identify_id); return popart::DebugContext(op_identify_id);
} }
void Compiler::PushNameScope(const OpDesc* op) {
auto op_namescope = BOOST_GET_CONST(std::string, op->GetAttr(sOpNamescope));
if (op_namescope == "/") {
return;
}
if (!op_namescope.empty()) {
op_namescope.pop_back();
}
if (!op_namescope.empty()) {
op_namescope.erase(op_namescope.begin());
}
VLOG(10) << "name_scope is: " << op_namescope;
builder_->pushNameScope(op_namescope);
}
void Compiler::PopNameScope(const OpDesc* op) {
auto op_namescope = BOOST_GET_CONST(std::string, op->GetAttr(sOpNamescope));
if (op_namescope == "/") {
return;
}
builder_->popNameScope();
}
} // namespace ipu } // namespace ipu
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
...@@ -70,7 +70,7 @@ struct CompilerResources { ...@@ -70,7 +70,7 @@ struct CompilerResources {
std::unique_ptr<popart::Optimizer> optimizer; std::unique_ptr<popart::Optimizer> optimizer;
}; };
// helper for lowering graph // Helper for lowering graph
struct GraphHelper { struct GraphHelper {
explicit GraphHelper(const Graph *); explicit GraphHelper(const Graph *);
...@@ -81,6 +81,30 @@ struct GraphHelper { ...@@ -81,6 +81,30 @@ struct GraphHelper {
std::vector<int> sorted_vars_id; std::vector<int> sorted_vars_id;
}; };
// Helper for adding namescope info
struct NameScopeHelper {
NameScopeHelper(const OpDesc *op, popart::Builder *builder)
: builder_(builder) {
auto op_namescope = BOOST_GET_CONST(std::string, op->GetAttr(sOpNamescope));
if (op_namescope.empty() || op_namescope == "/") {
return;
}
op_namescope.pop_back();
op_namescope.erase(op_namescope.begin());
builder->pushNameScope(op_namescope);
pushed_ = true;
}
~NameScopeHelper() {
if (pushed_) {
builder_->popNameScope();
}
}
bool pushed_ = false;
popart::Builder *builder_;
};
class Compiler { class Compiler {
public: public:
Compiler(); Compiler();
...@@ -119,18 +143,9 @@ class Compiler { ...@@ -119,18 +143,9 @@ class Compiler {
const std::vector<std::string> &tensor_ids); const std::vector<std::string> &tensor_ids);
void InsertTensors(const std::vector<std::string> &output_names, void InsertTensors(const std::vector<std::string> &output_names,
const std::string &tensor_id); const std::string &tensor_id);
void SetIpuIndexStage(const std::vector<std::string> &tensor_ids, void PostLower(const std::vector<std::string> &, const OpDesc *);
const OpDesc *op_desc); void PostLower(const std::string &, const OpDesc *);
void SetIpuIndexStage(const std::string &tensor_id, const OpDesc *op_desc); void PostLower(const std::string &, const OpDesc *, bool);
void SetAMPAttributes(const std::vector<std::string> &tensor_ids,
const OpDesc *op_desc);
void SetAMPAttributes(const std::string &tensor_id, const OpDesc *op_desc);
void SetSerializeAttributes(const std::vector<std::string> &tensor_ids,
const OpDesc *op_desc);
void SetSerializeAttributes(const std::string &tensor_id,
const OpDesc *op_desc);
void PushNameScope(const OpDesc *op);
void PopNameScope(const OpDesc *op);
private: private:
std::unique_ptr<popart::Builder> builder_; std::unique_ptr<popart::Builder> builder_;
......
...@@ -20,6 +20,40 @@ namespace paddle { ...@@ -20,6 +20,40 @@ namespace paddle {
namespace platform { namespace platform {
namespace ipu { namespace ipu {
// Get paddle prefix and popart postfix of weight states
// Format: {popart_postfix, paddle_prefix}
std::vector<std::pair<std::string, std::string>> GetOptPrePostfix(
const std::string &opt_type) {
std::vector<std::pair<std::string, std::string>> pre_post_fix;
// Weight self
pre_post_fix.push_back(std::make_pair("", ""));
// Weight states
// TODO(alleng) support pair("Accl1___", "_moment1_{id!=0}")
if (opt_type == "adam" || opt_type == "lamb" || opt_type == "adamw") {
pre_post_fix.push_back(std::make_pair("Accl1___", "_moment1_0"));
pre_post_fix.push_back(std::make_pair("Accl2___", "_moment2_0"));
pre_post_fix.push_back(std::make_pair("Step___", "_beta1_pow_acc_0"));
} else if (opt_type == "momentum") {
pre_post_fix.push_back(std::make_pair("Accl___", "_velocity_0"));
} else if (opt_type == "adamax") {
pre_post_fix.push_back(std::make_pair("Accl1___", "_moment_0"));
pre_post_fix.push_back(std::make_pair("Accl2___", "_inf_norm__0"));
pre_post_fix.push_back(std::make_pair("Step___", "_beta1_pow_acc_0"));
} else if (opt_type == "adagrad") {
pre_post_fix.push_back(std::make_pair("Accl1___", "_moment_0"));
} else if (opt_type == "adadelta") {
pre_post_fix.push_back(std::make_pair("Accl1___", "__avg_squared_grad_0"));
pre_post_fix.push_back(
std::make_pair("Accl2___", "__avg_squared_update_0"));
} else if (opt_type == "rmsprop") {
pre_post_fix.push_back(std::make_pair("Accl1___", "_mean_square_0"));
pre_post_fix.push_back(std::make_pair("Accl2___", "_mean_grad_0"));
pre_post_fix.push_back(std::make_pair("Accl3___", "_momentum__0"));
}
return pre_post_fix;
}
Executor::~Executor() { Executor::~Executor() {
Detach(); Detach();
session_.reset(); session_.reset();
......
...@@ -412,6 +412,15 @@ IpuStrategy::IpuStrategy() { ...@@ -412,6 +412,15 @@ IpuStrategy::IpuStrategy() {
RegisterGetter(map_options_getter, options_type, "gcl_options", "map", RegisterGetter(map_options_getter, options_type, "gcl_options", "map",
[&]() { return popart_options.gclOptions; }); [&]() { return popart_options.gclOptions; });
// Default options
// Can also be set as a custom logger in python, like using tqdm
popart_options.compilationProgressLogger = [](int progress, int total) {
if (progress % 10 == 0) {
VLOG(1) << "compile progress: " << progress << "%";
}
};
} }
void IpuStrategy::AddBoolOption(const std::string& option, bool value) { void IpuStrategy::AddBoolOption(const std::string& option, bool value) {
...@@ -513,6 +522,11 @@ void IpuStrategy::AddCustomOp(const std::string& paddle_op, ...@@ -513,6 +522,11 @@ void IpuStrategy::AddCustomOp(const std::string& paddle_op,
IpuCustomOpIdentifier(paddle_op, popart_op, domain, version)); IpuCustomOpIdentifier(paddle_op, popart_op, domain, version));
} }
void IpuStrategy::SetCompilationProgressLogger(
const std::function<void(int, int)>& logger) {
popart_options.compilationProgressLogger = logger;
}
std::string IpuStrategy::GetOption(const std::string& option) { std::string IpuStrategy::GetOption(const std::string& option) {
return get(option, options_getter); return get(option, options_getter);
} }
......
...@@ -125,6 +125,8 @@ class IpuStrategy { ...@@ -125,6 +125,8 @@ class IpuStrategy {
const std::vector<int> &values); const std::vector<int> &values);
void AddCustomOp(const std::string &paddle_op, const std::string &popart_op, void AddCustomOp(const std::string &paddle_op, const std::string &popart_op,
const std::string &domain, int version); const std::string &domain, int version);
void SetCompilationProgressLogger(
const std::function<void(int, int)> &logger);
std::string GetOption(const std::string &); std::string GetOption(const std::string &);
std::vector<std::string> GetVectorOption(const std::string &); std::vector<std::string> GetVectorOption(const std::string &);
......
...@@ -184,27 +184,6 @@ bool GetBoolEnv(std::string str) { ...@@ -184,27 +184,6 @@ bool GetBoolEnv(std::string str) {
} }
} }
std::vector<std::pair<std::string, std::string>> GetOptPrePostfix(
const std::string& opt_type) {
// format: {popart_tensor_id, paddle_tensor_id}, ...
std::vector<std::pair<std::string, std::string>> pre_post_fix;
if (opt_type == "adam" || opt_type == "lamb") {
pre_post_fix.push_back(std::make_pair("", ""));
pre_post_fix.push_back(std::make_pair("Accl1___", "_moment1_0"));
pre_post_fix.push_back(std::make_pair("Accl2___", "_moment2_0"));
pre_post_fix.push_back(std::make_pair("Step___", "_beta1_pow_acc_0"));
} else if (opt_type == "sgd" || opt_type == "momentum") {
// sgd
pre_post_fix.push_back(std::make_pair("", ""));
} else {
pre_post_fix.push_back(std::make_pair("", ""));
//
}
return pre_post_fix;
}
int RequestIpus(const int num_ipus) { int RequestIpus(const int num_ipus) {
// num_ipus must be pow(2, n); // num_ipus must be pow(2, n);
return std::pow(2, ceil(log2(num_ipus))); return std::pow(2, ceil(log2(num_ipus)));
......
...@@ -229,9 +229,6 @@ struct ConstantOpAttrVisitor : public boost::static_visitor<void> { ...@@ -229,9 +229,6 @@ struct ConstantOpAttrVisitor : public boost::static_visitor<void> {
void operator()(boost::blank) const { RaiseError(); } void operator()(boost::blank) const { RaiseError(); }
}; };
std::vector<std::pair<std::string, std::string>> GetOptPrePostfix(
const std::string& opt_type);
int RequestIpus(const int num_ipus); int RequestIpus(const int num_ipus);
} // namespace ipu } // namespace ipu
......
...@@ -4357,7 +4357,10 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -4357,7 +4357,10 @@ All parameter, weight, gradient are variables in Paddle.
for (auto element : opt) { for (auto element : opt) {
auto option_name = element.first.cast<std::string>(); auto option_name = element.first.cast<std::string>();
VLOG(10) << "Set option: " << option_name; VLOG(10) << "Set option: " << option_name;
if (py::isinstance<py::bool_>(element.second)) { if (option_name == "compilation_progress_logger") {
self.SetCompilationProgressLogger(
element.second.cast<py::function>());
} else if (py::isinstance<py::bool_>(element.second)) {
self.AddBoolOption(option_name, element.second.cast<bool>()); self.AddBoolOption(option_name, element.second.cast<bool>());
} else if (py::isinstance<py::float_>(element.second)) { } else if (py::isinstance<py::float_>(element.second)) {
self.AddDoubleOption(option_name, self.AddDoubleOption(option_name,
......
...@@ -11,4 +11,5 @@ if(WITH_IPU) ...@@ -11,4 +11,5 @@ if(WITH_IPU)
set_tests_properties(test_conv_op_ipu PROPERTIES TIMEOUT 300) set_tests_properties(test_conv_op_ipu PROPERTIES TIMEOUT 300)
set_tests_properties(test_elemetwise_x_op_ipu PROPERTIES TIMEOUT 300) set_tests_properties(test_elemetwise_x_op_ipu PROPERTIES TIMEOUT 300)
set_tests_properties(test_reduce_x_op_ipu PROPERTIES TIMEOUT 600) set_tests_properties(test_reduce_x_op_ipu PROPERTIES TIMEOUT 600)
set_tests_properties(test_save_load_ipu PROPERTIES TIMEOUT 600)
endif() endif()
...@@ -73,10 +73,15 @@ class TestIpuStrategy(unittest.TestCase): ...@@ -73,10 +73,15 @@ class TestIpuStrategy(unittest.TestCase):
'autoReport.directory': 'path', 'autoReport.directory': 'path',
'autoReport.all': 'true' 'autoReport.all': 'true'
} }
options['random_seed'] = 1234
for k, v in options.items(): for k, v in options.items():
ipu_strategy.set_options({k: v}) ipu_strategy.set_options({k: v})
assert v == ipu_strategy.get_option(k), f"set {k} to {v} failed " assert v == ipu_strategy.get_option(k), f"set {k} to {v} failed "
# The custom logger need 2 int as inputs
logger = lambda progress, total: print(f"compile progrss: {progress}/{total}")
ipu_strategy.set_options({'compilation_progress_logger': logger})
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -14,9 +14,11 @@ ...@@ -14,9 +14,11 @@
import tempfile import tempfile
import unittest import unittest
from functools import partial
import numpy as np import numpy as np
import paddle import paddle
import paddle.optimizer
import paddle.static import paddle.static
from paddle.fluid.tests.unittests.ipu.op_test_ipu import IPUOpTest from paddle.fluid.tests.unittests.ipu.op_test_ipu import IPUOpTest
...@@ -28,7 +30,8 @@ class TestBase(IPUOpTest): ...@@ -28,7 +30,8 @@ class TestBase(IPUOpTest):
self.set_atol() self.set_atol()
self.set_data_feed() self.set_data_feed()
self.set_feed_attr() self.set_feed_attr()
self.set_op_attrs() self.set_attrs()
self.set_optimizer()
def set_data_feed(self): def set_data_feed(self):
data = np.random.uniform(size=[1, 3, 10, 10]) data = np.random.uniform(size=[1, 3, 10, 10])
...@@ -39,15 +42,16 @@ class TestBase(IPUOpTest): ...@@ -39,15 +42,16 @@ class TestBase(IPUOpTest):
self.feed_shape = [x.shape for x in self.feed_fp32.values()] self.feed_shape = [x.shape for x in self.feed_fp32.values()]
self.feed_list = list(self.feed_fp32.keys()) self.feed_list = list(self.feed_fp32.keys())
def set_op_attrs(self): def set_attrs(self):
self.attrs = {} self.attrs = {}
self.attrs['steps'] = 100 self.attrs['steps'] = 100
self.attrs['save_at_step'] = 20 self.attrs['save_at_step'] = 20
self.attrs['is_training'] = True
self.attrs['opt_type'] = 'sgd'
self.attrs['enable_fp16'] = False self.attrs['enable_fp16'] = False
self.attrs['model_path'] = tempfile.TemporaryDirectory() self.attrs['model_path'] = tempfile.TemporaryDirectory()
def set_optimizer(self):
self.optimizer = partial(paddle.optimizer.SGD, learning_rate=1e-1)
def _test_base(self, save_otherwise_load): def _test_base(self, save_otherwise_load):
scope = paddle.static.Scope() scope = paddle.static.Scope()
main_prog = paddle.static.Program() main_prog = paddle.static.Program()
...@@ -71,16 +75,8 @@ class TestBase(IPUOpTest): ...@@ -71,16 +75,8 @@ class TestBase(IPUOpTest):
name='conv2d') name='conv2d')
loss = paddle.mean(conv1) loss = paddle.mean(conv1)
if self.attrs['is_training']: # apply optimizer
if self.attrs['opt_type'] == 'sgd': self.optimizer().minimize(loss)
sgd = paddle.optimizer.SGD(learning_rate=1e-2)
sgd.minimize(loss)
elif self.attrs['opt_type'] == 'adam':
adam = paddle.optimizer.Adam(learning_rate=1e-2)
adam.minimize(loss)
elif self.attrs['opt_type'] == 'lamb':
lamb = paddle.optimizer.Lamb(learning_rate=1e-2)
lamb.minimize(loss)
fetch_list = [loss.name] fetch_list = [loss.name]
place = paddle.IPUPlace() place = paddle.IPUPlace()
...@@ -91,8 +87,7 @@ class TestBase(IPUOpTest): ...@@ -91,8 +87,7 @@ class TestBase(IPUOpTest):
paddle.static.load(main_prog, self.attrs['model_path'].name) paddle.static.load(main_prog, self.attrs['model_path'].name)
ipu_strategy = paddle.static.IpuStrategy() ipu_strategy = paddle.static.IpuStrategy()
ipu_strategy.set_graph_config( ipu_strategy.set_graph_config(is_training=True)
is_training=self.attrs['is_training'])
ipu_strategy.set_precision_config( ipu_strategy.set_precision_config(
enable_fp16=self.attrs['enable_fp16']) enable_fp16=self.attrs['enable_fp16'])
ipu_program = paddle.static.IpuCompiledProgram( ipu_program = paddle.static.IpuCompiledProgram(
...@@ -131,62 +126,109 @@ class TestBase(IPUOpTest): ...@@ -131,62 +126,109 @@ class TestBase(IPUOpTest):
self.attrs['model_path'].cleanup() self.attrs['model_path'].cleanup()
class TestMomentum(TestBase):
def set_optimizer(self):
self.optimizer = partial(paddle.optimizer.Momentum, learning_rate=1e-1)
class TestAdam(TestBase): class TestAdam(TestBase):
def set_op_attrs(self): def set_optimizer(self):
self.attrs = {} self.optimizer = partial(paddle.optimizer.Adam, learning_rate=1e-1)
self.attrs['steps'] = 100
self.attrs['save_at_step'] = 20
self.attrs['is_training'] = True
self.attrs['opt_type'] = 'adam'
self.attrs['enable_fp16'] = False
self.attrs['model_path'] = tempfile.TemporaryDirectory()
class TestLamb(TestBase): class TestLamb(TestBase):
def set_op_attrs(self): def set_optimizer(self):
self.attrs = {} self.optimizer = partial(paddle.optimizer.Lamb, learning_rate=1e-1)
self.attrs['steps'] = 100
self.attrs['save_at_step'] = 20
self.attrs['is_training'] = True class TestAdamW(TestBase):
self.attrs['opt_type'] = 'lamb' def set_optimizer(self):
self.attrs['enable_fp16'] = False self.optimizer = partial(paddle.optimizer.AdamW, learning_rate=1e-1)
self.attrs['model_path'] = tempfile.TemporaryDirectory()
class TestAdamax(TestBase):
def set_optimizer(self):
self.optimizer = partial(paddle.optimizer.Adamax, learning_rate=1e-1)
class TestAdagrad(TestBase):
def set_optimizer(self):
self.optimizer = partial(paddle.optimizer.Adagrad, learning_rate=1e-1)
class TestAdadelta(TestBase):
def set_optimizer(self):
self.optimizer = partial(paddle.optimizer.Adagrad, learning_rate=1e-1)
class TestRMSProp(TestBase):
def set_optimizer(self):
self.optimizer = partial(paddle.optimizer.RMSProp, learning_rate=1e-1)
class TestCenteredRMSProp(TestBase):
def set_optimizer(self):
self.optimizer = partial(
paddle.optimizer.RMSProp, learning_rate=1e-1, centered=True)
@unittest.skipIf(IPUOpTest.use_ipumodel(), "skip for ipumodel") @unittest.skipIf(IPUOpTest.use_ipumodel(), "skip for ipumodel")
class TestSGDFP16(TestBase): class TestSGDFP16(TestBase):
def set_op_attrs(self): def set_attrs(self):
self.attrs = {} self.attrs = {}
self.attrs['steps'] = 100 self.attrs['steps'] = 100
self.attrs['save_at_step'] = 20 self.attrs['save_at_step'] = 20
self.attrs['is_training'] = True
self.attrs['opt_type'] = 'sgd'
self.attrs['enable_fp16'] = True self.attrs['enable_fp16'] = True
self.attrs['model_path'] = tempfile.TemporaryDirectory() self.attrs['model_path'] = tempfile.TemporaryDirectory()
def set_optimizer(self):
self.optimizer = partial(paddle.optimizer.SGD, learning_rate=1e-1)
@unittest.skipIf(IPUOpTest.use_ipumodel(), "skip for ipumodel")
class TestAdamFP16(TestBase):
def set_op_attrs(self):
self.attrs = {}
self.attrs['steps'] = 100
self.attrs['save_at_step'] = 20
self.attrs['is_training'] = True
self.attrs['opt_type'] = 'adam'
self.attrs['enable_fp16'] = True
self.attrs['model_path'] = tempfile.TemporaryDirectory()
class TestMomentumFp16(TestSGDFP16):
def set_optimizer(self):
self.optimizer = partial(paddle.optimizer.Momentum, learning_rate=1e-1)
@unittest.skipIf(IPUOpTest.use_ipumodel(), "skip for ipumodel")
class TestLambFP16(TestBase): class TestAdamFP16(TestSGDFP16):
def set_op_attrs(self): def set_optimizer(self):
self.attrs = {} self.optimizer = partial(paddle.optimizer.Adam, learning_rate=1e-1)
self.attrs['steps'] = 100
self.attrs['save_at_step'] = 20
self.attrs['is_training'] = True class TestLambFP16(TestSGDFP16):
self.attrs['opt_type'] = 'lamb' def set_optimizer(self):
self.attrs['enable_fp16'] = True self.optimizer = partial(paddle.optimizer.Lamb, learning_rate=1e-1)
self.attrs['model_path'] = tempfile.TemporaryDirectory()
class TestAdamWFP16FP16(TestSGDFP16):
def set_optimizer(self):
self.optimizer = partial(paddle.optimizer.AdamW, learning_rate=1e-1)
class TestAdamaxFP16(TestSGDFP16):
def set_optimizer(self):
self.optimizer = partial(paddle.optimizer.Adamax, learning_rate=1e-1)
class TestAdagradFP16(TestSGDFP16):
def set_optimizer(self):
self.optimizer = partial(paddle.optimizer.Adagrad, learning_rate=1e-1)
class TestAdadeltaFP16(TestSGDFP16):
def set_optimizer(self):
self.optimizer = partial(paddle.optimizer.Adagrad, learning_rate=1e-1)
class TestRMSPropFP16(TestSGDFP16):
def set_optimizer(self):
self.optimizer = partial(paddle.optimizer.RMSProp, learning_rate=1e-1)
class TestCenteredRMSPropFP16(TestSGDFP16):
def set_optimizer(self):
self.optimizer = partial(
paddle.optimizer.RMSProp, learning_rate=1e-1, centered=True)
if __name__ == "__main__": if __name__ == "__main__":
......
# A image for building paddle binaries # A image for building paddle binaries
# build docker image # build docker image
# docker build -t paddlepaddle/paddle:ipu-dev-2.3.0 -f tools/dockerfile/Dockerfile.ipu . # docker build -t paddlepaddle/paddle:latest-dev-ipu -f tools/dockerfile/Dockerfile.ipu .
# run a container # run a container
# docker run --ulimit memlock=-1:-1 --net=host --cap-add=IPC_LOCK --device=/dev/infiniband/ --ipc=host --rm -it paddlepaddle/paddle:ipu-dev-2.3.0 bash # docker run --ulimit memlock=-1:-1 --net=host --cap-add=IPC_LOCK --device=/dev/infiniband/ --ipc=host --rm -it paddlepaddle/paddle:latest-dev-ipu bash
FROM graphcore/poplar:2.3.0 FROM graphcore/poplar:2.3.0
MAINTAINER PaddlePaddle Authors <paddle-dev@baidu.com> MAINTAINER PaddlePaddle Authors <paddle-dev@baidu.com>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册