提交 dadd48a5 编写于 作者: Y Yu Yang 提交者: GitHub

Merge pull request #963 from reyoung/feature/add_const_in_parameter_updater

Add const in ParameterUpdater init
......@@ -19,7 +19,7 @@ limitations under the License. */
namespace paddle {
void ParameterUpdater::init(std::vector<ParameterPtr>& parameters) {
void ParameterUpdater::init(const std::vector<ParameterPtr>& parameters) {
parameters_ = parameters;
for (ParameterType type : getParameterTypes()) {
for (auto& para : parameters) {
......
......@@ -32,7 +32,7 @@ public:
parameterTypes_.push_back(type);
}
virtual void init(std::vector<ParameterPtr>& parameters);
virtual void init(const std::vector<ParameterPtr>& parameters);
// called by Trainer when starting a new pass
virtual void startPass() {}
......@@ -105,7 +105,7 @@ public:
ParameterUpdaterComposite() {}
virtual ~ParameterUpdaterComposite() {}
virtual void init(std::vector<ParameterPtr>& parameters) = 0;
virtual void init(const std::vector<ParameterPtr>& parameters) = 0;
virtual void startPass() {
syncThreadPool_->execPlusOwner(
......
......@@ -34,7 +34,8 @@ SgdUpdaterWithCpuAverager::SgdUpdaterWithCpuAverager(
updateWorker_.addJob([]() { hl_set_device(FLAGS_gpu_id); });
}
void SgdUpdaterWithCpuAverager::init(std::vector<ParameterPtr>& parameters) {
void SgdUpdaterWithCpuAverager::init(
const std::vector<ParameterPtr>& parameters) {
SgdLocalUpdater::init(parameters);
averager_->init(parameters_.size(), nullptr);
copyEvents_.resize(parameters_.size());
......
......@@ -64,7 +64,7 @@ public:
* be initialized.
* @param parameters The parameter need to be initialized.
*/
virtual void init(std::vector<ParameterPtr>& parameters) {
virtual void init(const std::vector<ParameterPtr>& parameters) {
ParameterUpdater::init(parameters);
optimizer_->init(parameters_.size(), nullptr);
// check no L1 decay in parameter configs
......@@ -208,7 +208,7 @@ public:
* @brief init. Initialize cpu parameters, model average optimizer.
* @param parameters
*/
virtual void init(std::vector<ParameterPtr>& parameters);
virtual void init(const std::vector<ParameterPtr>& parameters);
virtual PassType startBatch(int64_t batchSize) {
averager_->startBatch(-1UL);
......
......@@ -44,7 +44,7 @@ RemoteParameterUpdater::RemoteParameterUpdater(
addParameterType(PARAMETER_MOMENTUM);
}
void RemoteParameterUpdater::init(std::vector<ParameterPtr>& parameters) {
void RemoteParameterUpdater::init(const std::vector<ParameterPtr>& parameters) {
ParameterUpdater::init(parameters);
if (localUpdater_) {
......@@ -595,7 +595,8 @@ SparseRemoteParameterUpdater::SparseRemoteParameterUpdater(
testing_(testing),
useApplyInPserver_(false) {}
void SparseRemoteParameterUpdater::init(std::vector<ParameterPtr>& parameters) {
void SparseRemoteParameterUpdater::init(
const std::vector<ParameterPtr>& parameters) {
ParameterUpdater::init(parameters);
parameterClient_.reset(new ParameterClient2(
......@@ -809,7 +810,7 @@ void SparseRemoteParameterUpdater::saveParametersRemote(
}
void SparseRemoteParameterUpdaterComposite::init(
std::vector<ParameterPtr>& parameters) {
const std::vector<ParameterPtr>& parameters) {
parameters_ = parameters;
std::vector<ParameterPtr> parametersArray[NUMBER_UPDATERS];
......
......@@ -67,7 +67,7 @@ public:
/**
* initialize the internal parameter client and itself.
*/
virtual void init(std::vector<ParameterPtr>& parameters);
virtual void init(const std::vector<ParameterPtr>& parameters);
/**
* @brief start batch
*
......@@ -274,7 +274,7 @@ public:
}
/// initialization
virtual void init(std::vector<ParameterPtr>& parameters);
virtual void init(const std::vector<ParameterPtr>& parameters);
/// stateful batch control
virtual PassType startBatch(int64_t batchSize);
......@@ -360,7 +360,7 @@ public:
}
/// initialization of dense and sparse updaters
virtual void init(std::vector<ParameterPtr>& parameters);
virtual void init(const std::vector<ParameterPtr>& parameters);
};
class ParameterUpdaterCreators {
......
......@@ -32,7 +32,7 @@ SgdThreadUpdater::SgdThreadUpdater(const OptimizationConfig& optConfig)
}
}
void SgdThreadUpdater::init(std::vector<ParameterPtr>& parameters) {
void SgdThreadUpdater::init(const std::vector<ParameterPtr>& parameters) {
ParameterUpdater::init(parameters);
// calc max parameter id
......
......@@ -49,7 +49,7 @@ public:
// Use the finishPass() function of the base optimizer.
virtual bool finishPass(real cost);
virtual void init(std::vector<ParameterPtr>& parameters);
virtual void init(const std::vector<ParameterPtr>& parameters);
virtual PassType startBatch(int64_t batchSize);
// Call finishBatch for each optimizer.
virtual void finishBatch(real cost);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册