From 8b4cbcfc1847c50228c151a485755202912e7df2 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Tue, 20 Dec 2016 22:01:28 +0800 Subject: [PATCH] Start doing mnist_train_api --- demo/mnist/api_train.py | 31 ++++++++++++++++++++++++++--- paddle/api/CMakeLists.txt | 1 + paddle/api/Paddle.swig | 3 ++- paddle/api/PaddleAPI.h | 20 +++++++++++++++++++ paddle/api/PaddleAPIPrivate.h | 27 +++++++++++++++++++++++-- paddle/api/Parameter.cpp | 16 +-------------- paddle/api/ParameterUpdater.cpp | 35 +++++++++++++++++++++++++++++++++ 7 files changed, 112 insertions(+), 21 deletions(-) create mode 100644 paddle/api/ParameterUpdater.cpp diff --git a/demo/mnist/api_train.py b/demo/mnist/api_train.py index 6abb5d4e562..5d4ef90f10d 100644 --- a/demo/mnist/api_train.py +++ b/demo/mnist/api_train.py @@ -1,11 +1,36 @@ import py_paddle.swig_paddle as api -from paddle.trainer.config_parser import parse_config +import paddle.trainer.config_parser +import numpy as np + + +def init_parameter(network): + assert isinstance(network, api.GradientMachine) + for each_param in network.getParameters(): + assert isinstance(each_param, api.Parameter) + array = each_param.getBuf(api.PARAMETER_VALUE).toNumpyArrayInplace() + assert isinstance(array, np.ndarray) + for i in xrange(len(array)): + array[i] = np.random.uniform(-1.0, 1.0) def main(): api.initPaddle("-use_gpu=false", "-trainer_count=4") # use 4 cpu cores - config = parse_config('simple_mnist_network.py', '') - m = api.GradientMachine.createFromConfigProto(config.model_config) + config = paddle.trainer.config_parser.parse_config( + 'simple_mnist_network.py', '') + + opt_config = api.OptimizationConfig.createFromProto(config.opt_config) + _temp_optimizer_ = api.ParameterOptimizer.create(opt_config) + enable_types = _temp_optimizer_.getParameterTypes() + + m = api.GradientMachine.createFromConfigProto( + config.model_config, api.CREATE_MODE_NORMAL, enable_types) + assert isinstance(m, api.GradientMachine) + init_parameter(network=m) + + updater = api.ParameterUpdater.createLocalUpdater(opt_config) + assert isinstance(updater, api.ParameterUpdater) + updater.init(m) + updater.startPass() if __name__ == '__main__': diff --git a/paddle/api/CMakeLists.txt b/paddle/api/CMakeLists.txt index 6ad1d79e59b..39fe4355659 100644 --- a/paddle/api/CMakeLists.txt +++ b/paddle/api/CMakeLists.txt @@ -5,6 +5,7 @@ set(API_SOURCES Matrix.cpp Parameter.cpp ParameterOptimizer.cpp + ParameterUpdater.cpp SequenceGenerator.cpp Trainer.cpp Util.cpp diff --git a/paddle/api/Paddle.swig b/paddle/api/Paddle.swig index 9194a6371be..b0fa8beb166 100644 --- a/paddle/api/Paddle.swig +++ b/paddle/api/Paddle.swig @@ -174,6 +174,7 @@ namespace std { %newobject Parameter::getConfig; %newobject ParameterOptimizer::create; %newobject ParameterOptimizer::needSpecialTraversal; +%newobject ParameterUpdater::createLocalUpdater; %feature("director") UpdateCallback; %feature("autodoc", 1); // To generate method stub, for code hint in ide @@ -193,4 +194,4 @@ namespace std { %ignore OptimizationConfigPrivate; %ignore ParameterTraverseCallbackPrivate; %include "utils/GlobalConstants.h" -%include "api/PaddleAPI.h" \ No newline at end of file +%include "api/PaddleAPI.h" diff --git a/paddle/api/PaddleAPI.h b/paddle/api/PaddleAPI.h index 84a66719c33..bd413eb1e9d 100644 --- a/paddle/api/PaddleAPI.h +++ b/paddle/api/PaddleAPI.h @@ -519,6 +519,7 @@ private: friend class TrainerConfig; friend class ParameterOptimizer; + friend class ParameterUpdater; friend class Trainer; }; @@ -557,6 +558,7 @@ private: ParameterPrivate* m; friend class UpdateCallbackWrapper; friend class GradientMachine; + friend class ParameterUpdater; }; struct ModelConfigPrivate; @@ -772,6 +774,24 @@ private: // Not to use c++ 11 init-list, so we use static var as function default arg. static std::vector defaultParamTypes; friend class Trainer; + friend class ParameterUpdater; +}; + +struct ParameterUpdaterPrivate; +class ParameterUpdater { +private: + ParameterUpdater(); + +public: + static ParameterUpdater* createLocalUpdater(OptimizationConfig* config); + ~ParameterUpdater(); + + void init(const GradientMachine& gm); + + void startPass(); + +private: + ParameterUpdaterPrivate* m; }; struct TrainerPrivate; diff --git a/paddle/api/PaddleAPIPrivate.h b/paddle/api/PaddleAPIPrivate.h index d2b56fc41c8..905668a62f2 100644 --- a/paddle/api/PaddleAPIPrivate.h +++ b/paddle/api/PaddleAPIPrivate.h @@ -11,11 +11,13 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - +#pragma once +#include +#include "PaddleAPI.h" #include "paddle/gserver/gradientmachines/GradientMachine.h" #include "paddle/trainer/TrainerConfigHelper.h" -#pragma once +#include "paddle/parameter/ParameterUpdaterBase.h" struct GradientMachinePrivate { std::shared_ptr machine; @@ -65,3 +67,24 @@ struct ArgumentsPrivate { return *(std::shared_ptr*)(rawPtr); } }; + +struct ParameterUpdaterPrivate { + std::unique_ptr updater; +}; + +struct ParameterPrivate { + std::shared_ptr sharedPtr; + paddle::Parameter* rawPtr; // rawPtr only used in ParameterUpdater, + // in other situation sharedPtr should + // contains value. + + ParameterPrivate() : sharedPtr(nullptr), rawPtr(nullptr) {} + + paddle::Parameter* getPtr() { + if (sharedPtr) { + return sharedPtr.get(); + } else { + return rawPtr; + } + } +}; diff --git a/paddle/api/Parameter.cpp b/paddle/api/Parameter.cpp index 4eed00a84a6..41cf50043cc 100644 --- a/paddle/api/Parameter.cpp +++ b/paddle/api/Parameter.cpp @@ -14,21 +14,7 @@ limitations under the License. */ #include "paddle/parameter/Parameter.h" #include "PaddleAPI.h" - -struct ParameterPrivate { - std::shared_ptr sharedPtr; - paddle::Parameter* rawPtr; - - ParameterPrivate() : sharedPtr(nullptr), rawPtr(nullptr) {} - - paddle::Parameter* getPtr() { - if (sharedPtr) { - return sharedPtr.get(); - } else { - return rawPtr; - } - } -}; +#include "PaddleAPIPrivate.h" Parameter::Parameter() : m(new ParameterPrivate()) {} diff --git a/paddle/api/ParameterUpdater.cpp b/paddle/api/ParameterUpdater.cpp new file mode 100644 index 00000000000..af5b746a7cd --- /dev/null +++ b/paddle/api/ParameterUpdater.cpp @@ -0,0 +1,35 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "PaddleAPI.h" + +#include "PaddleAPIPrivate.h" +#include "paddle/trainer/ThreadParameterUpdater.h" + +ParameterUpdater::ParameterUpdater() : m(new ParameterUpdaterPrivate()) {} + +ParameterUpdater *ParameterUpdater::createLocalUpdater( + OptimizationConfig *config) { + auto param = new ParameterUpdater(); + param->m->updater.reset(new paddle::SgdThreadUpdater(config->m->getConfig())); + return param; +} + +ParameterUpdater::~ParameterUpdater() { delete m; } + +void ParameterUpdater::init(const GradientMachine &gm) { + m->updater->init(gm.m->machine->getParameters()); +} + +void ParameterUpdater::startPass() { m->updater->startPass(); } -- GitLab