提交 eaba2e2e 编写于 作者: Y Yu Yang

Expose Evaluator API

上级 1e6c87bd
...@@ -62,14 +62,14 @@ def main(): ...@@ -62,14 +62,14 @@ def main():
train_data_generator = input_order_converter( train_data_generator = input_order_converter(
read_from_mnist(train_file)) read_from_mnist(train_file))
for batch_id, data_batch in enumerate( for batch_id, data_batch in enumerate(
generator_to_batch(train_data_generator, 256)): generator_to_batch(train_data_generator, 2048)):
trainRole = updater.startBatch(len(data_batch)) trainRole = updater.startBatch(len(data_batch))
def update_callback(param): def updater_callback(param):
updater.update(param) updater.update(param)
m.forwardBackward( m.forwardBackward(
converter(data_batch), outArgs, trainRole, update_callback) converter(data_batch), outArgs, trainRole, updater_callback)
cost_vec = outArgs.getSlotValue(0) cost_vec = outArgs.getSlotValue(0)
cost_vec = cost_vec.copyToNumpyMat() cost_vec = cost_vec.copyToNumpyMat()
......
set(API_SOURCES set(API_SOURCES
Arguments.cpp Arguments.cpp
ConfigParser.cpp ConfigParser.cpp
Evaluator.cpp
GradientMachine.cpp GradientMachine.cpp
Matrix.cpp Matrix.cpp
Parameter.cpp Parameter.cpp
...@@ -63,6 +64,15 @@ install(DIRECTORY ${PROJ_ROOT}/paddle/dist/ ...@@ -63,6 +64,15 @@ install(DIRECTORY ${PROJ_ROOT}/paddle/dist/
add_custom_target(python_api_wheel ALL DEPENDS add_custom_target(python_api_wheel ALL DEPENDS
${PROJ_ROOT}/paddle/dist/.timestamp) ${PROJ_ROOT}/paddle/dist/.timestamp)
add_dependencies(python_api_wheel python_swig_sources
paddle_parameter
paddle_math
paddle_utils
paddle_gserver
paddle_pserver
paddle_trainer
paddle_api
paddle_cuda)
if(WITH_TESTING) if(WITH_TESTING)
add_subdirectory(test) add_subdirectory(test)
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <sstream>
#include "PaddleAPI.h"
#include "PaddleAPIPrivate.h"
Evaluator::Evaluator() : m(new EvaluatorPrivate()) {}
Evaluator::~Evaluator() { delete m; }
void Evaluator::start() { m->rawPtr->start(); }
void Evaluator::finish() { m->rawPtr->finish(); }
std::string Evaluator::toString() {
std::ostringstream sout;
m->rawPtr->printStats(sout);
return sout.str();
}
...@@ -162,3 +162,13 @@ SequenceGenerator* GradientMachine::asSequenceGenerator( ...@@ -162,3 +162,13 @@ SequenceGenerator* GradientMachine::asSequenceGenerator(
r->setBeamSize(beam_size); r->setBeamSize(beam_size);
return r; return r;
} }
Evaluator* GradientMachine::makeEvaluator() {
auto ev = new Evaluator();
ev->m->rawPtr = m->machine->makeEvaluator();
return ev;
}
void GradientMachine::eval(Evaluator* evaluator) {
m->machine->eval(evaluator->m->rawPtr);
}
...@@ -97,6 +97,7 @@ namespace std { ...@@ -97,6 +97,7 @@ namespace std {
%rename(__setitem__) Vector::set; %rename(__setitem__) Vector::set;
%rename(__len__) Vector::getSize; %rename(__len__) Vector::getSize;
%rename(__call__) ParameterTraverseCallback::apply; %rename(__call__) ParameterTraverseCallback::apply;
%rename(__repr__) Evaluator::toString;
%apply (float* INPLACE_ARRAY2, int DIM1, int DIM2) { %apply (float* INPLACE_ARRAY2, int DIM1, int DIM2) {
(float* data, int dim1, int dim2) (float* data, int dim1, int dim2)
...@@ -167,6 +168,7 @@ namespace std { ...@@ -167,6 +168,7 @@ namespace std {
%newobject GradientMachine::asSequenceGenerator; %newobject GradientMachine::asSequenceGenerator;
%newobject GradientMachine::getParameter; %newobject GradientMachine::getParameter;
%newobject GradientMachine::getLayerOutput; %newobject GradientMachine::getLayerOutput;
%newobject GradientMachine::makeEvaluator;
%newobject TrainerConfig::createFromTrainerConfigFile; %newobject TrainerConfig::createFromTrainerConfigFile;
%newobject TrainerConfig::getModelConfig; %newobject TrainerConfig::getModelConfig;
%newobject TrainerConfig::getOptimizationConfig; %newobject TrainerConfig::getOptimizationConfig;
......
...@@ -685,7 +685,7 @@ private: ...@@ -685,7 +685,7 @@ private:
}; };
class SequenceGenerator; class SequenceGenerator;
class Evaluator;
struct GradientMachinePrivate; struct GradientMachinePrivate;
class GradientMachine { class GradientMachine {
private: private:
...@@ -770,6 +770,10 @@ public: ...@@ -770,6 +770,10 @@ public:
size_t max_length = 100UL, size_t max_length = 100UL,
size_t beam_size = -1UL); size_t beam_size = -1UL);
Evaluator* makeEvaluator();
void eval(Evaluator* evaluator);
private: private:
GradientMachinePrivate* m; GradientMachinePrivate* m;
...@@ -809,6 +813,27 @@ private: ...@@ -809,6 +813,27 @@ private:
ParameterUpdaterPrivate* m; ParameterUpdaterPrivate* m;
}; };
struct EvaluatorPrivate;
class Evaluator {
private:
Evaluator();
DISABLE_COPY_AND_ASSIGN(Evaluator);
public:
~Evaluator();
void start();
void finish();
std::string toString();
private:
EvaluatorPrivate* m;
friend class GradientMachine;
};
struct TrainerPrivate; struct TrainerPrivate;
class Trainer { class Trainer {
private: private:
......
...@@ -14,10 +14,10 @@ limitations under the License. */ ...@@ -14,10 +14,10 @@ limitations under the License. */
#pragma once #pragma once
#include <memory> #include <memory>
#include "PaddleAPI.h" #include "PaddleAPI.h"
#include "paddle/gserver/evaluators/Evaluator.h"
#include "paddle/gserver/gradientmachines/GradientMachine.h" #include "paddle/gserver/gradientmachines/GradientMachine.h"
#include "paddle/trainer/TrainerConfigHelper.h"
#include "paddle/parameter/ParameterUpdaterBase.h" #include "paddle/parameter/ParameterUpdaterBase.h"
#include "paddle/trainer/TrainerConfigHelper.h"
struct GradientMachinePrivate { struct GradientMachinePrivate {
std::shared_ptr<paddle::GradientMachine> machine; std::shared_ptr<paddle::GradientMachine> machine;
...@@ -88,3 +88,10 @@ struct ParameterPrivate { ...@@ -88,3 +88,10 @@ struct ParameterPrivate {
} }
} }
}; };
struct EvaluatorPrivate {
paddle::Evaluator* rawPtr;
EvaluatorPrivate() : rawPtr(nullptr) {}
~EvaluatorPrivate() { delete rawPtr; }
};
...@@ -29,7 +29,7 @@ ParameterUpdater *ParameterUpdater::createLocalUpdater( ...@@ -29,7 +29,7 @@ ParameterUpdater *ParameterUpdater::createLocalUpdater(
ParameterUpdater::~ParameterUpdater() { delete m; } ParameterUpdater::~ParameterUpdater() { delete m; }
void ParameterUpdater::init(const GradientMachine &gm) { void ParameterUpdater::init(const GradientMachine &gm) {
m->updater->init(gm.m->machine->getParameters()); m->updater->init(gm.m->machine->getNonStaticParameters());
} }
void ParameterUpdater::startPass() { m->updater->startPass(); } void ParameterUpdater::startPass() { m->updater->startPass(); }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册