From 59009ba72d54cc35717dbd80d73500f11fbb7852 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Thu, 22 Dec 2016 14:51:51 +0800 Subject: [PATCH] Always use copy method for numpy. * Make this demo support GPU --- demo/mnist/.gitignore | 1 + demo/mnist/api_train.py | 9 ++++----- paddle/api/Paddle.swig | 1 + paddle/api/PaddleAPI.h | 2 ++ paddle/api/Parameter.cpp | 2 ++ 5 files changed, 10 insertions(+), 5 deletions(-) diff --git a/demo/mnist/.gitignore b/demo/mnist/.gitignore index 810910fd5ca..8bd9837523c 100644 --- a/demo/mnist/.gitignore +++ b/demo/mnist/.gitignore @@ -4,3 +4,4 @@ mnist_vgg_model plot.png train.log *pyc +.ipynb_checkpoints diff --git a/demo/mnist/api_train.py b/demo/mnist/api_train.py index ce75d79bebe..7e653246a31 100644 --- a/demo/mnist/api_train.py +++ b/demo/mnist/api_train.py @@ -19,10 +19,9 @@ def init_parameter(network): assert isinstance(network, api.GradientMachine) for each_param in network.getParameters(): assert isinstance(each_param, api.Parameter) - array = each_param.getBuf(api.PARAMETER_VALUE).toNumpyArrayInplace() - assert isinstance(array, np.ndarray) - for i in xrange(len(array)): - array[i] = np.random.uniform(-1.0, 1.0) + array_size = len(each_param) + array = np.random.uniform(-1.0, 1.0, array_size).astype('float32') + each_param.getBuf(api.PARAMETER_VALUE).copyFromNumpyArray(array) def generator_to_batch(generator, batch_size): @@ -175,7 +174,7 @@ def main(): for each_param in params: assert isinstance(each_param, api.Parameter) value = each_param.getBuf(api.PARAMETER_VALUE) - value = value.toNumpyArrayInplace() + value = value.copyToNumpyArray() # Here, we could save parameter to every where you want print each_param.getName(), value diff --git a/paddle/api/Paddle.swig b/paddle/api/Paddle.swig index 7a110a90b84..3365927f9b5 100644 --- a/paddle/api/Paddle.swig +++ b/paddle/api/Paddle.swig @@ -96,6 +96,7 @@ namespace std { %rename(__getitem__) Vector::get; %rename(__setitem__) Vector::set; %rename(__len__) Vector::getSize; +%rename(__len__) Parameter::getSize; %rename(__call__) ParameterTraverseCallback::apply; %rename(__repr__) Evaluator::toString; diff --git a/paddle/api/PaddleAPI.h b/paddle/api/PaddleAPI.h index d94fd1e52ed..d4b057e8a19 100644 --- a/paddle/api/PaddleAPI.h +++ b/paddle/api/PaddleAPI.h @@ -550,6 +550,8 @@ public: ParameterConfig* getConfig(); void setValueUpdated(); + size_t getSize() const; + private: static Parameter* createFromRawPtr(void* ptr); static Parameter* createFromSharedPtr(void* ptr); diff --git a/paddle/api/Parameter.cpp b/paddle/api/Parameter.cpp index 41cf50043cc..ddc00d8d1af 100644 --- a/paddle/api/Parameter.cpp +++ b/paddle/api/Parameter.cpp @@ -56,3 +56,5 @@ ParameterConfig* Parameter::getConfig() { size_t Parameter::getID() const { return m->getPtr()->getID(); } void Parameter::setValueUpdated() { m->getPtr()->setValueUpdated(); } + +size_t Parameter::getSize() const { return m->getPtr()->getSize(); } -- GitLab