提交 fec6f809 编写于 作者: X xuwei06 提交者: wangyang59

Skeleton for Generative Adverserial Nets

上级 7105962f
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.trainer_config_helpers import *
mode = get_config_arg("mode", str, "generator")
assert mode in set(["generator",
"discriminator",
"generator_training",
"discriminator_training"])
is_generator_training = mode == "generator_training"
is_discriminator_training = mode == "discriminator_training"
is_generator = mode == "generator"
is_discriminator = mode == "discriminator"
print('mode=%s' % mode)
noise_dim = 10
sample_dim = 2
settings(
batch_size=100,
learning_rate=1e-2,
learning_method=AdamOptimizer()
)
def discriminator(sample):
"""
discriminator ouputs the probablity of a sample is from generator
or real data.
The output has two dimenstional: dimension 0 is the probablity
of the sample is from generator and dimension 1 is the probabblity
of the sample is from real data.
"""
param_attr = ParamAttr(is_static=is_generator_training)
bias_attr = ParamAttr(is_static=is_generator_training,
initial_mean=0,
initial_std=0)
return fc_layer(input=sample, name="dis_prob", size=2,
bias_attr=bias_attr,
param_attr=param_attr,
act=SoftmaxActivation())
def generator(noise):
"""
generator generates a sample given noise
"""
param_attr = ParamAttr(is_static=is_discriminator_training)
bias_attr = ParamAttr(is_static=is_discriminator_training,
initial_mean=0,
initial_std=0)
return fc_layer(input=noise,
name="gen_layer1",
size=sample_dim,
bias_attr=bias_attr,
param_attr=param_attr,
act=LinearActivation())
if is_generator_training:
noise = data_layer(name="noise", size=noise_dim)
sample = generator(noise)
if is_discriminator_training:
sample = data_layer(name="sample", size=sample_dim)
if is_generator_training or is_discriminator_training:
label = data_layer(name="label", size=1)
prob = discriminator(sample)
cost = cross_entropy(input=prob, label=label)
classification_error_evaluator(input=prob, label=label, name=mode+'_error')
outputs(cost)
if is_generator:
noise = data_layer(name="noise", size=noise_dim)
outputs(generator(noise))
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import itertools
import random
import numpy
from paddle.trainer.config_parser import parse_config
from paddle.trainer.config_parser import logger
import py_paddle.swig_paddle as api
from py_paddle import DataProviderConverter
def CHECK_EQ(a, b):
assert a == b, "a=%s, b=%s" % (a, b)
def copy_shared_parameters(src, dst):
src_params = [src.getParameter(i)
for i in xrange(src.getParameterSize())]
src_params = dict([(p.getName(), p) for p in src_params])
for i in xrange(dst.getParameterSize()):
dst_param = dst.getParameter(i)
src_param = src_params.get(dst_param.getName(), None)
if src_param is None:
continue
src_value = src_param.getBuf(api.PARAMETER_VALUE)
dst_value = dst_param.getBuf(api.PARAMETER_VALUE)
CHECK_EQ(len(src_value), len(dst_value))
dst_value.copyFrom(src_value)
dst_param.setValueUpdated()
def get_real_samples(batch_size, sample_dim):
return numpy.random.rand(batch_size, sample_dim).astype('float32')
def prepare_discriminator_data_batch(
generator_machine, batch_size, noise_dim, sample_dim):
gen_inputs = prepare_generator_data_batch(batch_size / 2, noise_dim)
gen_inputs.resize(1)
gen_outputs = api.Arguments.createArguments(0)
generator_machine.forward(gen_inputs, gen_outputs, api.PASS_TEST)
fake_samples = gen_outputs.getSlotValue(0).copyToNumpyMat()
real_samples = get_real_samples(batch_size / 2, sample_dim)
all_samples = numpy.concatenate((fake_samples, real_samples), 0)
all_labels = numpy.concatenate(
(numpy.zeros(batch_size / 2, dtype='int32'),
numpy.ones(batch_size / 2, dtype='int32')), 0)
inputs = api.Arguments.createArguments(2)
inputs.setSlotValue(0, api.Matrix.createCpuDenseFromNumpy(all_samples))
inputs.setSlotIds(1, api.IVector.createCpuVectorFromNumpy(all_labels))
return inputs
def prepare_generator_data_batch(batch_size, dim):
noise = numpy.random.normal(size=(batch_size, dim)).astype('float32')
label = numpy.ones(batch_size, dtype='int32')
inputs = api.Arguments.createArguments(2)
inputs.setSlotValue(0, api.Matrix.createCpuDenseFromNumpy(noise))
inputs.setSlotIds(1, api.IVector.createCpuVectorFromNumpy(label))
return inputs
def find(iterable, cond):
for item in iterable:
if cond(item):
return item
return None
def get_layer_size(model_conf, layer_name):
layer_conf = find(model_conf.layers, lambda x: x.name == layer_name)
assert layer_conf is not None, "Cannot find '%s' layer" % layer_name
return layer_conf.size
def main():
api.initPaddle('--use_gpu=0', '--dot_period=100', '--log_period=10000')
gen_conf = parse_config("gan_conf.py", "mode=generator_training")
dis_conf = parse_config("gan_conf.py", "mode=discriminator_training")
generator_conf = parse_config("gan_conf.py", "mode=generator")
batch_size = dis_conf.opt_config.batch_size
noise_dim = get_layer_size(gen_conf.model_config, "noise")
sample_dim = get_layer_size(dis_conf.model_config, "sample")
# this create a gradient machine for discriminator
dis_training_machine = api.GradientMachine.createFromConfigProto(
dis_conf.model_config)
gen_training_machine = api.GradientMachine.createFromConfigProto(
gen_conf.model_config)
# generator_machine is used to generate data only, which is used for
# training discrinator
logger.info(str(generator_conf.model_config))
generator_machine = api.GradientMachine.createFromConfigProto(
generator_conf.model_config)
dis_trainer = api.Trainer.create(
dis_conf, dis_training_machine)
gen_trainer = api.Trainer.create(
gen_conf, gen_training_machine)
dis_trainer.startTrain()
gen_trainer.startTrain()
for train_pass in xrange(10):
dis_trainer.startTrainPass()
gen_trainer.startTrainPass()
for i in xrange(100000):
copy_shared_parameters(gen_training_machine, generator_machine)
copy_shared_parameters(gen_training_machine, dis_training_machine)
data_batch = prepare_discriminator_data_batch(
generator_machine, batch_size, noise_dim, sample_dim)
dis_trainer.trainOneDataBatch(batch_size, data_batch)
copy_shared_parameters(dis_training_machine, gen_training_machine)
data_batch = prepare_generator_data_batch(
batch_size, noise_dim)
gen_trainer.trainOneDataBatch(batch_size, data_batch)
dis_trainer.finishTrainPass()
gen_trainer.finishTrainPass()
dis_trainer.finishTrain()
gen_trainer.finishTrain()
if __name__ == '__main__':
main()
......@@ -27,11 +27,6 @@ Arguments* Arguments::createArguments(size_t slotNum) {
void Arguments::resize(size_t slotNum) { m->outputs.resize(slotNum); }
Matrix* Arguments::getSlotValue(size_t idx) const throw(RangeError) {
auto& a = m->getArg(idx);
return Matrix::createByPaddleMatrixPtr(&a.value);
}
Arguments::Arguments() : m(new ArgumentsPrivate()) {}
Arguments::~Arguments() { delete m; }
......@@ -43,6 +38,16 @@ Arguments* Arguments::createByPaddleArgumentVector(void* ptr) {
return args;
}
Matrix* Arguments::getSlotValue(size_t idx) const throw(RangeError) {
auto& a = m->getArg(idx);
return Matrix::createByPaddleMatrixPtr(&a.value);
}
Matrix* Arguments::getSlotGrad(size_t idx) const throw(RangeError) {
auto& a = m->getArg(idx);
return Matrix::createByPaddleMatrixPtr(&a.grad);
}
IVector* Arguments::getSlotIds(size_t idx) const throw(RangeError) {
auto& a = m->getArg(idx);
return IVector::createByPaddleVectorPtr(&a.ids);
......@@ -58,6 +63,11 @@ void Arguments::setSlotValue(size_t idx, Matrix* mat) throw(RangeError) {
a.value = m->cast<paddle::Matrix>(mat->getSharedPtr());
}
void Arguments::setSlotGrad(size_t idx, Matrix* mat) throw(RangeError) {
auto& a = m->getArg(idx);
a.grad = m->cast<paddle::Matrix>(mat->getSharedPtr());
}
void Arguments::setSlotIn(size_t idx, Matrix* mat) throw(RangeError) {
auto& a = m->getArg(idx);
a.in = m->cast<paddle::Matrix>(mat->getSharedPtr());
......
......@@ -156,12 +156,15 @@ public:
* @param dim1 dimension of data.
* @param dim2 dimension of data.
* @param copy true if copy into a new matrix, false will create
* matrix inplace.
* matrix inplace. copy = false should be used with extreme
* care because Matrix will share the memory with the given
* numpy array. If the numpy array object is no longer valid,
* the memory space will not be usable.
*/
static Matrix* createCpuDenseFromNumpy(float* data,
int dim1,
int dim2,
bool copy = false);
bool copy = true);
/// Create Gpu Dense Matrix from numpy matrix, dtype=float32
static Matrix* createGpuDenseFromNumpy(float* data, int dim1, int dim2);
......@@ -271,11 +274,18 @@ public:
*/
static Vector* createCpuVectorFromNumpy(float* data,
int dim,
bool copy = false);
bool copy = true);
/// Create Gpu Vector from numpy array, which dtype=float32
static Vector* createGpuVectorFromNumpy(float* data, int dim);
/**
* copy from another vector
* throw(RangeError) if size of src vector is different from size of this
* vector
*/
void copyFrom(Vector* src) throw(RangeError);
/// Cast to numpy array inplace.
void toNumpyArrayInplace(float** view_data, int* dim1) throw(UnsupportError);
......@@ -339,7 +349,7 @@ public:
*/
static IVector* createCpuVectorFromNumpy(int* data,
int dim,
bool copy = false);
bool copy = true);
/**
* Create Gpu IVector from numpy array, which dtype=int32
*/
......@@ -418,6 +428,7 @@ public:
* the param idx is the slot id
*/
Matrix* getSlotValue(size_t idx) const throw(RangeError);
Matrix* getSlotGrad(size_t idx) const throw(RangeError);
IVector* getSlotIds(size_t idx) const throw(RangeError);
Matrix* getSlotIn(size_t idx) const throw(RangeError);
IVector* getSlotSequenceStartPositions(size_t idx) const throw(RangeError);
......@@ -434,6 +445,7 @@ public:
* The other param is the input Matrix or vector.
*/
void setSlotValue(size_t idx, Matrix* mat) throw(RangeError);
void setSlotGrad(size_t idx, Matrix* mat) throw(RangeError);
void setSlotIn(size_t idx, Matrix* mat) throw(RangeError);
void setSlotIds(size_t idx, IVector* vec) throw(RangeError);
void setSlotSequenceStartPositions(size_t idx,
......@@ -535,6 +547,7 @@ public:
size_t getID() const;
ParameterConfig* getConfig();
void setValueUpdated();
private:
static Parameter* createFromRawPtr(void* ptr);
......
......@@ -68,3 +68,5 @@ ParameterConfig* Parameter::getConfig() {
}
size_t Parameter::getID() const { return m->getPtr()->getID(); }
void Parameter::setValueUpdated() { m->getPtr()->setValueUpdated(); }
......@@ -281,6 +281,13 @@ FloatArray Vector::getData() const {
}
}
void Vector::copyFrom(Vector* src) throw(RangeError) {
if (src->m->vec->getSize() != m->vec->getSize()) {
throw RangeError();
}
m->vec->copyFrom(*src->m->vec);
}
bool Vector::isGpu() const {
return std::dynamic_pointer_cast<paddle::GpuVector>(m->vec) != nullptr;
}
......
......@@ -43,7 +43,7 @@ class TestIVector(unittest.TestCase):
def test_cpu_numpy(self):
vec = np.array([1, 3, 4, 65, 78, 1, 4], dtype="int32")
iv = swig_paddle.IVector.createCpuVectorFromNumpy(vec)
iv = swig_paddle.IVector.createCpuVectorFromNumpy(vec, copy=False)
self.assertEqual(vec.shape[0], int(iv.__len__()))
vec[4] = 832
for i in xrange(len(iv)):
......@@ -107,7 +107,7 @@ class TestVector(unittest.TestCase):
def testCpuNumpy(self):
numpy_arr = np.array([1.2, 2.3, 3.4, 4.5], dtype="float32")
vec = swig_paddle.Vector.createCpuVectorFromNumpy(numpy_arr)
vec = swig_paddle.Vector.createCpuVectorFromNumpy(numpy_arr, copy=False)
assert isinstance(vec, swig_paddle.Vector)
numpy_arr[0] = 0.1
for n, v in zip(numpy_arr, vec):
......@@ -152,4 +152,4 @@ if __name__ == '__main__':
unittest.TextTestRunner().run(suite)
if swig_paddle.isGpuVersion():
swig_paddle.setUseGpu(True)
unittest.main()
\ No newline at end of file
unittest.main()
......@@ -24,7 +24,9 @@ def doubleEqual(a, b):
def __readFromFile():
for i in xrange(10002):
yield np.random.rand(784), random.randint(0, 9)
label = np.random.randint(0, 9)
sample = np.random.rand(784) + 0.1 * label
yield sample, label
def loadMNISTTrainData(batch_size=100):
......
......@@ -559,10 +559,10 @@ def __monkey_patch_trainer__():
def monkeypatches():
patches = [
__monkeypatch_init_paddle__, __monkeypatch_gradient_machine__,
__monkey_patch_protobuf_objects__, __monkey_patch_parameter__,
__monkey_patch_trainer__
]
patches = [__monkeypatch_init_paddle__,
__monkeypatch_gradient_machine__,
__monkey_patch_protobuf_objects__,
__monkey_patch_parameter__,
__monkey_patch_trainer__]
for patch in patches:
patch()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册