未验证 提交 1b86bb5a 编写于 作者: T Teng Xi 提交者: GitHub

Cherry pick 1.1.1 (#332)

* fix conflict

* fix windows CPU envs (#315)

* change scipy.misc.imread() to imageio.imread() (#324)

* adapt PY2 PY3 (#326)
上级 fe947e3d
...@@ -13,7 +13,11 @@ ...@@ -13,7 +13,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import numpy as np import numpy as np
import scipy.misc import six
if six.PY2:
import scipy.misc as imgreader
else:
import imageio as imgreader
import os import os
import paddle import paddle
from paddle import fluid from paddle import fluid
...@@ -54,7 +58,7 @@ class CASIA_Face(object): ...@@ -54,7 +58,7 @@ class CASIA_Face(object):
target = self.label_list[index] target = self.label_list[index]
try: try:
img = scipy.misc.imread(img_path) img = imgreader.imread(img_path)
except: except:
continue continue
......
...@@ -13,8 +13,11 @@ ...@@ -13,8 +13,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import numpy as np import numpy as np
import scipy.misc import six
if six.PY2:
import scipy.misc as imgreader
else:
import imageio as imgreader
import paddle import paddle
from paddle import fluid from paddle import fluid
...@@ -33,10 +36,10 @@ class LFW(object): ...@@ -33,10 +36,10 @@ class LFW(object):
return return
index = self.shuffle_idx.pop(0) index = self.shuffle_idx.pop(0)
imgl = scipy.misc.imread(self.imgl_list[index]) imgl = imgreader.imread(self.imgl_list[index])
if len(imgl.shape) == 2: if len(imgl.shape) == 2:
imgl = np.stack([imgl] * 3, 2) imgl = np.stack([imgl] * 3, 2)
imgr = scipy.misc.imread(self.imgr_list[index]) imgr = imgreader.imread(self.imgr_list[index])
if len(imgr.shape) == 2: if len(imgr.shape) == 2:
imgr = np.stack([imgr] * 3, 2) imgr = np.stack([imgr] * 3, 2)
......
...@@ -150,7 +150,10 @@ def train(exe, train_program, train_out, test_program, test_out, args): ...@@ -150,7 +150,10 @@ def train(exe, train_program, train_out, test_program, test_out, args):
def build_program(program, startup, args, is_train=True): def build_program(program, startup, args, is_train=True):
num_trainers = len(os.getenv('CUDA_VISIBLE_DEVICES').split(',')) if args.use_gpu:
num_trainers = fluid.core.get_cuda_device_count()
else:
num_trainers = int(os.environ.get('CPU_NUM', 1))
places = fluid.cuda_places() if args.use_gpu else fluid.CPUPlace() places = fluid.cuda_places() if args.use_gpu else fluid.CPUPlace()
train_dataset = CASIA_Face(root=args.train_data_dir) train_dataset = CASIA_Face(root=args.train_data_dir)
...@@ -166,7 +169,7 @@ def build_program(program, startup, args, is_train=True): ...@@ -166,7 +169,7 @@ def build_program(program, startup, args, is_train=True):
image = fluid.data( image = fluid.data(
name='image', shape=[-1, 3, 112, 96], dtype='float32') name='image', shape=[-1, 3, 112, 96], dtype='float32')
label = fluid.data(name='label', shape=[-1, 1], dtype='int64') label = fluid.data(name='label', shape=[-1, 1], dtype='int64')
train_reader = paddle.batch( train_reader = fluid.io.batch(
train_dataset.reader, train_dataset.reader,
batch_size=args.train_batchsize // num_trainers, batch_size=args.train_batchsize // num_trainers,
drop_last=False) drop_last=False)
...@@ -187,7 +190,7 @@ def build_program(program, startup, args, is_train=True): ...@@ -187,7 +190,7 @@ def build_program(program, startup, args, is_train=True):
else: else:
nl, nr, flods, flags = parse_filelist(args.test_data_dir) nl, nr, flods, flags = parse_filelist(args.test_data_dir)
test_dataset = LFW(nl, nr) test_dataset = LFW(nl, nr)
test_reader = paddle.batch( test_reader = fluid.io.batch(
test_dataset.reader, test_dataset.reader,
batch_size=args.test_batchsize, batch_size=args.test_batchsize,
drop_last=False) drop_last=False)
...@@ -231,7 +234,7 @@ def build_program(program, startup, args, is_train=True): ...@@ -231,7 +234,7 @@ def build_program(program, startup, args, is_train=True):
def quant_val_reader_batch(): def quant_val_reader_batch():
nl, nr, flods, flags = parse_filelist(args.test_data_dir) nl, nr, flods, flags = parse_filelist(args.test_data_dir)
test_dataset = LFW(nl, nr) test_dataset = LFW(nl, nr)
test_reader = paddle.batch( test_reader = fluid.io.batch(
test_dataset.reader, batch_size=1, drop_last=False) test_dataset.reader, batch_size=1, drop_last=False)
shuffle_reader = fluid.io.shuffle(test_reader, 3) shuffle_reader = fluid.io.shuffle(test_reader, 3)
...@@ -298,14 +301,16 @@ def main(): ...@@ -298,14 +301,16 @@ def main():
help='The path of the extract features save, must be .mat file') help='The path of the extract features save, must be .mat file')
args = parser.parse_args() args = parser.parse_args()
num_trainers = len(os.getenv('CUDA_VISIBLE_DEVICES').split(',')) if args.use_gpu:
num_trainers = fluid.core.get_cuda_device_count()
else:
num_trainers = int(os.environ.get('CPU_NUM', 1))
print(args) print(args)
print('num_trainers: {}'.format(num_trainers)) print('num_trainers: {}'.format(num_trainers))
if args.save_ckpt == None: if args.save_ckpt == None:
args.save_ckpt = 'output' args.save_ckpt = 'output'
if not os.path.exists(args.save_ckpt): if not os.path.isdir(args.save_ckpt):
subprocess.call(['mkdir', '-p', args.save_ckpt]) os.makedirs(args.save_ckpt)
with open(os.path.join(args.save_ckpt, 'log.txt'), 'w+') as f: with open(os.path.join(args.save_ckpt, 'log.txt'), 'w+') as f:
f.writelines(str(args) + '\n') f.writelines(str(args) + '\n')
f.writelines('num_trainers: {}'.format(num_trainers) + '\n') f.writelines('num_trainers: {}'.format(num_trainers) + '\n')
...@@ -346,7 +351,7 @@ def main(): ...@@ -346,7 +351,7 @@ def main():
executor=exe) executor=exe)
nl, nr, flods, flags = parse_filelist(args.test_data_dir) nl, nr, flods, flags = parse_filelist(args.test_data_dir)
test_dataset = LFW(nl, nr) test_dataset = LFW(nl, nr)
test_reader = paddle.batch( test_reader = fluid.io.batch(
test_dataset.reader, test_dataset.reader,
batch_size=args.test_batchsize, batch_size=args.test_batchsize,
drop_last=False) drop_last=False)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册