未验证 提交 d67fe921 编写于 作者: A Allen Guo 提交者: GitHub

[IPU] update ipu_backend (#40685)

* sync changes

* copy sOpNamescope

* fix UTs

* add authors
Co-authored-by: NXiaobing Wang <xiaobingw@graphcore.ai>
Co-authored-by: NAllen Guo <alleng@graphcore.ai>
Co-authored-by: NZhixin Yao <zhixiny@graphcore.ai>
Co-authored-by: NZhaorui Chen <zhaoruic@graphcore.ai>
Co-authored-by: NHan Zhao <hanzhao@graphcore.ai>

* fix code-format

* fix compile error

* add comments for feed_op
Co-authored-by: NXiaobing Wang <xiaobingw@graphcore.ai>
Co-authored-by: NZhixin Yao <zhixiny@graphcore.ai>
Co-authored-by: NZhaorui Chen <zhaoruic@graphcore.ai>
Co-authored-by: NHan Zhao <hanzhao@graphcore.ai>
上级 a6f77fdf
......@@ -25,14 +25,14 @@ std::set<std::string> ignored_ops = {
"sum",
"clip",
"clip_by_norm",
"square",
"reduce_sum",
"sqrt",
"elementwise_max",
"elementwise_div",
"elementwise_mul",
"scale", // adamax
"assign", // adamw
"scale", // adamax
"assign", // adamw
"squared_l2_norm" // gradient_clip_norm
};
const bool startswith(const std::string& str, const std::string& pre) {
......@@ -62,6 +62,10 @@ void IpuOptimizerExtractPass::ApplyImpl(ir::Graph* graph) const {
new_op.SetAttr("with_lr_sched", false);
std::set<std::string> set_ops{};
// save the weight decay tensor_name and weight_decay_value for Lamb
std::vector<std::string> weight_decay_vars{};
std::vector<float> weight_decay_values{};
// use map store <op_type, op_ptr> ?
for (auto* node : graph->Nodes()) {
if (!node->IsOp()) {
......@@ -75,6 +79,15 @@ void IpuOptimizerExtractPass::ApplyImpl(ir::Graph* graph) const {
auto op_role = static_cast<OpRole>(op_role_);
if (op_role == OpRole::kOptimize) {
// save weight decay value from every lamb optimizer op
if (op_type == "lamb" && op->HasAttr("weight_decay")) {
auto weight_decay_value =
BOOST_GET_CONST(float, op->GetAttr("weight_decay"));
auto params = op->Output("ParamOut");
weight_decay_vars.push_back(params[0]);
weight_decay_values.push_back(weight_decay_value);
}
if (set_ops.count(op_type)) {
continue;
}
......@@ -270,7 +283,10 @@ void IpuOptimizerExtractPass::ApplyImpl(ir::Graph* graph) const {
// seems with_lr_sched is always true
new_op.SetAttr("with_lr_sched", true);
// setup weight deacy
// setup weight decay for Lamb
new_op.SetAttr("weight_decay_vars", weight_decay_vars);
new_op.SetAttr("weight_decay_values", weight_decay_values);
// weight_decay/coeff is "scale" attr of scale_op
if (set_ops.count("scale") && set_ops.count("sum")) {
if (set_ops.count("sign")) {
......
......@@ -30,7 +30,8 @@ void TransferCastOpPass::ApplyImpl(ir::Graph* graph) const {
auto ipu_backend = platform::ipu::IpuBackend::GetInstance();
auto enable_fp16 = ipu_backend->GetIpuStrategy()->enable_fp16;
if (enable_fp16) {
auto transfer_cast_op = ipu_backend->GetIpuStrategy()->transfer_cast_op;
if (enable_fp16 && transfer_cast_op) {
for (auto* node : graph->Nodes()) {
if (node->IsOp() && node->Op()->Type() == "popart_cast") {
if (BOOST_GET_CONST(std::string, node->Op()->GetAttr("to")) ==
......
......@@ -79,18 +79,6 @@ void TensorCopyImpl(const TENSOR& src, const platform::Place& dst_place,
if (platform::is_cpu_place(src_place) && platform::is_cpu_place(dst_place)) {
memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size);
}
#ifdef PADDLE_WITH_IPU
else if (platform::is_ipu_place(src_place) && // NOLINT
platform::is_cpu_place(dst_place)) {
memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size);
} else if (platform::is_cpu_place(src_place) &&
platform::is_ipu_place(dst_place)) {
memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size);
} else if (platform::is_ipu_place(src_place) &&
platform::is_ipu_place(dst_place)) {
memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size);
}
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
else if (platform::is_custom_place(src_place) && // NOLINT
platform::is_cpu_place(dst_place)) {
......@@ -390,6 +378,29 @@ void TensorCopyImpl(const TENSOR& src, const platform::Place& dst_place,
"Copying from %s to %s is not supported.", src_place, dst_place));
}
#endif
#ifdef PADDLE_WITH_IPU
else if (platform::is_ipu_place(src_place) && // NOLINT
platform::is_cpu_place(dst_place)) {
memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size);
}
else if (platform::is_cpu_place(src_place) && // NOLINT
platform::is_ipu_place(dst_place)) {
memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size);
}
else if (platform::is_ipu_place(src_place) && // NOLINT
platform::is_ipu_place(dst_place)) {
if (src_ptr == dst_ptr) {
VLOG(3) << "Skip copy the same data sync from " << src_place << " to "
<< dst_place;
return;
}
memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size);
}
else { // NOLINT
PADDLE_THROW(platform::errors::Unimplemented(
"Copying from %s to %s is not supported.", src_place, dst_place));
}
#endif
}
template <typename TENSOR>
......@@ -447,27 +458,15 @@ void TensorCopySync(const Tensor& src, const platform::Place& dst_place,
if (platform::is_cpu_place(src_place) && platform::is_cpu_place(dst_place)) {
memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size);
}
#ifdef PADDLE_WITH_IPU
else if (platform::is_ipu_place(src_place) && // NOLINT
platform::is_cpu_place(dst_place)) {
memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size);
} else if (platform::is_cpu_place(src_place) && // NOLINT
platform::is_ipu_place(dst_place)) {
memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size);
} else { // NOLINT
PADDLE_THROW(platform::errors::Unimplemented(
"Copy from %s to %s is not supported.", src_place, dst_place));
}
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
else if (platform::is_custom_place(src_place) && // NOLINT
platform::is_cpu_place(dst_place)) { /* custom_device -> cpu*/
memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size, nullptr);
}
} // NOLINT
else if (platform::is_cpu_place(src_place) && // NOLINT
platform::is_custom_place(dst_place)) { /* cpu -> custom_device*/
memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size, nullptr);
}
} // NOLINT
else if (platform::is_custom_place(src_place) && // NOLINT
platform::is_custom_place(
dst_place)) { /* custom_device -> custom_device*/
......@@ -483,11 +482,11 @@ void TensorCopySync(const Tensor& src, const platform::Place& dst_place,
else if (platform::is_xpu_place(src_place) && // NOLINT
platform::is_cpu_place(dst_place)) {
memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size);
}
} // NOLINT
else if (platform::is_cpu_place(src_place) && // NOLINT
platform::is_xpu_place(dst_place)) {
memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size);
}
} // NOLINT
else if (platform::is_xpu_place(src_place) && // NOLINT
platform::is_xpu_place(dst_place)) {
if (src_ptr == dst_ptr) {
......@@ -502,7 +501,7 @@ void TensorCopySync(const Tensor& src, const platform::Place& dst_place,
auto xpu_ctx = platform::DeviceContextPool::Instance().Get(xpu_dst_place);
xpu_ctx->Wait();
}
}
} // NOLINT
else { // NOLINT
PADDLE_THROW(platform::errors::Unimplemented(
"Copy from %s to %s is not supported.", src_place, dst_place));
......@@ -601,6 +600,29 @@ void TensorCopySync(const Tensor& src, const platform::Place& dst_place,
"Copy from %s to %s is not supported.", src_place, dst_place));
}
#endif
#ifdef PADDLE_WITH_IPU
else if (platform::is_ipu_place(src_place) && // NOLINT
platform::is_cpu_place(dst_place)) {
memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size);
}
else if (platform::is_cpu_place(src_place) && // NOLINT
platform::is_ipu_place(dst_place)) {
memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size);
}
else if (platform::is_ipu_place(src_place) && // NOLINT
platform::is_ipu_place(dst_place)) {
if (src_ptr == dst_ptr) {
VLOG(3) << "Skip copy the same data sync from " << src_place << " to "
<< dst_place;
return;
}
memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size);
}
else { // NOLINT
PADDLE_THROW(platform::errors::Unimplemented(
"Copy from %s to %s is not supported.", src_place, dst_place));
}
#endif
}
template <typename Predicate, typename DevCtx>
......
......@@ -40,6 +40,13 @@ class FeedVariableVisitor : public boost::static_visitor<void> {
out_var_->GetMutable<framework::LoDTensor>();
if (platform::is_same_place(in_tensor.place(), place_)) {
out_tensor->ShareDataWith(in_tensor);
#ifdef PADDLE_WITH_IPU
} else if (platform::is_ipu_place(place_)) {
// For ipu, both in_tensor and out_tensor are allocated on cpu,
// PopART will copy tensor from host automatically,
// no TensorCopy() is required here.
out_tensor->ShareDataWith(in_tensor);
#endif
} else {
platform::DeviceContext *context =
platform::DeviceContextPool::Instance().Get(place_);
......
......@@ -13,7 +13,7 @@ IF(WITH_IPU)
"ipu_device.cc"
)
cc_library(ipu_backend SRCS ${IPU_BACKEND_SRC} DEPS popart-only graph graph_helper)
cc_library(ipu_backend SRCS ${IPU_BACKEND_SRC} DEPS popart-only graph graph_helper popdist)
cc_library(ipu_info SRCS ${IPU_INFO_SRC} DEPS popart-only enforce)
add_library(paddle_ipu SHARED ${PADDLE_IPU_SRC})
add_dependencies(paddle_ipu ipu_backend)
......
......@@ -32,6 +32,7 @@ IpuBackend* IpuBackend::GetInstance() {
IpuBackend::IpuBackend() {
compiler_ = std::make_unique<Compiler>();
executor_ = std::make_unique<Executor>();
timer_ = std::make_unique<platform::Timer>();
}
IpuBackend::~IpuBackend() {
......@@ -43,6 +44,7 @@ void IpuBackend::Compile(Graph* graph,
const std::vector<std::string>& feed_list,
const std::vector<std::string>& fetch_list) {
VLOG(10) << "enter IpuBackend::Compile";
is_compiled_ = false;
compiler_->Prepare(graph);
compiler_->InitInputs(feed_list);
compiler_->LowerConstants(scope_);
......@@ -52,31 +54,25 @@ void IpuBackend::Compile(Graph* graph,
if (ipu_strategy_->is_training) {
compiler_->LowerOptimizer(scope_);
}
if (!ipu_strategy_->onnx_dump_path.empty()) {
SaveModelProto(ipu_strategy_->onnx_dump_path);
}
executor_->SetCompilerResources(compiler_->GetResources());
executor_->Prepare(compiler_->GetModelProto());
is_compiled_ = true;
// when call compile, means a new graph
is_prepared_ = false;
VLOG(10) << "leave IpuBackend::Compile";
}
void IpuBackend::Run(const std::vector<const Tensor*>& inputs,
const std::vector<Tensor*>& outputs,
const framework::ExecutionContext& ctx) {
Prepare();
timer_->Start();
executor_->Run(inputs, outputs, ctx);
timer_->Pause();
VLOG(10) << "[IPU Run]: " << timer_->ElapsedMS() << " (ms)";
}
void IpuBackend::Prepare() {
if (!is_prepared_) {
executor_->Prepare(compiler_->GetModelProto());
timer_.reset(new platform::Timer());
is_prepared_ = true;
}
}
void IpuBackend::WeightsToHost() { executor_->WeightsToHost(); }
void IpuBackend::Detach() { executor_->Detach(); }
......@@ -101,12 +97,10 @@ void IpuBackend::SetIpuStrategy(const IpuStrategy& strategy) {
}
void IpuBackend::SaveModelProto(const std::string& path) {
if (ipu_strategy_->is_training && is_prepared_) {
if (ipu_strategy_->is_training && is_compiled_) {
executor_->SaveModelToHost(path);
} else if (is_compiled_) {
compiler_->SaveModelProtoNoCheck(path);
} else {
LOG(WARNING) << "Model is empty";
compiler_->SaveModelProtoNoCheck(path);
}
}
......
......@@ -60,6 +60,9 @@ class IpuBackend {
const std::vector<Tensor *> &outputs,
const framework::ExecutionContext &ctx);
// Sync weights from IPU while training
void WeightsToHost();
// detach IPU manually
void Detach();
......@@ -76,22 +79,17 @@ class IpuBackend {
void SaveModelProto(const std::string &path);
private:
void Prepare();
private:
std::unique_ptr<Compiler> compiler_;
std::unique_ptr<Executor> executor_;
bool is_compiled_ = false;
bool is_prepared_ = false;
// not own
const Scope *scope_ = nullptr;
const IpuStrategy *ipu_strategy_ = nullptr;
private:
// time record for IpuBackend::Run
// own
std::unique_ptr<Compiler> compiler_;
std::unique_ptr<Executor> executor_;
std::unique_ptr<platform::Timer> timer_;
bool is_compiled_ = false;
DISABLE_COPY_AND_ASSIGN(IpuBackend);
};
......
......@@ -18,6 +18,7 @@
#include <popart/adaptive.hpp>
#include <popart/optimizer.hpp>
#include <popart/sgd.hpp>
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/platform/device/ipu/ipu_utils.h"
......@@ -25,13 +26,20 @@ namespace paddle {
namespace platform {
namespace ipu {
popart::AdamMode AdamModeFromStr(const std::string& str) {
popart::AdamMode AdamModeFromStr(const std::string& str,
const bool& use_no_bias_optimizer) {
if (str == "adam") {
return popart::AdamMode::Adam;
if (!use_no_bias_optimizer)
return popart::AdamMode::Adam;
else
return popart::AdamMode::AdamNoBias;
} else if (str == "adamax") {
return popart::AdamMode::AdaMax;
} else if (str == "lamb") {
return popart::AdamMode::Lamb;
if (!use_no_bias_optimizer)
return popart::AdamMode::Lamb;
else
return popart::AdamMode::LambNoBias;
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Uknown AdamMode: %s, AdamMode must be one of these values: adam, "
......@@ -70,6 +78,17 @@ popart::WeightDecayMode WeightDecayModeFromStr(const std::string& str) {
}
}
popart::DataType DataTypeFromStr(const std::string& str) {
if (str == "FLOAT") {
return popart::DataType::FLOAT;
} else if (str == "FLOAT16") {
return popart::DataType::FLOAT16;
} else {
PADDLE_THROW(
platform::errors::Unimplemented("Unsupported DataType: %s", str));
}
}
template <typename T>
T GetAttrAllowNull(std::string attr, OpDesc* op_desc) {
if (op_desc->HasAttr(attr)) {
......@@ -122,6 +141,17 @@ void Compiler::Prepare(const Graph* graph) {
builder_ = popart::Builder::create();
resources_ = std::make_unique<CompilerResources>();
graph_helper_ = std::make_unique<GraphHelper>(graph);
// Set the flag of set_amp_for_all_
for (auto* node : graph_helper_->sorted_ops) {
auto* op_desc = node->Op();
auto op_type = op_desc->Type();
if (op_type == "popart_matmul") {
if (op_desc->HasAttr(sAvailMemAttribute)) {
set_amp_for_all_ = false;
return;
}
}
}
}
void Compiler::RegisterOpFunc() {
......@@ -155,7 +185,9 @@ void Compiler::RegisterOpFunc() {
auto debug_context = BuildDebugContext(op_desc); \
auto aiGraphcoreOpset = builder_->aiGraphcoreOpset1(); \
auto aiOnnxOpset = builder_->aiOnnxOpset11(); \
PushNameScope(op_desc); \
auto output_ids = OnnxImpl(inputs Args, debug_context); \
PopNameScope(op_desc); \
SetIpuIndexStage(output_ids, op_desc); \
SetAMPAttributes(output_ids, op_desc); \
SetSerializeAttributes(output_ids, op_desc); \
......@@ -241,7 +273,9 @@ void Compiler::LowerConstants(const Scope* scope) {
popart::TensorInfo tensor_info(PdDataType2PopartType(tensor->dtype()),
shape);
const_data.reset(new popart::ConstVoidData(tensor->data(), tensor_info));
PushNameScope(op_desc);
popart::TensorId result = builder_->aiOnnxOpset11().constant(*const_data);
PopNameScope(op_desc);
SetIpuIndexStage(result, op_desc);
resources_->tensors.emplace(tensor_name, result);
}
......@@ -261,6 +295,10 @@ void Compiler::LowerWeights(const Scope* scope) {
VLOG(10) << "found existed one, skip lowering Weight: " << var_name;
continue;
}
if (var_name.rfind("learning_rate", 0) == 0) {
VLOG(10) << "skip learning_rate_var: " << var_name;
continue;
}
VLOG(10) << "lowering weight: " << var_name;
auto var = scope->FindVar(var_name);
......@@ -273,10 +311,15 @@ void Compiler::LowerWeights(const Scope* scope) {
}
popart::TensorInfo tensor_info(dtype, shape);
popart::ConstVoidData const_data{tensor.data(), tensor_info};
popart::TensorId result =
builder_->addInitializedInputTensor(const_data, var_name);
resources_->tensors.emplace(var_name, result);
resources_->weights.push_back(result);
if (!node->outputs.empty()) {
auto op_node = node->outputs[0];
PushNameScope(op_node->Op());
popart::TensorId result =
builder_->addInitializedInputTensor(const_data, var_name);
PopNameScope(op_node->Op());
resources_->tensors.emplace(var_name, result);
resources_->weights.push_back(var_name);
}
}
}
}
......@@ -298,7 +341,10 @@ void Compiler::LowerBody() {
} else if (op_type == "popart_checkpointoutput") {
auto inputs = GetOpInputs(op_desc);
auto outputs = GetOpOutputs(op_desc);
PushNameScope(op_desc);
auto output_ids = builder_->checkpointOutput(inputs);
PopNameScope(op_desc);
SetIpuIndexStage(output_ids, op_desc);
InsertTensors(outputs, output_ids);
} else if (op_type == "popart_custom_op") {
auto inputs = GetOpInputs(op_desc);
......@@ -313,9 +359,11 @@ void Compiler::LowerBody() {
BOOST_GET_CONST(std::string, op_desc->GetAttr("__op_type"));
VLOG(10) << "Build graph from custom op: " << __op_type;
auto it = custom_ops_.find(__op_type);
PushNameScope(op_desc);
auto output_ids =
builder_->customOp(it->second.popart_op, it->second.popart_op.version,
inputs, outputs.size(), attributes, debug_context);
PopNameScope(op_desc);
SetIpuIndexStage(output_ids, op_desc);
InsertTensors(outputs, output_ids);
} else if (op_type == "popart_printtensor") {
......@@ -325,8 +373,10 @@ void Compiler::LowerBody() {
auto print_gradient =
BOOST_GET_CONST(int64_t, op_desc->GetAttr("print_gradient"));
auto title = BOOST_GET_CONST(std::string, op_desc->GetAttr("title"));
PushNameScope(op_desc);
auto output_ids = builder_->aiGraphcoreOpset1().printtensor(
inputs, print_gradient, debug_context, title);
PopNameScope(op_desc);
SetIpuIndexStage(output_ids, op_desc);
InsertTensors(outputs, output_ids);
} else {
......@@ -367,8 +417,31 @@ void Compiler::LowerOptimizer(const Scope* scope) {
resources_->with_lr_sched = false;
}
VLOG(10) << "Set initial lr: " << resources_->lr;
auto loss_scaling = ipu_strategy_->loss_scaling;
// Get the type of optimizer
auto type = BOOST_GET_CONST(std::string, op_desc->GetAttr("type"));
// Set weight decay by tensor names for Lamb
auto weight_decay_vars = BOOST_GET_CONST(
std::vector<std::string>, op_desc->GetAttr("weight_decay_vars"));
auto weight_decay_values = BOOST_GET_CONST(
std::vector<float>, op_desc->GetAttr("weight_decay_values"));
// Get the maximum permissible value for gradient clipping
std::vector<popart::ClipNormSettings> clip_norm_settings = {};
if (op_desc->HasAttr("clip_norm")) {
auto clip_norm = BOOST_GET_CONST(float, op_desc->GetAttr("clip_norm"));
clip_norm_settings.push_back(
popart::ClipNormSettings::clipAllWeights(clip_norm));
VLOG(10) << "Set the global gradient clipping with the maximum "
"permissible value: "
<< clip_norm;
}
// Values from ipu_strategy
auto loss_scaling = ipu_strategy_->loss_scaling;
auto accl1_type = DataTypeFromStr(ipu_strategy_->accl1_type);
auto accl2_type = DataTypeFromStr(ipu_strategy_->accl2_type);
auto accl3_type = DataTypeFromStr(ipu_strategy_->accl3_type);
if (type == "sgd") {
auto weight_decay =
BOOST_GET_CONST(float, op_desc->GetAttr("weight_decay"));
......@@ -376,12 +449,18 @@ void Compiler::LowerOptimizer(const Scope* scope) {
resources_->optimizer_fn = [=](float lr) {
return std::make_unique<popart::SGD>(
popart::OptimizerValue(lr, false),
popart::OptimizerValue(weight_decay, true),
popart::OptimizerValue(weight_decay, false),
popart::OptimizerValue(momentum, true),
popart::SGD::getUnsetDampening(),
popart::SGD::getUnsetVelocityScaling(),
popart::OptimizerValue(loss_scaling, true));
popart::OptimizerValue(loss_scaling, true), clip_norm_settings);
};
resources_->eval_optimizer = std::make_unique<popart::SGD>(
popart::OptimizerValue(0.0, false),
popart::OptimizerValue(0.0, false),
popart::OptimizerValue(0.0, true), popart::SGD::getUnsetDampening(),
popart::SGD::getUnsetVelocityScaling(),
popart::OptimizerValue(loss_scaling, true), clip_norm_settings);
} else if (type == "adam") {
auto weight_decay =
BOOST_GET_CONST(float, op_desc->GetAttr("weight_decay"));
......@@ -392,22 +471,79 @@ void Compiler::LowerOptimizer(const Scope* scope) {
VLOG(10) << "set max_weight_norm: " << mwn;
auto adam_mode_ =
BOOST_GET_CONST(std::string, op_desc->GetAttr("adam_mode"));
auto adam_mode = AdamModeFromStr(adam_mode_);
auto weight_decay_mode_ =
BOOST_GET_CONST(std::string, op_desc->GetAttr("weight_decay_mode"));
auto adam_mode =
AdamModeFromStr(adam_mode_, ipu_strategy_->use_no_bias_optimizer);
auto weight_decay_mode_ = ipu_strategy_->weight_decay_mode;
if (weight_decay_mode_.empty()) {
weight_decay_mode_ = BOOST_GET_CONST(
std::string, op_desc->GetAttr("weight_decay_mode"));
}
auto weight_decay_mode = WeightDecayModeFromStr(weight_decay_mode_);
resources_->optimizer_fn = [=](float lr) {
return std::make_unique<popart::Adam>(
popart::OptimizerValue(lr, false),
popart::OptimizerValue(weight_decay, true),
popart::OptimizerValue(beta1, true),
popart::OptimizerValue(beta2, true),
if (adam_mode == popart::AdamMode::Lamb ||
adam_mode == popart::AdamMode::LambNoBias) {
const std::map<std::string, std::pair<float, bool>>
optimizer_value = {{"defaultLearningRate", {lr, false}},
{"defaultBeta1", {beta1, false}},
{"defaultBeta2", {beta2, false}},
{"defaultEps", {eps, true}},
{"lossScaling", {loss_scaling, true}},
{"defaultMaxWeightNorm", {mwn, true}}};
auto optimizer_instance = std::make_unique<popart::Adam>(
optimizer_value, adam_mode, weight_decay_mode,
popart::DataType::UNDEFINED, accl1_type, accl2_type,
clip_norm_settings);
for (int i = 0; i < weight_decay_vars.size(); i++) {
optimizer_instance->insertSpecific(
weight_decay_vars[i],
{{"weightDecay", {weight_decay_values[i], false}}});
VLOG(10) << "Set Tensor " << weight_decay_vars[i]
<< " weight decay as " << weight_decay_values[i];
}
return optimizer_instance;
} else {
return std::make_unique<popart::Adam>(
popart::OptimizerValue(lr, false),
popart::OptimizerValue(weight_decay, false),
popart::OptimizerValue(beta1, false),
popart::OptimizerValue(beta2, false),
popart::OptimizerValue(eps, true),
popart::OptimizerValue(loss_scaling, true),
popart::OptimizerValue(mwn, true), adam_mode, weight_decay_mode,
popart::DataType::UNDEFINED, accl1_type, accl2_type,
clip_norm_settings);
}
};
if (adam_mode == popart::AdamMode::Lamb ||
adam_mode == popart::AdamMode::LambNoBias) {
const std::map<std::string, std::pair<float, bool>> optimizer_value =
{{"defaultLearningRate", {0.0, false}},
{"defaultBeta1", {beta1, false}},
{"defaultBeta2", {beta2, false}},
{"defaultEps", {eps, true}},
{"lossScaling", {loss_scaling, true}},
{"defaultMaxWeightNorm", {mwn, true}}};
auto eval_optimizer = std::make_unique<popart::Adam>(
optimizer_value, adam_mode, weight_decay_mode,
popart::DataType::UNDEFINED, popart::DataType::FLOAT,
popart::DataType::FLOAT, clip_norm_settings);
for (int i = 0; i < weight_decay_vars.size(); i++) {
eval_optimizer->insertSpecific(weight_decay_vars[i],
{{"weightDecay", {0.0, false}}});
}
resources_->eval_optimizer = std::move(eval_optimizer);
} else {
resources_->eval_optimizer = std::make_unique<popart::Adam>(
popart::OptimizerValue(0.0, false),
popart::OptimizerValue(0.0, false),
popart::OptimizerValue(beta1, false),
popart::OptimizerValue(beta2, false),
popart::OptimizerValue(eps, true),
popart::OptimizerValue(loss_scaling, true),
popart::OptimizerValue(mwn, true), adam_mode, weight_decay_mode,
popart::DataType::UNDEFINED, popart::DataType::FLOAT,
popart::DataType::FLOAT);
};
popart::DataType::FLOAT, clip_norm_settings);
}
} else if (type == "adaptive") {
auto alpha = BOOST_GET_CONST(float, op_desc->GetAttr("alpha"));
auto momentum = BOOST_GET_CONST(float, op_desc->GetAttr("momentum"));
......@@ -417,21 +553,33 @@ void Compiler::LowerOptimizer(const Scope* scope) {
auto adaptive_mode_ =
BOOST_GET_CONST(std::string, op_desc->GetAttr("adaptive_mode"));
auto adaptive_mode = AdaptiveModeFromStr(adaptive_mode_);
auto weight_decay_mode_ =
BOOST_GET_CONST(std::string, op_desc->GetAttr("weight_decay_mode"));
auto weight_decay_mode_ = ipu_strategy_->weight_decay_mode;
if (weight_decay_mode_.empty()) {
weight_decay_mode_ = BOOST_GET_CONST(
std::string, op_desc->GetAttr("weight_decay_mode"));
}
auto weight_decay_mode = WeightDecayModeFromStr(weight_decay_mode_);
resources_->optimizer_fn = [=](float lr) {
return std::make_unique<popart::Adaptive>(
popart::OptimizerValue(lr, false),
popart::OptimizerValue(weight_decay, true),
popart::OptimizerValue(weight_decay, false),
popart::OptimizerValue(alpha, true),
popart::OptimizerValue(momentum, true),
popart::OptimizerValue(eps, true),
popart::OptimizerValue(loss_scaling, true), adaptive_mode,
weight_decay_mode, popart::DataType::UNDEFINED,
popart::DataType::FLOAT, popart::DataType::FLOAT,
popart::DataType::FLOAT);
weight_decay_mode, popart::DataType::UNDEFINED, accl1_type,
accl2_type, accl3_type);
};
resources_->eval_optimizer = std::make_unique<popart::Adaptive>(
popart::OptimizerValue(0.0, false),
popart::OptimizerValue(0.0, false),
popart::OptimizerValue(alpha, true),
popart::OptimizerValue(momentum, true),
popart::OptimizerValue(eps, true),
popart::OptimizerValue(loss_scaling, true), adaptive_mode,
weight_decay_mode, popart::DataType::UNDEFINED,
popart::DataType::FLOAT, popart::DataType::FLOAT,
popart::DataType::UNDEFINED);
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"optimizer %s is not implemented", type));
......@@ -510,9 +658,32 @@ void Compiler::SetAMPAttributes(const std::string& tensor_id,
const OpDesc* op_desc) {
VLOG(10) << "enter Compiler::SetAMPAttributes";
if (op_desc->Type() == "popart_matmul") {
auto amp = ipu_strategy_->available_memory_proportion;
if (amp > 0.0f && amp <= 1.0) {
builder_->setAvailableMemoryProportion(tensor_id, amp);
if (set_amp_for_all_) {
auto amp = ipu_strategy_->available_memory_proportion;
if (amp < 0.0f || amp > 1.0) {
PADDLE_THROW(platform::errors::InvalidArgument(
"AvailableMemoryProportion %f is invalid, which should be set 0 <= "
"amp <= 1",
amp));
}
if (amp > 0.0f) {
builder_->setAvailableMemoryProportion(tensor_id, amp);
}
} else {
if (op_desc->HasAttr(sAvailMemAttribute)) {
auto amp = BOOST_GET_CONST(float, op_desc->GetAttr(sAvailMemAttribute));
if (amp < 0.0f || amp > 1.0) {
PADDLE_THROW(platform::errors::InvalidArgument(
"AvailableMemoryProportion %f is invalid, which should be set 0 "
"<= amp <= 1",
amp));
}
if (amp > 0.0f) {
builder_->setAvailableMemoryProportion(tensor_id, amp);
VLOG(10) << "set available_memory_proportion for tensor: "
<< tensor_id << " as " << amp;
}
}
}
}
VLOG(10) << "leave Compiler::SetAMPAttributes";
......@@ -602,6 +773,29 @@ popart::DebugContext Compiler::BuildDebugContext(const OpDesc* op) {
return popart::DebugContext(op_identify_id);
}
void Compiler::PushNameScope(const OpDesc* op) {
auto op_namescope = BOOST_GET_CONST(std::string, op->GetAttr(sOpNamescope));
if (op_namescope == "/") {
return;
}
if (!op_namescope.empty()) {
op_namescope.pop_back();
}
if (!op_namescope.empty()) {
op_namescope.erase(op_namescope.begin());
}
VLOG(10) << "name_scope is: " << op_namescope;
builder_->pushNameScope(op_namescope);
}
void Compiler::PopNameScope(const OpDesc* op) {
auto op_namescope = BOOST_GET_CONST(std::string, op->GetAttr(sOpNamescope));
if (op_namescope == "/") {
return;
}
builder_->popNameScope();
}
} // namespace ipu
} // namespace platform
} // namespace paddle
......@@ -50,6 +50,8 @@ struct CompilerResources {
using OptimizerFn =
std::function<std::unique_ptr<popart::Optimizer>(float lr)>;
OptimizerFn optimizer_fn;
// The eval mode of optimizer in training
std::unique_ptr<popart::Optimizer> eval_optimizer;
public:
popart::Optimizer *Optimizer() { return optimizer.get(); }
......@@ -110,6 +112,7 @@ class Compiler {
void RegisterOpFunc();
std::vector<std::string> GetOpInputs(const OpDesc *op);
const std::vector<std::string> &GetOpOutputs(const OpDesc *op);
const std::string GetNameScope(const OpDesc *op);
popart::DebugContext BuildDebugContext(const OpDesc *op);
void InsertTensors(const std::vector<std::string> &output_names,
......@@ -126,6 +129,8 @@ class Compiler {
const OpDesc *op_desc);
void SetSerializeAttributes(const std::string &tensor_id,
const OpDesc *op_desc);
void PushNameScope(const OpDesc *op);
void PopNameScope(const OpDesc *op);
private:
std::unique_ptr<popart::Builder> builder_;
......@@ -137,6 +142,14 @@ class Compiler {
const IpuStrategy *ipu_strategy_ = nullptr;
std::map<std::string, IpuCustomOpIdentifier> custom_ops_;
// Used to choose the way to set amp for Ops
// If anyone op has the attr sAvailMemAttribute, the
// available_memory_proportion from ipu_strategy
// will be ignored and the Ops are set by their own sAvailMemAttribute. Else,
// all relevant Ops will be set by
// the available_memory_proportion from ipu_strategy.
bool set_amp_for_all_ = true;
};
} // namespace ipu
......
......@@ -64,15 +64,10 @@ void Executor::Prepare(const std::string &proto) {
WeightsFromPaddle();
VLOG(10) << "Copy weights from paddle to popart...done";
VLOG(10) << "Copy weights from host to device...";
session_->weightsFromHost();
VLOG(10) << "Copy weights from host to device...done";
if (ipu_strategy_->save_init_onnx) {
session_->modelToHost("test_init.onnx");
if (ipu_strategy_->random_seed != std::numeric_limits<std::uint64_t>::max()) {
VLOG(10) << "Setting random seed to: " << ipu_strategy_->random_seed;
session_->setRandomSeed(ipu_strategy_->random_seed);
}
// init run step
step_ = 0;
}
void Executor::Run(const std::vector<const Tensor *> &inputs,
......@@ -120,11 +115,17 @@ void Executor::Run(const std::vector<const Tensor *> &inputs,
VLOG(10) << "Prepared inputs/anchors";
if (ipu_strategy_->is_training && compiler_resources_->with_lr_sched) {
VLOG(10) << "Update learning_rate";
auto new_lr =
GetSingleVarFromScope<float>(scope_, compiler_resources_->lr_var);
VLOG(10) << "New Lr: " << new_lr;
auto *optimizer = compiler_resources_->UpdateOptimizer(new_lr);
popart::Optimizer *optimizer;
if (ipu_strategy_->runtime_options.enable_eval) {
VLOG(10) << "Switch optimizer to eval mode";
optimizer = compiler_resources_->eval_optimizer.get();
} else {
VLOG(10) << "Update learning_rate";
auto new_lr =
GetSingleVarFromScope<float>(scope_, compiler_resources_->lr_var);
VLOG(10) << "New Lr: " << new_lr;
optimizer = compiler_resources_->UpdateOptimizer(new_lr);
}
auto *session = dynamic_cast<popart::TrainingSession *>(session_.get());
session->updateOptimizerFromHost(optimizer);
}
......@@ -133,15 +134,13 @@ void Executor::Run(const std::vector<const Tensor *> &inputs,
VLOG(10) << "Running...";
session_->run(stepio);
VLOG(10) << "Running...done";
}
step_++;
if (ipu_strategy_->is_training &&
step_ % ipu_strategy_->save_per_n_step == 0) {
session_->weightsToHost();
void Executor::WeightsToHost() {
if (ipu_strategy_->is_training && session_) {
WeightsToPaddle();
if (ipu_strategy_->save_onnx_checkpoint) {
session_->modelToHost("test_last" + std::to_string(step_) + ".onnx");
}
} else {
LOG(WARNING) << "For a non-trainning graph, cannot sync weights from IPU.";
}
}
......@@ -153,6 +152,7 @@ void Executor::AcquireDevice() {
}
bool use_ipu_model = GetBoolEnv("POPLAR_IPUMODEL");
bool enable_distribution = ipu_strategy_->enable_distribution;
if (use_ipu_model) {
std::map<std::string, std::string> deviceOpts{
{
......@@ -162,6 +162,16 @@ void Executor::AcquireDevice() {
};
device_ = popart::DeviceManager::createDeviceManager().createIpuModelDevice(
deviceOpts);
} else if (enable_distribution) {
auto ipus_per_replica = ipu_strategy_->num_ipus /
ipu_strategy_->popart_options.replicatedGraphCount;
auto device_id = popdist_get_device(ipus_per_replica);
device_ = popart::DeviceManager::createDeviceManager().acquireDeviceById(
device_id);
PADDLE_ENFORCE_NOT_NULL(
device_, platform::errors::Unavailable(
"Can't attach IPU in distribution, ipu_num = %d.",
RequestIpus(ipu_strategy_->num_ipus)));
} else {
device_ =
popart::DeviceManager::createDeviceManager().acquireAvailableDevice(
......@@ -185,28 +195,29 @@ void Executor::SetWeightsIO() {
auto opt_type = compiler_resources_->optimizer_type;
VLOG(10) << "SetWeightsIO for " << opt_type;
auto pre_post_fix = GetOptPrePostfix(opt_type);
for (const auto &weight_id : compiler_resources_->weights) {
for (const auto &weight_pd : compiler_resources_->weights) {
for (const auto &pair : pre_post_fix) {
// pair.first : popart prefix, pair.second : paddle postfix
auto popart_var_name = pair.first + weight_id;
auto paddle_var_name = weight_id + pair.second;
auto weight_pop = compiler_resources_->tensors[weight_pd];
auto popart_var = pair.first + weight_pop;
auto paddle_var = weight_pd + pair.second;
if (scope_->FindVar(paddle_var_name) == nullptr) {
if (scope_->FindVar(paddle_var) == nullptr) {
continue;
}
if (!session_->hasInfo(popart_var_name)) {
if (!session_->hasInfo(popart_var)) {
continue;
}
auto var = scope_->GetVar(paddle_var_name);
VLOG(10) << "Connect paddle weight: " << paddle_var
<< " with popart weight: " << popart_var;
auto var = scope_->GetVar(paddle_var);
auto data_ptr = var->GetMutable<framework::LoDTensor>()->data();
auto tensor_info = session_->getInfo(popart_var_name);
executor_resources_->weights_io.insert(popart_var_name,
auto tensor_info = session_->getInfo(popart_var);
executor_resources_->weights_io.insert(popart_var,
{data_ptr, tensor_info});
executor_resources_->weights_and_opt_state.emplace_back(
std::make_pair(popart_var_name, paddle_var_name));
std::make_pair(popart_var, paddle_var));
}
}
}
......@@ -284,6 +295,7 @@ void Executor::ConvertWeights(bool align_to_popart) {
void Executor::WeightsFromPaddle() {
ConvertWeights(true);
session_->writeWeights(executor_resources_->weights_io);
session_->weightsFromHost();
}
// |-----------------------------------------------------|
......@@ -297,13 +309,13 @@ void Executor::WeightsFromPaddle() {
// Paddle -> halfToFloat: cast then save to paddle
// Popart -> Paddle: copy from paddle to popart
void Executor::WeightsToPaddle() {
session_->weightsToHost();
session_->readWeights(executor_resources_->weights_io);
ConvertWeights(false);
}
void Executor::SaveModelToHost(const std::string &path) {
if (session_) {
session_->weightsToHost();
WeightsToPaddle();
session_->modelToHost(path);
} else {
......
......@@ -20,6 +20,7 @@ limitations under the License. */
#include <popart/patterns/patterns.hpp>
#include <popart/session.hpp>
#include <popart/tensorinfo.hpp>
#include <popdist/popdist_poplar.hpp>
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/scope.h"
......@@ -36,8 +37,7 @@ struct ExecutorResources {
// map<tensor_id, paddle_var_ptr>
popart::WeightsIO weights_io;
// <popart_var, paddle_var> pairs, include weights and optimizer states
std::vector<std::pair<popart::TensorId, popart::TensorId>>
weights_and_opt_state;
std::vector<std::pair<popart::TensorId, std::string>> weights_and_opt_state;
};
class Executor {
......@@ -53,14 +53,12 @@ class Executor {
const std::vector<Tensor *> &outputs,
const framework::ExecutionContext &ctx);
// sync weights from popart to paddle
void WeightsToHost();
// detach IPU
void Detach();
void SetWeightsIO();
void ConvertWeights(bool align_to_popart);
void WeightsFromPaddle();
void WeightsToPaddle();
// Scope
void SetScope(const Scope *scope) { scope_ = scope; }
......@@ -79,6 +77,10 @@ class Executor {
private:
void AcquireDevice();
void SetWeightsIO();
void ConvertWeights(bool);
void WeightsFromPaddle();
void WeightsToPaddle();
private:
// not own
......@@ -92,8 +94,6 @@ class Executor {
std::unique_ptr<popart::Session> session_;
// one OneSession means a graph
std::unique_ptr<ExecutorResources> executor_resources_;
int step_ = 0;
};
} // namespace ipu
......
......@@ -24,6 +24,8 @@ static constexpr const char *sIpuIndexAttr = "ipu_index";
static constexpr const char *sIpuStageAttr = "ipu_stage";
static constexpr const char *sMatmulSerializeFactor = "serialize_factor";
static constexpr const char *sMatmulSerializeMode = "serialize_mode";
static constexpr const char *sAvailMemAttribute = "__available_memory";
static constexpr const char *sOpNamescope = "op_namescope";
static constexpr const char *sOpIdentifyIdAttr = "op_identify_id";
static constexpr const char *sDebugInfoId = "__debug_info_id";
......
......@@ -62,23 +62,40 @@ IpuStrategy::IpuStrategy() {
[&]() { return name; })
ADD_BOOL_OPTION(is_training);
ADD_BOOL_OPTION(save_init_onnx);
ADD_BOOL_OPTION(save_onnx_checkpoint);
ADD_BOOL_OPTION(need_avg_shard);
ADD_BOOL_OPTION(enable_fp16);
ADD_BOOL_OPTION(transfer_cast_op);
ADD_BOOL_OPTION(use_no_bias_optimizer);
ADD_BOOL_OPTION(enable_distribution);
ADD_UINT64_OPTION(num_ipus);
ADD_UINT64_OPTION(batches_per_step);
ADD_UINT64_OPTION(micro_batch_size);
ADD_UINT64_OPTION(save_per_n_step);
ADD_UINT64_OPTION(random_seed);
ADD_DOUBLE_OPTION(available_memory_proportion);
ADD_DOUBLE_OPTION(loss_scaling);
ADD_DOUBLE_OPTION(max_weight_norm);
ADD_STRING_OPTION(accl1_type);
ADD_STRING_OPTION(accl2_type);
ADD_STRING_OPTION(accl3_type);
ADD_STRING_OPTION(onnx_dump_path);
ADD_STRING_OPTION(weight_decay_mode);
#undef ADD_STRING_OPTION
#undef ADD_DOUBLE_OPTION
#undef ADD_UINT64_OPTION
#undef ADD_BOOL_OPTION
#define ADD_RUNTIME_BOOL_OPTION(name, aliased_name) \
RegisterSetter(bool_options, #name, \
[&](bool value) { runtime_options.aliased_name = value; }); \
RegisterGetter(options_getter, options_type, #name, "bool", [&]() { \
return std::to_string(runtime_options.aliased_name); \
})
ADD_RUNTIME_BOOL_OPTION(runtime_options.enable_eval, enable_eval);
#undef ADD_RUNTIME_BOOL_OPTION
#define ADD_POPART_ENUM_OPTION_ALIAS(name, aliased_name, EnumType) \
RegisterSetter(uint64_options, #name, [&](std::uint64_t value) { \
PADDLE_ENFORCE_LT( \
......@@ -171,6 +188,7 @@ IpuStrategy::IpuStrategy() {
ADD_POPART_UINT64_OPTION_ALIAS(merge_var_update_mem_threshold,
mergeVarUpdateMemThreshold);
ADD_POPART_UINT64_OPTION_ALIAS(loose_threshold_at_peak, looseThresholdAtPeak);
ADD_POPART_UINT64_OPTION_ALIAS(replicated_graph_count, replicatedGraphCount);
ADD_POPART_UINT64_OPTION_ALIAS(accumulation_factor, accumulationFactor);
ADD_POPART_UINT64_OPTION_ALIAS(swap_limit_scheduler, swapLimitScheduler);
ADD_POPART_UINT64_OPTION_ALIAS(global_replication_factor,
......@@ -462,12 +480,30 @@ void IpuStrategy::SetTensorLocation(const std::string& tensor,
} else if (opt == "use_io_tiles_to_store") {
settings->location.storageTileSet =
value > 0 ? popart::TileSet::IO : popart::TileSet::Compute;
} else if (opt == "sharding_domain_with_all") {
settings->location.shardingDomain =
popart::CommGroup(popart::CommGroupType::All, value);
} else if (opt == "sharding_domain_with_consecutive") {
settings->location.shardingDomain =
popart::CommGroup(popart::CommGroupType::Consecutive, value);
} else if (opt == "sharding_domain_with_orthogonal") {
settings->location.shardingDomain =
popart::CommGroup(popart::CommGroupType::Orthogonal, value);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Unknown option ' %s' for tensor location: %s", opt, tensor));
}
}
void IpuStrategy::SetAccumulateOuterFragmentSettings(
const std::uint64_t& schedule, const std::vector<int>& values) {
VLOG(10) << "SetAccumulateOuterFragmentSettings schedule:" << schedule;
auto schedule_ =
static_cast<popart::AccumulateOuterFragmentSchedule>(schedule);
popart_options.accumulateOuterFragmentSettings =
popart::AccumulateOuterFragmentSettings(schedule_, values);
}
void IpuStrategy::AddCustomOp(const std::string& paddle_op,
const std::string& popart_op,
const std::string& domain, int version) {
......
......@@ -24,6 +24,11 @@ namespace paddle {
namespace platform {
namespace ipu {
struct RuntimeOptions {
// enable the eval mode in training by switching optimizers.
bool enable_eval = false;
};
class IpuStrategy {
public:
IpuStrategy();
......@@ -32,19 +37,24 @@ class IpuStrategy {
// training flag, true for training
bool is_training = true;
// save the onnx model lowered by paddle program description
bool save_init_onnx = false;
// save the trained model
bool save_onnx_checkpoint = false;
// average sharding, debugging used
bool need_avg_shard = false;
// flag for fp16, true for pure fp16
bool enable_fp16 = false;
// Number ipus total needed, replica * ipu_per_replica
// enable transfer cast Op target from fp32 to fp16 in fp16 mode
bool transfer_cast_op = true;
// The mode of Adam/Lamb optimizer
// false: The standard Adam/Lamb optimizer
// true: The Adam_No_Bias/Lamb_No_Bias optimizer from PopART
bool use_no_bias_optimizer = false;
// enable distributed computing for POD128 or POD256
bool enable_distribution = false;
// Number ipus total needed, local_replica * ipu_per_replica
int num_ipus = 1;
// batches per step
......@@ -53,8 +63,8 @@ class IpuStrategy {
// micro batch-size
int micro_batch_size = 1;
// save paddle model per n steps
int save_per_n_step = 1;
// random seed
std::uint64_t random_seed = std::numeric_limits<std::uint64_t>::max();
// TODO(alleng) remove this param
// available memory proportion, 0.0f for disable
......@@ -67,6 +77,29 @@ class IpuStrategy {
// defaultMaxWeightNorm for adam optimizer
float max_weight_norm = 65504.0f;
// file path for dumping compiled model in onnx format
std::string onnx_dump_path;
// Data type to use for tensor that stores first-order momentum optimizer
// state. FLOAT or FLOAT16
std::string accl1_type = "FLOAT";
// Data type to use for tensor that stores second-order momentum optimizer
// state. FLOAT or FLOAT16
std::string accl2_type = "FLOAT";
// Data type to use for tensor that stores third-order momentum optimizer
// state. FLOAT or FLOAT16
std::string accl3_type = "FLOAT";
// WeightDecayMode for setting the optimizer
// if set, it will override other settings
// value must be one of "decay" or "l2_regularization" or not set
std::string weight_decay_mode = "";
// Runtime Options
RuntimeOptions runtime_options;
// popart session option
popart::SessionOptions popart_options;
......@@ -86,6 +119,8 @@ class IpuStrategy {
const std::string &value);
void SetTensorLocation(const std::string &tensor, const std::string &option,
std::uint64_t value);
void SetAccumulateOuterFragmentSettings(const std::uint64_t &schedule,
const std::vector<int> &values);
void AddCustomOp(const std::string &paddle_op, const std::string &popart_op,
const std::string &domain, int version);
......
......@@ -32,30 +32,10 @@ const std::string GenerateOpName() {
const std::string CreateOpIdentifyId(Node *node) {
// format:
// if has custom op_namescope:
// {op_namescope}/op_type/_gen_*
// else:
// {op_type}/{out_var0}/{out_var1}/.../_gen_*
// op_type/_gen_*
// this name will be used as op name when exporting onnx model from popart
auto op_type = node->Name();
std::string op_namescope;
if (node->Op()->HasAttr("op_namescope")) {
op_namescope =
BOOST_GET_CONST(std::string, node->Op()->GetAttr("op_namescope"));
} else {
op_namescope = "/";
}
if (op_namescope != "/") {
return {op_namescope + op_type + "/" + GenerateOpName()};
} else {
std::string op_out = "";
for (auto *out_node : node->outputs) {
op_out += "/";
op_out += out_node->Name();
}
return {op_type + op_out + "/" + GenerateOpName()};
}
return {op_type + "/" + GenerateOpName()};
}
Node *MakeVarNode(Graph *graph, Node *node) {
......@@ -122,6 +102,12 @@ Node *CreateBaseOp(Graph *graph, Node *node, const std::string &type,
if (node->Op()->HasAttr(sMatmulSerializeMode)) {
CopyOpAttr(sMatmulSerializeMode, node->Op(), new_node->Op());
}
if (node->Op()->HasAttr(sAvailMemAttribute)) {
CopyOpAttr(sAvailMemAttribute, node->Op(), new_node->Op());
}
if (node->Op()->HasAttr(sOpNamescope)) {
CopyOpAttr(sOpNamescope, node->Op(), new_node->Op());
}
{
new_node->Op()->SetAttr(sOpIdentifyIdAttr, CreateOpIdentifyId(node));
new_node->Op()->Flush();
......
......@@ -4264,6 +4264,7 @@ All parameter, weight, gradient are variables in Paddle.
platform::ipu::IpuBackend::GetInstance());
},
py::return_value_policy::reference)
.def("weights_to_host", &platform::ipu::IpuBackend::WeightsToHost)
.def("detach", &platform::ipu::IpuBackend::Detach)
.def("reset", &platform::ipu::IpuBackend::Reset)
.def("set_scope", &platform::ipu::IpuBackend::SetScope)
......@@ -4311,6 +4312,15 @@ All parameter, weight, gradient are variables in Paddle.
option_name, option.first.cast<std::string>(),
option.second.cast<std::uint64_t>());
}
} else if (option_name == "accumulate_outer_fragment") {
for (auto option : element.second.cast<py::dict>()) {
std::vector<int> values;
for (auto value : option.second.cast<py::list>()) {
values.push_back(value.cast<int>());
}
self.SetAccumulateOuterFragmentSettings(
option.first.cast<std::uint64_t>(), values);
}
} else if (option_name == "custom_op") {
std::string paddle_op;
std::string popart_op;
......
......@@ -26,7 +26,13 @@ class TestIpuStrategy(unittest.TestCase):
def test_set_options(self):
ipu_strategy = paddle.static.IpuStrategy()
all_option_names = ipu_strategy._ipu_strategy.get_all_option_names()
skip_options = []
skip_options.append('random_seed')
for option_name in all_option_names:
if option_name in skip_options:
continue
option = ipu_strategy._ipu_strategy.get_option(option_name)
option_type = option['type']
option_value = option['value']
......@@ -38,9 +44,13 @@ class TestIpuStrategy(unittest.TestCase):
set_value = not option_value
else:
continue
ipu_strategy.set_options({option_name: set_value})
new_value = ipu_strategy.get_option(option_name)
assert new_value == set_value, f"set {option_name} to {set_value} failed"
try:
ipu_strategy.set_options({option_name: set_value})
new_value = ipu_strategy.get_option(option_name)
assert new_value == set_value, f"set {option_name} to {set_value} failed"
except:
raise Exception(f"set {option_name} to {set_value} failed")
def test_set_string_options(self):
ipu_strategy = paddle.static.IpuStrategy()
......
......@@ -95,12 +95,9 @@ class TestBase(IPUOpTest):
is_training=self.attrs['is_training'])
ipu_strategy.set_precision_config(
enable_fp16=self.attrs['enable_fp16'])
ipu_strategy.set_options({
'save_per_n_step': self.attrs['save_at_step']
})
program = paddle.static.IpuCompiledProgram(
main_prog, ipu_strategy=ipu_strategy).compile(
self.feed_list, fetch_list)
ipu_program = paddle.static.IpuCompiledProgram(
main_prog, ipu_strategy=ipu_strategy)
program = ipu_program.compile(self.feed_list, fetch_list)
result = []
run_steps = self.attrs['steps'] if save_otherwise_load \
......@@ -111,10 +108,9 @@ class TestBase(IPUOpTest):
for i in range(run_steps):
tmp = exe.run(program, feed=feed, fetch_list=fetch_list)
# currently, we update opt state every sess.run,
# will optimize
if save_otherwise_load and \
i == self.attrs['save_at_step'] - 1:
ipu_program._backend.weights_to_host()
paddle.static.save(main_prog,
self.attrs['model_path'].name)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册