提交 65e957ca 编写于 作者: Y Yu Yang

Merge branch 'feature/mnist_train_api' of github.com:reyoung/Paddle into feature/mnist_train_api

...@@ -19,10 +19,9 @@ def init_parameter(network): ...@@ -19,10 +19,9 @@ def init_parameter(network):
assert isinstance(network, api.GradientMachine) assert isinstance(network, api.GradientMachine)
for each_param in network.getParameters(): for each_param in network.getParameters():
assert isinstance(each_param, api.Parameter) assert isinstance(each_param, api.Parameter)
array = each_param.getBuf(api.PARAMETER_VALUE).toNumpyArrayInplace() array_size = len(each_param)
assert isinstance(array, np.ndarray) array = np.random.uniform(-1.0, 1.0, array_size).astype('float32')
for i in xrange(len(array)): each_param.getBuf(api.PARAMETER_VALUE).copyFromNumpyArray(array)
array[i] = np.random.uniform(-1.0, 1.0)
def generator_to_batch(generator, batch_size): def generator_to_batch(generator, batch_size):
...@@ -175,7 +174,7 @@ def main(): ...@@ -175,7 +174,7 @@ def main():
for each_param in params: for each_param in params:
assert isinstance(each_param, api.Parameter) assert isinstance(each_param, api.Parameter)
value = each_param.getBuf(api.PARAMETER_VALUE) value = each_param.getBuf(api.PARAMETER_VALUE)
value = value.toNumpyArrayInplace() value = value.copyToNumpyArray()
# Here, we could save parameter to every where you want # Here, we could save parameter to every where you want
print each_param.getName(), value print each_param.getName(), value
......
...@@ -96,6 +96,7 @@ namespace std { ...@@ -96,6 +96,7 @@ namespace std {
%rename(__getitem__) Vector::get; %rename(__getitem__) Vector::get;
%rename(__setitem__) Vector::set; %rename(__setitem__) Vector::set;
%rename(__len__) Vector::getSize; %rename(__len__) Vector::getSize;
%rename(__len__) Parameter::getSize;
%rename(__call__) ParameterTraverseCallback::apply; %rename(__call__) ParameterTraverseCallback::apply;
%rename(__repr__) Evaluator::toString; %rename(__repr__) Evaluator::toString;
......
...@@ -550,6 +550,8 @@ public: ...@@ -550,6 +550,8 @@ public:
ParameterConfig* getConfig(); ParameterConfig* getConfig();
void setValueUpdated(); void setValueUpdated();
size_t getSize() const;
private: private:
static Parameter* createFromRawPtr(void* ptr); static Parameter* createFromRawPtr(void* ptr);
static Parameter* createFromSharedPtr(void* ptr); static Parameter* createFromSharedPtr(void* ptr);
......
...@@ -56,3 +56,5 @@ ParameterConfig* Parameter::getConfig() { ...@@ -56,3 +56,5 @@ ParameterConfig* Parameter::getConfig() {
size_t Parameter::getID() const { return m->getPtr()->getID(); } size_t Parameter::getID() const { return m->getPtr()->getID(); }
void Parameter::setValueUpdated() { m->getPtr()->setValueUpdated(); } void Parameter::setValueUpdated() { m->getPtr()->setValueUpdated(); }
size_t Parameter::getSize() const { return m->getPtr()->getSize(); }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册