未验证 提交 05c98ec7 编写于 作者: A Allen Guo 提交者: GitHub

update ipu_executor, remove ipu_optimizer (#38986)

Co-authored-by: NXiaobing Wang <xiaobingw@graphcore.ai>
Co-authored-by: NAllen Guo <alleng@graphcore.ai>
Co-authored-by: NZhixin Yao <zhixiny@graphcore.ai>
Co-authored-by: NHaicheng Jiang <haichengj@graphcore.ai>
Co-authored-by: NHan Zhao <hanzhao@graphcore.ai>
Co-authored-by: NXiaobing Wang <xiaobingw@graphcore.ai>
Co-authored-by: NZhixin Yao <zhixiny@graphcore.ai>
Co-authored-by: NHaicheng Jiang <haichengj@graphcore.ai>
Co-authored-by: NHan Zhao <hanzhao@graphcore.ai>
上级 b2aee3e3
...@@ -12,52 +12,45 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,52 +12,45 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/platform/ipu/ipu_executor.h" #include "paddle/fluid/platform/device/ipu/ipu_executor.h"
using float16 = paddle::platform::float16;
namespace paddle { namespace paddle {
namespace platform { namespace platform {
namespace ipu { namespace ipu {
Executor::Executor() {} Executor::~Executor() {
Detach();
session_.reset();
executor_resources_.reset();
}
void Executor::Prepare(const std::string &proto) {
VLOG(10) << "enter Executor::Prepare";
Executor::~Executor() {} AcquireDevice();
executor_resources_ = std::make_unique<ExecutorResources>();
void Executor::Prepare(const std::string &proto,
const std::map<std::string, popart::TensorId> &tensors,
const std::vector<popart::TensorId> &outputs,
std::shared_ptr<popart::DeviceInfo> device) {
auto art = popart::AnchorReturnType("All"); auto art = popart::AnchorReturnType("All");
std::map<popart::TensorId, popart::AnchorReturnType> anchor_ids; std::map<popart::TensorId, popart::AnchorReturnType> anchor_ids;
for (const auto &id : outputs) { for (const auto &id : compiler_resources_->outputs) {
anchor_ids.emplace(id, art); anchor_ids.emplace(id, art);
} }
auto dataFlow = popart::DataFlow(ipu_strategy_->batches_per_step, anchor_ids); auto dataFlow = popart::DataFlow(ipu_strategy_->batches_per_step, anchor_ids);
PADDLE_ENFORCE_NOT_NULL(device, platform::errors::Unavailable( if (ipu_strategy_->is_training) {
"IPU device isn't attached, please call "
"IpuBackend::AttachDevice(id) first."));
if (ipu_strategy_ != nullptr && ipu_strategy_->is_training) {
VLOG(10) << "Creating TrainingSession from Onnx Model..."; VLOG(10) << "Creating TrainingSession from Onnx Model...";
auto popart_optimizer = GetPopartOptimizer(opt_info); auto optimizer = compiler_resources_->NewOptimizer();
auto it = tensors.find(opt_info.GetLoss());
PADDLE_ENFORCE_NE(
it, tensors.end(),
paddle::platform::errors::InvalidArgument(
"loss_id = %s doesn't exist in popart graph.", opt_info.GetLoss()));
session_ = popart::TrainingSession::createFromOnnxModel( session_ = popart::TrainingSession::createFromOnnxModel(
proto, dataFlow, it->second, *popart_optimizer, device, proto, dataFlow, compiler_resources_->loss_var, *optimizer, device_,
popart::InputShapeInfo(), ipu_strategy_->popart_options_, popart::InputShapeInfo(), ipu_strategy_->popart_options,
popart::Patterns(popart::PatternsLevel::Default)); ipu_strategy_->popart_patterns);
} else { } else {
VLOG(10) << "Creating InferenceSession from Onnx Model..."; VLOG(10) << "Creating InferenceSession from Onnx Model...";
session_ = popart::InferenceSession::createFromOnnxModel( session_ = popart::InferenceSession::createFromOnnxModel(
proto, dataFlow, device, popart::InputShapeInfo(), proto, dataFlow, device_, popart::InputShapeInfo(),
ipu_strategy_->popart_options_, ipu_strategy_->popart_options, ipu_strategy_->popart_patterns);
popart::Patterns(popart::PatternsLevel::Default));
} }
VLOG(10) << "Creating session from Onnx Model...done"; VLOG(10) << "Creating session from Onnx Model...done";
...@@ -78,30 +71,27 @@ void Executor::Prepare(const std::string &proto, ...@@ -78,30 +71,27 @@ void Executor::Prepare(const std::string &proto,
if (ipu_strategy_->save_init_onnx) { if (ipu_strategy_->save_init_onnx) {
session_->modelToHost("test_init.onnx"); session_->modelToHost("test_init.onnx");
} }
// init run step
step_ = 0;
} }
void Executor::Run(const std::vector<popart::TensorId> &inputs_id, void Executor::Run(const std::vector<const Tensor *> &inputs,
const std::vector<const framework::Tensor *> &inputs, const std::vector<Tensor *> &outputs,
const std::vector<popart::TensorId> &outputs_id,
const std::vector<framework::Tensor *> &outputs,
const framework::ExecutionContext &ctx) { const framework::ExecutionContext &ctx) {
VLOG(10) << "enter Executor::Run";
// inputs // inputs
std::map<popart::TensorId, popart::IArray &> popart_inputs; std::map<popart::TensorId, popart::IArray &> popart_inputs;
std::map<popart::TensorId, PaddleIArray> input_wrappers; std::map<popart::TensorId, PaddleIArray> input_wrappers;
for (size_t i = 0; i < inputs.size(); i++) { for (size_t i = 0; i < inputs.size(); i++) {
auto tensor_id = inputs_id[i]; auto tensor_id = compiler_resources_->inputs[i];
framework::Tensor *tensor = nullptr; input_wrappers.emplace(tensor_id, PaddleIArray(inputs[i]));
tensor->ShareDataWith(*inputs[i]);
input_wrappers.emplace(tensor_id, PaddleIArray(tensor));
popart_inputs.emplace(tensor_id, input_wrappers.at(tensor_id)); popart_inputs.emplace(tensor_id, input_wrappers.at(tensor_id));
} }
// anchors // anchors
std::map<popart::TensorId, popart::IArray &> popart_anchors; std::map<popart::TensorId, popart::IArray &> popart_anchors;
std::map<popart::TensorId, PaddleIArray> anchor_wrappers; std::map<popart::TensorId, PaddleIArray> anchor_wrappers;
for (size_t i = 0; i < outputs.size(); i++) { for (size_t i = 0; i < outputs.size(); i++) {
auto tensor_id = outputs_id[i]; auto tensor_id = compiler_resources_->outputs[i];
framework::Tensor *tensor = nullptr;
tensor->ShareDataWith(*outputs[i]);
// get dims & dtype from session // get dims & dtype from session
auto fetch_info = session_->getInfo(tensor_id); auto fetch_info = session_->getInfo(tensor_id);
auto output_shape = fetch_info.shape(); auto output_shape = fetch_info.shape();
...@@ -109,6 +99,16 @@ void Executor::Run(const std::vector<popart::TensorId> &inputs_id, ...@@ -109,6 +99,16 @@ void Executor::Run(const std::vector<popart::TensorId> &inputs_id,
output_shape.insert(output_shape.begin(), output_shape.insert(output_shape.begin(),
ipu_strategy_->batches_per_step); ipu_strategy_->batches_per_step);
} }
if (ipu_strategy_->popart_options.enableGradientAccumulation) {
output_shape.insert(output_shape.begin(),
ipu_strategy_->popart_options.accumulationFactor);
}
if (ipu_strategy_->popart_options.enableReplicatedGraphs) {
output_shape.insert(output_shape.begin(),
ipu_strategy_->popart_options.replicatedGraphCount);
}
auto *tensor = outputs[i];
tensor->Resize(framework::make_ddim(output_shape)); tensor->Resize(framework::make_ddim(output_shape));
auto fetch_dtype = fetch_info.dataType(); auto fetch_dtype = fetch_info.dataType();
auto paddle_type = PopartType2VarType(fetch_dtype); auto paddle_type = PopartType2VarType(fetch_dtype);
...@@ -116,13 +116,16 @@ void Executor::Run(const std::vector<popart::TensorId> &inputs_id, ...@@ -116,13 +116,16 @@ void Executor::Run(const std::vector<popart::TensorId> &inputs_id,
anchor_wrappers.emplace(tensor_id, PaddleIArray(tensor)); anchor_wrappers.emplace(tensor_id, PaddleIArray(tensor));
popart_anchors.emplace(tensor_id, anchor_wrappers.at(tensor_id)); popart_anchors.emplace(tensor_id, anchor_wrappers.at(tensor_id));
} }
VLOG(10) << "Prepared inputs/anchors";
if (ipu_strategy_ != nullptr && ipu_strategy_->is_training) {
VLOG(10) << "Update optimizer learning rate..."; if (ipu_strategy_->is_training && compiler_resources_->with_lr_sched) {
SetLR(GetLRFromScope()); VLOG(10) << "Update learning_rate";
auto popart_optimizer = GetPopartOptimizer(opt_info); auto new_lr =
auto &session = dynamic_cast<popart::TrainingSession &>(*session_); GetSingleVarFromScope<float>(scope_, compiler_resources_->lr_var);
session.updateOptimizerFromHost(popart_optimizer.get()); VLOG(10) << "New Lr: " << new_lr;
auto *optimizer = compiler_resources_->UpdateOptimizer(new_lr);
auto *session = dynamic_cast<popart::TrainingSession *>(session_.get());
session->updateOptimizerFromHost(optimizer);
} }
popart::StepIO stepio(popart_inputs, popart_anchors); popart::StepIO stepio(popart_inputs, popart_anchors);
...@@ -130,44 +133,54 @@ void Executor::Run(const std::vector<popart::TensorId> &inputs_id, ...@@ -130,44 +133,54 @@ void Executor::Run(const std::vector<popart::TensorId> &inputs_id,
session_->run(stepio); session_->run(stepio);
VLOG(10) << "Running...done"; VLOG(10) << "Running...done";
if (ipu_strategy_ != nullptr && ipu_strategy_->is_training) { step_++;
if (ipu_strategy_->is_training &&
step_ % ipu_strategy_->save_per_n_step == 0) {
session_->weightsToHost(); session_->weightsToHost();
WeightsToPaddle(); WeightsToPaddle();
if (ipu_strategy_->save_last_onnx) { if (ipu_strategy_->save_onnx_checkpoint) {
session_->modelToHost("test_last.onnx"); session_->modelToHost("test_last" + std::to_string(step_) + ".onnx");
} }
} }
} }
void Executor::SetOptimizerType(const std::string &type) { void Executor::AcquireDevice() {
opt_info.SetType(type); VLOG(10) << "enter Executor::AcquireDevice";
} if (device_) {
Detach();
void Executor::SetLR(float lr_rate) { opt_info.SetLR(lr_rate); } device_.reset();
}
void Executor::SetOptimizerAttr(const std::string &attr, float value) {
opt_info.SetAttr(attr, value);
}
void Executor::SetLoss(const std::string &loss) { opt_info.SetLoss(loss); }
void Executor::SetLRVarName(const std::string &name) { bool use_ipu_model = GetBoolEnv("POPLAR_IPUMODEL");
opt_info.SetLRVarName(name); if (use_ipu_model) {
std::map<std::string, std::string> deviceOpts{{"numIPUs", "1 "}};
device_ = popart::DeviceManager::createDeviceManager().createIpuModelDevice(
deviceOpts);
} else {
device_ =
popart::DeviceManager::createDeviceManager().acquireAvailableDevice(
RequestIpus(ipu_strategy_->num_ipus));
PADDLE_ENFORCE_NOT_NULL(device_, platform::errors::Unavailable(
"Can't attach IPU, ipu_num = %d.",
RequestIpus(ipu_strategy_->num_ipus)));
}
VLOG(10) << "leave Executor::AcquireDevice";
} }
void Executor::SetWeights(const std::vector<popart::TensorId> &weights) { void Executor::Detach() {
weights_ = weights; if (device_ && device_->isAttached()) {
VLOG(10) << "trying to detach IPU";
device_->detach();
VLOG(10) << " detached IPU";
}
} }
void Executor::SetWeightsIO() { void Executor::SetWeightsIO() {
auto opt_type = opt_info.GetType(); auto opt_type = compiler_resources_->optimizer_type;
VLOG(10) << "SetWeightsIO for " << opt_type;
auto pre_post_fix = GetOptPrePostfix(opt_type); auto pre_post_fix = GetOptPrePostfix(opt_type);
for (const auto &weight_id : weights_) { for (const auto &weight_id : compiler_resources_->weights) {
for (const auto &pair : pre_post_fix) { for (const auto &pair : pre_post_fix) {
if (!IsOptimizerSupported(opt_type)) {
continue;
}
// pair.first : popart prefix, pair.second : paddle postfix // pair.first : popart prefix, pair.second : paddle postfix
auto popart_var_name = pair.first + weight_id; auto popart_var_name = pair.first + weight_id;
auto paddle_var_name = weight_id + pair.second; auto paddle_var_name = weight_id + pair.second;
...@@ -176,32 +189,120 @@ void Executor::SetWeightsIO() { ...@@ -176,32 +189,120 @@ void Executor::SetWeightsIO() {
continue; continue;
} }
if (!session_->hasInfo(popart_var_name)) {
continue;
}
auto var = scope_->GetVar(paddle_var_name); auto var = scope_->GetVar(paddle_var_name);
auto data_ptr = var->GetMutable<framework::LoDTensor>()->data<float>(); auto data_ptr = var->GetMutable<framework::LoDTensor>()->data();
auto tensor_info = session_->getInfo(popart_var_name); auto tensor_info = session_->getInfo(popart_var_name);
weights_io_.insert(popart_var_name, {data_ptr, tensor_info}); executor_resources_->weights_io.insert(popart_var_name,
{data_ptr, tensor_info});
executor_resources_->weights_and_opt_state.emplace_back(
std::make_pair(popart_var_name, paddle_var_name));
} }
} }
} }
void Executor::WeightsFromPaddle() { session_->writeWeights(weights_io_); } // align_to_popart: align dtype to popart if true, else to paddle
void Executor::ConvertWeights(bool align_to_popart) {
for (auto weight_pair : executor_resources_->weights_and_opt_state) {
auto paddle_var = scope_->GetVar(weight_pair.second);
auto paddle_var_dtype = VarType2PopartType(
paddle_var->GetMutable<framework::LoDTensor>()->type());
void Executor::WeightsToPaddle() { session_->readWeights(weights_io_); } PADDLE_ENFORCE_EQ((paddle_var_dtype == popart::DataType::FLOAT ||
paddle_var_dtype == popart::DataType::FLOAT16),
true,
platform::errors::InvalidArgument(
"Currently, we only support FLOAT16 and FLOAT with "
"Paddle, but received type is %s.",
paddle_var_dtype));
popart::TensorInfo info = session_->getInfo(weight_pair.first);
auto popart_var_dtype = info.dataType();
PADDLE_ENFORCE_EQ((popart_var_dtype == popart::DataType::FLOAT ||
popart_var_dtype == popart::DataType::FLOAT16),
true,
platform::errors::InvalidArgument(
"Currently, we only support FLOAT16 and FLOAT with "
"popart, but received type is %s.",
popart_var_dtype));
void Executor::SetIpuStrategy(const IpuStrategy &strategy) { if (paddle_var_dtype == popart_var_dtype) {
ipu_strategy_ = &strategy; VLOG(10) << weight_pair.first << " and " << weight_pair.second
<< " have the same dtype : " << popart_var_dtype;
continue;
} else if (paddle_var_dtype == popart::DataType::FLOAT) {
VLOG(10) << weight_pair.first << " and " << weight_pair.second
<< " have different dtype : " << popart_var_dtype;
auto *data_ptr =
paddle_var->GetMutable<framework::LoDTensor>()->data<float>();
auto num_elem = info.nelms();
if (align_to_popart) {
std::vector<uint16_t> fp16_data;
std::transform(data_ptr, data_ptr + num_elem,
std::back_inserter(fp16_data),
[&](float elem) { return popart::floatToHalf(elem); });
memcpy(reinterpret_cast<void *>(data_ptr), fp16_data.data(),
num_elem * sizeof(float16));
} else {
std::vector<float> fp32_data;
auto fp16_data_ptr = reinterpret_cast<uint16_t *>(data_ptr);
std::transform(fp16_data_ptr, fp16_data_ptr + num_elem,
std::back_inserter(fp32_data), [&](uint16_t elem) {
return popart::halfToFloat(elem);
});
memcpy(reinterpret_cast<void *>(data_ptr), fp32_data.data(),
num_elem * sizeof(float));
}
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Convert Paddle FLOAT16 to popart FLOAT"));
}
}
} }
float Executor::GetLRFromScope() { // |-----------------------------------------------------|
auto lr_var = scope_->GetVar(opt_info.GetLRVarName()); // | Paddle | Popart | Method |
auto tensor = lr_var->Get<framework::LoDTensor>(); // |-----------------------------------------------------|
// | FLOAT | FLOAT | Paddle -> Popart |
// | FLOAT | FLOAT16 | floatToHalf -> Paddle -> Popart |
// | FLOAT16 | FLOAT | Unimplemented |
// | FLOAT16 | FLOAT16 | Paddle -> Popart |
// |-----------------------------------------------------|
// floatToHalf -> Paddle: cast then save to paddle
// Paddle -> Popart: copy from paddle to popart
void Executor::WeightsFromPaddle() {
ConvertWeights(true);
session_->writeWeights(executor_resources_->weights_io);
}
PADDLE_ENFORCE_EQ(tensor.type(), framework::proto::VarType::FP32, // |-----------------------------------------------------|
platform::errors::InvalidArgument( // | Paddle | Popart | Method |
"LR requiree float, but got (%s).", tensor.type())); // |-----------------------------------------------------|
// | FLOAT | FLOAT | Popart -> Paddle |
// | FLOAT | FLOAT16 | Popart -> Paddle -> halfToFloat |
// | FLOAT16 | FLOAT | Unimplemented |
// | FLOAT16 | FLOAT16 | Popart -> Paddle |
// |-----------------------------------------------------|
// Paddle -> halfToFloat: cast then save to paddle
// Popart -> Paddle: copy from paddle to popart
void Executor::WeightsToPaddle() {
session_->readWeights(executor_resources_->weights_io);
ConvertWeights(false);
}
return tensor.data<float>()[0]; void Executor::SaveModelToHost(const std::string &path) {
if (session_) {
session_->weightsToHost();
WeightsToPaddle();
session_->modelToHost(path);
} else {
LOG(WARNING) << "Model is empty";
}
} }
} // namespace ipu } // namespace ipu
......
...@@ -15,67 +15,85 @@ limitations under the License. */ ...@@ -15,67 +15,85 @@ limitations under the License. */
#pragma once #pragma once
#include <popart/dataflow.hpp> #include <popart/dataflow.hpp>
#include <popart/half.hpp>
#include <popart/names.hpp> #include <popart/names.hpp>
#include <popart/patterns/patterns.hpp>
#include <popart/session.hpp> #include <popart/session.hpp>
#include <popart/tensorinfo.hpp>
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/ipu/common.h" #include "paddle/fluid/platform/device/ipu/ipu_compiler.h"
#include "paddle/fluid/platform/ipu/ipu_optimizer.h" #include "paddle/fluid/platform/device/ipu/ipu_names.h"
#include "paddle/fluid/platform/ipu/ipu_strategy.h" #include "paddle/fluid/platform/device/ipu/ipu_strategy.h"
#include "paddle/fluid/platform/ipu/ipu_utils.h" #include "paddle/fluid/platform/device/ipu/ipu_utils.h"
namespace paddle { namespace paddle {
namespace platform { namespace platform {
namespace ipu { namespace ipu {
struct ExecutorResources {
// map<tensor_id, paddle_var_ptr>
popart::WeightsIO weights_io;
// <popart_var, paddle_var> pairs, include weights and optimizer states
std::vector<std::pair<popart::TensorId, popart::TensorId>>
weights_and_opt_state;
};
class Executor { class Executor {
public: public:
Executor(); Executor() = default;
~Executor(); ~Executor();
void Prepare(const std::string &proto, // build popart session
const std::map<std::string, popart::TensorId> &tensors, void Prepare(const std::string &proto);
const std::vector<popart::TensorId> &outputs,
std::shared_ptr<popart::DeviceInfo> device);
void Run(const std::vector<popart::TensorId> &inputs_id, // run popart session
const std::vector<const framework::Tensor *> &inputs, void Run(const std::vector<const Tensor *> &inputs,
const std::vector<popart::TensorId> &outputs_id, const std::vector<Tensor *> &outputs,
const std::vector<framework::Tensor *> &outputs,
const framework::ExecutionContext &ctx); const framework::ExecutionContext &ctx);
// Optimizer // detach IPU
void SetOptimizerType(const std::string &type); void Detach();
void SetOptimizerAttr(const std::string &attr, float value);
void SetLoss(const std::string &loss);
void SetLR(float lr_rate);
void SetLRVarName(const std::string &name);
void SetWeights(const std::vector<popart::TensorId> &info);
void SetWeightsIO(); void SetWeightsIO();
void ConvertWeights(bool align_to_popart);
void WeightsFromPaddle(); void WeightsFromPaddle();
void WeightsToPaddle(); void WeightsToPaddle();
// Scope // Scope
void SetScope(const framework::Scope *scope) { scope_ = scope; } void SetScope(const Scope *scope) { scope_ = scope; }
// Strategy // Strategy
void SetIpuStrategy(const IpuStrategy &strategy); void SetIpuStrategy(const IpuStrategy &strategy) {
ipu_strategy_ = &strategy;
}
private: // CompilerResources
float GetLRFromScope(); void SetCompilerResources(CompilerResources *resources) {
compiler_resources_ = resources;
}
public: // Save model to onnx
OptmizerMetaInfo opt_info; void SaveModelToHost(const std::string &path);
std::unique_ptr<popart::Session> session_;
private: private:
const framework::Scope *scope_ = nullptr; void AcquireDevice();
private:
// not own
const Scope *scope_ = nullptr;
const IpuStrategy *ipu_strategy_ = nullptr; const IpuStrategy *ipu_strategy_ = nullptr;
popart::WeightsIO weights_io_; CompilerResources *compiler_resources_ = nullptr;
std::vector<popart::TensorId> weights_;
// deviceinfo for popart session
std::shared_ptr<popart::DeviceInfo> device_;
// popart session, where graph running
std::unique_ptr<popart::Session> session_;
// one OneSession means a graph
std::unique_ptr<ExecutorResources> executor_resources_;
int step_ = 0;
}; };
} // namespace ipu } // namespace ipu
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/platform/device/ipu/ipu_optimizer.h"
namespace paddle {
namespace platform {
namespace ipu {
OptmizerMetaInfo::OptmizerMetaInfo() {}
OptmizerMetaInfo::~OptmizerMetaInfo() {}
void OptmizerMetaInfo::SetType(const std::string &type) {
type_ = OptTypeStr2Enum(type);
}
float OptmizerMetaInfo::GetAttr(const std::string &attr,
float default_value) const {
if (attrs_.count(attr) == 0) {
return default_value;
}
return attrs_.at(attr);
}
void OptmizerMetaInfo::SetAttr(const std::string &attr, float value) {
attrs_[attr] = value;
}
OptimizerType OptTypeStr2Enum(const std::string type) {
if (type == "sgd") {
return OptimizerType::SGD;
} else if (type == "adam") {
return OptimizerType::Adam;
} else if (type == "lamb") {
return OptimizerType::Lamb;
} else {
return OptimizerType::Undefined;
}
}
std::unique_ptr<popart::Optimizer> GetPopartOptimizer(
const OptmizerMetaInfo &opt_meta_info) {
auto opt_type = opt_meta_info.GetType();
PADDLE_ENFORCE_NE(
opt_type, OptimizerType::Undefined,
platform::errors::InvalidArgument("Optimizer type have not been set."));
if (opt_type == OptimizerType::SGD) {
auto optimizer = std::make_unique<popart::SGD>(
popart::OptimizerValue(opt_meta_info.GetLR(), false),
popart::OptimizerValue(popart::SGD::getUnsetWeightDecay()),
popart::OptimizerValue(popart::SGD::getUnsetMomentum()),
popart::OptimizerValue(popart::SGD::getUnsetDampening()),
popart::OptimizerValue(popart::SGD::getUnsetVelocityScaling()),
popart::OptimizerValue(popart::SGD::getUnsetLossScaling()));
return optimizer;
} else if (opt_type == OptimizerType::Adam) {
auto optimizer = std::make_unique<popart::Adam>(
popart::OptimizerValue(opt_meta_info.GetLR(), false),
popart::OptimizerValue(popart::Adam::getUnsetWeightDecay()),
popart::OptimizerValue(opt_meta_info.GetAttr("beta1"), false),
popart::OptimizerValue(opt_meta_info.GetAttr("beta2"), false),
popart::OptimizerValue(opt_meta_info.GetAttr("epsilon"), false),
popart::OptimizerValue(popart::Adam::getUnsetLossScaling()),
popart::AdamMode::Adam, popart::WeightDecayMode::Decay,
popart::DataType::FLOAT, popart::DataType::FLOAT,
popart::DataType::FLOAT);
return optimizer;
} else if (opt_type == OptimizerType::Lamb) {
auto optimizer = std::make_unique<popart::Adam>(
popart::OptimizerValue(opt_meta_info.GetLR(), false),
popart::OptimizerValue(opt_meta_info.GetAttr("weight_decay"), false),
popart::OptimizerValue(opt_meta_info.GetAttr("beta1"), false),
popart::OptimizerValue(opt_meta_info.GetAttr("beta2"), false),
popart::OptimizerValue(opt_meta_info.GetAttr("epsilon"), false),
popart::OptimizerValue(popart::Adam::getUnsetLossScaling()),
popart::AdamMode::Lamb, popart::WeightDecayMode::Decay,
popart::DataType::FLOAT, popart::DataType::FLOAT,
popart::DataType::FLOAT);
return optimizer;
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Optimizer %d is not implemented now.", static_cast<int>(opt_type)));
}
}
bool IsOptimizerSupported(OptimizerType type) {
switch (type) {
case OptimizerType::SGD:
case OptimizerType::Adam:
case OptimizerType::Lamb:
return true;
default:
return false;
}
}
std::vector<std::pair<std::string, std::string>> GetOptPrePostfix(
OptimizerType opt_type) {
// format: {popart_tensor_id, paddle_tensor_id}, ...
std::vector<std::pair<std::string, std::string>> pre_post_fix;
switch (opt_type) {
case OptimizerType::SGD:
pre_post_fix.push_back(std::make_pair("", ""));
break;
case OptimizerType::Adam:
case OptimizerType::Lamb:
pre_post_fix.push_back(std::make_pair("", ""));
pre_post_fix.push_back(std::make_pair("Accl1___", "_moment1_0"));
pre_post_fix.push_back(std::make_pair("Accl2___", "_moment2_0"));
pre_post_fix.push_back(std::make_pair("Step___", "_beta1_pow_acc_0"));
break;
default:
pre_post_fix.push_back(std::make_pair("", ""));
break;
}
return pre_post_fix;
}
} // namespace ipu
} // namespace platform
} // namespace paddle
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <popart/adam.hpp>
#include <popart/names.hpp>
#include <popart/optimizer.hpp>
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace platform {
namespace ipu {
enum class OptimizerType { SGD = 0, Adam, Lamb, Undefined };
class OptmizerMetaInfo {
public:
OptmizerMetaInfo();
~OptmizerMetaInfo();
void SetType(const std::string &type);
OptimizerType GetType() const { return type_; }
void SetAttr(const std::string &attr, float value);
float GetAttr(const std::string &attr, float default_value = 0.0f) const;
void SetLoss(const std::string &loss) { loss_ = loss; }
std::string GetLoss() const { return loss_; }
void SetLR(float lr_rate) { lr_rate_ = lr_rate; }
float GetLR() const { return lr_rate_; }
void SetLRVarName(const std::string &name) { lr_var_name_ = name; }
std::string GetLRVarName() const { return lr_var_name_; }
private:
// type: adam, sgd, ...
OptimizerType type_ = OptimizerType::Undefined;
// loss: loss TensorId
std::string loss_;
// attrs: beta1, beta2, ...
std::map<std::string, float> attrs_;
// learning rate
float lr_rate_ = 1.0;
std::string lr_var_name_;
};
OptimizerType OptTypeStr2Enum(const std::string type);
std::unique_ptr<popart::Optimizer> GetPopartOptimizer(
const OptmizerMetaInfo &info);
bool IsOptimizerSupported(OptimizerType type);
std::vector<std::pair<std::string, std::string>> GetOptPrePostfix(
OptimizerType type);
} // namespace ipu
} // namespace platform
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册