提交 06944ee1 编写于 作者: Y Yu Yang

Merge branch 'feature/add_const_in_parameter_updater' into feature/mnist_train_api

...@@ -19,7 +19,7 @@ limitations under the License. */ ...@@ -19,7 +19,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
void ParameterUpdater::init(std::vector<ParameterPtr>& parameters) { void ParameterUpdater::init(const std::vector<ParameterPtr>& parameters) {
parameters_ = parameters; parameters_ = parameters;
for (ParameterType type : getParameterTypes()) { for (ParameterType type : getParameterTypes()) {
for (auto& para : parameters) { for (auto& para : parameters) {
......
...@@ -32,7 +32,7 @@ public: ...@@ -32,7 +32,7 @@ public:
parameterTypes_.push_back(type); parameterTypes_.push_back(type);
} }
virtual void init(std::vector<ParameterPtr>& parameters); virtual void init(const std::vector<ParameterPtr>& parameters);
// called by Trainer when starting a new pass // called by Trainer when starting a new pass
virtual void startPass() {} virtual void startPass() {}
...@@ -105,7 +105,7 @@ public: ...@@ -105,7 +105,7 @@ public:
ParameterUpdaterComposite() {} ParameterUpdaterComposite() {}
virtual ~ParameterUpdaterComposite() {} virtual ~ParameterUpdaterComposite() {}
virtual void init(std::vector<ParameterPtr>& parameters) = 0; virtual void init(const std::vector<ParameterPtr>& parameters) = 0;
virtual void startPass() { virtual void startPass() {
syncThreadPool_->execPlusOwner( syncThreadPool_->execPlusOwner(
......
...@@ -34,7 +34,8 @@ SgdUpdaterWithCpuAverager::SgdUpdaterWithCpuAverager( ...@@ -34,7 +34,8 @@ SgdUpdaterWithCpuAverager::SgdUpdaterWithCpuAverager(
updateWorker_.addJob([]() { hl_set_device(FLAGS_gpu_id); }); updateWorker_.addJob([]() { hl_set_device(FLAGS_gpu_id); });
} }
void SgdUpdaterWithCpuAverager::init(std::vector<ParameterPtr>& parameters) { void SgdUpdaterWithCpuAverager::init(
const std::vector<ParameterPtr>& parameters) {
SgdLocalUpdater::init(parameters); SgdLocalUpdater::init(parameters);
averager_->init(parameters_.size(), nullptr); averager_->init(parameters_.size(), nullptr);
copyEvents_.resize(parameters_.size()); copyEvents_.resize(parameters_.size());
......
...@@ -64,7 +64,7 @@ public: ...@@ -64,7 +64,7 @@ public:
* be initialized. * be initialized.
* @param parameters The parameter need to be initialized. * @param parameters The parameter need to be initialized.
*/ */
virtual void init(std::vector<ParameterPtr>& parameters) { virtual void init(const std::vector<ParameterPtr>& parameters) {
ParameterUpdater::init(parameters); ParameterUpdater::init(parameters);
optimizer_->init(parameters_.size(), nullptr); optimizer_->init(parameters_.size(), nullptr);
// check no L1 decay in parameter configs // check no L1 decay in parameter configs
...@@ -208,7 +208,7 @@ public: ...@@ -208,7 +208,7 @@ public:
* @brief init. Initialize cpu parameters, model average optimizer. * @brief init. Initialize cpu parameters, model average optimizer.
* @param parameters * @param parameters
*/ */
virtual void init(std::vector<ParameterPtr>& parameters); virtual void init(const std::vector<ParameterPtr>& parameters);
virtual PassType startBatch(int64_t batchSize) { virtual PassType startBatch(int64_t batchSize) {
averager_->startBatch(-1UL); averager_->startBatch(-1UL);
......
...@@ -44,7 +44,7 @@ RemoteParameterUpdater::RemoteParameterUpdater( ...@@ -44,7 +44,7 @@ RemoteParameterUpdater::RemoteParameterUpdater(
addParameterType(PARAMETER_MOMENTUM); addParameterType(PARAMETER_MOMENTUM);
} }
void RemoteParameterUpdater::init(std::vector<ParameterPtr>& parameters) { void RemoteParameterUpdater::init(const std::vector<ParameterPtr>& parameters) {
ParameterUpdater::init(parameters); ParameterUpdater::init(parameters);
if (localUpdater_) { if (localUpdater_) {
...@@ -595,7 +595,8 @@ SparseRemoteParameterUpdater::SparseRemoteParameterUpdater( ...@@ -595,7 +595,8 @@ SparseRemoteParameterUpdater::SparseRemoteParameterUpdater(
testing_(testing), testing_(testing),
useApplyInPserver_(false) {} useApplyInPserver_(false) {}
void SparseRemoteParameterUpdater::init(std::vector<ParameterPtr>& parameters) { void SparseRemoteParameterUpdater::init(
const std::vector<ParameterPtr>& parameters) {
ParameterUpdater::init(parameters); ParameterUpdater::init(parameters);
parameterClient_.reset(new ParameterClient2( parameterClient_.reset(new ParameterClient2(
...@@ -809,7 +810,7 @@ void SparseRemoteParameterUpdater::saveParametersRemote( ...@@ -809,7 +810,7 @@ void SparseRemoteParameterUpdater::saveParametersRemote(
} }
void SparseRemoteParameterUpdaterComposite::init( void SparseRemoteParameterUpdaterComposite::init(
std::vector<ParameterPtr>& parameters) { const std::vector<ParameterPtr>& parameters) {
parameters_ = parameters; parameters_ = parameters;
std::vector<ParameterPtr> parametersArray[NUMBER_UPDATERS]; std::vector<ParameterPtr> parametersArray[NUMBER_UPDATERS];
......
...@@ -67,7 +67,7 @@ public: ...@@ -67,7 +67,7 @@ public:
/** /**
* initialize the internal parameter client and itself. * initialize the internal parameter client and itself.
*/ */
virtual void init(std::vector<ParameterPtr>& parameters); virtual void init(const std::vector<ParameterPtr>& parameters);
/** /**
* @brief start batch * @brief start batch
* *
...@@ -274,7 +274,7 @@ public: ...@@ -274,7 +274,7 @@ public:
} }
/// initialization /// initialization
virtual void init(std::vector<ParameterPtr>& parameters); virtual void init(const std::vector<ParameterPtr>& parameters);
/// stateful batch control /// stateful batch control
virtual PassType startBatch(int64_t batchSize); virtual PassType startBatch(int64_t batchSize);
...@@ -360,7 +360,7 @@ public: ...@@ -360,7 +360,7 @@ public:
} }
/// initialization of dense and sparse updaters /// initialization of dense and sparse updaters
virtual void init(std::vector<ParameterPtr>& parameters); virtual void init(const std::vector<ParameterPtr>& parameters);
}; };
class ParameterUpdaterCreators { class ParameterUpdaterCreators {
......
...@@ -32,7 +32,7 @@ SgdThreadUpdater::SgdThreadUpdater(const OptimizationConfig& optConfig) ...@@ -32,7 +32,7 @@ SgdThreadUpdater::SgdThreadUpdater(const OptimizationConfig& optConfig)
} }
} }
void SgdThreadUpdater::init(std::vector<ParameterPtr>& parameters) { void SgdThreadUpdater::init(const std::vector<ParameterPtr>& parameters) {
ParameterUpdater::init(parameters); ParameterUpdater::init(parameters);
// calc max parameter id // calc max parameter id
......
...@@ -49,7 +49,7 @@ public: ...@@ -49,7 +49,7 @@ public:
// Use the finishPass() function of the base optimizer. // Use the finishPass() function of the base optimizer.
virtual bool finishPass(real cost); virtual bool finishPass(real cost);
virtual void init(std::vector<ParameterPtr>& parameters); virtual void init(const std::vector<ParameterPtr>& parameters);
virtual PassType startBatch(int64_t batchSize); virtual PassType startBatch(int64_t batchSize);
// Call finishBatch for each optimizer. // Call finishBatch for each optimizer.
virtual void finishBatch(real cost); virtual void finishBatch(real cost);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册