diff --git a/demo/mnist/api_train.py b/demo/mnist/api_train.py index 425c5f897a9c254bdae2aa1a6e91a4ce7a69874e..52cc13c5a3eaeece68d4198ee7ebd41d572f8b11 100644 --- a/demo/mnist/api_train.py +++ b/demo/mnist/api_train.py @@ -62,14 +62,14 @@ def main(): train_data_generator = input_order_converter( read_from_mnist(train_file)) for batch_id, data_batch in enumerate( - generator_to_batch(train_data_generator, 256)): + generator_to_batch(train_data_generator, 2048)): trainRole = updater.startBatch(len(data_batch)) - def update_callback(param): + def updater_callback(param): updater.update(param) m.forwardBackward( - converter(data_batch), outArgs, trainRole, update_callback) + converter(data_batch), outArgs, trainRole, updater_callback) cost_vec = outArgs.getSlotValue(0) cost_vec = cost_vec.copyToNumpyMat() diff --git a/paddle/api/CMakeLists.txt b/paddle/api/CMakeLists.txt index 39fe43556595ccfccb6190314ff43d323dc5c2e9..a7f17e186bf6b452628a24ab514c8e9aa2658e9d 100644 --- a/paddle/api/CMakeLists.txt +++ b/paddle/api/CMakeLists.txt @@ -1,6 +1,7 @@ set(API_SOURCES Arguments.cpp ConfigParser.cpp + Evaluator.cpp GradientMachine.cpp Matrix.cpp Parameter.cpp @@ -63,6 +64,15 @@ install(DIRECTORY ${PROJ_ROOT}/paddle/dist/ add_custom_target(python_api_wheel ALL DEPENDS ${PROJ_ROOT}/paddle/dist/.timestamp) +add_dependencies(python_api_wheel python_swig_sources + paddle_parameter + paddle_math + paddle_utils + paddle_gserver + paddle_pserver + paddle_trainer + paddle_api + paddle_cuda) if(WITH_TESTING) add_subdirectory(test) diff --git a/paddle/api/Evaluator.cpp b/paddle/api/Evaluator.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c30e09876397e37ef9ed4ec3200d1aa372ceb609 --- /dev/null +++ b/paddle/api/Evaluator.cpp @@ -0,0 +1,29 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#include +#include "PaddleAPI.h" +#include "PaddleAPIPrivate.h" + +Evaluator::Evaluator() : m(new EvaluatorPrivate()) {} +Evaluator::~Evaluator() { delete m; } + +void Evaluator::start() { m->rawPtr->start(); } + +void Evaluator::finish() { m->rawPtr->finish(); } + +std::string Evaluator::toString() { + std::ostringstream sout; + m->rawPtr->printStats(sout); + return sout.str(); +} diff --git a/paddle/api/GradientMachine.cpp b/paddle/api/GradientMachine.cpp index 2cece2109795a986966d2decfdde27b2759e51cc..0d1e17529611d11136914cb810b0633e0afccedf 100644 --- a/paddle/api/GradientMachine.cpp +++ b/paddle/api/GradientMachine.cpp @@ -162,3 +162,13 @@ SequenceGenerator* GradientMachine::asSequenceGenerator( r->setBeamSize(beam_size); return r; } + +Evaluator* GradientMachine::makeEvaluator() { + auto ev = new Evaluator(); + ev->m->rawPtr = m->machine->makeEvaluator(); + return ev; +} + +void GradientMachine::eval(Evaluator* evaluator) { + m->machine->eval(evaluator->m->rawPtr); +} diff --git a/paddle/api/Paddle.swig b/paddle/api/Paddle.swig index b0fa8beb166b3438d2b6cbf7afd46791979f41bb..7a110a90b84fcbbabd32639a97977322c2aecc2a 100644 --- a/paddle/api/Paddle.swig +++ b/paddle/api/Paddle.swig @@ -97,6 +97,7 @@ namespace std { %rename(__setitem__) Vector::set; %rename(__len__) Vector::getSize; %rename(__call__) ParameterTraverseCallback::apply; +%rename(__repr__) Evaluator::toString; %apply (float* INPLACE_ARRAY2, int DIM1, int DIM2) { (float* data, int dim1, int dim2) @@ -167,6 +168,7 @@ namespace std { %newobject GradientMachine::asSequenceGenerator; %newobject GradientMachine::getParameter; %newobject GradientMachine::getLayerOutput; +%newobject GradientMachine::makeEvaluator; %newobject TrainerConfig::createFromTrainerConfigFile; %newobject TrainerConfig::getModelConfig; %newobject TrainerConfig::getOptimizationConfig; diff --git a/paddle/api/PaddleAPI.h b/paddle/api/PaddleAPI.h index cc49e6a09d5dee41ff47606025fcc492559aa958..413c38514646211befc18a83a2d7ce70644b5183 100644 --- a/paddle/api/PaddleAPI.h +++ b/paddle/api/PaddleAPI.h @@ -685,7 +685,7 @@ private: }; class SequenceGenerator; - +class Evaluator; struct GradientMachinePrivate; class GradientMachine { private: @@ -770,6 +770,10 @@ public: size_t max_length = 100UL, size_t beam_size = -1UL); + Evaluator* makeEvaluator(); + + void eval(Evaluator* evaluator); + private: GradientMachinePrivate* m; @@ -809,6 +813,27 @@ private: ParameterUpdaterPrivate* m; }; +struct EvaluatorPrivate; +class Evaluator { +private: + Evaluator(); + DISABLE_COPY_AND_ASSIGN(Evaluator); + +public: + ~Evaluator(); + + void start(); + + void finish(); + + std::string toString(); + +private: + EvaluatorPrivate* m; + + friend class GradientMachine; +}; + struct TrainerPrivate; class Trainer { private: diff --git a/paddle/api/PaddleAPIPrivate.h b/paddle/api/PaddleAPIPrivate.h index 905668a62f24fbd8db4a8833d92df2fe43b6d0c1..f41352bfec7c3333bde9509957aba8c5f373b9f2 100644 --- a/paddle/api/PaddleAPIPrivate.h +++ b/paddle/api/PaddleAPIPrivate.h @@ -14,10 +14,10 @@ limitations under the License. */ #pragma once #include #include "PaddleAPI.h" +#include "paddle/gserver/evaluators/Evaluator.h" #include "paddle/gserver/gradientmachines/GradientMachine.h" -#include "paddle/trainer/TrainerConfigHelper.h" - #include "paddle/parameter/ParameterUpdaterBase.h" +#include "paddle/trainer/TrainerConfigHelper.h" struct GradientMachinePrivate { std::shared_ptr machine; @@ -88,3 +88,10 @@ struct ParameterPrivate { } } }; + +struct EvaluatorPrivate { + paddle::Evaluator* rawPtr; + + EvaluatorPrivate() : rawPtr(nullptr) {} + ~EvaluatorPrivate() { delete rawPtr; } +}; diff --git a/paddle/api/ParameterUpdater.cpp b/paddle/api/ParameterUpdater.cpp index fba47620249dbc7543678b3e7e969a21ff32647a..91c839276280804bc9decc87c245728e0893de51 100644 --- a/paddle/api/ParameterUpdater.cpp +++ b/paddle/api/ParameterUpdater.cpp @@ -29,7 +29,7 @@ ParameterUpdater *ParameterUpdater::createLocalUpdater( ParameterUpdater::~ParameterUpdater() { delete m; } void ParameterUpdater::init(const GradientMachine &gm) { - m->updater->init(gm.m->machine->getParameters()); + m->updater->init(gm.m->machine->getNonStaticParameters()); } void ParameterUpdater::startPass() { m->updater->startPass(); }