未验证 提交 340a104c 编写于 作者: Q Qiyang Min 提交者: GitHub

Merge pull request #12658 from velconia/port_pybind11

Port pybind11 and python code to support py3 CI test
......@@ -202,6 +202,52 @@ std::vector<std::string> OpDesc::AttrNames() const {
}
void OpDesc::SetAttr(const std::string &name, const Attribute &v) {
// NOTICE(minqiyang): pybind11 will take the empty list in python as
// the std::vector<int> type in C++; so we have to change the attr's type
// here if we meet this issue
proto::AttrType attr_type = static_cast<proto::AttrType>(v.which() - 1);
if (attr_type == proto::AttrType::INTS &&
boost::get<std::vector<int>>(v).size() == 0u) {
// Find current attr via attr name and set the correct attribute value
const proto::OpProto::Attr &attr = GetProtoAttr(name);
switch (attr.type()) {
case proto::AttrType::BOOLEANS: {
VLOG(11) << "SetAttr: " << Type() << ", " << name
<< " from INTS to BOOLEANS";
this->attrs_[name] = std::vector<bool>();
break;
}
case proto::AttrType::INTS: {
VLOG(11) << "SetAttr: " << Type() << ", " << name
<< " from INTS to INTS";
this->attrs_[name] = std::vector<int>();
break;
}
case proto::AttrType::FLOATS: {
VLOG(11) << "SetAttr: " << Type() << ", " << name
<< " from INTS to FLOATS";
this->attrs_[name] = std::vector<float>();
break;
}
case proto::AttrType::STRINGS: {
VLOG(11) << "SetAttr: " << Type() << ", " << name
<< " from INTS to STRINGS";
this->attrs_[name] = std::vector<std::string>();
break;
}
case proto::AttrType::BLOCKS: {
VLOG(11) << "SetAttr: " << Type() << ", " << name
<< " from INTS to BLOCKS";
this->SetBlocksAttr(name, std::vector<BlockDesc *>());
return;
}
default:
PADDLE_THROW("Wrong attr type %d", attr.type());
}
need_update_ = true;
return;
}
this->attrs_[name] = v;
need_update_ = true;
}
......@@ -229,6 +275,19 @@ Attribute OpDesc::GetAttr(const std::string &name) const {
return it->second;
}
const proto::OpProto::Attr &OpDesc::GetProtoAttr(
const std::string &name) const {
const proto::OpProto &proto = OpInfoMap::Instance().Get(Type()).Proto();
for (int i = 0; i != proto.attrs_size(); ++i) {
const proto::OpProto::Attr &attr = proto.attrs(i);
if (attr.name() == name) {
return attr;
}
}
PADDLE_THROW("Attribute %s is not found in proto %s", name, proto.type());
}
Attribute OpDesc::GetNullableAttr(const std::string &name) const {
auto it = attrs_.find(name);
if (it != attrs_.end()) {
......
......@@ -81,6 +81,8 @@ class OpDesc {
Attribute GetAttr(const std::string &name) const;
const proto::OpProto::Attr &GetProtoAttr(const std::string &name) const;
Attribute GetNullableAttr(const std::string &name) const;
int GetBlockAttrId(const std::string &name) const;
......
......@@ -205,12 +205,7 @@ void BindBlockDesc(pybind11::module *m) {
void BindVarDsec(pybind11::module *m) {
pybind11::class_<pd::VarDesc> var_desc(*m, "VarDesc", "");
var_desc
.def("name",
[](pd::VarDesc &self) {
pybind11::bytes name = self.Name();
return name;
},
pybind11::return_value_policy::reference)
.def("name", &pd::VarDesc::Name, pybind11::return_value_policy::reference)
.def("set_name", &pd::VarDesc::SetName)
.def("set_shape", &pd::VarDesc::SetShape)
.def("set_shapes", &pd::VarDesc::SetShapes)
......
......@@ -54,6 +54,8 @@ limitations under the License. */
#include "paddle/fluid/platform/gpu_info.h"
#endif
#include "pybind11/stl.h"
// disable auto conversion to list in Python
PYBIND11_MAKE_OPAQUE(paddle::framework::LoDTensorArray);
......
......@@ -24,4 +24,5 @@ except ImportError:
import paddle.reader
import paddle.dataset
import paddle.batch
import paddle.compat
batch = batch.batch
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import six
import math
__all__ = [
'long_type',
'to_text',
'to_bytes',
'round',
'floor_division',
'get_exception_message',
]
if six.PY2:
int_type = int
long_type = long
else:
int_type = int
long_type = int
# str and bytes related functions
def to_text(obj, encoding='utf-8', inplace=False):
"""
All string in PaddlePaddle should be represented as a literal string.
This function will convert object to a literal string without any encoding.
Especially, if the object type is a list or set container, we will iterate
all items in the object and convert them to literal string.
In Python3:
Decode the bytes type object to str type with specific encoding
In Python2:
Decode the str type object to unicode type with specific encoding
Args:
obj(unicode|str|bytes|list|set) : The object to be decoded.
encoding(str) : The encoding format to decode a string
inplace(bool) : If we change the original object or we create a new one
Returns:
Decoded result of obj
"""
if obj is None:
return obj
if isinstance(obj, list):
if inplace:
for i in six.moves.xrange(len(obj)):
obj[i] = _to_text(obj[i], encoding)
return obj
else:
return [_to_text(item, encoding) for item in obj]
elif isinstance(obj, set):
if inplace:
for item in obj:
obj.remove(item)
obj.add(_to_text(item, encoding))
return obj
else:
return set([_to_text(item, encoding) for item in obj])
else:
return _to_text(obj, encoding)
def _to_text(obj, encoding):
"""
In Python3:
Decode the bytes type object to str type with specific encoding
In Python2:
Decode the str type object to unicode type with specific encoding,
or we just return the unicode string of object
Args:
obj(unicode|str|bytes) : The object to be decoded.
encoding(str) : The encoding format
Returns:
decoded result of obj
"""
if obj is None:
return obj
if isinstance(obj, six.binary_type):
return obj.decode(encoding)
elif isinstance(obj, six.text_type):
return obj
else:
return six.u(obj)
def to_bytes(obj, encoding='utf-8', inplace=False):
"""
All string in PaddlePaddle should be represented as a literal string.
This function will convert object to a bytes with specific encoding.
Especially, if the object type is a list or set container, we will iterate
all items in the object and convert them to bytes.
In Python3:
Encode the str type object to bytes type with specific encoding
In Python2:
Encode the unicode type object to str type with specific encoding,
or we just return the 8-bit string of object
Args:
obj(unicode|str|bytes|list|set) : The object to be encoded.
encoding(str) : The encoding format to encode a string
inplace(bool) : If we change the original object or we create a new one
Returns:
Decoded result of obj
"""
if obj is None:
return obj
if isinstance(obj, list):
if inplace:
for i in six.moves.xrange(len(obj)):
obj[i] = _to_bytes(obj[i], encoding)
return obj
else:
return [_to_bytes(item, encoding) for item in obj]
elif isinstance(obj, set):
if inplace:
for item in obj:
obj.remove(item)
obj.add(_to_bytes(item, encoding))
return obj
else:
return set([_to_bytes(item, encoding) for item in obj])
else:
return _to_bytes(obj, encoding)
def _to_bytes(obj, encoding):
"""
In Python3:
Encode the str type object to bytes type with specific encoding
In Python2:
Encode the unicode type object to str type with specific encoding,
or we just return the 8-bit string of object
Args:
obj(unicode|str|bytes) : The object to be encoded.
encoding(str) : The encoding format
Returns:
encoded result of obj
"""
if obj is None:
return obj
assert encoding is not None
if isinstance(obj, six.text_type):
return obj.encode(encoding)
elif isinstance(obj, six.binary_type):
return obj
else:
return six.b(obj)
# math related functions
def round(x, d=0):
"""
Compatible round which act the same behaviour in Python3.
Args:
x(float) : The number to round halfway.
Returns:
round result of x
"""
if six.PY3:
# The official walkaround of round in Python3 is incorrect
# we implement accroding this answer: https://www.techforgeek.info/round_python.html
if x > 0.0:
p = 10**d
return float(math.floor((x * p) + math.copysign(0.5, x))) / p
elif x < 0.0:
p = 10**d
return float(math.ceil((x * p) + math.copysign(0.5, x))) / p
else:
return math.copysign(0.0, x)
else:
import __builtin__
return __builtin__.round(x, d)
def floor_division(x, y):
"""
Compatible division which act the same behaviour in Python3 and Python2,
whose result will be a int value of floor(x / y) in Python3 and value of
(x / y) in Python2.
Args:
x(int|float) : The number to divide.
y(int|float) : The number to be divided
Returns:
division result of x // y
"""
return x // y
# exception related functions
def get_exception_message(exc):
"""
Get the error message of a specific exception
Args:
exec(Exception) : The exception to get error message.
Returns:
the error message of exec
"""
assert exc is not None
if six.PY2:
return exc.message
else:
return str(exc)
......@@ -32,7 +32,7 @@ import itertools
import numpy
import paddle.dataset.common
import tarfile
from six.moves import zip
import six
from six.moves import cPickle as pickle
__all__ = ['train100', 'test100', 'train10', 'test10', 'convert']
......@@ -46,10 +46,11 @@ CIFAR100_MD5 = 'eb9058c3a382ffc7106e4002c42a8d85'
def reader_creator(filename, sub_name, cycle=False):
def read_batch(batch):
data = batch['data']
labels = batch.get('labels', batch.get('fine_labels', None))
data = batch[six.b('data')]
labels = batch.get(
six.b('labels'), batch.get(six.b('fine_labels'), None))
assert labels is not None
for sample, label in zip(data, labels):
for sample, label in six.moves.zip(data, labels):
yield (sample / 255.0).astype(numpy.float32), int(label)
def reader():
......@@ -59,7 +60,11 @@ def reader_creator(filename, sub_name, cycle=False):
while True:
for name in names:
batch = pickle.load(f.extractfile(name))
if six.PY2:
batch = pickle.load(f.extractfile(name))
else:
batch = pickle.load(
f.extractfile(name), encoding='bytes')
for item in read_batch(batch):
yield item
if not cycle:
......
......@@ -85,10 +85,10 @@ def download(url, module_name, md5sum, save_name=None):
total_length = r.headers.get('content-length')
if total_length is None:
with open(filename, 'w') as f:
with open(filename, 'wb') as f:
shutil.copyfileobj(r.raw, f)
else:
with open(filename, 'w') as f:
with open(filename, 'wb') as f:
dl = 0
total_length = int(total_length)
for data in r.iter_content(chunk_size=4096):
......
......@@ -24,11 +24,12 @@ import tarfile
import gzip
import itertools
import paddle.dataset.common
from six.moves import zip
import paddle.compat as cpt
from six.moves import zip, range
__all__ = ['test, get_dict', 'get_embedding', 'convert']
DATA_URL = 'http://www.cs.upc.edu/~srlconll/conll05st-tests.tar.gz'
DATA_URL = 'http://paddlemodels.bj.bcebos.com/conll05st/conll05st-tests.tar.gz'
DATA_MD5 = '387719152ae52d60422c016e92a742fc'
WORDDICT_URL = 'http://paddlemodels.bj.bcebos.com/conll05st%2FwordDict.txt'
WORDDICT_MD5 = 'ea7fb7d4c75cc6254716f0177a506baa'
......@@ -89,8 +90,8 @@ def corpus_reader(data_path, words_name, props_name):
labels = []
one_seg = []
for word, label in zip(words_file, props_file):
word = word.strip()
label = label.strip().split()
word = cpt.to_text(word.strip())
label = cpt.to_text(label.strip().split())
if len(label) == 0: # end of sentence
for i in range(len(one_seg[0])):
......
......@@ -116,8 +116,8 @@ def reader_creator(data_file,
for file in open(file_list):
file = file.strip()
batch = None
with open(file, 'r') as f:
batch = pickle.load(f)
with open(file, 'rb') as f:
batch = pickle.loads(f.read())
data = batch['data']
labels = batch['label']
for sample, label in zip(data, batch['label']):
......
......@@ -33,6 +33,11 @@ import numpy as np
try:
import cv2
except ImportError:
import sys
sys.stderr.write(
'''Warning with paddle image module: opencv-python should be imported,
or paddle image module could NOT work; please install opencv-python first.'''
)
cv2 = None
import os
import tarfile
......@@ -56,7 +61,7 @@ def batch_images_from_tar(data_file,
:type data_file: string
:param dataset_name: 'train','test' or 'valid'
:type dataset_name: string
:param img2label: a dic with image file name as key
:param img2label: a dic with image file name as key
and image's label as value
:type img2label: dic
:param num_per_batch: image number per batch file
......@@ -88,7 +93,7 @@ def batch_images_from_tar(data_file,
output['data'] = data
pickle.dump(
output,
open('%s/batch_%d' % (out_path, file_id), 'w'),
open('%s/batch_%d' % (out_path, file_id), 'wb'),
protocol=pickle.HIGHEST_PROTOCOL)
file_id += 1
data = []
......@@ -99,7 +104,7 @@ def batch_images_from_tar(data_file,
output['data'] = data
pickle.dump(
output,
open('%s/batch_%d' % (out_path, file_id), 'w'),
open('%s/batch_%d' % (out_path, file_id), 'wb'),
protocol=pickle.HIGHEST_PROTOCOL)
with open(meta_file, 'a') as meta:
......@@ -113,7 +118,7 @@ def load_image_bytes(bytes, is_color=True):
Load an color or gray image from bytes array.
Example usage:
.. code-block:: python
with open('cat.jpg') as f:
......@@ -126,6 +131,8 @@ def load_image_bytes(bytes, is_color=True):
load and return a gray image.
:type is_color: bool
"""
assert cv2 is not None
flag = 1 if is_color else 0
file_bytes = np.asarray(bytearray(bytes), dtype=np.uint8)
img = cv2.imdecode(file_bytes, flag)
......@@ -137,7 +144,7 @@ def load_image(file, is_color=True):
Load an color or gray image from the file path.
Example usage:
.. code-block:: python
im = load_image('cat.jpg')
......@@ -149,6 +156,8 @@ def load_image(file, is_color=True):
load and return a gray image.
:type is_color: bool
"""
assert cv2 is not None
# cv2.IMAGE_COLOR for OpenCV3
# cv2.CV_LOAD_IMAGE_COLOR for older OpenCV Version
# cv2.IMAGE_GRAYSCALE for OpenCV3
......@@ -161,27 +170,29 @@ def load_image(file, is_color=True):
def resize_short(im, size):
"""
"""
Resize an image so that the length of shorter edge is size.
Example usage:
.. code-block:: python
im = load_image('cat.jpg')
im = resize_short(im, 256)
:param im: the input image with HWC layout.
:type im: ndarray
:param size: the shorter edge size of image after resizing.
:type size: int
"""
assert cv2 is not None
h, w = im.shape[:2]
h_new, w_new = size, size
if h > w:
h_new = size * h / w
h_new = size * h // w
else:
w_new = size * w / h
w_new = size * w // h
im = cv2.resize(im, (h_new, w_new), interpolation=cv2.INTER_CUBIC)
return im
......@@ -193,17 +204,17 @@ def to_chw(im, order=(2, 0, 1)):
according the order (2,0,1).
Example usage:
.. code-block:: python
im = load_image('cat.jpg')
im = resize_short(im, 256)
im = to_chw(im)
:param im: the input image with HWC layout.
:type im: ndarray
:param order: the transposed order.
:type order: tuple|list
:type order: tuple|list
"""
assert len(im.shape) == len(order)
im = im.transpose(order)
......@@ -215,11 +226,11 @@ def center_crop(im, size, is_color=True):
Crop the center of image with size.
Example usage:
.. code-block:: python
im = center_crop(im, 224)
:param im: the input image with HWC layout.
:type im: ndarray
:param size: the cropping size.
......@@ -228,8 +239,8 @@ def center_crop(im, size, is_color=True):
:type is_color: bool
"""
h, w = im.shape[:2]
h_start = (h - size) / 2
w_start = (w - size) / 2
h_start = (h - size) // 2
w_start = (w - size) // 2
h_end, w_end = h_start + size, w_start + size
if is_color:
im = im[h_start:h_end, w_start:w_end, :]
......@@ -243,11 +254,11 @@ def random_crop(im, size, is_color=True):
Randomly crop input image with size.
Example usage:
.. code-block:: python
im = random_crop(im, 224)
:param im: the input image with HWC layout.
:type im: ndarray
:param size: the cropping size.
......@@ -272,11 +283,11 @@ def left_right_flip(im, is_color=True):
Return the flipped image.
Example usage:
.. code-block:: python
im = left_right_flip(im)
:param im: input image with HWC layout or HW layout for gray image
:type im: ndarray
:param is_color: whether input image is color or not
......@@ -299,7 +310,7 @@ def simple_transform(im,
resizing, croping and flipping.
Example usage:
.. code-block:: python
im = simple_transform(im, 256, 224, True)
......@@ -314,7 +325,7 @@ def simple_transform(im,
:type is_train: bool
:param is_color: whether the image is color or not.
:type is_color: bool
:param mean: the mean values, which can be element-wise mean values or
:param mean: the mean values, which can be element-wise mean values or
mean values per channel.
:type mean: numpy array | list
"""
......@@ -332,7 +343,7 @@ def simple_transform(im,
im = im.astype('float32')
if mean is not None:
mean = np.array(mean, dtype=np.float32)
# mean value, may be one value per channel
# mean value, may be one value per channel
if mean.ndim == 1 and is_color:
mean = mean[:, np.newaxis, np.newaxis]
elif mean.ndim == 1:
......@@ -357,7 +368,7 @@ def load_and_transform(filename,
for the transform operations.
Example usage:
.. code-block:: python
im = load_and_transform('cat.jpg', 256, 224, True)
......@@ -372,7 +383,7 @@ def load_and_transform(filename,
:type is_train: bool
:param is_color: whether the image is color or not.
:type is_color: bool
:param mean: the mean values, which can be element-wise mean values or
:param mean: the mean values, which can be element-wise mean values or
mean values per channel.
:type mean: numpy array | list
"""
......
......@@ -25,6 +25,7 @@ import collections
import tarfile
import re
import string
import six
__all__ = ['build_dict', 'train', 'test', 'convert']
......@@ -42,13 +43,14 @@ def tokenize(pattern):
# sequential access of member files, other than
# tarfile.extractfile, which does random access and might
# destroy hard disks.
tf = next(tarf)
tf = tarf.next()
while tf != None:
if bool(pattern.match(tf.name)):
# newline and punctuations removal and ad-hoc tokenization.
yield tarf.extractfile(tf).read().rstrip("\n\r").translate(
None, string.punctuation).lower().split()
tf = next(tarf)
yield tarf.extractfile(tf).read().rstrip(six.b(
"\n\r")).translate(
None, six.b(string.punctuation)).lower().split()
tf = tarf.next()
def build_dict(pattern, cutoff):
......@@ -62,11 +64,11 @@ def build_dict(pattern, cutoff):
word_freq[word] += 1
# Not sure if we should prune less-frequent words here.
word_freq = [x for x in list(word_freq.items()) if x[1] > cutoff]
word_freq = [x for x in six.iteritems(word_freq) if x[1] > cutoff]
dictionary = sorted(word_freq, key=lambda x: (-x[1], x[0]))
words, _ = list(zip(*dictionary))
word_idx = dict(list(zip(words, list(range(len(words))))))
word_idx = dict(list(zip(words, six.moves.range(len(words)))))
word_idx['<unk>'] = len(words)
return word_idx
......
......@@ -14,13 +14,14 @@
"""
imikolov's simple dataset.
This module will download dataset from
This module will download dataset from
http://www.fit.vutbr.cz/~imikolov/rnnlm/ and parse training set and test set
into paddle reader creators.
"""
import paddle.dataset.common
import collections
import tarfile
import six
__all__ = ['train', 'test', 'build_dict', 'convert']
......@@ -64,11 +65,13 @@ def build_dict(min_word_freq=50):
# remove <unk> for now, since we will set it as last index
del word_freq['<unk>']
word_freq = [x for x in list(word_freq.items()) if x[1] > min_word_freq]
word_freq = [
x for x in six.iteritems(word_freq) if x[1] > min_word_freq
]
word_freq_sorted = sorted(word_freq, key=lambda x: (-x[1], x[0]))
words, _ = list(zip(*word_freq_sorted))
word_idx = dict(list(zip(words, list(range(len(words))))))
word_idx = dict(list(zip(words, six.moves.range(len(words)))))
word_idx['<unk>'] = len(words)
return word_idx
......@@ -89,7 +92,7 @@ def reader_creator(filename, word_idx, n, data_type):
l = ['<s>'] + l.strip().split() + ['<e>']
if len(l) >= n:
l = [word_idx.get(w, UNK) for w in l]
for i in range(n, len(l) + 1):
for i in six.moves.range(n, len(l) + 1):
yield tuple(l[i - n:i])
elif DataType.SEQ == data_type:
l = l.strip().split()
......
......@@ -21,6 +21,9 @@ import paddle.dataset.common
import subprocess
import numpy
import platform
import six
import tempfile
from six.moves import range
__all__ = ['train', 'test', 'convert']
URL_PREFIX = 'http://yann.lecun.com/exdb/mnist/'
......@@ -45,23 +48,28 @@ def reader_creator(image_filename, label_filename, buffer_size):
# According to http://stackoverflow.com/a/38061619/724872, we
# cannot use standard package gzip here.
m = subprocess.Popen([zcat_cmd, image_filename], stdout=subprocess.PIPE)
m.stdout.read(16) # skip some magic bytes
tmp_image_file = tempfile.TemporaryFile(prefix='paddle_dataset')
m = subprocess.Popen(
[zcat_cmd, image_filename], stdout=tmp_image_file).communicate()
tmp_image_file.seek(16) # skip some magic bytes
l = subprocess.Popen([zcat_cmd, label_filename], stdout=subprocess.PIPE)
l.stdout.read(8) # skip some magic bytes
# Python3 will not take stdout as file
tmp_label_file = tempfile.TemporaryFile(prefix='paddle_dataset')
l = subprocess.Popen(
[zcat_cmd, label_filename], stdout=tmp_label_file).communicate()
tmp_label_file.seek(8) # skip some magic bytes
try: # reader could be break.
while True:
labels = numpy.fromfile(
l.stdout, 'ubyte', count=buffer_size).astype("int")
tmp_label_file, 'ubyte', count=buffer_size).astype("int")
if labels.size != buffer_size:
break # numpy.fromfile returns empty slice after EOF.
images = numpy.fromfile(
m.stdout, 'ubyte', count=buffer_size * 28 * 28).reshape(
(buffer_size, 28 * 28)).astype('float32')
tmp_image_file, 'ubyte', count=buffer_size * 28 *
28).reshape((buffer_size, 28 * 28)).astype('float32')
images = images / 255.0 * 2.0 - 1.0
......
......@@ -27,6 +27,8 @@ import paddle.dataset.common
import re
import random
import functools
import six
import paddle.compat as cpt
__all__ = [
'train', 'test', 'get_movie_title_dict', 'max_movie_id', 'max_user_id',
......@@ -112,6 +114,7 @@ def __initialize_meta_info__():
categories_set = set()
with package.open('ml-1m/movies.dat') as movie_file:
for i, line in enumerate(movie_file):
line = cpt.to_text(line, encoding='latin')
movie_id, title, categories = line.strip().split('::')
categories = categories.split('|')
for c in categories:
......@@ -136,6 +139,7 @@ def __initialize_meta_info__():
USER_INFO = dict()
with package.open('ml-1m/users.dat') as user_file:
for line in user_file:
line = cpt.to_text(line, encoding='latin')
uid, gender, age, job, _ = line.strip().split("::")
USER_INFO[int(uid)] = UserInfo(
index=uid, gender=gender, age=age, job_id=job)
......@@ -148,6 +152,7 @@ def __reader__(rand_seed=0, test_ratio=0.1, is_test=False):
with zipfile.ZipFile(file=fn) as package:
with package.open('ml-1m/ratings.dat') as rating:
for line in rating:
line = cpt.to_text(line, encoding='latin')
if (rand.random() < test_ratio) == is_test:
uid, mov_id, rating, _ = line.strip().split("::")
uid = int(uid)
......@@ -187,7 +192,7 @@ def max_movie_id():
Get the maximum value of movie id.
"""
__initialize_meta_info__()
return reduce(__max_index_info__, list(MOVIE_INFO.values())).index
return six.moves.reduce(__max_index_info__, list(MOVIE_INFO.values())).index
def max_user_id():
......@@ -195,7 +200,7 @@ def max_user_id():
Get the maximum value of user id.
"""
__initialize_meta_info__()
return reduce(__max_index_info__, list(USER_INFO.values())).index
return six.moves.reduce(__max_index_info__, list(USER_INFO.values())).index
def __max_job_id_impl__(a, b):
......@@ -210,7 +215,8 @@ def max_job_id():
Get the maximum value of job id.
"""
__initialize_meta_info__()
return reduce(__max_job_id_impl__, list(USER_INFO.values())).job_id
return six.moves.reduce(__max_job_id_impl__,
list(USER_INFO.values())).job_id
def movie_categories():
......
......@@ -20,6 +20,7 @@ The script fetch and preprocess movie_reviews data set that provided by NLTK
TODO(yuyang18): Complete dataset.
"""
import six
import collections
from itertools import chain
......@@ -64,7 +65,7 @@ def get_word_dict():
for field in movie_reviews.fileids(category):
for words in movie_reviews.words(field):
word_freq_dict[words] += 1
words_sort_list = list(word_freq_dict.items())
words_sort_list = six.iteritems(word_freq_dict)
words_sort_list.sort(cmp=lambda a, b: b[1] - a[1])
for index, word in enumerate(words_sort_list):
words_freq_sorted.append((word[0], index))
......
......@@ -16,6 +16,7 @@ import paddle.dataset.common
import unittest
import tempfile
import glob
from six.moves import range
class TestCommon(unittest.TestCase):
......
......@@ -22,6 +22,7 @@ parse training set and test set into paddle reader creators.
import os
import numpy as np
import six
import tempfile
import tarfile
import os
......@@ -70,11 +71,11 @@ def load_data(filename, feature_num=14, ratio=0.8):
return
data = np.fromfile(filename, sep=' ')
data = data.reshape(data.shape[0] / feature_num, feature_num)
data = data.reshape(data.shape[0] // feature_num, feature_num)
maximums, minimums, avgs = data.max(axis=0), data.min(axis=0), data.sum(
axis=0) / data.shape[0]
feature_range(maximums[:-1], minimums[:-1])
for i in range(feature_num - 1):
for i in six.moves.range(feature_num - 1):
data[:, i] = (data[:, i] - avgs[i]) / (maximums[i] - minimums[i])
offset = int(data.shape[0] * ratio)
UCI_TRAIN_DATA = data[:offset]
......@@ -137,7 +138,7 @@ def predict_reader():
It returns just one tuple data to do inference.
:return: one tuple data
:rtype: tuple
:rtype: tuple
"""
global UCI_TEST_DATA
load_data(paddle.dataset.common.download(URL, 'uci_housing', MD5))
......
......@@ -19,10 +19,12 @@ http://paddlepaddle.cdn.bcebos.com/demo/wmt_shrinked_data/wmt14.tgz and
parse training set and test set into paddle reader creators.
"""
import six
import tarfile
import gzip
import paddle.dataset.common
import paddle.compat as cpt
__all__ = [
'train',
......@@ -53,7 +55,7 @@ def __read_to_dict(tar_file, dict_size):
out_dict = dict()
for line_count, line in enumerate(fd):
if line_count < size:
out_dict[line.strip()] = line_count
out_dict[cpt.to_text(line.strip())] = line_count
else:
break
return out_dict
......@@ -84,7 +86,7 @@ def reader_creator(tar_file, file_name, dict_size):
]
for name in names:
for line in f.extractfile(name):
line_split = line.strip().split('\t')
line_split = line.strip().split(six.b('\t'))
if len(line_split) != 2:
continue
src_seq = line_split[0] # one source sequence
......@@ -153,8 +155,8 @@ def get_dict(dict_size, reverse=True):
tar_file = paddle.dataset.common.download(URL_TRAIN, 'wmt14', MD5_TRAIN)
src_dict, trg_dict = __read_to_dict(tar_file, dict_size)
if reverse:
src_dict = {v: k for k, v in list(src_dict.items())}
trg_dict = {v: k for k, v in list(trg_dict.items())}
src_dict = {v: k for k, v in six.iteritems(src_dict)}
trg_dict = {v: k for k, v in six.iteritems(trg_dict)}
return src_dict, trg_dict
......
......@@ -29,11 +29,13 @@ Multi30K: Multilingual English-German Image Descriptions.
"""
import os
import six
import tarfile
import gzip
from collections import defaultdict
import paddle.dataset.common
import paddle.compat as cpt
__all__ = [
"train",
......@@ -60,7 +62,7 @@ def __build_dict(tar_file, dict_size, save_path, lang):
word_dict = defaultdict(int)
with tarfile.open(tar_file, mode="r") as f:
for line in f.extractfile("wmt16/train"):
line_split = line.strip().split("\t")
line_split = line.strip().split(six.b("\t"))
if len(line_split) != 2: continue
sen = line_split[0] if lang == "en" else line_split[1]
for w in sen.split():
......@@ -70,8 +72,7 @@ def __build_dict(tar_file, dict_size, save_path, lang):
fout.write("%s\n%s\n%s\n" % (START_MARK, END_MARK, UNK_MARK))
for idx, word in enumerate(
sorted(
iter(list(word_dict.items())),
key=lambda x: x[1],
six.iteritems(word_dict), key=lambda x: x[1],
reverse=True)):
if idx + 3 == dict_size: break
fout.write("%s\n" % (word[0]))
......@@ -81,16 +82,16 @@ def __load_dict(tar_file, dict_size, lang, reverse=False):
dict_path = os.path.join(paddle.dataset.common.DATA_HOME,
"wmt16/%s_%d.dict" % (lang, dict_size))
if not os.path.exists(dict_path) or (
len(open(dict_path, "r").readlines()) != dict_size):
len(open(dict_path, "rb").readlines()) != dict_size):
__build_dict(tar_file, dict_size, dict_path, lang)
word_dict = {}
with open(dict_path, "r") as fdict:
with open(dict_path, "rb") as fdict:
for idx, line in enumerate(fdict):
if reverse:
word_dict[idx] = line.strip()
word_dict[idx] = cpt.to_text(line.strip())
else:
word_dict[line.strip()] = idx
word_dict[cpt.to_text(line.strip())] = idx
return word_dict
......@@ -120,7 +121,7 @@ def reader_creator(tar_file, file_name, src_dict_size, trg_dict_size, src_lang):
with tarfile.open(tar_file, mode="r") as f:
for line in f.extractfile(file_name):
line_split = line.strip().split("\t")
line_split = line.strip().split(six.b("\t"))
if len(line_split) != 2:
continue
src_words = line_split[src_col].split()
......
......@@ -17,6 +17,7 @@ from . import core
import collections
import copy
import six
from .. import compat as cpt
from . import unique_name
__all__ = ['append_backward']
......@@ -45,13 +46,13 @@ def _create_op_desc_(op_type, inputs, outputs, attrs):
"""
op_desc = core.OpDesc()
op_desc.set_type(op_type)
for para, args in list(inputs.items()):
for para, args in six.iteritems(inputs):
op_desc.set_input(
para,
list(
map(lambda arg: arg.decode() if isinstance(arg, six.binary_type) else arg,
args)))
for para, args in list(outputs.items()):
for para, args in six.iteritems(outputs):
op_desc.set_output(
para,
list(
......@@ -63,7 +64,7 @@ def _create_op_desc_(op_type, inputs, outputs, attrs):
if op_role_attr_name not in attrs:
attrs[
op_role_attr_name] = core.op_proto_and_checker_maker.OpRole.Backward
for name, val in list(attrs.items()):
for name, val in six.iteritems(attrs):
if isinstance(val, framework.Block):
op_desc.set_block_attr(name, val.desc)
else:
......@@ -75,10 +76,10 @@ def _infer_var_data_type_(grad_var_name, block):
"""
Infer the data type of given grad variable
"""
grad_var = block.desc.find_var(grad_var_name.encode("ascii"))
fwd_name = _strip_grad_suffix_(grad_var_name.encode("ascii"))
if block.desc.has_var_recursive(fwd_name):
fwd_var = block.desc.find_var_recursive(fwd_name.encode("ascii"))
grad_var = block.desc.find_var(cpt.to_bytes(grad_var_name))
fwd_name = _strip_grad_suffix_(grad_var_name)
if block.desc.has_var_recursive(cpt.to_bytes(fwd_name)):
fwd_var = block.desc.find_var_recursive(cpt.to_bytes(fwd_name))
grad_var.set_dtype(fwd_var.dtype())
else:
grad_var.set_dtype(core.VarDesc.VarType.FP32)
......@@ -102,8 +103,10 @@ def _some_in_set_(cands, s):
"""
if len(cands) == 0:
return False
for c in cands:
if c in s:
literal_set = cpt.to_text(s)
literal_cands = cpt.to_text(cands)
for c in literal_cands:
if c in literal_set:
return True
return False
......@@ -114,9 +117,8 @@ def _strip_grad_suffix_(name):
e.g. x@GRAD ==> x
y@GRAD@RENAME@1 ==> y
"""
if isinstance(name, six.text_type):
name = name.encode()
pos = name.find(six.b(core.grad_var_suffix()))
name = cpt.to_text(name)
pos = name.find(core.grad_var_suffix())
return name[:pos] if pos != -1 else name
......@@ -125,9 +127,7 @@ def _append_grad_suffix_(name):
Append grad suffix to the given variable name
e.g. x ==> x@GRAD
"""
if isinstance(name, six.text_type):
name = name.encode()
return name + six.b(core.grad_var_suffix())
return cpt.to_text(name) + core.grad_var_suffix()
def _addup_repetitive_outputs_(op_descs):
......@@ -187,7 +187,7 @@ def _addup_repetitive_outputs_(op_descs):
op_desc.set_output(param_name, arg_names)
renamed_vars[var_name].append(new_name)
for var_name, inputs in list(renamed_vars.items()):
for var_name, inputs in six.iteritems(renamed_vars):
if len(inputs) > 1:
pending_sum_ops.append(
(_create_op_desc_("sum", {"X": inputs}, {"Out": [var_name]},
......@@ -243,7 +243,7 @@ from .proto import framework_pb2
def serialize_op_decs(op_desc):
protostr = op_desc.serialize_to_string()
proto = framework_pb2.OpDesc.FromString(str(protostr))
proto = framework_pb2.OpDesc.FromString(six.binary_type(protostr))
return proto.__str__()
......@@ -364,7 +364,7 @@ def _append_backward_ops_(block,
# Getting op's corresponding grad_op
grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
op.desc, no_grad_dict[block.idx], grad_sub_block_list)
op.desc, cpt.to_text(no_grad_dict[block.idx]), grad_sub_block_list)
grad_op_descs.extend(grad_op_desc)
grad_to_var.update(op_grad_to_var)
......@@ -411,11 +411,10 @@ def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map):
new_vars = set()
# create new gradient variables
for grad_var_name in op_desc.output_arg_names():
grad_var_name = grad_var_name.encode("ascii")
if block.desc.has_var_recursive(
grad_var_name) or grad_var_name == core.empty_var_name():
if block.desc.has_var_recursive(cpt.to_bytes(
grad_var_name)) or grad_var_name == core.empty_var_name():
continue
block.desc.var(grad_var_name)
block.desc.var(cpt.to_bytes(grad_var_name))
new_vars.add(grad_var_name)
if grad_var_name not in grad_to_var:
continue
......@@ -445,7 +444,7 @@ def _rename_grad_(block, start_op_idx, grad_to_var, target_grad_map):
op_desc.rename_output(name, new_name)
var_map[name] = new_name
for g, ng in list(var_map.items()):
for g, ng in six.iteritems(var_map):
if g in grad_to_var:
grad_to_var[ng] = grad_to_var[g]
grad_to_var.pop(g)
......@@ -595,11 +594,12 @@ def append_backward(loss, parameter_list=None, no_grad_set=None,
parameters = parameter_list
else:
params = program.global_block().all_parameters()
program.global_block().iter_parameters()
parameters = [param.name for param in params]
params_and_grads = []
for param in parameters:
if param not in grad_info_map:
if cpt.to_text(param) not in grad_info_map:
continue
grad_info = grad_info_map[param]
grad_block = grad_info[1]
......
......@@ -14,12 +14,14 @@
"""
This module privides a memory usage calculate function for user.
The purpose of this API is to allow users to estimate memory usage of
a program under a special batch size, then user can set appropriate
batch size to fully utilize a GPU.
a program under a special batch size, then user can set appropriate
batch size to fully utilize a GPU.
This API is still under active development and may change drastically.
"""
import six
from .. import core
from ..framework import Program, Variable
......@@ -45,15 +47,15 @@ def memory_usage(program, batch_size):
Args:
program(Program): The current Program.
batch_size(int): The current input data batch_size.
batch_size(int): The current input data batch_size.
Returns:
min_total_memory(float): the estimate memory usage lower bound.
max_total_memory(float): the estimate memory usage upper bound.
unit_str(string): the unit of estimate usage result.
Examples:
>>> import paddle.fluid as fluid
>>> lower_usage, upper_usage, unit = fluid.contrib.memory_usage(
fluid.default_main_program(), batch_size=10)
......@@ -72,7 +74,7 @@ def memory_usage(program, batch_size):
# Get the var_name list of first block and calculate
total_memory = 0.0
for var in program.global_block().vars.itervalues():
for var in six.itervalues(program.global_block().vars):
data_count = 1
for x in var.shape:
if x == -1:
......@@ -81,10 +83,10 @@ def memory_usage(program, batch_size):
data_count *= x
var_memory = data_count * dtype_to_size[var.dtype]
if DEBUG:
print "%s memory usage: %d" % (var.name, var_memory)
print("%s memory usage: %d" % (var.name, var_memory))
total_memory += var_memory
if DEBUG:
print "total memory usage: %.2f" % (total_memory)
print("total memory usage: %.2f" % (total_memory))
# Convert appropriate unit
unit_str = "B"
......
......@@ -13,6 +13,7 @@
# limitations under the License.
import sys
import six
import re
from .graphviz import GraphPreviewGenerator
from .proto import framework_pb2
......@@ -225,7 +226,7 @@ def draw_block_graphviz(block, highlights=None, path="./temp.dot"):
graph = GraphPreviewGenerator("some graph")
# collect parameters and args
protostr = block.desc.serialize_to_string()
desc = framework_pb2.BlockDesc.FromString(str(protostr))
desc = framework_pb2.BlockDesc.FromString(six.binary_type(protostr))
def need_highlight(name):
if highlights is None: return False
......
......@@ -320,8 +320,9 @@ class Executor(object):
# append fetch_operators
if not has_fetch_operators(global_block, fetch_list, fetch_var_name):
for i, var in enumerate(fetch_list):
assert isinstance(var, Variable) or isinstance(var, str), (
"Wrong type for fetch_list[%s]: %s" % (i, type(var)))
assert isinstance(var, Variable) or isinstance(
var, six.string_types), (
"Wrong type for fetch_list[%s]: %s" % (i, type(var)))
global_block.append_op(
type='fetch',
inputs={'X': [var]},
......@@ -346,7 +347,7 @@ class Executor(object):
def _fetch_data(self, fetch_list, fetch_var_name, scope):
outs = [
core.get_fetch_variable(scope, fetch_var_name, i)
for i in range(len(fetch_list))
for i in six.moves.range(len(fetch_list))
]
return outs
......
......@@ -19,6 +19,7 @@ import six
import numpy as np
from .. import compat as cpt
from .proto import framework_pb2
try:
from . import core
......@@ -27,7 +28,7 @@ except ImportError as e:
"""NOTE: You may need to run \"export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH\"
if you encounters \"libmkldnn.so not found\" errors. If you have python
installed in other directory, replace \"/usr/local/lib\" with your own
directory. The original error is: \n""" + e.message)
directory. The original error is: \n""" + cpt.get_exception_message(e))
except Exception as e:
raise e
from . import unique_name
......@@ -87,7 +88,7 @@ def convert_np_dtype_to_dtype_(np_dtype):
elif dtype == np.uint8:
return core.VarDesc.VarType.UINT8
else:
raise ValueError("Not supported numpy dtype " + six.binary_type(dtype))
raise ValueError("Not supported numpy dtype %s" % dtype)
def dtype_is_floating(dtype):
......@@ -198,11 +199,11 @@ class Variable(object):
if name is None:
name = unique_name.generate('_generated_var')
is_new_var = False
name = name if isinstance(name, six.binary_type) else name.encode()
self.desc = self.block.desc.find_var(name)
name = cpt.to_text(name)
self.desc = self.block.desc.find_var(cpt.to_bytes(name))
if self.desc is None:
self.desc = self.block.desc.var(name)
self.desc = self.block.desc.var(cpt.to_bytes(name))
is_new_var = True
if is_new_var:
......@@ -325,7 +326,7 @@ class Variable(object):
@property
def name(self):
return self.desc.name()
return cpt.to_text(self.desc.name())
@name.setter
def name(self, new_name):
......@@ -531,14 +532,7 @@ class Operator(object):
elif isinstance(arg, six.binary_type):
in_arg_names.append(arg.decode())
else:
if isinstance(arg.name, six.string_types):
in_arg_names.append(arg.name)
elif isinstance(arg.name, six.binary_type):
in_arg_names.append(arg.name.decode())
else:
raise TypeError(
"arguments require unicode, str or bytes, but get %s instead."
% (type(arg.name)))
in_arg_names.append(cpt.to_text(arg.name))
self.desc.set_input(in_proto.name, in_arg_names)
else:
self.desc.set_input(in_proto.name, [])
......@@ -567,14 +561,7 @@ class Operator(object):
(out_proto.name, len(out_args)))
out_arg_names = []
for arg in out_args:
if isinstance(arg.name, six.string_types):
out_arg_names.append(arg.name)
elif isinstance(arg.name, six.binary_type):
out_arg_names.append(arg.name.decode())
else:
raise TypeError(
"arguments require unicode, str or bytes, but get %s instead."
% (type(arg.name)))
out_arg_names.append(cpt.to_text(arg.name))
arg.op = self
self.desc.set_output(out_proto.name, out_arg_names)
......@@ -970,10 +957,9 @@ class Block(object):
Variable: the Variable with the giving name.
"""
if not isinstance(name, six.string_types):
if not isinstance(name, six.binary_type):
raise TypeError(
"var require string as parameter, but get %s instead." %
(type(name)))
raise TypeError(
"var require string as parameter, but get %s instead." %
(type(name)))
v = self.vars.get(name, None)
if v is None:
raise ValueError("var %s not in this block" % name)
......@@ -1024,7 +1010,7 @@ class Block(object):
return list(self.iter_parameters())
def iter_parameters(self):
return (item[1] for item in list(self.vars.items())
return (item[1] for item in six.iteritems(self.vars)
if isinstance(item[1], Parameter))
def create_var(self, *args, **kwargs):
......@@ -1052,6 +1038,9 @@ class Block(object):
Returns:
Variable: the Variable with the giving name.
"""
name = cpt.to_text(name)
new_name = cpt.to_text(new_name)
if not self.has_var(name):
raise ValueError("var %s is not in current block" % name)
v = self.var(name)
......@@ -1070,9 +1059,9 @@ class Block(object):
else:
raise ValueError("unsupported var type: %s", type(v))
orig_var_type = v.type
self.desc._rename_var(name, new_name)
self.desc._rename_var(cpt.to_bytes(name), cpt.to_bytes(new_name))
# NOTE: v is destroyed by C++ after calling _rename_var.
d = self.desc.find_var(new_name)
d = self.desc.find_var(cpt.to_bytes(new_name))
if var_type == "Parameter":
var = Parameter(
self,
......@@ -1103,7 +1092,7 @@ class Block(object):
def _remove_var(self, name):
self._sync_with_cpp()
self.desc._remove_var(name)
self.desc._remove_var(cpt.to_bytes(name))
del self.vars[name]
def create_parameter(self, *args, **kwargs):
......@@ -1205,7 +1194,7 @@ class Block(object):
# sync variables removed from c++ end
for var in list(self.vars.keys()):
if not self.desc.find_var(var):
if not self.desc.find_var(cpt.to_bytes(var)):
self.vars.pop(var)
# sync operators from cpp
......@@ -1576,7 +1565,9 @@ class Program(object):
p.current_block_idx = self.current_block_idx
p._seed = self._seed
p.desc = core.ProgramDesc(self.desc)
p.blocks = [Block(p, i) for i in xrange(self.desc.num_blocks())]
p.blocks = [
Block(p, i) for i in six.moves.range(self.desc.num_blocks())
]
p._current_role = self._current_role
p._op_role_var = self._op_role_var
......@@ -1632,7 +1623,9 @@ class Program(object):
targets_idx.append([t.block.idx, t.idx])
res = Program()
res.desc = core.prune(self.desc, targets_idx)
res.blocks = [Block(res, i) for i in range(res.desc.num_blocks())]
res.blocks = [
Block(res, i) for i in six.moves.range(res.desc.num_blocks())
]
res._sync_with_cpp()
return res
......@@ -1675,16 +1668,18 @@ class Program(object):
root_block._remove_op(0, read_op_idx + 1)
for var in root_block.all_vars():
if var.type() == core.VarDesc.VarType.READER:
root_block._remove_var(var.name())
root_block._remove_var(cpt.to_bytes(var.name()))
# change all `is_test` attributes to True
for i in range(res.desc.num_blocks()):
for i in six.moves.range(res.desc.num_blocks()):
block = res.desc.block(i)
for j in range(block.op_size()):
for j in six.moves.range(block.op_size()):
op = block.op(j)
if op.has_attr('is_test'):
op.set_attr('is_test', True)
res.blocks = [Block(res, i) for i in range(res.desc.num_blocks())]
res.blocks = [
Block(res, i) for i in six.moves.range(res.desc.num_blocks())
]
res._sync_with_cpp()
return res
......@@ -1704,7 +1699,7 @@ class Program(object):
"""
p = Program()
p.desc = core.ProgramDesc(binary_str)
p.blocks = [Block(p, i) for i in range(p.desc.num_blocks())]
p.blocks = [Block(p, i) for i in six.moves.range(p.desc.num_blocks())]
p._sync_with_cpp()
return p
......
......@@ -15,6 +15,7 @@
import os
import random
import six
import functools
import subprocess
import logging
......@@ -105,8 +106,9 @@ class Graph(object):
def _rank_repr(self):
ranks = sorted(
list(self.rank_groups.items()),
cmp=lambda a, b: a[1].priority > b[1].priority)
six.iteritems(self.rank_groups),
key=functools.cmp_to_key(
lambda a, b: a[1].priority > b[1].priority))
repr = []
for x in ranks:
repr.append(str(x[1]))
......@@ -149,7 +151,7 @@ class Node(object):
name=self.name,
label=self.label,
extra=',' + ','.join("%s=%s" % (key, crepr(value))
for key, value in list(self.attrs.items()))
for key, value in six.iteritems(self.attrs))
if self.attrs else "")
return reprs
......@@ -173,7 +175,7 @@ class Edge(object):
target=self.target.name,
extra="" if not self.attrs else
"[" + ','.join("{}={}".format(attr[0], crepr(attr[1]))
for attr in list(self.attrs.items())) + "]")
for attr in six.iteritems(self.attrs)) + "]")
return repr
......
......@@ -603,25 +603,15 @@ def save_inference_model(dirname,
# "./infer_model".
"""
if isinstance(feeded_var_names, six.binary_type):
if isinstance(feeded_var_names, six.string_types):
feeded_var_names = [feeded_var_names]
elif isinstance(feeded_var_names, six.text_type):
feeded_var_names = [feeded_var_names.encode()]
else:
if len(feeded_var_names) > 0:
# TODO(paddle-dev): polish these code blocks
if not (bool(feeded_var_names) and all(
isinstance(name, six.binary_type)
isinstance(name, six.string_types)
for name in feeded_var_names)):
if not (all(
isinstance(name, six.text_type)
for name in feeded_var_names)):
raise ValueError(
"'feed_var_names' should be a list of str.")
else:
feeded_var_names = [
name.encode() for name in feeded_var_names
]
raise ValueError("'feed_var_names' should be a list of str.")
if isinstance(target_vars, Variable):
target_vars = [target_vars]
......
......@@ -85,7 +85,7 @@ class LayerHelper(object):
raise ValueError("parameter number mismatch")
elif len(param_attr) == 1 and length != 1:
tmp = [None] * length
for i in range(length):
for i in six.moves.range(length):
tmp[i] = copy.deepcopy(param_attr[0])
param_attr = tmp
return param_attr
......
......@@ -22,6 +22,7 @@ from ..initializer import force_init_on_cpu
from .ops import logical_and, logical_not, logical_or
import numpy
import warnings
import six
from functools import reduce
__all__ = [
......@@ -602,7 +603,7 @@ class StaticRNN(object):
boot_memories = []
pre_memories = []
memories = []
for _, mem in list(self.memories.items()):
for _, mem in six.iteritems(self.memories):
boot_memories.append(mem.init)
pre_memories.append(mem.pre_mem.name)
mem_var = rnn_block.var(mem.mem.name)
......
......@@ -21,7 +21,9 @@ from ..layer_helper import LayerHelper
from . import tensor
from . import nn
from . import ops
from ... import compat as cpt
import math
import six
import numpy
from functools import reduce
......@@ -104,7 +106,7 @@ def rpn_target_assign(loc,
examples.
Returns:
tuple:
tuple:
A tuple(predicted_scores, predicted_location, target_label,
target_bbox) is returned. The predicted_scores and
predicted_location is the predicted result of the RPN.
......@@ -115,7 +117,7 @@ def rpn_target_assign(loc,
anchors. The predicted_scores is a 2D Tensor with shape
[F + B, 1], and the shape of target_label is same as the shape
of the predicted_scores, B is the number of the background
anchors, the F and B is depends on the input of this operator.
anchors, the F and B is depends on the input of this operator.
Examples:
.. code-block:: python
......@@ -232,8 +234,8 @@ def detection_output(loc,
nms_eta(float): The parameter for adaptive NMS.
Returns:
Variable:
Variable:
The detection outputs is a LoDTensor with shape [No, 6].
Each row has six values: [label, confidence, xmin, ymin, xmax, ymax].
`No` is the total number of detections in this mini-batch. For each
......@@ -504,7 +506,7 @@ def target_assign(input,
Assumed that the row offset for each instance in `neg_indices` is called neg_lod,
for i-th instance and each `id` of neg_indices in this instance:
.. code-block:: text
out[i][id][0 : K] = {mismatch_value, mismatch_value, ...}
......@@ -522,11 +524,11 @@ def target_assign(input,
mismatch_value (float32): Fill this value to the mismatched location.
Returns:
tuple:
A tuple(out, out_weight) is returned. out is a 3D Tensor with
shape [N, P, K], N and P is the same as they are in
`neg_indices`, K is the same as it in input of X. If
`match_indices[i][j]`. out_weight is the weight for output with
tuple:
A tuple(out, out_weight) is returned. out is a 3D Tensor with
shape [N, P, K], N and P is the same as they are in
`neg_indices`, K is the same as it in input of X. If
`match_indices[i][j]`. out_weight is the weight for output with
the shape of [N, P, 1].
Examples:
......@@ -834,7 +836,7 @@ def prior_box(input,
offset(float): Prior boxes center offset. Default: 0.5
name(str): Name of the prior box op. Default: None.
min_max_aspect_ratios_order(bool): If set True, the output prior box is
in order of [min, max, aspect_ratios], which is consistent with
in order of [min, max, aspect_ratios], which is consistent with
Caffe. Please note, this order affects the weights order of
convolution layer followed by and does not affect the final
detection results. Default: False.
......@@ -977,7 +979,7 @@ def multi_box_head(inputs,
stride(int|list|tuple): The stride of conv2d. Default:1,
name(str): Name of the prior box layer. Default: None.
min_max_aspect_ratios_order(bool): If set True, the output prior box is
in order of [min, max, aspect_ratios], which is consistent with
in order of [min, max, aspect_ratios], which is consistent with
Caffe. Please note, this order affects the weights order of
convolution layer followed by and does not affect the fininal
detection results. Default: False.
......@@ -1039,7 +1041,7 @@ def multi_box_head(inputs,
min_sizes = []
max_sizes = []
step = int(math.floor(((max_ratio - min_ratio)) / (num_layer - 2)))
for ratio in range(min_ratio, max_ratio + 1, step):
for ratio in six.moves.range(min_ratio, max_ratio + 1, step):
min_sizes.append(base_size * ratio / 100.)
max_sizes.append(base_size * (ratio + step) / 100.)
min_sizes = [base_size * .10] + min_sizes
......@@ -1108,8 +1110,8 @@ def multi_box_head(inputs,
mbox_loc = nn.transpose(mbox_loc, perm=[0, 2, 3, 1])
compile_shape = [
mbox_loc.shape[0],
mbox_loc.shape[1] * mbox_loc.shape[2] * mbox_loc.shape[3] / 4, 4
mbox_loc.shape[0], cpt.floor_division(
mbox_loc.shape[1] * mbox_loc.shape[2] * mbox_loc.shape[3], 4), 4
]
run_shape = tensor.assign(numpy.array([0, -1, 4]).astype("int32"))
mbox_loc_flatten = nn.reshape(
......@@ -1127,8 +1129,9 @@ def multi_box_head(inputs,
conf_loc = nn.transpose(conf_loc, perm=[0, 2, 3, 1])
new_shape = [0, -1, num_classes]
compile_shape = [
conf_loc.shape[0], conf_loc.shape[1] * conf_loc.shape[2] *
conf_loc.shape[3] / num_classes, num_classes
conf_loc.shape[0],
cpt.floor_division(conf_loc.shape[1] * conf_loc.shape[2] *
conf_loc.shape[3], num_classes), num_classes
]
run_shape = tensor.assign(
numpy.array([0, -1, num_classes]).astype("int32"))
......
......@@ -13,6 +13,7 @@
# limitations under the License.
import contextlib
import multiprocessing
import six
import threading
from ..data_feeder import DataFeeder
......@@ -69,7 +70,7 @@ def data(name,
"""
helper = LayerHelper('data', **locals())
shape = list(shape)
for i in range(len(shape)):
for i in six.moves.range(len(shape)):
if shape[i] is None:
shape[i] = -1
append_batch_size = False
......@@ -674,7 +675,7 @@ def py_reader(capacity,
def __tensor_provider__():
for slots in paddle_reader():
yield [slots[str(idx)] for idx in xrange(counter)]
yield [slots[str(idx)] for idx in six.moves.xrange(counter)]
__set_tensor_provider__(__tensor_provider__)
......@@ -750,7 +751,7 @@ def open_files(filenames,
else:
buffer_size = int(buffer_size)
if isinstance(filenames, basestring):
if isinstance(filenames, six.string_types):
filenames = [filenames]
dtypes = [convert_np_dtype_to_dtype_(dt) for dt in dtypes]
shape_concat = []
......@@ -1005,7 +1006,7 @@ class Preprocessor(object):
source_lod_levels = self.underlying_reader.desc.lod_levels()
self.source_var_names = [
unique_name("preprocessor_source")
for _ in range(len(source_shapes))
for _ in six.moves.range(len(source_shapes))
]
source_vars = []
for var_name, shape, dtype, lod_level in zip(
......
......@@ -362,7 +362,7 @@ def dynamic_lstm(input,
"""
helper = LayerHelper('lstm', **locals())
size = size / 4
size = size // 4
weight = helper.create_parameter(
attr=helper.param_attr, shape=[size, 4 * size], dtype=dtype)
bias_size = [1, 7 * size]
......@@ -552,7 +552,7 @@ def dynamic_lstmp(input,
"""
helper = LayerHelper('lstmp', **locals())
size = size / 4
size = size // 4
weight = helper.create_parameter(
attr=helper.param_attr, shape=[proj_size, 4 * size], dtype=dtype)
proj_weight = helper.create_parameter(
......@@ -780,7 +780,7 @@ def gru_unit(input,
helper = LayerHelper('gru_unit', **locals())
dtype = helper.input_dtype()
size = size / 3
size = size // 3
# create weight
weight = helper.create_parameter(
......@@ -1264,7 +1264,7 @@ def sequence_conv(input,
outputs={"Out": pre_bias},
attrs={
'contextStride': filter_stride,
'contextStart': -int(filter_size / 2),
'contextStart': -int(filter_size // 2),
'contextLength': filter_size
})
pre_act = helper.append_bias_op(pre_bias)
......@@ -1496,7 +1496,7 @@ def conv2d(input,
else:
if num_channels % groups != 0:
raise ValueError("num_channels must be divisible by groups.")
num_filter_channels = num_channels / groups
num_filter_channels = num_channels // groups
filter_size = utils.convert_to_list(filter_size, 2, 'filter_size')
stride = utils.convert_to_list(stride, 2, 'stride')
......@@ -1507,7 +1507,7 @@ def conv2d(input,
raise ValueError("use_cudnn should be True or False")
input_shape = input.shape
filter_shape = [num_filters, num_filter_channels] + filter_size
filter_shape = [num_filters, int(num_filter_channels)] + filter_size
def _get_default_param_initializer():
std = (2.0 / (filter_size[0]**2 * num_channels))**0.5
......@@ -1658,7 +1658,7 @@ def conv3d(input,
else:
if num_channels % groups != 0:
raise ValueError("num_channels must be divisible by groups.")
num_filter_channels = num_channels / groups
num_filter_channels = num_channels // groups
filter_size = utils.convert_to_list(filter_size, 3, 'filter_size')
stride = utils.convert_to_list(stride, 3, 'stride')
......@@ -2393,16 +2393,16 @@ def conv2d_transpose(input,
w_in = input.shape[3]
filter_size_h = (output_size[0] - (h_in - 1) * stride[0] + 2 *
padding[0] - 1) / dilation[0] + 1
padding[0] - 1) // dilation[0] + 1
filter_size_w = (output_size[1] - (w_in - 1) * stride[1] + 2 *
padding[1] - 1) / dilation[1] + 1
padding[1] - 1) // dilation[1] + 1
filter_size = [filter_size_h, filter_size_w]
else:
filter_size = utils.convert_to_list(filter_size, 2,
'conv2d_transpose.filter_size')
groups = 1 if groups is None else groups
filter_shape = [input_channel, num_filters / groups] + filter_size
filter_shape = [input_channel, num_filters // groups] + filter_size
img_filter = helper.create_parameter(
dtype=input.dtype, shape=filter_shape, attr=helper.param_attr)
......@@ -2560,18 +2560,18 @@ def conv3d_transpose(input,
w_in = input.shape[4]
filter_size_d = (output_size[0] - (d_in - 1) * stride[0] + 2 *
padding[0] - 1) / dilation[0] + 1
padding[0] - 1) // dilation[0] + 1
filter_size_h = (output_size[1] - (h_in - 1) * stride[1] + 2 *
padding[1] - 1) / dilation[1] + 1
padding[1] - 1) // dilation[1] + 1
filter_size_w = (output_size[2] - (w_in - 1) * stride[2] + 2 *
padding[2] - 1) / dilation[2] + 1
padding[2] - 1) // dilation[2] + 1
filter_size = [filter_size_d, filter_size_h, filter_size_w]
else:
filter_size = utils.convert_to_list(filter_size, 3,
'conv3d_transpose.filter_size')
groups = 1 if groups is None else groups
filter_shape = [input_channel, num_filters / groups] + filter_size
filter_shape = [input_channel, num_filters // groups] + filter_size
img_filter = helper.create_parameter(
dtype=input.dtype, shape=filter_shape, attr=helper.param_attr)
......@@ -2678,15 +2678,15 @@ def beam_search(pre_ids,
Refer to `Beam search <https://en.wikipedia.org/wiki/Beam_search>`_
for more details.
This layer does the search in beams for one time step. Specifically, it
This layer does the search in beams for one time step. Specifically, it
selects the top-K candidate word ids of current step from :attr:`ids`
according to their :attr:`scores` for all source sentences, where K is
:attr:`beam_size` and :attr:`ids, scores` are predicted results from the
computation cell. Additionally, :attr:`pre_ids` and :attr:`pre_scores` are
the output of beam_search at previous step, they are needed for special use
to handle ended candidate translations.
Note that the :attr:`scores` passed in should be accumulated scores, and
length penalty should be done with extra operators before calculating the
accumulated scores if needed, also suggest finding top-K before it and
......@@ -3887,7 +3887,7 @@ def nce(input,
def hsigmoid(input, label, num_classes, param_attr=None, bias_attr=None):
"""
The hierarchical sigmoid operator is used to accelerate the training
process of language model. This operator organizes the classes into a
process of language model. This operator organizes the classes into a
complete binary tree, each leaf node represents a class(a word) and each
internal node acts as a binary classifier. For each word there's a unique
path from root to it's leaf node, hsigmoid calculate the cost for each
......@@ -3897,9 +3897,9 @@ def hsigmoid(input, label, num_classes, param_attr=None, bias_attr=None):
Refer to `Hierarchical Probabilistic Neural Network Language Model
<http://www.iro.umontreal.ca/~lisa/pointeurs/hierarchical-nnlm-aistats05.pdf>`_
Args:
input (Variable): The input tensor variable with shape
input (Variable): The input tensor variable with shape
:math:`[N \\times D]`, where :math:`N` is the size of mini-batch,
and :math:`D` is the feature size.
label (Variable): The tensor variable contains labels of training data.
......@@ -3907,7 +3907,7 @@ def hsigmoid(input, label, num_classes, param_attr=None, bias_attr=None):
num_classes: (int), The number of classes, must not be less than 2.
param_attr (ParamAttr|list of ParamAttr, default None): The parameter
attribute for learnable parameters/weights of this layer.
bias_attr (ParamAttr|list of ParamAttr, default None): The parameter
bias_attr (ParamAttr|list of ParamAttr, default None): The parameter
attribute for the bias of this layer. If it is set to False, no
bias will be applied.
......@@ -5306,23 +5306,23 @@ def rank_loss(label, left, right, name=None):
is a pairwise ranking model with a training sample consisting of a pair
of documents, A and B. Label P indicates whether A is ranked higher than B
or not:
P = {0, 1} or {0, 0.5, 1}, where 0.5 means that there is no information
about the rank of the input pair.
Rank loss layer takes three inputs: left (o_i), right (o_j) and
label (P_{i,j}). The inputs respectively represent RankNet's output scores
for documents A and B and the value of label P. The following equation
computes rank loss C_{i,j} from the inputs:
$$
C_{i,j} = -\tilde{P_{ij}} * o_{i,j} + \log(1 + e^{o_{i,j}}) \\
o_{i,j} = o_i - o_j \\
\tilde{P_{i,j}} = \left \{0, 0.5, 1 \right \} \ or \ \left \{0, 1 \right \}
$$
Rank loss layer takes batch inputs with size batch_size (batch_size >= 1).
Rank loss layer takes batch inputs with size batch_size (batch_size >= 1).
Args:
label (Variable): Indicats whether A ranked higher than B or not.
left (Variable): RankNet's output score for doc A.
......
......@@ -14,11 +14,12 @@
"""
Fluid Metrics
The metrics are accomplished via Python natively.
The metrics are accomplished via Python natively.
"""
import numpy as np
import copy
import warnings
import six
__all__ = [
'MetricBase',
......@@ -79,10 +80,10 @@ class MetricBase(object):
"""
states = {
attr: value
for attr, value in list(self.__dict__.items())
for attr, value in six.iteritems(self.__dict__)
if not attr.startswith("_")
}
for attr, value in list(states.items()):
for attr, value in six.iteritems(states):
if isinstance(value, int):
setattr(self, attr, 0)
elif isinstance(value, float):
......@@ -105,7 +106,7 @@ class MetricBase(object):
"""
states = {
attr: value
for attr, value in list(self.__dict__.items())
for attr, value in six.iteritems(self.__dict__)
if not attr.startswith("_")
}
config = {}
......@@ -141,10 +142,10 @@ class CompositeMetric(MetricBase):
"""
Composite multiple metrics in one instance.
for example, merge F1, accuracy, recall into one Metric.
Examples:
.. code-block:: python
labels = fluid.layers.data(name="data", shape=[1], dtype="int32")
data = fluid.layers.data(name="data", shape=[32, 32], dtype="int32")
pred = fluid.layers.fc(input=data, size=1000, act="tanh")
......
......@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import six
from . import layers
__all__ = [
......@@ -210,7 +211,7 @@ def img_conv_group(input,
conv_with_batchnorm = __extend_list__(conv_with_batchnorm)
conv_batchnorm_drop_rate = __extend_list__(conv_batchnorm_drop_rate)
for i in range(len(conv_num_filter)):
for i in six.moves.range(len(conv_num_filter)):
local_conv_act = conv_act
if conv_with_batchnorm[i]:
local_conv_act = None
......
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import six
import paddle.fluid.core as core
......@@ -99,6 +100,8 @@ class OpDescCreationMethod(object):
new_attr = op_desc.attrs.add()
new_attr.name = attr.name
new_attr.type = attr.type
if isinstance(user_defined_attr, np.ndarray):
user_defined_attr = user_defined_attr.tolist()
if attr.type == framework_pb2.INT:
new_attr.i = user_defined_attr
elif attr.type == framework_pb2.FLOAT:
......
......@@ -17,8 +17,10 @@ import multiprocessing
from . import core
from . import framework
from . import executor
from .. import compat as cpt
import warnings
import sys
import six
import os
__all__ = ['ParallelExecutor', 'ExecutionStrategy', 'BuildStrategy']
......@@ -95,7 +97,7 @@ class ParallelExecutor(object):
self._places = []
self._act_places = []
if use_cuda:
for i in range(core.get_cuda_device_count()):
for i in six.moves.range(core.get_cuda_device_count()):
p = core.Place()
self._act_places.append(core.CUDAPlace(i))
p.set_place(self._act_places[-1])
......@@ -103,7 +105,7 @@ class ParallelExecutor(object):
else:
cpu_num = int(
os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
for i in range(cpu_num):
for i in six.moves.range(cpu_num):
p = core.Place()
self._act_places.append(core.CPUPlace())
p.set_place(self._act_places[-1])
......@@ -153,11 +155,13 @@ class ParallelExecutor(object):
self.executor = core.ParallelExecutor(
self._places,
set([
p.name for p in main.global_block().iter_parameters()
cpt.to_text(p.name)
for p in main.global_block().iter_parameters()
if not p.stop_gradient
]),
set(self.persistable_vars), main.desc, loss_name
if loss_name else '', scope, local_scopes, exec_strategy,
set(cpt.to_text(var) for var in self.persistable_vars), main.desc,
cpt.to_text(loss_name)
if loss_name else six.u(''), scope, local_scopes, exec_strategy,
build_strategy, num_trainers, trainer_id)
self.scope = scope
......@@ -269,7 +273,7 @@ class ParallelExecutor(object):
self.executor.feed_tensors_into_local_scopes(res)
fetch_var_name = '@FETCHED_VAR_NAME@'
self.executor.run(fetch_list, fetch_var_name)
self.executor.run(cpt.to_text(fetch_list), cpt.to_text(fetch_var_name))
arr = self.scope.find_var(fetch_var_name).get_lod_tensor_array()
if self.is_dist:
......
......@@ -15,6 +15,7 @@
from . import core
from contextlib import contextmanager
import os
import six
__all__ = [
'cuda_profiler', 'reset_profiler', 'profiler', 'start_profiler',
......@@ -88,7 +89,7 @@ def cuda_profiler(output_file, output_mode=None, config=None):
config = NVPROF_CONFIG if config is None else config
config_file = 'nvprof_config_file'
with open(config_file, 'wb') as fp:
fp.writelines(["%s\n" % item for item in config])
fp.writelines([six.b("%s\n" % item) for item in config])
core.nvprof_init(output_file, output_mode, config_file)
# Enables profiler collection by the active CUDA profiling tool.
core.nvprof_start()
......
......@@ -30,10 +30,10 @@ images per class.
import itertools
import numpy
import paddle.v2.dataset.common
import paddle.dataset.common
import tarfile
import six
from six.moves import cPickle as pickle
from six.moves import zip
__all__ = ['train10']
......@@ -44,20 +44,25 @@ CIFAR10_MD5 = 'c58f30108f718f92721af3b95e74349a'
def reader_creator(filename, sub_name, batch_size=None):
def read_batch(batch):
data = batch['data']
labels = batch.get('labels', batch.get('fine_labels', None))
data = batch[six.b('data')]
labels = batch.get(
six.b('labels'), batch.get(six.b('fine_labels'), None))
assert labels is not None
for sample, label in zip(data, labels):
for sample, label in six.moves.zip(data, labels):
yield (sample / 255.0).astype(numpy.float32), int(label)
def reader():
with tarfile.open(filename, mode='r') as f:
names = (each_item.name for each_item in f
if sub_name in each_item.name)
names = [
each_item.name for each_item in f if sub_name in each_item.name
]
batch_count = 0
for name in names:
batch = pickle.load(f.extractfile(name))
if six.PY2:
batch = pickle.load(f.extractfile(name))
else:
batch = pickle.load(f.extractfile(name), encoding='bytes')
for item in read_batch(batch):
if isinstance(batch_size, int) and batch_count > batch_size:
break
......@@ -78,6 +83,6 @@ def train10(batch_size=None):
:rtype: callable
"""
return reader_creator(
paddle.v2.dataset.common.download(CIFAR10_URL, 'cifar', CIFAR10_MD5),
paddle.dataset.common.download(CIFAR10_URL, 'cifar', CIFAR10_MD5),
'data_batch',
batch_size=batch_size)
......@@ -55,7 +55,7 @@ def resnet_cifar10(input, depth=32):
return tmp
assert (depth - 2) % 6 == 0
n = (depth - 2) / 6
n = (depth - 2) // 6
conv1 = conv_bn_layer(
input=input, ch_out=16, filter_size=3, stride=1, padding=1)
res1 = layer_warp(basicblock, conv1, 16, 16, n, 1)
......
......@@ -60,7 +60,7 @@ def resnet_cifar10(input, depth=32):
return tmp
assert (depth - 2) % 6 == 0
n = (depth - 2) / 6
n = (depth - 2) // 6
conv1 = conv_bn_layer(
input=input, ch_out=16, filter_size=3, stride=1, padding=1)
res1 = layer_warp(basicblock, conv1, 16, 16, n, 1)
......
......@@ -56,7 +56,7 @@ def resnet_cifar10(input, depth=32):
return tmp
assert (depth - 2) % 6 == 0
n = (depth - 2) / 6
n = (depth - 2) // 6
conv1 = conv_bn_layer(
input=input, ch_out=16, filter_size=3, stride=1, padding=1)
res1 = layer_warp(basicblock, conv1, 16, 16, n, 1)
......
......@@ -13,6 +13,7 @@
# limitations under the License.
import numpy
import six
import paddle
import paddle.dataset.mnist as mnist
......@@ -31,7 +32,7 @@ def network(is_train):
hidden = img
for i in xrange(2):
for i in six.moves.xrange(2):
hidden = fluid.layers.fc(input=hidden, size=100, act='tanh')
hidden = fluid.layers.dropout(
hidden, dropout_prob=0.5, is_test=not is_train)
......@@ -74,7 +75,7 @@ def main():
test_reader.decorate_paddle_reader(paddle.batch(mnist.test(), 512))
for epoch_id in xrange(10):
for epoch_id in six.moves.xrange(10):
train_reader.start()
try:
while True:
......
......@@ -54,7 +54,7 @@ class BenchmarkSuite(OpTest):
def _get_input_names(self):
inputs = []
for name, value in list(self.inputs.items()):
for name, value in six.iteritems(self.inputs):
if isinstance(value, list):
inputs.extend([sub_name for sub_name, _ in value])
inputs.append(name)
......@@ -62,7 +62,7 @@ class BenchmarkSuite(OpTest):
def _get_output_names(self):
outputs = []
for var_name, var in list(self.outputs.items()):
for var_name, var in six.iteritems(self.outputs):
if isinstance(var, list):
for sub_var_name, sub_var in var:
outputs.append(sub_var_name)
......
......@@ -173,7 +173,7 @@ class SE_ResNeXt():
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=(filter_size - 1) / 2,
padding=(filter_size - 1) // 2,
groups=groups,
act=None,
# avoid pserver CPU init differs from GPU
......@@ -187,7 +187,7 @@ class SE_ResNeXt():
input=input, pool_size=0, pool_type='avg', global_pooling=True)
stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0)
squeeze = fluid.layers.fc(input=pool,
size=num_channels / reduction_ratio,
size=num_channels // reduction_ratio,
act='relu')
stdv = 1.0 / math.sqrt(squeeze.shape[1] * 1.0)
excitation = fluid.layers.fc(input=squeeze,
......
......@@ -22,6 +22,7 @@ import paddle.fluid as fluid
from paddle.fluid import core
import os
import sys
import six
import transformer_model
import paddle.dataset.wmt16 as wmt16
......@@ -159,6 +160,7 @@ def get_model():
avg_cost = transformer(use_feed=False)
optimizer = fluid.optimizer.Adam()
optimizer.minimize(avg_cost)
fluid.memory_optimize(fluid.default_main_program())
return avg_cost
......@@ -222,7 +224,7 @@ class DistTransformer2x2(object):
first_loss, = exe.run(fetch_list=[avg_cost.name])
print(first_loss)
for i in xrange(5):
for i in six.moves.xrange(5):
_ = exe.run(fetch_list=[avg_cost.name])
last_loss, = exe.run(fetch_list=[avg_cost.name])
print(last_loss)
......
......@@ -15,6 +15,7 @@
import unittest
import numpy as np
import random
import six
import time
import itertools
import collections
......@@ -26,15 +27,13 @@ from paddle.fluid.op import Operator
from paddle.fluid.executor import Executor
from paddle.fluid.framework import Program, OpProtoHolder, Variable
from testsuite import create_op, set_input, append_input_output, append_loss_ops
from functools import reduce
from six.moves import zip
def randomize_probability(batch_size, class_num, dtype='float32'):
prob = np.random.uniform(
0.1, 1.0, size=(batch_size, class_num)).astype(dtype)
prob_sum = prob.sum(axis=1)
for i in range(len(prob)):
for i in six.moves.xrange(len(prob)):
prob[i] /= prob_sum[i]
return prob
......@@ -51,7 +50,7 @@ def get_numeric_gradient(place,
set_input(scope, op, inputs, place)
def product(dim):
return reduce(lambda a, b: a * b, dim, 1)
return six.moves.reduce(lambda a, b: a * b, dim, 1)
def get_output():
sum = []
......@@ -103,7 +102,7 @@ def get_numeric_gradient(place,
# we only compute gradient of one element each time.
# we use a for loop to compute the gradient of every element.
for i in range(tensor_size):
for i in six.moves.xrange(tensor_size):
if in_place:
set_input(scope, op, inputs, place)
......@@ -161,7 +160,7 @@ class OpTest(unittest.TestCase):
assert isinstance(
numpy_dict,
dict), "self.inputs, self.outputs must be numpy_dict"
for var_name, var_value in numpy_dict.items():
for var_name, var_value in six.iteritems(numpy_dict):
if isinstance(var_value, (np.ndarray, np.generic)):
self.try_call_once(var_value.dtype)
elif isinstance(var_value, (list, tuple)):
......@@ -225,7 +224,7 @@ class OpTest(unittest.TestCase):
def _get_io_vars(self, block, numpy_inputs):
inputs = {}
for name, value in numpy_inputs.items():
for name, value in six.iteritems(numpy_inputs):
if isinstance(value, list):
var_list = [
block.var(sub_name) for sub_name, sub_value in value
......@@ -268,7 +267,7 @@ class OpTest(unittest.TestCase):
# if the fetch_list is customized by user, we use it directly.
# if not, fill the fetch_list by the user configured outputs in test.
if len(fetch_list) == 0:
for var_name, var in outputs.items():
for var_name, var in six.iteritems(outputs):
if isinstance(var, list):
for v in var:
fetch_list.append(v)
......@@ -366,12 +365,13 @@ class OpTest(unittest.TestCase):
for place in places:
outs = self.calc_output(place)
outs = [np.array(out) for out in outs]
outs.sort(key=len)
checker(outs)
def __assert_is_close(self, numeric_grads, analytic_grads, names,
max_relative_error, msg_prefix):
for a, b, name in zip(numeric_grads, analytic_grads, names):
for a, b, name in six.moves.zip(numeric_grads, analytic_grads, names):
abs_a = np.abs(a)
abs_a[abs_a < 1e-3] = 1
......
此差异已折叠。
......@@ -24,12 +24,12 @@ def conv2d_forward_naive(input, filter, group, conv_param):
out_c, f_c, f_h, f_w = filter.shape
assert f_c * group == in_c
assert np.mod(out_c, group) == 0
sub_out_c = out_c / group
sub_out_c = out_c // group
stride, pad, dilation = conv_param['stride'], conv_param['pad'], conv_param[
'dilation']
out_h = 1 + (in_h + 2 * pad[0] - (dilation[0] * (f_h - 1) + 1)) / stride[0]
out_w = 1 + (in_w + 2 * pad[1] - (dilation[1] * (f_w - 1) + 1)) / stride[1]
out_h = 1 + (in_h + 2 * pad[0] - (dilation[0] * (f_h - 1) + 1)) // stride[0]
out_w = 1 + (in_w + 2 * pad[1] - (dilation[1] * (f_w - 1) + 1)) // stride[1]
out = np.zeros((in_n, out_c, out_h, out_w))
d_bolck_h = (dilation[0] * (f_h - 1) + 1)
......@@ -138,7 +138,7 @@ class TestConv2dOp(OpTest):
self.stride = [1, 1]
self.input_size = [2, 3, 5, 5] # NCHW
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] / self.groups
f_c = self.input_size[1] // self.groups
self.filter_size = [6, f_c, 3, 3]
def init_dilation(self):
......@@ -157,7 +157,7 @@ class TestWithPad(TestConv2dOp):
self.stride = [1, 1]
self.input_size = [2, 3, 5, 5] # NCHW
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] / self.groups
f_c = self.input_size[1] // self.groups
self.filter_size = [6, f_c, 3, 3]
......@@ -167,7 +167,7 @@ class TestWithStride(TestConv2dOp):
self.stride = [2, 2]
self.input_size = [2, 3, 6, 6] # NCHW
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] / self.groups
f_c = self.input_size[1] // self.groups
self.filter_size = [6, f_c, 3, 3]
......@@ -182,7 +182,7 @@ class TestWith1x1(TestConv2dOp):
self.stride = [1, 1]
self.input_size = [2, 3, 5, 5] # NCHW
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] / self.groups
f_c = self.input_size[1] // self.groups
self.filter_size = [6, f_c, 1, 1]
def init_group(self):
......@@ -195,7 +195,7 @@ class TestWithDilation(TestConv2dOp):
self.stride = [1, 1]
self.input_size = [2, 3, 10, 10] # NCHW
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] / self.groups
f_c = self.input_size[1] // self.groups
self.filter_size = [6, f_c, 3, 3]
def init_dilation(self):
......@@ -211,7 +211,7 @@ class TestWithInput1x1Filter1x1(TestConv2dOp):
self.stride = [1, 1]
self.input_size = [2, 3, 1, 1] # NCHW
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] / self.groups
f_c = self.input_size[1] // self.groups
self.filter_size = [6, f_c, 1, 1]
def init_group(self):
......@@ -328,7 +328,7 @@ class TestDepthwiseConv(TestConv2dOp):
self.input_size = [2, 3, 5, 5] # NCHW
self.groups = 3
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] / self.groups
f_c = self.input_size[1] // self.groups
self.filter_size = [6, f_c, 3, 3]
self.op_type = "depthwise_conv2d"
......@@ -340,7 +340,7 @@ class TestDepthwiseConv2(TestConv2dOp):
self.input_size = [2, 3, 5, 5] # NCHW
self.groups = 3
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] / self.groups
f_c = self.input_size[1] // self.groups
self.filter_size = [6, f_c, 3, 3]
self.op_type = "depthwise_conv2d"
......
......@@ -25,7 +25,7 @@ def conv2dtranspose_forward_naive(input_, filter_, attrs):
groups = attrs['groups']
assert in_c == f_c
out_c = f_out_c * groups
sub_in_c = in_c / groups
sub_in_c = in_c // groups
stride, pad, dilations = attrs['strides'], attrs['paddings'], attrs[
'dilations']
......@@ -258,7 +258,7 @@ class TestDepthwiseConvTranspose(TestConv2dTransposeOp):
self.input_size = [2, 8, 16, 16] # NCHW
self.groups = 8
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] / self.groups
f_c = self.input_size[1] // self.groups
self.filter_size = [self.input_size[1], f_c, 4, 4]
self.op_type = "depthwise_conv2d_transpose"
......
......@@ -24,14 +24,14 @@ def conv3d_forward_naive(input, filter, group, conv_param):
out_c, f_c, f_d, f_h, f_w = filter.shape
assert f_c * group == in_c
assert np.mod(out_c, group) == 0
sub_out_c = out_c / group
sub_out_c = out_c // group
stride, pad, dilation = conv_param['stride'], conv_param['pad'], conv_param[
'dilations']
out_d = 1 + (in_d + 2 * pad[0] - (dilation[0] * (f_d - 1) + 1)) / stride[0]
out_h = 1 + (in_h + 2 * pad[1] - (dilation[1] * (f_h - 1) + 1)) / stride[1]
out_w = 1 + (in_w + 2 * pad[2] - (dilation[2] * (f_w - 1) + 1)) / stride[2]
out_d = 1 + (in_d + 2 * pad[0] - (dilation[0] * (f_d - 1) + 1)) // stride[0]
out_h = 1 + (in_h + 2 * pad[1] - (dilation[1] * (f_h - 1) + 1)) // stride[1]
out_w = 1 + (in_w + 2 * pad[2] - (dilation[2] * (f_w - 1) + 1)) // stride[2]
out = np.zeros((in_n, out_c, out_d, out_h, out_w))
......@@ -166,7 +166,7 @@ class TestConv3dOp(OpTest):
self.stride = [1, 1, 1]
self.input_size = [2, 3, 4, 4, 4] # NCDHW
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] / self.groups
f_c = self.input_size[1] // self.groups
self.filter_size = [6, f_c, 3, 3, 3]
def init_dilation(self):
......@@ -185,7 +185,7 @@ class TestCase1(TestConv3dOp):
self.stride = [1, 1, 1]
self.input_size = [2, 3, 4, 4, 4] # NCDHW
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] / self.groups
f_c = self.input_size[1] // self.groups
self.filter_size = [6, f_c, 3, 3, 3]
......@@ -205,7 +205,7 @@ class TestWith1x1(TestConv3dOp):
self.stride = [1, 1, 1]
self.input_size = [2, 3, 4, 4, 4] # NCHW
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] / self.groups
f_c = self.input_size[1] // self.groups
self.filter_size = [6, f_c, 1, 1, 1]
def init_dilation(self):
......@@ -221,7 +221,7 @@ class TestWithInput1x1Filter1x1(TestConv3dOp):
self.stride = [1, 1, 1]
self.input_size = [2, 3, 1, 1, 1] # NCHW
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] / self.groups
f_c = self.input_size[1] // self.groups
self.filter_size = [6, f_c, 1, 1, 1]
def init_dilation(self):
......@@ -237,7 +237,7 @@ class TestWithDilation(TestConv3dOp):
self.stride = [1, 1, 1]
self.input_size = [2, 3, 6, 6, 6] # NCDHW
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] / self.groups
f_c = self.input_size[1] // self.groups
self.filter_size = [6, f_c, 2, 2, 2]
def init_dilation(self):
......
......@@ -25,7 +25,7 @@ def conv3dtranspose_forward_naive(input_, filter_, attrs):
groups = attrs['groups']
assert in_c == f_c
out_c = f_out_c * groups
sub_in_c = in_c / groups
sub_in_c = in_c // groups
stride, pad, dilations = attrs['strides'], attrs['paddings'], attrs[
'dilations']
......
......@@ -21,7 +21,7 @@ def conv_shift_forward(x, y):
out = np.zeros_like(x)
M = x.shape[1]
N = y.shape[1]
y_half_width = (N - 1) / 2
y_half_width = (N - 1) // 2
for i in range(M):
for j in range(N):
out[:, i] += x[:, (i + j + M - y_half_width) % M] * y[:, j]
......
......@@ -14,7 +14,7 @@
import unittest
import paddle.fluid as fluid
import paddle.v2 as paddle
import paddle
import numpy as np
......
......@@ -14,6 +14,7 @@
import unittest
import numpy as np
import six
import sys
import collections
import math
......@@ -176,7 +177,7 @@ class TestDetectionMAPOp(OpTest):
true_pos[label].append([score, tp])
false_pos[label].append([score, fp])
for (label, label_pos_num) in list(label_count.items()):
for (label, label_pos_num) in six.iteritems(label_count):
if label_pos_num == 0 or label not in true_pos: continue
label_true_pos = true_pos[label]
label_false_pos = false_pos[label]
......
......@@ -16,6 +16,7 @@ import time
import unittest
import os
import sys
import six
import signal
import subprocess
import six
......@@ -123,6 +124,9 @@ def runtime_main(test_class):
model.run_trainer(p, endpoints, trainer_id, trainers, is_dist)
import paddle.compat as cpt
class TestDistBase(unittest.TestCase):
def setUp(self):
self._trainers = 2
......@@ -209,7 +213,7 @@ class TestDistBase(unittest.TestCase):
local_proc.wait()
out, err = local_proc.communicate()
local_ret = out
local_ret = cpt.to_text(out)
sys.stderr.write('local_loss: %s\n' % local_ret)
sys.stderr.write('local_stderr: %s\n' % err)
......@@ -256,7 +260,7 @@ class TestDistBase(unittest.TestCase):
tr1_proc.wait()
out, err = tr0_proc.communicate()
sys.stderr.write('dist_stderr: %s\n' % err)
loss_data0 = out
loss_data0 = cpt.to_text(out)
sys.stderr.write('dist_loss: %s\n' % loss_data0)
lines = loss_data0.split("\n")
dist_first_loss = eval(lines[0].replace(" ", ","))[0]
......
......@@ -49,6 +49,7 @@ class TranspilerTest(unittest.TestCase):
def get_main_program(self):
main = fluid.Program()
main.random_seed = 1
with fluid.program_guard(main):
self.net_conf()
self.origin_prog = main.clone()
......
......@@ -12,19 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.compat as cpt
import paddle.fluid.core as core
import unittest
class TestException(unittest.TestCase):
def test_exception(self):
ex = None
exception = None
try:
core.__unittest_throw_exception__()
except core.EnforceNotMet as ex:
self.assertIn("test exception", ex.message)
self.assertIn("test exception", cpt.get_exception_message(ex))
exception = ex
self.assertIsNotNone(ex)
self.assertIsNotNone(exception)
if __name__ == "__main__":
......
......@@ -15,6 +15,7 @@
import unittest
import numpy as np
import math
import functools
from op_test import OpTest
from test_lstm_op import identity, sigmoid, tanh, relu
......@@ -38,7 +39,8 @@ class TestGRUOp(OpTest):
for i in range(len(seq_lens)):
seq_starts.append(seq_starts[-1] + seq_lens[i])
sorted_seqs = sorted(
list(range(len(seq_lens))), lambda x, y: seq_lens[y] - seq_lens[x])
list(range(len(seq_lens))),
key=functools.cmp_to_key(lambda x, y: seq_lens[y] - seq_lens[x]))
num_batch = seq_lens[sorted_seqs[0]]
for batch_idx in range(num_batch):
idx_in_seq = []
......
......@@ -14,6 +14,7 @@
import unittest
import six
import paddle.fluid.core as core
......@@ -27,14 +28,14 @@ class TestInferShape(unittest.TestCase):
shape = [10, 20]
# prepare input/output
x1 = block.var("x1")
x1 = block.var(six.b("x1"))
x1.set_type(core.VarDesc.VarType.LOD_TENSOR)
x1.set_shape(shape)
x2 = block.var("x2")
x2 = block.var(six.b("x2"))
x2.set_type(core.VarDesc.VarType.LOD_TENSOR)
x2.set_shape(shape)
out = block.var("out")
out = block.var(six.b("out"))
out.set_type(core.VarDesc.VarType.LOD_TENSOR)
# prepare the operator
......@@ -57,14 +58,14 @@ class TestInferShape(unittest.TestCase):
y_shape = [20, 30]
# prepare input/output
x1 = block.var("x")
x1 = block.var(six.b("x"))
x1.set_type(core.VarDesc.VarType.LOD_TENSOR)
x1.set_shape(x_shape)
x2 = block.var("y")
x2 = block.var(six.b("y"))
x2.set_type(core.VarDesc.VarType.LOD_TENSOR)
x2.set_shape(y_shape)
out = block.var("out")
out = block.var(six.b("out"))
out.set_type(core.VarDesc.VarType.LOD_TENSOR)
# prepare the operator
......
......@@ -14,6 +14,7 @@
import unittest
import six
import numpy as np
import paddle.fluid.core as core
......@@ -48,7 +49,7 @@ class TestBook(unittest.TestCase):
exe.run(init_program, feed={}, fetch_list=[])
for i in range(100):
for i in six.moves.xrange(100):
tensor_x = np.array(
[[1, 1], [1, 2], [3, 4], [5, 2]]).astype("float32")
tensor_y = np.array([[-2], [-3], [-7], [-7]]).astype("float32")
......@@ -64,7 +65,7 @@ class TestBook(unittest.TestCase):
'y': tensor_y},
fetch_list=[avg_cost])[0]
reload(executor) # reload to build a new scope
six.moves.reload_module(executor) # reload to build a new scope
exe = executor.Executor(place)
[infer_prog, feed_var_names, fetch_vars] = load_inference_model(
......
......@@ -159,7 +159,7 @@ class TestBook(unittest.TestCase):
input=crf_decode,
label=label,
chunk_scheme="IOB",
num_chunk_types=(label_dict_len - 1) / 2)
num_chunk_types=(label_dict_len - 1) // 2)
self.assertFalse(crf is None)
self.assertFalse(crf_decode is None)
......@@ -286,7 +286,7 @@ class TestBook(unittest.TestCase):
name='word_{0}'.format(i), shape=[1], dtype='int64'))
dict_size = 10000
label_word = int(window_size / 2) + 1
label_word = int(window_size // 2) + 1
embs = []
for i in range(window_size):
......
......@@ -17,6 +17,7 @@ import numpy as np
from op_test import OpTest
import paddle.fluid.core as core
from paddle.fluid.op import Operator
import paddle.compat as cpt
class TestLookupTableOp(OpTest):
......@@ -71,7 +72,7 @@ class TestLookupTableOpWithTensorIdsAndPadding(TestLookupTableOpWithTensorIds):
flatten_idx = ids.flatten()
padding_idx = np.random.choice(flatten_idx, 1)[0]
self.outputs['Out'][np.squeeze(ids == padding_idx)] = np.zeros(31)
self.attrs = {'padding_idx': long(padding_idx)}
self.attrs = {'padding_idx': cpt.long_type(padding_idx)}
self.check_output()
def test_check_grad(self):
......
......@@ -34,7 +34,7 @@ class TestLRNOp(OpTest):
return x + 1
def get_out(self):
start = -(self.n - 1) / 2
start = -(self.n - 1) // 2
end = start + self.n
mid = np.empty((self.N, self.C, self.H, self.W)).astype("float32")
......
......@@ -19,7 +19,7 @@ from op_test import OpTest
def maxout_forward_naive(input, groups):
s0, s1, s2, s3 = input.shape
return np.ndarray([s0, s1 / groups, groups, s2, s3], \
return np.ndarray([s0, s1 // groups, groups, s2, s3], \
buffer = input, dtype=input.dtype).max(axis=(2))
......
......@@ -15,6 +15,7 @@
import unittest
import paddle.fluid.core as core
import paddle.compat as cpt
from paddle.fluid.framework import Program, default_startup_program
......@@ -29,14 +30,15 @@ class TestOperator(unittest.TestCase):
self.assertFail()
except ValueError as v_err:
self.assertEqual(
v_err.message,
cpt.get_exception_message(v_err),
"`type` to initilized an Operator can not be None.")
try:
block.append_op(type="no_such_op")
self.assertFail()
except ValueError as a_err:
self.assertEqual(a_err.message,
"Operator \"no_such_op\" has not been registered.")
self.assertEqual(
cpt.get_exception_message(a_err),
"Operator \"no_such_op\" has not been registered.")
def test_op_desc_creation(self):
program = Program()
......
......@@ -46,7 +46,7 @@ def squeeze_excitation(input, num_channels, reduction_ratio):
pool = fluid.layers.reduce_mean(input=reshape, dim=2)
squeeze = fluid.layers.fc(input=pool,
size=num_channels / reduction_ratio,
size=num_channels // reduction_ratio,
act='relu')
excitation = fluid.layers.fc(input=squeeze,
size=num_channels,
......@@ -62,7 +62,7 @@ def conv_bn_layer(input, num_filters, filter_size, stride=1, groups=1,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=(filter_size - 1) / 2,
padding=(filter_size - 1) // 2,
groups=groups,
act=None,
bias_attr=False)
......
......@@ -29,11 +29,11 @@ def max_pool2D_forward_naive(x,
if global_pool == 1:
ksize = [H, W]
H_out = (H - ksize[0] + 2 * paddings[0] + strides[0] - 1
) / strides[0] + 1 if ceil_mode else (H - ksize[0] + 2 *
paddings[0]) / strides[0] + 1
) // strides[0] + 1 if ceil_mode else (
H - ksize[0] + 2 * paddings[0]) // strides[0] + 1
W_out = (W - ksize[1] + 2 * paddings[1] + strides[1] - 1
) / strides[1] + 1 if ceil_mode else (W - ksize[1] + 2 *
paddings[1]) / strides[1] + 1
) // strides[1] + 1 if ceil_mode else (
W - ksize[1] + 2 * paddings[1]) // strides[1] + 1
out = np.zeros((N, C, H_out, W_out))
for i in range(H_out):
for j in range(W_out):
......@@ -57,11 +57,11 @@ def avg_pool2D_forward_naive(x,
if global_pool == 1:
ksize = [H, W]
H_out = (H - ksize[0] + 2 * paddings[0] + strides[0] - 1
) / strides[0] + 1 if ceil_mode else (H - ksize[0] + 2 *
paddings[0]) / strides[0] + 1
) // strides[0] + 1 if ceil_mode else (
H - ksize[0] + 2 * paddings[0]) // strides[0] + 1
W_out = (W - ksize[1] + 2 * paddings[1] + strides[1] - 1
) / strides[1] + 1 if ceil_mode else (W - ksize[1] + 2 *
paddings[1]) / strides[1] + 1
) // strides[1] + 1 if ceil_mode else (
W - ksize[1] + 2 * paddings[1]) // strides[1] + 1
out = np.zeros((N, C, H_out, W_out))
for i in range(H_out):
for j in range(W_out):
......
......@@ -29,14 +29,14 @@ def max_pool3D_forward_naive(x,
if global_pool == 1:
ksize = [D, H, W]
D_out = (D - ksize[0] + 2 * paddings[0] + strides[0] - 1
) / strides[0] + 1 if ceil_mode else (H - ksize[0] + 2 *
paddings[0]) / strides[0] + 1
) // strides[0] + 1 if ceil_mode else (
H - ksize[0] + 2 * paddings[0]) // strides[0] + 1
H_out = (H - ksize[1] + 2 * paddings[1] + strides[1] - 1
) / strides[1] + 1 if ceil_mode else (W - ksize[1] + 2 *
paddings[1]) / strides[1] + 1
) // strides[1] + 1 if ceil_mode else (
W - ksize[1] + 2 * paddings[1]) // strides[1] + 1
W_out = (W - ksize[2] + 2 * paddings[2] + strides[2] - 1
) / strides[2] + 1 if ceil_mode else (W - ksize[2] + 2 *
paddings[2]) / strides[2] + 1
) // strides[2] + 1 if ceil_mode else (
W - ksize[2] + 2 * paddings[2]) // strides[2] + 1
out = np.zeros((N, C, D_out, H_out, W_out))
for k in range(D_out):
d_start = np.max((k * strides[0] - paddings[0], 0))
......@@ -63,14 +63,14 @@ def avg_pool3D_forward_naive(x,
if global_pool == 1:
ksize = [D, H, W]
D_out = (D - ksize[0] + 2 * paddings[0] + strides[0] - 1
) / strides[0] + 1 if ceil_mode else (H - ksize[0] + 2 *
paddings[0]) / strides[0] + 1
) // strides[0] + 1 if ceil_mode else (
H - ksize[0] + 2 * paddings[0]) // strides[0] + 1
H_out = (H - ksize[1] + 2 * paddings[1] + strides[1] - 1
) / strides[1] + 1 if ceil_mode else (W - ksize[1] + 2 *
paddings[1]) / strides[1] + 1
) // strides[1] + 1 if ceil_mode else (
W - ksize[1] + 2 * paddings[1]) // strides[1] + 1
W_out = (W - ksize[2] + 2 * paddings[2] + strides[2] - 1
) / strides[2] + 1 if ceil_mode else (W - ksize[2] + 2 *
paddings[2]) / strides[2] + 1
) // strides[2] + 1 if ceil_mode else (
W - ksize[2] + 2 * paddings[2]) // strides[2] + 1
out = np.zeros((N, C, D_out, H_out, W_out))
for k in range(D_out):
d_start = np.max((k * strides[0] - paddings[0], 0))
......
......@@ -24,9 +24,9 @@ def max_pool3D_forward_naive(x, ksize, strides, paddings, global_pool=False):
ksize = [D, H, W]
paddings = [0, 0, 0]
D_out = (D - ksize[0] + 2 * paddings[0]) / strides[0] + 1
H_out = (H - ksize[1] + 2 * paddings[1]) / strides[1] + 1
W_out = (W - ksize[2] + 2 * paddings[2]) / strides[2] + 1
D_out = (D - ksize[0] + 2 * paddings[0]) // strides[0] + 1
H_out = (H - ksize[1] + 2 * paddings[1]) // strides[1] + 1
W_out = (W - ksize[2] + 2 * paddings[2]) // strides[2] + 1
out = np.zeros((N, C, D_out, H_out, W_out))
mask = np.zeros((N, C, D_out, H_out, W_out))
for k in range(D_out):
......@@ -63,8 +63,8 @@ def max_pool2D_forward_naive(x, ksize, strides, paddings, global_pool=False):
ksize = [H, W]
paddings = [0, 0]
H_out = (H - ksize[0] + 2 * paddings[0]) / strides[0] + 1
W_out = (W - ksize[1] + 2 * paddings[1]) / strides[1] + 1
H_out = (H - ksize[0] + 2 * paddings[0]) // strides[0] + 1
W_out = (W - ksize[1] + 2 * paddings[1]) // strides[1] + 1
out = np.zeros((N, C, H_out, W_out))
mask = np.zeros((N, C, H_out, W_out))
for i in range(H_out):
......
......@@ -15,6 +15,7 @@
import unittest
import itertools
import numpy as np
import six
from op_test import OpTest
......@@ -32,7 +33,7 @@ def py_pnpair_op(score, label, query, column=-1, weight=None):
# accumulate statistics
pos, neg, neu = 0, 0, 0
for _, ranks in list(predictions.items()):
for _, ranks in six.iteritems(predictions):
for e1, e2 in itertools.combinations(ranks, 2):
s1, s2, l1, l2, w1, w2 = e1[0], e2[0], e1[1], e2[1], e1[2], e2[2]
w = (w1 + w2) * 0.5
......
......@@ -15,9 +15,9 @@
import unittest
import numpy as np
import paddle
import paddle.fluid as fluid
import paddle.v2 as paddle
import paddle.v2.dataset.mnist as mnist
import paddle.dataset.mnist as mnist
class TestPreprocessor(unittest.TestCase):
......
......@@ -93,7 +93,7 @@ class TestProfiler(unittest.TestCase):
"profiler is enabled only with GPU")
def test_all_profiler(self):
self.net_profiler('All', '/tmp/profile_out')
with open('/tmp/profile_out', 'r') as f:
with open('/tmp/profile_out', 'rb') as f:
self.assertGreater(len(f.read()), 0)
......
......@@ -14,6 +14,7 @@
import unittest
import paddle.fluid.core as core
import paddle.compat as cpt
from paddle.fluid.framework import Program
......@@ -108,7 +109,7 @@ class TestVarDesc(unittest.TestCase):
def test_shape(self):
program_desc = core.ProgramDesc()
block = program_desc.block(0)
var = block.var('my_var')
var = block.var(cpt.to_bytes('my_var'))
var.set_type(core.VarDesc.VarType.SELECTED_ROWS)
src_shape = [3, 2, 10, 8]
var.set_shape(src_shape)
......@@ -119,7 +120,7 @@ class TestVarDesc(unittest.TestCase):
def test_multiple_shape(self):
program_desc = core.ProgramDesc()
block = program_desc.block(0)
var = block.var('my_reader')
var = block.var(cpt.to_bytes('my_reader'))
var.set_type(core.VarDesc.VarType.READER)
src_shapes = [[2, 3, 3], [4, 5], [6, 7, 8, 9]]
var.set_shapes(src_shapes)
......@@ -130,7 +131,7 @@ class TestVarDesc(unittest.TestCase):
def test_dtype(self):
program_desc = core.ProgramDesc()
block = program_desc.block(0)
var = block.var('my_var')
var = block.var(cpt.to_bytes('my_var'))
var.set_type(core.VarDesc.VarType.LOD_TENSOR)
var.set_dtype(core.VarDesc.VarType.INT32)
self.assertEqual(core.VarDesc.VarType.INT32, var.dtype())
......@@ -139,7 +140,7 @@ class TestVarDesc(unittest.TestCase):
def test_multiple_dtype(self):
program_desc = core.ProgramDesc()
block = program_desc.block(0)
var = block.var('my_reader')
var = block.var(cpt.to_bytes('my_reader'))
var.set_type(core.VarDesc.VarType.READER)
src_types = [
core.VarDesc.VarType.INT32, core.VarDesc.VarType.FP64,
......@@ -152,7 +153,7 @@ class TestVarDesc(unittest.TestCase):
def test_multiple_lod_level(self):
program_desc = core.ProgramDesc()
block = program_desc.block(0)
var = block.var('my_reader')
var = block.var(cpt.to_bytes('my_reader'))
var.set_type(core.VarDesc.VarType.READER)
src_types = [3, 1, 2]
var.set_lod_levels(src_types)
......@@ -166,12 +167,12 @@ class TestBlockDesc(unittest.TestCase):
self.assertIsNotNone(program_desc)
block = program_desc.block(0)
self.assertIsNotNone(block)
var1 = block.var("var1")
var2 = block.var("var2")
var3 = block.var("var3")
var1 = block.var(cpt.to_bytes("var1"))
var2 = block.var(cpt.to_bytes("var2"))
var3 = block.var(cpt.to_bytes("var3"))
all_vars = block.all_vars()
self.assertEqual(set(all_vars), {var1, var2, var3})
var2_re = block.find_var("var2")
var2_re = block.find_var(cpt.to_bytes("var2"))
self.assertEqual(var2_re, var2)
def test_add_op(self):
......
......@@ -21,11 +21,12 @@ from op_test import OpTest
class TestRandomCropOp(OpTest):
def setUp(self):
to_crop = np.array([[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]] *
5).astype("float32")
5).astype(np.int32)
self.possible_res = [
np.array([[1, 2, 3], [5, 6, 7]]), np.array([[2, 3, 4], [6, 7, 8]]),
np.array([[5, 6, 7], [9, 10, 11]]),
np.array([[6, 7, 8], [10, 11, 12]])
np.array([[1, 2, 3], [5, 6, 7]]).astype(np.int32),
np.array([[2, 3, 4], [6, 7, 8]]).astype(np.int32),
np.array([[5, 6, 7], [9, 10, 11]]).astype(np.int32),
np.array([[6, 7, 8], [10, 11, 12]]).astype(np.int32)
]
self.op_type = "random_crop"
self.inputs = {'X': to_crop, 'Seed': np.array([10])}
......
......@@ -13,7 +13,7 @@
# limitations under the License.
import paddle.fluid as fluid
import paddle.v2 as paddle
import paddle
import numpy as np
import unittest
......
......@@ -15,8 +15,8 @@
import unittest
import paddle.fluid as fluid
import paddle.v2 as paddle
import paddle.v2.dataset.mnist as mnist
import paddle
import paddle.dataset.mnist as mnist
class TestRecordIO(unittest.TestCase):
......
......@@ -17,6 +17,7 @@ import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid.layers.control_flow import lod_rank_table
import numpy
import functools
class TestReorderLoDTensor(unittest.TestCase):
......@@ -101,7 +102,8 @@ class TestReorderLoDTensor(unittest.TestCase):
rank_table = [] # list of (index, length)
for i in range(len(ref_lod)):
rank_table.append((i, ref_lod[i]))
rank_table = sorted(rank_table, lambda x, y: y[1] - x[1])
rank_table = sorted(
rank_table, key=functools.cmp_to_key(lambda x, y: y[1] - x[1]))
# compute the input sequence info according to input_lod
input_value, input_lod = self.data[self.data_desc[0][0]]
......
......@@ -16,6 +16,7 @@ import unittest
import numpy as np
import math
import sys
import paddle.compat as cpt
from op_test import OpTest
......@@ -59,10 +60,10 @@ class TestROIPoolOp(OpTest):
for i in range(self.rois_num):
roi = self.rois[i]
roi_batch_id = roi[0]
roi_start_w = int(round(roi[1] * self.spatial_scale))
roi_start_h = int(round(roi[2] * self.spatial_scale))
roi_end_w = int(round(roi[3] * self.spatial_scale))
roi_end_h = int(round(roi[4] * self.spatial_scale))
roi_start_w = int(cpt.round(roi[1] * self.spatial_scale))
roi_start_h = int(cpt.round(roi[2] * self.spatial_scale))
roi_end_w = int(cpt.round(roi[3] * self.spatial_scale))
roi_end_h = int(cpt.round(roi[4] * self.spatial_scale))
roi_height = int(max(roi_end_h - roi_start_h + 1, 1))
roi_width = int(max(roi_end_w - roi_start_w + 1, 1))
......@@ -97,8 +98,8 @@ class TestROIPoolOp(OpTest):
for w in range(wstart, wend):
if x_i[c, h, w] > out_data[i, c, ph, pw]:
out_data[i, c, ph, pw] = x_i[c, h, w]
argmax_data[i, c, ph, pw] = h * \
self.width + w
argmax_data[i, c, ph,
pw] = h * self.width + w
self.outs = out_data.astype('float32')
self.argmaxes = argmax_data.astype('int64')
......@@ -110,14 +111,14 @@ class TestROIPoolOp(OpTest):
self.rois_lod[0].append(bno + 1)
for i in range(bno + 1):
x1 = np.random.random_integers(
0, self.width / self.spatial_scale - self.pooled_width)
0, self.width // self.spatial_scale - self.pooled_width)
y1 = np.random.random_integers(
0, self.height / self.spatial_scale - self.pooled_height)
0, self.height // self.spatial_scale - self.pooled_height)
x2 = np.random.random_integers(x1 + self.pooled_width,
self.width / self.spatial_scale)
y2 = np.random.random_integers(y1 + self.pooled_height,
self.height / self.spatial_scale)
self.width // self.spatial_scale)
y2 = np.random.random_integers(
y1 + self.pooled_height, self.height // self.spatial_scale)
roi = [bno, x1, y1, x2, y2]
rois.append(roi)
......
......@@ -14,6 +14,7 @@
import unittest
import numpy as np
import six
from op_test import OpTest
import paddle.fluid.core as core
from paddle.fluid.op import Operator
......@@ -59,7 +60,7 @@ class TestSpliteIds(unittest.TestCase):
x_tensor = x.get_tensor()
x_tensor.set(np_array, place)
outs_name = ["out%d" % i for i in xrange(3)]
outs_name = ["out%d" % i for i in six.moves.xrange(3)]
outs = [
scope.var(var_name).get_selected_rows() for var_name in outs_name
]
......
......@@ -27,7 +27,7 @@ def unpool2dmax_forward_naive(input, indices, ksize, strides, paddings):
for h in range(s2):
for w in range(s3):
index = indices[nidx, cidx, h, w]
hidx = (index - index % out_wsize) / out_wsize
hidx = (index - index % out_wsize) // out_wsize
widx = index % out_wsize
out[nidx, cidx, int(hidx), int(widx)] = \
input[nidx, cidx, h, w]
......@@ -41,9 +41,9 @@ class TestUnpoolOp(OpTest):
self.init_test_case()
pre_input = np.random.random(self.shape).astype("float32")
nsize, csize, hsize, wsize = pre_input.shape
hsize_out = (hsize - self.ksize[0] + 2 * self.paddings[0]) / \
hsize_out = (hsize - self.ksize[0] + 2 * self.paddings[0]) // \
self.strides[0] + 1
wsize_out = (wsize - self.ksize[1] + 2 * self.paddings[1]) / \
wsize_out = (wsize - self.ksize[1] + 2 * self.paddings[1]) // \
self.strides[1] + 1
input = np.zeros((nsize, csize, hsize_out, wsize_out))
indices = np.zeros((nsize, csize, hsize_out, wsize_out))
......@@ -62,7 +62,7 @@ class TestUnpoolOp(OpTest):
input[nidx, cidx, i, j] = x_masked.max()
arg = x_masked.argmax()
indices[nidx, cidx, i, j] = \
(r_start + arg / self.ksize[1]) * wsize + \
(r_start + arg // self.ksize[1]) * wsize + \
c_start + arg % self.ksize[1]
output = self.unpool2d_forward_naive(input, indices, self.ksize, \
self.strides, self.paddings).astype("float32")
......
......@@ -132,7 +132,7 @@ class CTCForward(object):
for k in range(end - start):
j = k + start
if j & 1 == 1:
label_idx = j / 2
label_idx = j // 2
label_val = labels_a_sequence[label_idx, 0]
fv = self.log_add(forward_vars[i - 1, j],
forward_vars[i - 1, j - 1])
......
......@@ -16,6 +16,7 @@ import contextlib
import os
import errno
import shutil
import six
import time
from . import core
......@@ -618,7 +619,7 @@ def build_feed_var_list(program, feed_order):
"The values of 'feed_order' should be a permutation of [0, len(feed_order))"
)
sorted_pair_list = sorted(
list(feed_order.items()), key=lambda item: item[1])
six.iteritems(feed_order), key=lambda item: item[1])
feed_var_list = [
program.global_block().var(pair[0]) for pair in sorted_pair_list
]
......@@ -1036,7 +1037,7 @@ def _save_trainer_args(dirname, trainer_id, trainer_args):
cur_dir = _get_trainer_dir(dirname, trainer_id)
for name, value in list(trainer_args.items()):
for name, value in six.iteritems(trainer_args):
args_file = os.path.join(cur_dir, name)
with open(args_file, 'w') as f:
f.write(str(value))
......
......@@ -12,12 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import six
def delete_ops(block, ops):
try:
start = list(block.ops).index(ops[0])
end = list(block.ops).index(ops[-1])
[block._remove_op(start) for _ in range(end - start + 1)]
[block._remove_op(start) for _ in six.moves.range(end - start + 1)]
except Exception as e:
raise e
block.program._sync_with_cpp()
......
......@@ -31,6 +31,7 @@ Steps to transpile pserver:
import math
import random
import numpy as np
import collections
from .ps_dispatcher import RoundRobin, HashName, PSDispatcher
from .. import core, framework
......@@ -220,9 +221,10 @@ class DistributeTranspiler(object):
# fc_w@GRAD_trainer_0, fc_w@GRAD_trainer_1 --> pserver1
# fc_b@GRAD_trainer_0, fc_b@GRAD_trainer_1 --> pserver2
# shuffle the map will avoid the uneven distribution above
grad_var_mapping_items = list(self.grad_var_mapping.items())
grad_var_mapping_items = list(six.iteritems(self.grad_var_mapping))
if not self.config.slice_var_up:
random.seed(self.trainer_num)
random.seed(self.origin_program.random_seed)
random.shuffle(grad_var_mapping_items)
for orig_varname, splited_vars in grad_var_mapping_items:
......@@ -280,7 +282,7 @@ class DistributeTranspiler(object):
self.param_grad_ep_mapping[ep]["grads"].append(send_vars[i])
# step4: Concat the parameters splits together after recv.
for varname, splited_var in list(self.param_var_mapping.items()):
for varname, splited_var in six.iteritems(self.param_var_mapping):
eps = []
for var in splited_var:
index = [v.name for v in recv_vars].index(var.name)
......@@ -305,7 +307,7 @@ class DistributeTranspiler(object):
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
})
for varname, splited_var in list(self.param_var_mapping.items()):
for varname, splited_var in six.iteritems(self.param_var_mapping):
if len(splited_var) <= 1:
continue
orig_param = program.global_block().vars[varname]
......@@ -641,14 +643,14 @@ class DistributeTranspiler(object):
# 1. create vars in pserver program to startup program
pserver_vars = pserver_program.global_block().vars
created_var_map = dict()
for _, var in list(pserver_vars.items()):
created_var_map = collections.OrderedDict()
for _, var in six.iteritems(pserver_vars):
tmpvar = s_prog.global_block()._clone_variable(var)
created_var_map[var.name] = tmpvar
# 2. rename op outputs
for op in orig_s_prog.global_block().ops:
new_outputs = dict()
new_outputs = collections.OrderedDict()
# do not append startup op if var is not on this pserver
op_on_pserver = False
# TODO(gongwb): remove this line.
......@@ -789,7 +791,7 @@ class DistributeTranspiler(object):
self.origin_program,
grad_blocks,
add_trainer_suffix=self.trainer_num > 1)
self.grad_param_mapping = dict()
self.grad_param_mapping = collections.OrderedDict()
for g, p in zip(grad_blocks, param_blocks):
g_name, g_bid, _ = g.split(":")
p_name, p_bid, _ = p.split(":")
......@@ -797,7 +799,7 @@ class DistributeTranspiler(object):
self.param_var_mapping[p_name][int(p_bid)]
# create mapping of endpoint -> split var to create pserver side program
self.param_grad_ep_mapping = dict()
self.param_grad_ep_mapping = collections.OrderedDict()
[
self.param_grad_ep_mapping.update({
ep: {
......@@ -1072,21 +1074,21 @@ class DistributeTranspiler(object):
block_list (list[(varname, block_id, block_size)]): List of gradient blocks.
add_trainer_suffix (Bool): Add trainer suffix to new variable's name if set True.
Returns:
var_mapping (dict(varname->[new_varname_variable])):A dict mapping
var_mapping (collections.OrderedDict(varname->[new_varname_variable])):A dict mapping
from original var name to each var split.
"""
# varname->[(block_id, current_block_size)]
block_map = dict()
block_map = collections.OrderedDict()
var_mapping = dict()
var_mapping = collections.OrderedDict()
for block_str in block_list:
varname, offset, size = block_str.split(":")
if varname not in block_map:
block_map[varname] = []
block_map[varname].append((int(offset), int(size)))
for varname, splited in list(block_map.items()):
for varname, splited in six.iteritems(block_map):
orig_var = program.global_block().var(varname)
if len(splited) == 1:
if self.sync_mode and add_trainer_suffix:
......@@ -1107,7 +1109,7 @@ class DistributeTranspiler(object):
for i, block in enumerate(splited):
size = block[1]
rows = size / orig_dim1_flatten
rows = size // orig_dim1_flatten
splited_shape = [rows]
if len(orig_shape) >= 2:
splited_shape.extend(orig_shape[1:])
......@@ -1271,7 +1273,7 @@ class DistributeTranspiler(object):
grad_to_block_id, origin_program, merged_var):
program = optimize_block.program
pserver_block = program.global_block()
new_inputs = dict()
new_inputs = collections.OrderedDict()
# update param/grad shape first, then other inputs like
# moment can use the updated shape
......@@ -1357,9 +1359,7 @@ class DistributeTranspiler(object):
def _is_splited_grad_var(self, var, var_dict):
grad_block = None
# TODO(minqiyang): replace these items() with six.iteritems() to
# improve memory
for _, g in list(var_dict.items()):
for _, g in six.iteritems(var_dict):
if self._orig_varname(g.name) == self._orig_varname(var.name):
if g.name.find(".trainer_") == -1:
grad_block = g
......@@ -1369,7 +1369,7 @@ class DistributeTranspiler(object):
def _clone_lr_op(self, program, block, op):
inputs = self._get_input_map_from_op(
self.origin_program.global_block().vars, op)
for key, varlist in list(inputs.items()):
for key, varlist in six.iteritems(inputs):
if not isinstance(varlist, list):
varlist = [varlist]
for var in varlist:
......@@ -1378,7 +1378,7 @@ class DistributeTranspiler(object):
outputs = self._get_output_map_from_op(
self.origin_program.global_block().vars, op)
for key, varlist in list(outputs.items()):
for key, varlist in six.iteritems(outputs):
if not isinstance(varlist, list):
varlist = [varlist]
for var in varlist:
......@@ -1393,7 +1393,7 @@ class DistributeTranspiler(object):
# Append the ops for parameters that do not need to be optimized/updated
inputs = self._get_input_map_from_op(
self.origin_program.global_block().vars, opt_op)
for key, varlist in list(inputs.items()):
for key, varlist in six.iteritems(inputs):
if not isinstance(varlist, list):
varlist = [varlist]
for var in varlist:
......@@ -1412,7 +1412,7 @@ class DistributeTranspiler(object):
outputs = self._get_output_map_from_op(
self.origin_program.global_block().vars, opt_op)
for key, varlist in list(outputs.items()):
for key, varlist in six.iteritems(outputs):
if not isinstance(varlist, list):
varlist = [varlist]
for var in varlist:
......@@ -1470,7 +1470,7 @@ class DistributeTranspiler(object):
def _get_input_map_from_op(self, varmap, op):
"""Returns a dict from op input name to the vars in varmap."""
iomap = dict()
iomap = collections.OrderedDict()
for key in op.input_names:
vars = []
for varname in op.input(key):
......@@ -1483,7 +1483,7 @@ class DistributeTranspiler(object):
def _get_output_map_from_op(self, varmap, op):
"""Returns a dict from op output name to the vars in varmap."""
iomap = dict()
iomap = collections.OrderedDict()
for key in op.output_names:
vars = []
for varname in op.output(key):
......
......@@ -14,6 +14,7 @@
from collections import defaultdict
from .. import core
from ... import compat as cpt
from ..framework import Program, default_main_program, Parameter
from ..backward import _rename_arg_
from functools import reduce
......@@ -125,15 +126,15 @@ class ControlFlowGraph(object):
def _has_var(self, block_desc, var_name, is_forward):
if is_forward:
return block_desc.has_var(str(var_name))
return block_desc.has_var(cpt.to_bytes(var_name))
else:
return block_desc.has_var_recursive(str(var_name))
return block_desc.has_var_recursive(cpt.to_bytes(var_name))
def _find_var(self, block_desc, var_name, is_forward):
if is_forward:
return block_desc.find_var(str(var_name))
return block_desc.find_var(cpt.to_bytes(var_name))
else:
return block_desc.find_var_recursive(str(var_name))
return block_desc.find_var_recursive(cpt.to_bytes(var_name))
def _check_var_validity(self, block_desc, x, is_forward):
if str(x) == "@EMPTY@":
......@@ -258,7 +259,7 @@ class ControlFlowGraph(object):
# Rename the var to the cache var already with
# memory allocated in order to reuse the memory.
_rename_arg_(self._ops, x, cache_var, begin_idx=i)
self._program.block(block_desc.id).var(str(
self._program.block(block_desc.id).var(cpt.to_text(
x)).desc = self._find_var(block_desc, cache_var,
is_forward)
self._update_graph(x, cache_var, begin_idx=i)
......
......@@ -27,6 +27,7 @@ from six.moves import zip
import itertools
import random
import zlib
import paddle.compat as cpt
def map_readers(func, *readers):
......@@ -390,9 +391,9 @@ class PipeReader:
buff = self.process.stdout.read(self.bufsize)
if buff:
if self.file_type == "gzip":
decomp_buff = self.dec.decompress(buff)
decomp_buff = cpt.to_text(self.dec.decompress(buff))
elif self.file_type == "plain":
decomp_buff = buff
decomp_buff = cpt.to_text(buff)
else:
raise TypeError("file_type %s is not allowed" %
self.file_type)
......
......@@ -29,6 +29,7 @@ import os
import unittest
import numpy as np
import paddle.reader.creator
import six
class TestNumpyArray(unittest.TestCase):
......@@ -37,7 +38,7 @@ class TestNumpyArray(unittest.TestCase):
x = np.array(l, np.int32)
reader = paddle.reader.creator.np_array(x)
for idx, e in enumerate(reader()):
self.assertItemsEqual(e, l[idx])
six.assertCountEqual(self, e, l[idx])
class TestTextFile(unittest.TestCase):
......
......@@ -37,9 +37,9 @@ if __name__ == '__main__':
assert isinstance(conf, TrainerConfig_pb2.TrainerConfig)
if whole_conf:
print conf
print(conf)
else:
if binary:
sys.stdout.write(conf.model_config.SerializeToString())
else:
print conf.model_config
print(conf.model_config)
......@@ -15,7 +15,8 @@
import os, sys
import numpy as np
from PIL import Image
from cStringIO import StringIO
import six
from six.moves import cStringIO as StringIO
import multiprocessing
import functools
import itertools
......@@ -187,7 +188,8 @@ class PILTransformer(ImageTransformer):
return self.transform(im)
def job(is_img_string, transformer, (data, label)):
def job(is_img_string, transformer, data_label_pack):
(data, label) = data_label_pack
if is_img_string:
return transformer.transform_from_string(data), label
else:
......@@ -208,7 +210,7 @@ class MultiProcessImageTransformer(object):
"""
Processing image with multi-process. If it is used in PyDataProvider,
the simple usage for CNN is as follows:
.. code-block:: python
def hool(settings, is_train, **kwargs):
......@@ -229,7 +231,7 @@ class MultiProcessImageTransformer(object):
@provider(init_hook=hook, pool_size=20480)
def process(settings, file_list):
with open(file_list, 'r') as fdata:
for line in fdata:
for line in fdata:
data_dic = np.load(line.strip()) # load the data batch pickled by Pickle.
data = data_dic['data']
labels = data_dic['label']
......@@ -249,10 +251,10 @@ class MultiProcessImageTransformer(object):
:type channel_swap: tuple or list
:param mean: the mean values of image, per-channel mean or element-wise mean.
:type mean: array, The dimension is 1 for per-channel mean.
The dimension is 3 for element-wise mean.
The dimension is 3 for element-wise mean.
:param is_train: training peroid or testing peroid.
:type is_train: bool.
:param is_color: the image is color or gray.
:param is_color: the image is color or gray.
:type is_color: bool.
:param is_img_string: The input can be the file name of image or image string.
:type is_img_string: bool.
......@@ -273,4 +275,4 @@ class MultiProcessImageTransformer(object):
def run(self, data, label):
fun = functools.partial(job, self.is_img_string, self.transformer)
return self.pool.imap_unordered(
fun, itertools.izip(data, label), chunksize=100 * self.procnum)
fun, six.moves.zip(data, label), chunksize=100 * self.procnum)
......@@ -14,7 +14,7 @@
import numpy as np
from PIL import Image
from cStringIO import StringIO
from six.moves import cStringIO as StringIO
def resize_image(img, target_size):
......@@ -34,7 +34,7 @@ def flip(im):
"""
Return the flipped image.
Flip an image along the horizontal direction.
im: input image, (H x W x K) ndarrays
im: input image, (H x W x K) ndarrays
"""
if len(im.shape) == 3:
return im[:, :, ::-1]
......@@ -132,7 +132,7 @@ def load_meta(meta_path, mean_img_size, crop_size, color=True):
def load_image(img_path, is_color=True):
"""
Load image and return.
Load image and return.
img_path: image path.
is_color: is color image or not.
"""
......@@ -205,7 +205,7 @@ class ImageTransformer:
def set_mean(self, mean):
if mean is not None:
# mean value, may be one value per channel
# mean value, may be one value per channel
if mean.ndim == 1:
mean = mean[:, np.newaxis, np.newaxis]
else:
......
......@@ -15,6 +15,9 @@
# Generate dot diagram file for the given paddle model config
# The generated file can be viewed using Graphviz (http://graphviz.org)
from __future__ import print_function
import six
import sys
import traceback
......@@ -61,9 +64,9 @@ def make_diagram_from_proto(model_config, dot_file):
name2id[mem.link_name])
return s
print >> f, 'digraph graphname {'
print >> f, 'node [width=0.375,height=0.25];'
for i in xrange(len(model_config.layers)):
print('digraph graphname {', file=f)
print('node [width=0.375,height=0.25];', file=f)
for i in six.moves.xrange(len(model_config.layers)):
l = model_config.layers[i]
name2id[l.name] = i
......@@ -71,12 +74,12 @@ def make_diagram_from_proto(model_config, dot_file):
for sub_model in model_config.sub_models:
if sub_model.name == 'root':
continue
print >> f, 'subgraph cluster_%s {' % i
print >> f, 'style=dashed;'
print('subgraph cluster_%s {' % i, file=f)
print('style=dashed;', file=f)
label = '%s ' % sub_model.name
if sub_model.reversed:
label += '<=='
print >> f, 'label = "%s";' % label
print('label = "%s";' % label, file=f)
i += 1
submodel_layers.add(sub_model.name)
for layer_name in sub_model.layer_names:
......@@ -84,37 +87,41 @@ def make_diagram_from_proto(model_config, dot_file):
lid = name2id[layer_name]
layer_config = model_config.layers[lid]
label = make_layer_label(layer_config)
print >> f, 'l%s [label="%s", shape=box];' % (lid, label)
print >> f, '}'
print('l%s [label="%s", shape=box];' % (lid, label), file=f)
print('}', file=f)
for i in xrange(len(model_config.layers)):
for i in six.moves.xrange(len(model_config.layers)):
l = model_config.layers[i]
if l.name not in submodel_layers:
label = make_layer_label(l)
print >> f, 'l%s [label="%s", shape=box];' % (i, label)
print('l%s [label="%s", shape=box];' % (i, label), file=f)
for sub_model in model_config.sub_models:
if sub_model.name == 'root':
continue
for link in sub_model.in_links:
print >> f, make_link(link)
print(make_link(link), file=f)
for link in sub_model.out_links:
print >> f, make_link(link)
print(make_link(link), file=f)
for mem in sub_model.memories:
print >> f, make_mem(mem)
print(make_mem(mem), file=f)
for i in xrange(len(model_config.layers)):
for i in six.moves.xrange(len(model_config.layers)):
for l in model_config.layers[i].inputs:
print >> f, 'l%s -> l%s [label="%s"];' % (
name2id[l.input_layer_name], i, l.input_parameter_name)
print(
'l%s -> l%s [label="%s"];' % (name2id[l.input_layer_name], i,
l.input_parameter_name),
file=f)
print >> f, '}'
print('}', file=f)
f.close()
def usage():
print >> sys.stderr, ("Usage: python show_model_diagram.py" +
" CONFIG_FILE DOT_FILE [config_str]")
print(
("Usage: python show_model_diagram.py" +
" CONFIG_FILE DOT_FILE [config_str]"),
file=sys.stderr)
exit(1)
......
......@@ -70,4 +70,4 @@ def merge_v2_model(net, param_file, output_file):
for pname in param_names:
params.serialize(pname, f)
print 'Generate %s success!' % (output_file)
print('Generate %s success!' % (output_file))
......@@ -44,6 +44,7 @@ To use this script to generate plot for AvgCost, error:
python plotcurve.py -i paddle.INFO -o figure.png AvgCost error
"""
import six
import sys
import matplotlib
# the following line is added immediately after import matplotlib
......@@ -91,7 +92,7 @@ def plot_paddle_curve(keys, inputfile, outputfile, format='png',
sys.stderr.write("No data to plot. Exiting!\n")
return
m = len(keys) + 1
for i in xrange(1, m):
for i in six.moves.xrange(1, m):
pyplot.plot(
x[:, 0],
x[:, i],
......
......@@ -13,6 +13,7 @@
# limitations under the License.
import numpy as np
import six
import os
from paddle.trainer.config_parser import *
from paddle.utils.preprocess_img import \
......@@ -112,7 +113,7 @@ def simple_conv_net(data_conf, is_color=False):
num_classes: num of classes.
is_color: whether the input images are color.
"""
for k, v in data_conf.iteritems():
for k, v in six.iteritems(data_conf):
globals()[k] = v
data_input, label_input, num_image_channels = \
image_data_layers(image_size, num_classes, is_color, is_predict)
......@@ -340,7 +341,7 @@ def small_vgg(data_conf, is_predict=False):
num_classes: num of classes.
is_color: whether the input images are color.
"""
for k, v in data_conf.iteritems():
for k, v in six.iteritems(data_conf):
globals()[k] = v
vgg_conv_net(image_size, num_classes,
num_layers=[2, 2, 3, 3],
......
......@@ -17,9 +17,9 @@ import os
import random
import numpy as np
import PIL.Image as Image
import StringIO
import preprocess_util
from image_util import crop_img
from six.moves import cStringIO as StringIO
from . import preprocess_util
from .image_util import crop_img
def resize_image(img, target_size):
......@@ -52,7 +52,7 @@ class DiskImage:
def read_image(self):
if self.img is None:
print "reading: " + self.path
print("reading: " + self.path)
image = resize_image(Image.open(self.path), self.target_size)
self.img = image
......@@ -69,7 +69,7 @@ class DiskImage:
convert the image into the paddle batch format.
"""
self.read_image()
output = StringIO.StringIO()
output = StringIO()
self.img.save(output, "jpeg")
contents = output.getvalue()
return contents
......@@ -127,7 +127,7 @@ class ImageClassificationDatasetCreater(preprocess_util.DatasetCreater):
image_path = items[0]
label_name = items[1]
if not label_name in label_set:
label_set[label_name] = len(label_set.keys())
label_set[label_name] = len(list(label_set.keys()))
img = DiskImage(path=image_path, target_size=self.target_size)
label = preprocess_util.Lablel(
label=label_set[label_name], name=label_name)
......@@ -144,7 +144,7 @@ class ImageClassificationDatasetCreater(preprocess_util.DatasetCreater):
return create_dataset_from_list(path)
label_set = preprocess_util.get_label_set_from_dir(path)
data = []
for l_name in label_set.keys():
for l_name in list(label_set.keys()):
image_paths = preprocess_util.list_images(
os.path.join(path, l_name))
for p in image_paths:
......
......@@ -14,7 +14,7 @@
import os
import math
import cPickle as pickle
import six.moves.cPickle as pickle
import random
import collections
......@@ -169,7 +169,7 @@ class Dataset:
random.shuffle(keyvalue_indices[k])
num_data_per_key_batch = \
math.ceil(num_per_batch / float(len(keyvalue_indices.keys())))
math.ceil(num_per_batch / float(len(list(keyvalue_indices.keys()))))
if num_data_per_key_batch < 2:
raise Exception("The number of data in a batch is too small")
......@@ -182,8 +182,8 @@ class Dataset:
end_idx = int(
min(begin_idx + num_data_per_key_batch,
len(keyvalue_indices[k])))
print "begin_idx, end_idx"
print begin_idx, end_idx
print("begin_idx, end_idx")
print(begin_idx, end_idx)
for idx in range(begin_idx, end_idx):
permuted_data.append(self.data[keyvalue_indices[k][idx]])
keyvalue_readpointer[k] = end_idx
......@@ -357,6 +357,6 @@ class DatasetCreater(object):
data_batcher.create_batches_and_list(
self.output_path, self.train_list_name, self.test_list_name,
self.label_set_name)
self.num_classes = len(train_label_set.keys())
self.num_classes = len(list(train_label_set.keys()))
self.create_meta_file(train_data)
return out_path
......@@ -15,6 +15,8 @@
Show the content of proto buffer data file of PADDLE
"""
from __future__ import print_function
import os
import sys
from google.protobuf.internal.decoder import _DecodeVarint
......@@ -39,7 +41,7 @@ def read_proto(file, message):
def usage():
print >> sys.stderr, "Usage: python show_pb.py PROTO_DATA_FILE"
print("Usage: python show_pb.py PROTO_DATA_FILE", file=sys.stderr)
exit(1)
......@@ -50,8 +52,8 @@ if __name__ == '__main__':
f = open(sys.argv[1])
header = DataFormat.DataHeader()
read_proto(f, header)
print header
print(header)
sample = DataFormat.DataSample()
while read_proto(f, sample):
print sample
print(sample)
......@@ -24,7 +24,7 @@ import sys
import struct
import numpy as np
import torchfile
import cPickle as pickle
import six.moves.cPickle as pickle
import argparse
......@@ -48,7 +48,7 @@ def save_net_parameters(layers, params, output_path):
biases = params[i * 2 + 1]
weight_file = os.path.join(output_path, '_%s.w0' % layers[i])
biases_file = os.path.join(output_path, '_%s.wbias' % layers[i])
print "Saving for layer %s." % layers[i]
print("Saving for layer %s." % layers[i])
save_layer_parameters(weight_file, [weight])
save_layer_parameters(biases_file, biases)
......
......@@ -13,6 +13,7 @@
# limitations under the License.
from __future__ import print_function
import unittest
import os
import sys
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册