提交 20249e8e 编写于 作者: Y Yu Yang

Try expose ParamUpdater::update

上级 efb5c10c
...@@ -45,7 +45,6 @@ def main(): ...@@ -45,7 +45,6 @@ def main():
config.model_config, api.CREATE_MODE_NORMAL, enable_types) config.model_config, api.CREATE_MODE_NORMAL, enable_types)
assert isinstance(m, api.GradientMachine) assert isinstance(m, api.GradientMachine)
init_parameter(network=m) init_parameter(network=m)
updater = api.ParameterUpdater.createLocalUpdater(opt_config) updater = api.ParameterUpdater.createLocalUpdater(opt_config)
assert isinstance(updater, api.ParameterUpdater) assert isinstance(updater, api.ParameterUpdater)
updater.init(m) updater.init(m)
...@@ -62,7 +61,7 @@ def main(): ...@@ -62,7 +61,7 @@ def main():
train_data_generator = input_order_converter( train_data_generator = input_order_converter(
read_from_mnist(train_file)) read_from_mnist(train_file))
for data_batch in generator_to_batch(train_data_generator, 128): for data_batch in generator_to_batch(train_data_generator, 128):
inArgs = converter(data_batch) trainRole = updater.startBatch(len(data_batch))
updater.finishPass() updater.finishPass()
......
...@@ -799,6 +799,12 @@ public: ...@@ -799,6 +799,12 @@ public:
void finishPass(); void finishPass();
PassType startBatch(int64_t batchSize);
void finishBatch(float cost);
void update(Parameter* param);
private: private:
ParameterUpdaterPrivate* m; ParameterUpdaterPrivate* m;
}; };
......
...@@ -35,3 +35,16 @@ void ParameterUpdater::init(const GradientMachine &gm) { ...@@ -35,3 +35,16 @@ void ParameterUpdater::init(const GradientMachine &gm) {
void ParameterUpdater::startPass() { m->updater->startPass(); } void ParameterUpdater::startPass() { m->updater->startPass(); }
void ParameterUpdater::finishPass() { m->updater->finishPass(); } void ParameterUpdater::finishPass() { m->updater->finishPass(); }
PassType ParameterUpdater::startBatch(int64_t batchSize) {
return m->updater->startBatch(batchSize);
}
void ParameterUpdater::finishBatch(float cost) {
m->updater->finishBatch(cost);
}
void ParameterUpdater::update(Parameter *param) {
auto paddleParam = param->m->getPtr();
m->updater->update(paddleParam);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册