diff --git a/paddle/fluid/platform/device/ipu/ipu_executor.cc b/paddle/fluid/platform/device/ipu/ipu_executor.cc index a7978ba6f37b130390a7f5bc97b592855d64ce5a..df2f1c786e947334cc3f44685fea797fd5113df6 100644 --- a/paddle/fluid/platform/device/ipu/ipu_executor.cc +++ b/paddle/fluid/platform/device/ipu/ipu_executor.cc @@ -12,52 +12,45 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/platform/ipu/ipu_executor.h" +#include "paddle/fluid/platform/device/ipu/ipu_executor.h" + +using float16 = paddle::platform::float16; namespace paddle { namespace platform { namespace ipu { -Executor::Executor() {} +Executor::~Executor() { + Detach(); + session_.reset(); + executor_resources_.reset(); +} + +void Executor::Prepare(const std::string &proto) { + VLOG(10) << "enter Executor::Prepare"; -Executor::~Executor() {} + AcquireDevice(); + executor_resources_ = std::make_unique(); -void Executor::Prepare(const std::string &proto, - const std::map &tensors, - const std::vector &outputs, - std::shared_ptr device) { auto art = popart::AnchorReturnType("All"); std::map anchor_ids; - for (const auto &id : outputs) { + for (const auto &id : compiler_resources_->outputs) { anchor_ids.emplace(id, art); } - auto dataFlow = popart::DataFlow(ipu_strategy_->batches_per_step, anchor_ids); - PADDLE_ENFORCE_NOT_NULL(device, platform::errors::Unavailable( - "IPU device isn't attached, please call " - "IpuBackend::AttachDevice(id) first.")); - - if (ipu_strategy_ != nullptr && ipu_strategy_->is_training) { + if (ipu_strategy_->is_training) { VLOG(10) << "Creating TrainingSession from Onnx Model..."; - auto popart_optimizer = GetPopartOptimizer(opt_info); - - auto it = tensors.find(opt_info.GetLoss()); - PADDLE_ENFORCE_NE( - it, tensors.end(), - paddle::platform::errors::InvalidArgument( - "loss_id = %s doesn't exist in popart graph.", opt_info.GetLoss())); - + auto optimizer = compiler_resources_->NewOptimizer(); session_ = popart::TrainingSession::createFromOnnxModel( - proto, dataFlow, it->second, *popart_optimizer, device, - popart::InputShapeInfo(), ipu_strategy_->popart_options_, - popart::Patterns(popart::PatternsLevel::Default)); + proto, dataFlow, compiler_resources_->loss_var, *optimizer, device_, + popart::InputShapeInfo(), ipu_strategy_->popart_options, + ipu_strategy_->popart_patterns); } else { VLOG(10) << "Creating InferenceSession from Onnx Model..."; session_ = popart::InferenceSession::createFromOnnxModel( - proto, dataFlow, device, popart::InputShapeInfo(), - ipu_strategy_->popart_options_, - popart::Patterns(popart::PatternsLevel::Default)); + proto, dataFlow, device_, popart::InputShapeInfo(), + ipu_strategy_->popart_options, ipu_strategy_->popart_patterns); } VLOG(10) << "Creating session from Onnx Model...done"; @@ -78,30 +71,27 @@ void Executor::Prepare(const std::string &proto, if (ipu_strategy_->save_init_onnx) { session_->modelToHost("test_init.onnx"); } + // init run step + step_ = 0; } -void Executor::Run(const std::vector &inputs_id, - const std::vector &inputs, - const std::vector &outputs_id, - const std::vector &outputs, +void Executor::Run(const std::vector &inputs, + const std::vector &outputs, const framework::ExecutionContext &ctx) { + VLOG(10) << "enter Executor::Run"; // inputs std::map popart_inputs; std::map input_wrappers; for (size_t i = 0; i < inputs.size(); i++) { - auto tensor_id = inputs_id[i]; - framework::Tensor *tensor = nullptr; - tensor->ShareDataWith(*inputs[i]); - input_wrappers.emplace(tensor_id, PaddleIArray(tensor)); + auto tensor_id = compiler_resources_->inputs[i]; + input_wrappers.emplace(tensor_id, PaddleIArray(inputs[i])); popart_inputs.emplace(tensor_id, input_wrappers.at(tensor_id)); } // anchors std::map popart_anchors; std::map anchor_wrappers; for (size_t i = 0; i < outputs.size(); i++) { - auto tensor_id = outputs_id[i]; - framework::Tensor *tensor = nullptr; - tensor->ShareDataWith(*outputs[i]); + auto tensor_id = compiler_resources_->outputs[i]; // get dims & dtype from session auto fetch_info = session_->getInfo(tensor_id); auto output_shape = fetch_info.shape(); @@ -109,6 +99,16 @@ void Executor::Run(const std::vector &inputs_id, output_shape.insert(output_shape.begin(), ipu_strategy_->batches_per_step); } + if (ipu_strategy_->popart_options.enableGradientAccumulation) { + output_shape.insert(output_shape.begin(), + ipu_strategy_->popart_options.accumulationFactor); + } + if (ipu_strategy_->popart_options.enableReplicatedGraphs) { + output_shape.insert(output_shape.begin(), + ipu_strategy_->popart_options.replicatedGraphCount); + } + + auto *tensor = outputs[i]; tensor->Resize(framework::make_ddim(output_shape)); auto fetch_dtype = fetch_info.dataType(); auto paddle_type = PopartType2VarType(fetch_dtype); @@ -116,13 +116,16 @@ void Executor::Run(const std::vector &inputs_id, anchor_wrappers.emplace(tensor_id, PaddleIArray(tensor)); popart_anchors.emplace(tensor_id, anchor_wrappers.at(tensor_id)); } - - if (ipu_strategy_ != nullptr && ipu_strategy_->is_training) { - VLOG(10) << "Update optimizer learning rate..."; - SetLR(GetLRFromScope()); - auto popart_optimizer = GetPopartOptimizer(opt_info); - auto &session = dynamic_cast(*session_); - session.updateOptimizerFromHost(popart_optimizer.get()); + VLOG(10) << "Prepared inputs/anchors"; + + if (ipu_strategy_->is_training && compiler_resources_->with_lr_sched) { + VLOG(10) << "Update learning_rate"; + auto new_lr = + GetSingleVarFromScope(scope_, compiler_resources_->lr_var); + VLOG(10) << "New Lr: " << new_lr; + auto *optimizer = compiler_resources_->UpdateOptimizer(new_lr); + auto *session = dynamic_cast(session_.get()); + session->updateOptimizerFromHost(optimizer); } popart::StepIO stepio(popart_inputs, popart_anchors); @@ -130,44 +133,54 @@ void Executor::Run(const std::vector &inputs_id, session_->run(stepio); VLOG(10) << "Running...done"; - if (ipu_strategy_ != nullptr && ipu_strategy_->is_training) { + step_++; + if (ipu_strategy_->is_training && + step_ % ipu_strategy_->save_per_n_step == 0) { session_->weightsToHost(); WeightsToPaddle(); - if (ipu_strategy_->save_last_onnx) { - session_->modelToHost("test_last.onnx"); + if (ipu_strategy_->save_onnx_checkpoint) { + session_->modelToHost("test_last" + std::to_string(step_) + ".onnx"); } } } -void Executor::SetOptimizerType(const std::string &type) { - opt_info.SetType(type); -} - -void Executor::SetLR(float lr_rate) { opt_info.SetLR(lr_rate); } - -void Executor::SetOptimizerAttr(const std::string &attr, float value) { - opt_info.SetAttr(attr, value); -} - -void Executor::SetLoss(const std::string &loss) { opt_info.SetLoss(loss); } +void Executor::AcquireDevice() { + VLOG(10) << "enter Executor::AcquireDevice"; + if (device_) { + Detach(); + device_.reset(); + } -void Executor::SetLRVarName(const std::string &name) { - opt_info.SetLRVarName(name); + bool use_ipu_model = GetBoolEnv("POPLAR_IPUMODEL"); + if (use_ipu_model) { + std::map deviceOpts{{"numIPUs", "1 "}}; + device_ = popart::DeviceManager::createDeviceManager().createIpuModelDevice( + deviceOpts); + } else { + device_ = + popart::DeviceManager::createDeviceManager().acquireAvailableDevice( + RequestIpus(ipu_strategy_->num_ipus)); + PADDLE_ENFORCE_NOT_NULL(device_, platform::errors::Unavailable( + "Can't attach IPU, ipu_num = %d.", + RequestIpus(ipu_strategy_->num_ipus))); + } + VLOG(10) << "leave Executor::AcquireDevice"; } -void Executor::SetWeights(const std::vector &weights) { - weights_ = weights; +void Executor::Detach() { + if (device_ && device_->isAttached()) { + VLOG(10) << "trying to detach IPU"; + device_->detach(); + VLOG(10) << " detached IPU"; + } } void Executor::SetWeightsIO() { - auto opt_type = opt_info.GetType(); + auto opt_type = compiler_resources_->optimizer_type; + VLOG(10) << "SetWeightsIO for " << opt_type; auto pre_post_fix = GetOptPrePostfix(opt_type); - for (const auto &weight_id : weights_) { + for (const auto &weight_id : compiler_resources_->weights) { for (const auto &pair : pre_post_fix) { - if (!IsOptimizerSupported(opt_type)) { - continue; - } - // pair.first : popart prefix, pair.second : paddle postfix auto popart_var_name = pair.first + weight_id; auto paddle_var_name = weight_id + pair.second; @@ -176,32 +189,120 @@ void Executor::SetWeightsIO() { continue; } + if (!session_->hasInfo(popart_var_name)) { + continue; + } + auto var = scope_->GetVar(paddle_var_name); - auto data_ptr = var->GetMutable()->data(); + auto data_ptr = var->GetMutable()->data(); auto tensor_info = session_->getInfo(popart_var_name); - weights_io_.insert(popart_var_name, {data_ptr, tensor_info}); + executor_resources_->weights_io.insert(popart_var_name, + {data_ptr, tensor_info}); + executor_resources_->weights_and_opt_state.emplace_back( + std::make_pair(popart_var_name, paddle_var_name)); } } } -void Executor::WeightsFromPaddle() { session_->writeWeights(weights_io_); } - -void Executor::WeightsToPaddle() { session_->readWeights(weights_io_); } - -void Executor::SetIpuStrategy(const IpuStrategy &strategy) { - ipu_strategy_ = &strategy; +// align_to_popart: align dtype to popart if true, else to paddle +void Executor::ConvertWeights(bool align_to_popart) { + for (auto weight_pair : executor_resources_->weights_and_opt_state) { + auto paddle_var = scope_->GetVar(weight_pair.second); + auto paddle_var_dtype = VarType2PopartType( + paddle_var->GetMutable()->type()); + + PADDLE_ENFORCE_EQ((paddle_var_dtype == popart::DataType::FLOAT || + paddle_var_dtype == popart::DataType::FLOAT16), + true, + platform::errors::InvalidArgument( + "Currently, we only support FLOAT16 and FLOAT with " + "Paddle, but received type is %s.", + paddle_var_dtype)); + + popart::TensorInfo info = session_->getInfo(weight_pair.first); + auto popart_var_dtype = info.dataType(); + PADDLE_ENFORCE_EQ((popart_var_dtype == popart::DataType::FLOAT || + popart_var_dtype == popart::DataType::FLOAT16), + true, + platform::errors::InvalidArgument( + "Currently, we only support FLOAT16 and FLOAT with " + "popart, but received type is %s.", + popart_var_dtype)); + + if (paddle_var_dtype == popart_var_dtype) { + VLOG(10) << weight_pair.first << " and " << weight_pair.second + << " have the same dtype : " << popart_var_dtype; + continue; + } else if (paddle_var_dtype == popart::DataType::FLOAT) { + VLOG(10) << weight_pair.first << " and " << weight_pair.second + << " have different dtype : " << popart_var_dtype; + auto *data_ptr = + paddle_var->GetMutable()->data(); + + auto num_elem = info.nelms(); + if (align_to_popart) { + std::vector fp16_data; + std::transform(data_ptr, data_ptr + num_elem, + std::back_inserter(fp16_data), + [&](float elem) { return popart::floatToHalf(elem); }); + memcpy(reinterpret_cast(data_ptr), fp16_data.data(), + num_elem * sizeof(float16)); + } else { + std::vector fp32_data; + auto fp16_data_ptr = reinterpret_cast(data_ptr); + std::transform(fp16_data_ptr, fp16_data_ptr + num_elem, + std::back_inserter(fp32_data), [&](uint16_t elem) { + return popart::halfToFloat(elem); + }); + memcpy(reinterpret_cast(data_ptr), fp32_data.data(), + num_elem * sizeof(float)); + } + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "Convert Paddle FLOAT16 to popart FLOAT")); + } + } } -float Executor::GetLRFromScope() { - auto lr_var = scope_->GetVar(opt_info.GetLRVarName()); - auto tensor = lr_var->Get(); +// |-----------------------------------------------------| +// | Paddle | Popart | Method | +// |-----------------------------------------------------| +// | FLOAT | FLOAT | Paddle -> Popart | +// | FLOAT | FLOAT16 | floatToHalf -> Paddle -> Popart | +// | FLOAT16 | FLOAT | Unimplemented | +// | FLOAT16 | FLOAT16 | Paddle -> Popart | +// |-----------------------------------------------------| +// floatToHalf -> Paddle: cast then save to paddle +// Paddle -> Popart: copy from paddle to popart +void Executor::WeightsFromPaddle() { + ConvertWeights(true); + session_->writeWeights(executor_resources_->weights_io); +} - PADDLE_ENFORCE_EQ(tensor.type(), framework::proto::VarType::FP32, - platform::errors::InvalidArgument( - "LR requiree float, but got (%s).", tensor.type())); +// |-----------------------------------------------------| +// | Paddle | Popart | Method | +// |-----------------------------------------------------| +// | FLOAT | FLOAT | Popart -> Paddle | +// | FLOAT | FLOAT16 | Popart -> Paddle -> halfToFloat | +// | FLOAT16 | FLOAT | Unimplemented | +// | FLOAT16 | FLOAT16 | Popart -> Paddle | +// |-----------------------------------------------------| +// Paddle -> halfToFloat: cast then save to paddle +// Popart -> Paddle: copy from paddle to popart +void Executor::WeightsToPaddle() { + session_->readWeights(executor_resources_->weights_io); + ConvertWeights(false); +} - return tensor.data()[0]; +void Executor::SaveModelToHost(const std::string &path) { + if (session_) { + session_->weightsToHost(); + WeightsToPaddle(); + session_->modelToHost(path); + } else { + LOG(WARNING) << "Model is empty"; + } } } // namespace ipu diff --git a/paddle/fluid/platform/device/ipu/ipu_executor.h b/paddle/fluid/platform/device/ipu/ipu_executor.h index 400884a2c2b0fd6e5e5e8fd96a32f65a6efe2074..b08b94b45ff65d9e04da0447f55801859a59bb1b 100644 --- a/paddle/fluid/platform/device/ipu/ipu_executor.h +++ b/paddle/fluid/platform/device/ipu/ipu_executor.h @@ -15,67 +15,85 @@ limitations under the License. */ #pragma once #include +#include #include +#include #include +#include #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/scope.h" -#include "paddle/fluid/platform/ipu/common.h" -#include "paddle/fluid/platform/ipu/ipu_optimizer.h" -#include "paddle/fluid/platform/ipu/ipu_strategy.h" -#include "paddle/fluid/platform/ipu/ipu_utils.h" +#include "paddle/fluid/platform/device/ipu/ipu_compiler.h" +#include "paddle/fluid/platform/device/ipu/ipu_names.h" +#include "paddle/fluid/platform/device/ipu/ipu_strategy.h" +#include "paddle/fluid/platform/device/ipu/ipu_utils.h" namespace paddle { namespace platform { namespace ipu { +struct ExecutorResources { + // map + popart::WeightsIO weights_io; + // pairs, include weights and optimizer states + std::vector> + weights_and_opt_state; +}; + class Executor { public: - Executor(); + Executor() = default; ~Executor(); - void Prepare(const std::string &proto, - const std::map &tensors, - const std::vector &outputs, - std::shared_ptr device); + // build popart session + void Prepare(const std::string &proto); - void Run(const std::vector &inputs_id, - const std::vector &inputs, - const std::vector &outputs_id, - const std::vector &outputs, + // run popart session + void Run(const std::vector &inputs, + const std::vector &outputs, const framework::ExecutionContext &ctx); - // Optimizer - void SetOptimizerType(const std::string &type); - void SetOptimizerAttr(const std::string &attr, float value); - void SetLoss(const std::string &loss); - void SetLR(float lr_rate); - void SetLRVarName(const std::string &name); - - void SetWeights(const std::vector &info); + // detach IPU + void Detach(); void SetWeightsIO(); + void ConvertWeights(bool align_to_popart); void WeightsFromPaddle(); void WeightsToPaddle(); // Scope - void SetScope(const framework::Scope *scope) { scope_ = scope; } + void SetScope(const Scope *scope) { scope_ = scope; } // Strategy - void SetIpuStrategy(const IpuStrategy &strategy); + void SetIpuStrategy(const IpuStrategy &strategy) { + ipu_strategy_ = &strategy; + } - private: - float GetLRFromScope(); + // CompilerResources + void SetCompilerResources(CompilerResources *resources) { + compiler_resources_ = resources; + } - public: - OptmizerMetaInfo opt_info; - std::unique_ptr session_; + // Save model to onnx + void SaveModelToHost(const std::string &path); private: - const framework::Scope *scope_ = nullptr; + void AcquireDevice(); + + private: + // not own + const Scope *scope_ = nullptr; const IpuStrategy *ipu_strategy_ = nullptr; - popart::WeightsIO weights_io_; - std::vector weights_; + CompilerResources *compiler_resources_ = nullptr; + + // deviceinfo for popart session + std::shared_ptr device_; + // popart session, where graph running + std::unique_ptr session_; + // one OneSession means a graph + std::unique_ptr executor_resources_; + + int step_ = 0; }; } // namespace ipu diff --git a/paddle/fluid/platform/device/ipu/ipu_optimizer.cc b/paddle/fluid/platform/device/ipu/ipu_optimizer.cc deleted file mode 100644 index 92bb2ca3afcf881d0092a2c0bde74cfe322b1a8a..0000000000000000000000000000000000000000 --- a/paddle/fluid/platform/device/ipu/ipu_optimizer.cc +++ /dev/null @@ -1,136 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/platform/device/ipu/ipu_optimizer.h" - -namespace paddle { -namespace platform { -namespace ipu { - -OptmizerMetaInfo::OptmizerMetaInfo() {} - -OptmizerMetaInfo::~OptmizerMetaInfo() {} - -void OptmizerMetaInfo::SetType(const std::string &type) { - type_ = OptTypeStr2Enum(type); -} - -float OptmizerMetaInfo::GetAttr(const std::string &attr, - float default_value) const { - if (attrs_.count(attr) == 0) { - return default_value; - } - return attrs_.at(attr); -} - -void OptmizerMetaInfo::SetAttr(const std::string &attr, float value) { - attrs_[attr] = value; -} - -OptimizerType OptTypeStr2Enum(const std::string type) { - if (type == "sgd") { - return OptimizerType::SGD; - } else if (type == "adam") { - return OptimizerType::Adam; - } else if (type == "lamb") { - return OptimizerType::Lamb; - } else { - return OptimizerType::Undefined; - } -} - -std::unique_ptr GetPopartOptimizer( - const OptmizerMetaInfo &opt_meta_info) { - auto opt_type = opt_meta_info.GetType(); - PADDLE_ENFORCE_NE( - opt_type, OptimizerType::Undefined, - platform::errors::InvalidArgument("Optimizer type have not been set.")); - - if (opt_type == OptimizerType::SGD) { - auto optimizer = std::make_unique( - popart::OptimizerValue(opt_meta_info.GetLR(), false), - popart::OptimizerValue(popart::SGD::getUnsetWeightDecay()), - popart::OptimizerValue(popart::SGD::getUnsetMomentum()), - popart::OptimizerValue(popart::SGD::getUnsetDampening()), - popart::OptimizerValue(popart::SGD::getUnsetVelocityScaling()), - popart::OptimizerValue(popart::SGD::getUnsetLossScaling())); - return optimizer; - } else if (opt_type == OptimizerType::Adam) { - auto optimizer = std::make_unique( - popart::OptimizerValue(opt_meta_info.GetLR(), false), - popart::OptimizerValue(popart::Adam::getUnsetWeightDecay()), - popart::OptimizerValue(opt_meta_info.GetAttr("beta1"), false), - popart::OptimizerValue(opt_meta_info.GetAttr("beta2"), false), - popart::OptimizerValue(opt_meta_info.GetAttr("epsilon"), false), - popart::OptimizerValue(popart::Adam::getUnsetLossScaling()), - popart::AdamMode::Adam, popart::WeightDecayMode::Decay, - popart::DataType::FLOAT, popart::DataType::FLOAT, - popart::DataType::FLOAT); - return optimizer; - } else if (opt_type == OptimizerType::Lamb) { - auto optimizer = std::make_unique( - popart::OptimizerValue(opt_meta_info.GetLR(), false), - popart::OptimizerValue(opt_meta_info.GetAttr("weight_decay"), false), - popart::OptimizerValue(opt_meta_info.GetAttr("beta1"), false), - popart::OptimizerValue(opt_meta_info.GetAttr("beta2"), false), - popart::OptimizerValue(opt_meta_info.GetAttr("epsilon"), false), - popart::OptimizerValue(popart::Adam::getUnsetLossScaling()), - popart::AdamMode::Lamb, popart::WeightDecayMode::Decay, - popart::DataType::FLOAT, popart::DataType::FLOAT, - popart::DataType::FLOAT); - return optimizer; - } else { - PADDLE_THROW(platform::errors::Unimplemented( - "Optimizer %d is not implemented now.", static_cast(opt_type))); - } -} - -bool IsOptimizerSupported(OptimizerType type) { - switch (type) { - case OptimizerType::SGD: - case OptimizerType::Adam: - case OptimizerType::Lamb: - return true; - default: - return false; - } -} - -std::vector> GetOptPrePostfix( - OptimizerType opt_type) { - // format: {popart_tensor_id, paddle_tensor_id}, ... - std::vector> pre_post_fix; - - switch (opt_type) { - case OptimizerType::SGD: - pre_post_fix.push_back(std::make_pair("", "")); - break; - case OptimizerType::Adam: - case OptimizerType::Lamb: - pre_post_fix.push_back(std::make_pair("", "")); - pre_post_fix.push_back(std::make_pair("Accl1___", "_moment1_0")); - pre_post_fix.push_back(std::make_pair("Accl2___", "_moment2_0")); - pre_post_fix.push_back(std::make_pair("Step___", "_beta1_pow_acc_0")); - break; - default: - pre_post_fix.push_back(std::make_pair("", "")); - break; - } - - return pre_post_fix; -} - -} // namespace ipu -} // namespace platform -} // namespace paddle diff --git a/paddle/fluid/platform/device/ipu/ipu_optimizer.h b/paddle/fluid/platform/device/ipu/ipu_optimizer.h deleted file mode 100644 index ee16abce398fb65ccbea523f3d87c732d119453b..0000000000000000000000000000000000000000 --- a/paddle/fluid/platform/device/ipu/ipu_optimizer.h +++ /dev/null @@ -1,76 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include -#include -#include - -#include "paddle/fluid/platform/enforce.h" - -namespace paddle { -namespace platform { -namespace ipu { - -enum class OptimizerType { SGD = 0, Adam, Lamb, Undefined }; - -class OptmizerMetaInfo { - public: - OptmizerMetaInfo(); - ~OptmizerMetaInfo(); - - void SetType(const std::string &type); - OptimizerType GetType() const { return type_; } - - void SetAttr(const std::string &attr, float value); - float GetAttr(const std::string &attr, float default_value = 0.0f) const; - - void SetLoss(const std::string &loss) { loss_ = loss; } - std::string GetLoss() const { return loss_; } - - void SetLR(float lr_rate) { lr_rate_ = lr_rate; } - float GetLR() const { return lr_rate_; } - - void SetLRVarName(const std::string &name) { lr_var_name_ = name; } - std::string GetLRVarName() const { return lr_var_name_; } - - private: - // type: adam, sgd, ... - OptimizerType type_ = OptimizerType::Undefined; - - // loss: loss TensorId - std::string loss_; - - // attrs: beta1, beta2, ... - std::map attrs_; - - // learning rate - float lr_rate_ = 1.0; - std::string lr_var_name_; -}; - -OptimizerType OptTypeStr2Enum(const std::string type); - -std::unique_ptr GetPopartOptimizer( - const OptmizerMetaInfo &info); - -bool IsOptimizerSupported(OptimizerType type); - -std::vector> GetOptPrePostfix( - OptimizerType type); - -} // namespace ipu -} // namespace platform -} // namespace paddle