diff --git a/demo/mnist/api_train.py b/demo/mnist/api_train.py index e5a9075c8ecc0b640bc01e39468b0e4946246998..129922c30b48e8e58a30cfdee60feb6478c5116d 100644 --- a/demo/mnist/api_train.py +++ b/demo/mnist/api_train.py @@ -19,10 +19,9 @@ def init_parameter(network): assert isinstance(network, api.GradientMachine) for each_param in network.getParameters(): assert isinstance(each_param, api.Parameter) - array = each_param.getBuf(api.PARAMETER_VALUE).toNumpyArrayInplace() - assert isinstance(array, np.ndarray) - for i in xrange(len(array)): - array[i] = np.random.uniform(-1.0, 1.0) + array_size = len(each_param) + array = np.random.uniform(-1.0, 1.0, array_size).astype('float32') + each_param.getBuf(api.PARAMETER_VALUE).copyFromNumpyArray(array) def generator_to_batch(generator, batch_size): @@ -175,7 +174,7 @@ def main(): for each_param in params: assert isinstance(each_param, api.Parameter) value = each_param.getBuf(api.PARAMETER_VALUE) - value = value.toNumpyArrayInplace() + value = value.copyToNumpyArray() # Here, we could save parameter to every where you want print each_param.getName(), value diff --git a/paddle/api/Paddle.swig b/paddle/api/Paddle.swig index 7a110a90b84fcbbabd32639a97977322c2aecc2a..3365927f9b59936244230bed439808fa7ead2c61 100644 --- a/paddle/api/Paddle.swig +++ b/paddle/api/Paddle.swig @@ -96,6 +96,7 @@ namespace std { %rename(__getitem__) Vector::get; %rename(__setitem__) Vector::set; %rename(__len__) Vector::getSize; +%rename(__len__) Parameter::getSize; %rename(__call__) ParameterTraverseCallback::apply; %rename(__repr__) Evaluator::toString; diff --git a/paddle/api/PaddleAPI.h b/paddle/api/PaddleAPI.h index d94fd1e52ed0367d9ee5276b1a2480260d93bce1..d4b057e8a19894118cf4bbf2e067d8b3f6ad4e6b 100644 --- a/paddle/api/PaddleAPI.h +++ b/paddle/api/PaddleAPI.h @@ -550,6 +550,8 @@ public: ParameterConfig* getConfig(); void setValueUpdated(); + size_t getSize() const; + private: static Parameter* createFromRawPtr(void* ptr); static Parameter* createFromSharedPtr(void* ptr); diff --git a/paddle/api/Parameter.cpp b/paddle/api/Parameter.cpp index 41cf50043cc2b076dad49b9e772252b9243f39d6..ddc00d8d1af4c58d7e2233423bea916408bee92b 100644 --- a/paddle/api/Parameter.cpp +++ b/paddle/api/Parameter.cpp @@ -56,3 +56,5 @@ ParameterConfig* Parameter::getConfig() { size_t Parameter::getID() const { return m->getPtr()->getID(); } void Parameter::setValueUpdated() { m->getPtr()->setValueUpdated(); } + +size_t Parameter::getSize() const { return m->getPtr()->getSize(); }