“94c8ffd18bd81113f0b588457e0822cfa13f278d”上不存在“third_party/nnlib/git@gitcode.net:qq_37101384/mace.git”
提交 bf2c53ae 编写于 作者: T tangwei12

Merge branch 'develop' of github.com:PaddlePaddle/Paddle into new_api_about_cpkt

...@@ -19,4 +19,4 @@ ADD *.whl / ...@@ -19,4 +19,4 @@ ADD *.whl /
RUN pip install /*.whl && rm -f /*.whl && chmod +x /usr/bin/paddle_k8s RUN pip install /*.whl && rm -f /*.whl && chmod +x /usr/bin/paddle_k8s
ENV LD_LIBRARY_PATH=/usr/local/lib ENV LD_LIBRARY_PATH=/usr/local/lib
ADD fluid_benchmark.py dataset.py models/ /workspace/ ADD fluid_benchmark.py recordio_converter.py models/ /workspace/
...@@ -29,9 +29,11 @@ Currently supported `--model` argument include: ...@@ -29,9 +29,11 @@ Currently supported `--model` argument include:
You can choose to use GPU/CPU training. With GPU training, you can specify You can choose to use GPU/CPU training. With GPU training, you can specify
`--gpus <gpu_num>` to run multi GPU training. `--gpus <gpu_num>` to run multi GPU training.
* Run distributed training with parameter servers: * Run distributed training with parameter servers:
* see [run_fluid_benchmark.sh](https://github.com/PaddlePaddle/Paddle/blob/develop/benchmark/fluid/run_fluid_benchmark.sh) as an example.
* start parameter servers: * start parameter servers:
```bash ```bash
PADDLE_TRAINING_ROLE=PSERVER PADDLE_PSERVER_PORT=7164 PADDLE_PSERVER_IPS=127.0.0.1 PADDLE_TRAINERS=1 PADDLE_CURRENT_IP=127.0.0.1 PADDLE_TRAINER_ID=0 python fluid_benchmark.py --model mnist --device GPU --update_method pserver PADDLE_TRAINING_ROLE=PSERVER PADDLE_PSERVER_PORT=7164 PADDLE_PSERVER_IPS=127.0.0.1 PADDLE_TRAINERS=1 PADDLE_CURRENT_IP=127.0.0.1 PADDLE_TRAINER_ID=0 python fluid_benchmark.py --model mnist --device GPU --update_method pserver
sleep 15
``` ```
* start trainers: * start trainers:
```bash ```bash
...@@ -42,6 +44,16 @@ Currently supported `--model` argument include: ...@@ -42,6 +44,16 @@ Currently supported `--model` argument include:
PADDLE_PSERVER_PORT=7164 PADDLE_TRAINER_IPS=192.168.0.2,192.168.0.3 PADDLE_CURRENT_IP=127.0.0.1 PADDLE_TRAINER_ID=0 python fluid_benchmark.py --model mnist --device GPU --update_method nccl2 PADDLE_PSERVER_PORT=7164 PADDLE_TRAINER_IPS=192.168.0.2,192.168.0.3 PADDLE_CURRENT_IP=127.0.0.1 PADDLE_TRAINER_ID=0 python fluid_benchmark.py --model mnist --device GPU --update_method nccl2
``` ```
## Prepare the RecordIO file to Achieve Better Performance
Run the following command will generate RecordIO files like "mnist.recordio" under the path
and batch_size you choose, you can use batch_size=1 so that later reader can change the batch_size
at any time using `fluid.batch`.
```bash
python -c 'from recordio_converter import *; prepare_mnist("data", 1)'
```
## Run Distributed Benchmark on Kubernetes Cluster ## Run Distributed Benchmark on Kubernetes Cluster
You may need to build a Docker image before submitting a cluster job onto Kubernetes, or you will You may need to build a Docker image before submitting a cluster job onto Kubernetes, or you will
......
...@@ -38,10 +38,12 @@ def parse_args(): ...@@ -38,10 +38,12 @@ def parse_args():
default='resnet', default='resnet',
help='The model to run benchmark with.') help='The model to run benchmark with.')
parser.add_argument( parser.add_argument(
'--batch_size', type=int, default=32, help='The minibatch size.') '--batch_size',
type=int,
default=32,
help='The batch size on each gpu.')
parser.add_argument( parser.add_argument(
'--learning_rate', type=float, default=0.001, help='The learning rate.') '--learning_rate', type=float, default=0.001, help='The learning rate.')
# TODO(wuyi): add "--use_fake_data" option back.
parser.add_argument( parser.add_argument(
'--skip_batch_num', '--skip_batch_num',
type=int, type=int,
...@@ -49,7 +51,10 @@ def parse_args(): ...@@ -49,7 +51,10 @@ def parse_args():
help='The first num of minibatch num to skip, for better performance test' help='The first num of minibatch num to skip, for better performance test'
) )
parser.add_argument( parser.add_argument(
'--iterations', type=int, default=80, help='The number of minibatches.') '--iterations',
type=int,
default=80,
help='The number of minibatches, set to -1 to run all batches.')
parser.add_argument( parser.add_argument(
'--pass_num', type=int, default=100, help='The number of passes.') '--pass_num', type=int, default=100, help='The number of passes.')
parser.add_argument( parser.add_argument(
...@@ -69,6 +74,7 @@ def parse_args(): ...@@ -69,6 +74,7 @@ def parse_args():
type=int, type=int,
default=1, default=1,
help='If gpus > 1, will use ParallelExecutor to run, else use Executor.') help='If gpus > 1, will use ParallelExecutor to run, else use Executor.')
# this option is available only for vgg and resnet.
parser.add_argument( parser.add_argument(
'--cpus', '--cpus',
type=int, type=int,
...@@ -78,7 +84,7 @@ def parse_args(): ...@@ -78,7 +84,7 @@ def parse_args():
'--data_set', '--data_set',
type=str, type=str,
default='flowers', default='flowers',
choices=['cifar10', 'flowers'], choices=['cifar10', 'flowers', 'imagenet'],
help='Optional dataset for benchmark.') help='Optional dataset for benchmark.')
parser.add_argument( parser.add_argument(
'--infer_only', action='store_true', help='If set, run forward only.') '--infer_only', action='store_true', help='If set, run forward only.')
...@@ -108,6 +114,16 @@ def parse_args(): ...@@ -108,6 +114,16 @@ def parse_args():
default='local', default='local',
choices=['local', 'pserver', 'nccl2'], choices=['local', 'pserver', 'nccl2'],
help='Choose parameter update method, can be local, pserver, nccl2.') help='Choose parameter update method, can be local, pserver, nccl2.')
parser.add_argument(
'--use_reader_op',
action='store_true',
help='Whether to use reader op, and must specify the data path if set this to true.'
)
parser.add_argument(
'--data_path',
type=str,
default="",
help='Directory that contains all the training recordio files.')
args = parser.parse_args() args = parser.parse_args()
return args return args
...@@ -210,6 +226,8 @@ def train(avg_loss, infer_prog, optimizer, train_reader, test_reader, batch_acc, ...@@ -210,6 +226,8 @@ def train(avg_loss, infer_prog, optimizer, train_reader, test_reader, batch_acc,
place = core.CPUPlace() if args.device == 'CPU' else core.CUDAPlace(0) place = core.CPUPlace() if args.device == 'CPU' else core.CUDAPlace(0)
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(startup_prog) exe.run(startup_prog)
if not args.use_reader_op:
feed_var_list = [ feed_var_list = [
var for var in train_prog.global_block().vars.itervalues() var for var in train_prog.global_block().vars.itervalues()
if var.is_data if var.is_data
...@@ -219,16 +237,38 @@ def train(avg_loss, infer_prog, optimizer, train_reader, test_reader, batch_acc, ...@@ -219,16 +237,38 @@ def train(avg_loss, infer_prog, optimizer, train_reader, test_reader, batch_acc,
iters, num_samples, start_time = 0, 0, time.time() iters, num_samples, start_time = 0, 0, time.time()
for pass_id in range(args.pass_num): for pass_id in range(args.pass_num):
train_losses = [] train_losses = []
for batch_id, data in enumerate(train_reader()): if not args.use_reader_op:
reader_generator = train_reader()
batch_id = 0
data = None
while True:
if not args.use_reader_op:
data = next(reader_generator, None)
if data == None:
break
if iters == args.iterations:
break
if iters == args.skip_batch_num: if iters == args.skip_batch_num:
start_time = time.time() start_time = time.time()
num_samples = 0 num_samples = 0
if iters == args.iterations:
if args.use_reader_op:
try:
loss = exe.run(train_prog, fetch_list=[avg_loss])
except fluid.core.EnforceNotMet as ex:
break break
else:
loss = exe.run(train_prog, loss = exe.run(train_prog,
feed=feeder.feed(data), feed=feeder.feed(data),
fetch_list=[avg_loss]) fetch_list=[avg_loss])
iters += 1 iters += 1
batch_id += 1
# FIXME(wuyi): For use_reader_op, if the current
# pass is not the last, the last batch of this pass
# is also equal to args.batch_size.
if args.use_reader_op:
num_samples += args.batch_size * args.gpus
else:
num_samples += len(data) num_samples += len(data)
train_losses.append(loss) train_losses.append(loss)
print("Pass: %d, Iter: %d, Loss: %f\n" % print("Pass: %d, Iter: %d, Loss: %f\n" %
...@@ -250,10 +290,14 @@ def train(avg_loss, infer_prog, optimizer, train_reader, test_reader, batch_acc, ...@@ -250,10 +290,14 @@ def train(avg_loss, infer_prog, optimizer, train_reader, test_reader, batch_acc,
def train_parallel(avg_loss, infer_prog, optimizer, train_reader, test_reader, def train_parallel(avg_loss, infer_prog, optimizer, train_reader, test_reader,
batch_acc, args, train_prog, startup_prog, nccl_id_var, batch_acc, args, train_prog, startup_prog, nccl_id_var,
num_trainers, trainer_id): num_trainers, trainer_id):
place = core.CPUPlace() if args.device == 'CPU' else core.CUDAPlace(0)
if not args.use_reader_op:
feed_var_list = [ feed_var_list = [
var for var in train_prog.global_block().vars.itervalues() var for var in train_prog.global_block().vars.itervalues()
if var.is_data if var.is_data
] ]
feeder = fluid.DataFeeder(feed_var_list, place)
# generate fake: # generate fake:
if args.use_fake_data: if args.use_fake_data:
for var in feed_var_list: for var in feed_var_list:
...@@ -270,7 +314,6 @@ def train_parallel(avg_loss, infer_prog, optimizer, train_reader, test_reader, ...@@ -270,7 +314,6 @@ def train_parallel(avg_loss, infer_prog, optimizer, train_reader, test_reader,
"value": 1.0, "value": 1.0,
"dtype": var.dtype}) "dtype": var.dtype})
place = core.CPUPlace() if args.device == 'CPU' else core.CUDAPlace(0)
if nccl_id_var and trainer_id == 0: if nccl_id_var and trainer_id == 0:
#FIXME(wuyi): wait other trainer to start listening #FIXME(wuyi): wait other trainer to start listening
time.sleep(30) time.sleep(30)
...@@ -287,12 +330,21 @@ def train_parallel(avg_loss, infer_prog, optimizer, train_reader, test_reader, ...@@ -287,12 +330,21 @@ def train_parallel(avg_loss, infer_prog, optimizer, train_reader, test_reader,
num_trainers=num_trainers, num_trainers=num_trainers,
trainer_id=trainer_id) trainer_id=trainer_id)
feeder = fluid.DataFeeder(feed_var_list, place)
for pass_id in range(args.pass_num): for pass_id in range(args.pass_num):
num_samples = 0 num_samples = 0
iters = 0 iters = 0
start_time = time.time() start_time = time.time()
for batch_id, data in enumerate(train_reader()): if not args.use_reader_op:
reader_generator = train_reader()
batch_id = 0
data = None
while True:
if not args.use_reader_op:
data = next(reader_generator, None)
if data == None:
break
if iters == args.iterations:
break
if args.profile and pass_id == 0 and batch_id == 5: if args.profile and pass_id == 0 and batch_id == 5:
profiler.start_profiler("All") profiler.start_profiler("All")
elif args.profile and pass_id == 0 and batch_id == 10: elif args.profile and pass_id == 0 and batch_id == 10:
...@@ -301,19 +353,25 @@ def train_parallel(avg_loss, infer_prog, optimizer, train_reader, test_reader, ...@@ -301,19 +353,25 @@ def train_parallel(avg_loss, infer_prog, optimizer, train_reader, test_reader,
if iters == args.skip_batch_num: if iters == args.skip_batch_num:
start_time = time.time() start_time = time.time()
num_samples = 0 num_samples = 0
if iters == args.iterations: if args.use_fake_data or args.use_reader_op:
break try:
if args.use_fake_data:
loss, = exe.run([avg_loss.name]) loss, = exe.run([avg_loss.name])
except fluid.core.EnforceNotMet as ex:
break
else: else:
loss, = exe.run([avg_loss.name], feed=feeder.feed(data)) loss, = exe.run([avg_loss.name], feed=feeder.feed(data))
if args.update_method == "pserver": if args.update_method == "pserver":
exe.bcast_params() exe.bcast_params()
if args.use_reader_op:
num_samples += args.batch_size * args.gpus
else:
num_samples += len(data) num_samples += len(data)
iters += 1 iters += 1
if batch_id % 1 == 0: if batch_id % 1 == 0:
print("Pass %d, batch %d, loss %s" % print("Pass %d, batch %d, loss %s" %
(pass_id, batch_id, np.array(loss))) (pass_id, batch_id, np.array(loss)))
batch_id += 1
print_train_time(start_time, time.time(), num_samples) print_train_time(start_time, time.time(), num_samples)
if not args.no_test and batch_acc: if not args.no_test and batch_acc:
test_acc = test(startup_exe, infer_prog, test_reader, feeder, test_acc = test(startup_exe, infer_prog, test_reader, feeder,
......
...@@ -197,6 +197,8 @@ def lodtensor_to_ndarray(lod_tensor): ...@@ -197,6 +197,8 @@ def lodtensor_to_ndarray(lod_tensor):
def get_model(args): def get_model(args):
if args.use_reader_op:
raise Exception("machine_translation do not support reader op for now.")
embedding_dim = 512 embedding_dim = 512
encoder_size = 512 encoder_size = 512
decoder_size = 512 decoder_size = 512
...@@ -221,7 +223,7 @@ def get_model(args): ...@@ -221,7 +223,7 @@ def get_model(args):
train_batch_generator = paddle.batch( train_batch_generator = paddle.batch(
paddle.reader.shuffle( paddle.reader.shuffle(
paddle.dataset.wmt14.train(dict_size), buf_size=1000), paddle.dataset.wmt14.train(dict_size), buf_size=1000),
batch_size=args.batch_size) batch_size=args.batch_size * args.gpus)
test_batch_generator = paddle.batch( test_batch_generator = paddle.batch(
paddle.reader.shuffle( paddle.reader.shuffle(
......
...@@ -20,6 +20,7 @@ import numpy as np ...@@ -20,6 +20,7 @@ import numpy as np
import argparse import argparse
import time import time
import cProfile import cProfile
import os
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
...@@ -65,7 +66,22 @@ def cnn_model(data): ...@@ -65,7 +66,22 @@ def cnn_model(data):
def get_model(args): def get_model(args):
# Input data if args.use_reader_op:
filelist = [
os.path.join(args.data_path, f) for f in os.listdir(args.data_path)
]
data_file = fluid.layers.open_files(
filenames=filelist,
shapes=[[-1, 1, 28, 28], (-1, 1)],
lod_levels=[0, 0],
dtypes=["float32", "int64"],
thread_num=args.gpus,
pass_num=args.pass_num)
data_file = fluid.layers.double_buffer(
fluid.layers.batch(
data_file, batch_size=args.batch_size))
images, label = fluid.layers.read_file(data_file)
else:
images = fluid.layers.data(name='pixel', shape=[1, 28, 28], dtype=DTYPE) images = fluid.layers.data(name='pixel', shape=[1, 28, 28], dtype=DTYPE)
label = fluid.layers.data(name='label', shape=[1], dtype='int64') label = fluid.layers.data(name='label', shape=[1], dtype='int64')
...@@ -103,7 +119,7 @@ def get_model(args): ...@@ -103,7 +119,7 @@ def get_model(args):
# Reader # Reader
train_reader = paddle.batch( train_reader = paddle.batch(
paddle.dataset.mnist.train(), batch_size=args.batch_size) paddle.dataset.mnist.train(), batch_size=args.batch_size * args.gpus)
test_reader = paddle.batch( test_reader = paddle.batch(
paddle.dataset.mnist.test(), batch_size=args.batch_size) paddle.dataset.mnist.test(), batch_size=args.batch_size)
return avg_cost, inference_program, opt, train_reader, test_reader, batch_acc return avg_cost, inference_program, opt, train_reader, test_reader, batch_acc
...@@ -19,6 +19,7 @@ from __future__ import print_function ...@@ -19,6 +19,7 @@ from __future__ import print_function
import functools import functools
import numpy as np import numpy as np
import time import time
import os
import cProfile, pstats, StringIO import cProfile, pstats, StringIO
...@@ -26,6 +27,7 @@ import paddle ...@@ -26,6 +27,7 @@ import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.core as core import paddle.fluid.core as core
import paddle.fluid.profiler as profiler import paddle.fluid.profiler as profiler
from recordio_converter import imagenet_train, imagenet_test
def conv_bn_layer(input, ch_out, filter_size, stride, padding, act='relu'): def conv_bn_layer(input, ch_out, filter_size, stride, padding, act='relu'):
...@@ -122,14 +124,46 @@ def get_model(args): ...@@ -122,14 +124,46 @@ def get_model(args):
else: else:
dshape = [32, 32, 3] dshape = [32, 32, 3]
model = resnet_cifar10 model = resnet_cifar10
else: train_reader = paddle.dataset.cifar.train10()
test_reader = paddle.dataset.cifar.test10()
elif args.data_set == "flowers":
class_dim = 102 class_dim = 102
if args.data_format == 'NCHW': if args.data_format == 'NCHW':
dshape = [3, 224, 224] dshape = [3, 224, 224]
else: else:
dshape = [224, 224, 3] dshape = [224, 224, 3]
model = resnet_imagenet model = resnet_imagenet
train_reader = paddle.dataset.flowers.train()
test_reader = paddle.dataset.flowers.test()
elif args.data_set == "imagenet":
class_dim = 1000
if args.data_format == 'NCHW':
dshape = [3, 224, 224]
else:
dshape = [224, 224, 3]
model = resnet_imagenet
if not args.data_path:
raise Exception(
"Must specify --data_path when training with imagenet")
train_reader = imagenet_train(args.data_path)
test_reader = imagenet_test(args.data_path)
if args.use_reader_op:
filelist = [
os.path.join(args.data_path, f) for f in os.listdir(args.data_path)
]
data_file = fluid.layers.open_files(
filenames=filelist,
shapes=[[-1] + dshape, (-1, 1)],
lod_levels=[0, 0],
dtypes=["float32", "int64"],
thread_num=args.gpus,
pass_num=args.pass_num)
data_file = fluid.layers.double_buffer(
fluid.layers.batch(
data_file, batch_size=args.batch_size))
input, label = fluid.layers.read_file(data_file)
else:
input = fluid.layers.data(name='data', shape=dshape, dtype='float32') input = fluid.layers.data(name='data', shape=dshape, dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64') label = fluid.layers.data(name='label', shape=[1], dtype='int64')
...@@ -162,15 +196,10 @@ def get_model(args): ...@@ -162,15 +196,10 @@ def get_model(args):
optimizer = fluid.optimizer.Momentum(learning_rate=0.01, momentum=0.9) optimizer = fluid.optimizer.Momentum(learning_rate=0.01, momentum=0.9)
train_reader = paddle.batch( batched_train_reader = paddle.batch(
paddle.reader.shuffle( paddle.reader.shuffle(
paddle.dataset.cifar.train10() train_reader, buf_size=5120),
if args.data_set == 'cifar10' else paddle.dataset.flowers.train(), batch_size=args.batch_size * args.gpus)
buf_size=5120), batched_test_reader = paddle.batch(train_reader, batch_size=args.batch_size)
batch_size=args.batch_size)
test_reader = paddle.batch( return avg_cost, inference_program, optimizer, batched_train_reader, batched_test_reader, batch_acc
paddle.dataset.cifar.test10()
if args.data_set == 'cifar10' else paddle.dataset.flowers.test(),
batch_size=args.batch_size)
return avg_cost, inference_program, optimizer, train_reader, test_reader, batch_acc
...@@ -44,6 +44,9 @@ def crop_sentence(reader, crop_size): ...@@ -44,6 +44,9 @@ def crop_sentence(reader, crop_size):
def get_model(args): def get_model(args):
if args.use_reader_op:
raise Exception(
"stacked_dynamic_lstm do not support reader op for now.")
lstm_size = 512 lstm_size = 512
emb_dim = 512 emb_dim = 512
crop_size = 1500 crop_size = 1500
...@@ -114,7 +117,7 @@ def get_model(args): ...@@ -114,7 +117,7 @@ def get_model(args):
train_reader = batch( train_reader = batch(
paddle.reader.shuffle( paddle.reader.shuffle(
crop_sentence(imdb.train(word_dict), crop_size), buf_size=25000), crop_sentence(imdb.train(word_dict), crop_size), buf_size=25000),
batch_size=args.batch_size) batch_size=args.batch_size * args.gpus)
test_reader = batch( test_reader = batch(
paddle.reader.shuffle( paddle.reader.shuffle(
crop_sentence(imdb.test(word_dict), crop_size), buf_size=25000), crop_sentence(imdb.test(word_dict), crop_size), buf_size=25000),
......
...@@ -22,6 +22,7 @@ import paddle.fluid as fluid ...@@ -22,6 +22,7 @@ import paddle.fluid as fluid
import paddle.fluid.core as core import paddle.fluid.core as core
import argparse import argparse
import functools import functools
import os
def vgg16_bn_drop(input): def vgg16_bn_drop(input):
...@@ -65,8 +66,23 @@ def get_model(args): ...@@ -65,8 +66,23 @@ def get_model(args):
else: else:
data_shape = [224, 224, 3] data_shape = [224, 224, 3]
# Input data if args.use_reader_op:
images = fluid.layers.data(name='pixel', shape=data_shape, dtype='float32') filelist = [
os.path.join(args.data_path, f) for f in os.listdir(args.data_path)
]
data_file = fluid.layers.open_files(
filenames=filelist,
shapes=[[-1] + data_shape, (-1, 1)],
lod_levels=[0, 0],
dtypes=["float32", "int64"],
thread_num=args.gpus,
pass_num=args.pass_num)
data_file = fluid.layers.double_buffer(
fluid.layers.batch(
data_file, batch_size=args.batch_size))
images, label = fluid.layers.read_file(data_file)
else:
images = fluid.layers.data(name='data', shape=dshape, dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64') label = fluid.layers.data(name='label', shape=[1], dtype='int64')
# Train program # Train program
...@@ -95,7 +111,7 @@ def get_model(args): ...@@ -95,7 +111,7 @@ def get_model(args):
paddle.dataset.cifar.train10() paddle.dataset.cifar.train10()
if args.data_set == 'cifar10' else paddle.dataset.flowers.train(), if args.data_set == 'cifar10' else paddle.dataset.flowers.train(),
buf_size=5120), buf_size=5120),
batch_size=args.batch_size) batch_size=args.batch_size * args.gpus)
test_reader = paddle.batch( test_reader = paddle.batch(
paddle.dataset.cifar.test10() paddle.dataset.cifar.test10()
if args.data_set == 'cifar10' else paddle.dataset.flowers.test(), if args.data_set == 'cifar10' else paddle.dataset.flowers.test(),
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import random
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.dataset import mnist, cifar, flowers, image
def convert_2_recordio(py_reader, outfilepath, batch_size, shape_data,
shape_label):
num_batches = 0
with fluid.program_guard(fluid.Program(), fluid.Program()):
reader = paddle.batch(py_reader(), batch_size=batch_size)
feeder = fluid.DataFeeder(
feed_list=[ # order is image and label
fluid.layers.data(
name='image', shape=shape_data),
fluid.layers.data(
name='label', shape=shape_label, dtype='int64'),
],
place=fluid.CPUPlace())
num_batches = fluid.recordio_writer.convert_reader_to_recordio_file(
outfilepath, reader, feeder)
return num_batches
def prepare_mnist(outpath, batch_size):
outfilepath = os.path.join(outpath, "mnist.recordio")
convert_2_recordio(mnist.train, outfilepath, batch_size, [784], [1])
def prepare_cifar10(outpath, batch_size):
outfilepath = os.path.join(outpath, "cifar.recordio")
convert_2_recordio(cifar.train10, outfilepath, batch_size, [3, 32, 32], [1])
def prepare_flowers(outpath, batch_size):
outfilepath = os.path.join(outpath, "flowers.recordio")
convert_2_recordio(flowers.train, outfilepath, batch_size, [3, 224, 224],
[1])
def default_mapper(sample):
img, label = sample
img = image.simple_transform(
img, 256, 224, True, mean=[103.94, 116.78, 123.68])
return img.flatten().astype('float32'), label
def imagenet_train(data_dir):
contents = os.listdir(data_dir)
if set(contents) != set(
["train", "train.txt", "val", "val_set", "val.txt", "unzip.sh"]):
raise Exception("Imagenet data contents error!")
img2label = dict()
imgfilelist = []
with open(os.path.join(data_dir, "train.txt")) as fn:
while 1:
l = fn.readline()
if not l:
break
img, lbl = l[:-1].split(" ")
img2label[img] = int(lbl)
imgfilelist.append(img)
# shuffle all, this is slow
random.shuffle(imgfilelist)
def train_reader():
for idx, imgfile in enumerate(imgfilelist):
data = image.load_image(
os.path.join(data_dir, "train", imgfile.lower()))
label = [img2label[imgfile], ]
yield [data, label]
return paddle.reader.map_readers(default_mapper, train_reader)
def imagenet_test(data_dir):
contents = os.listdir(data_dir)
if set(contents) != set(
["train", "train.txt", "val", "val_set", "val.txt", "unzip.sh"]):
raise Exception("Imagenet data contents error!")
img2label = dict()
imgfilelist = []
with open(os.path.join(data_dir, "val.txt")) as fn:
while 1:
l = fn.readline()
if not l:
break
img, lbl = l[:-1].split(" ")
img2label[img] = int(lbl)
imgfilelist.append(img)
def test_reader():
for idx, imgfile in enumerate(imgfilelist):
base_path = os.path.join(data_dir, "val", imgfile.split(".")[0])
image_path = ".".join([base_path, "jpeg"])
data = image.load_image(image_path)
label = [img2label[imgfile], ]
yield [data, label]
return paddle.reader.map_readers(default_mapper, test_reader)
# FIXME(wuyi): delete this when https://github.com/PaddlePaddle/Paddle/pull/11066 is merged
def convert_reader_to_recordio_files(
filename,
batch_per_file,
reader_creator,
feeder,
compressor=core.RecordIOWriter.Compressor.Snappy,
max_num_records=1000,
feed_order=None):
if feed_order is None:
feed_order = feeder.feed_names
f_name, f_ext = os.path.splitext(filename)
assert (f_ext == ".recordio")
lines = []
f_idx = 0
counter = 0
for idx, batch in enumerate(reader_creator()):
lines.append(batch)
if idx >= batch_per_file and idx % batch_per_file == 0:
filename = "%s-%05d%s" % (f_name, f_idx, f_ext)
with fluid.recordio_writer.create_recordio_writer(
filename, compressor, max_num_records) as writer:
for l in lines:
res = feeder.feed(l)
for each in feed_order:
writer.append_tensor(res[each])
writer.complete_append_tensor()
counter += 1
lines = []
f_idx += 1
print("written file: ", filename)
return counter
def prepare_imagenet(inpath, outpath, batch_size):
r = paddle.batch(imagenet_train(inpath), batch_size=batch_size)
feeder = fluid.DataFeeder(
feed_list=[
fluid.layers.data(
name="image", shape=[3, 224, 224]), fluid.layers.data(
name="label", shape=[1], dtype='int64')
],
place=fluid.CPUPlace())
outpath = os.path.join(outpath, "imagenet.recordio")
convert_reader_to_recordio_files(outpath, 10000, r, feeder)
#!/bin/bash
PADDLE_TRAINING_ROLE=PSERVER PADDLE_PSERVER_PORT=7164 PADDLE_PSERVER_IPS=127.0.0.1 PADDLE_TRAINERS=2 PADDLE_CURRENT_IP=127.0.0.1 PADDLE_TRAINER_ID=0 python fluid_benchmark.py --model resnet --device CPU --update_method pserver --iterations=10000 &
sleep 15
CUDA_VISIBLE_DEVICES=0,1 PADDLE_TRAINING_ROLE=TRAINER PADDLE_PSERVER_PORT=7164 PADDLE_PSERVER_IPS=127.0.0.1 PADDLE_TRAINERS=2 PADDLE_CURRENT_IP=127.0.0.1 PADDLE_TRAINER_ID=0 python fluid_benchmark.py --model resnet --device GPU --update_method pserver --iterations=10000 --gpus 2 &
CUDA_VISIBLE_DEVICES=2,3 PADDLE_TRAINING_ROLE=TRAINER PADDLE_PSERVER_PORT=7164 PADDLE_PSERVER_IPS=127.0.0.1 PADDLE_TRAINERS=2 PADDLE_CURRENT_IP=127.0.0.1 PADDLE_TRAINER_ID=1 python fluid_benchmark.py --model resnet --device GPU --update_method pserver --iterations=10000 --gpus 2 &
...@@ -92,6 +92,9 @@ if(WITH_GPU) ...@@ -92,6 +92,9 @@ if(WITH_GPU)
if(${CUDNN_MAJOR_VERSION} VERSION_LESS 7) if(${CUDNN_MAJOR_VERSION} VERSION_LESS 7)
message(FATAL_ERROR "TensorRT needs CUDNN >= 7.0 to compile") message(FATAL_ERROR "TensorRT needs CUDNN >= 7.0 to compile")
endif() endif()
if(${TENSORRT_MAJOR_VERSION} VERSION_LESS 4)
message(FATAL_ERROR "Paddle needs TensorRT >= 4.0 to compile")
endif()
include_directories(${TENSORRT_INCLUDE_DIR}) include_directories(${TENSORRT_INCLUDE_DIR})
endif() endif()
elseif(WITH_AMD_GPU) elseif(WITH_AMD_GPU)
......
...@@ -106,7 +106,7 @@ PaddlePaddle需要使用Docker环境完成编译,这样可以免去单独安 ...@@ -106,7 +106,7 @@ PaddlePaddle需要使用Docker环境完成编译,这样可以免去单独安
- 学习 Docker 有多难? - 学习 Docker 有多难?
理解 Docker 并不难,大概花十分钟看一下 `这篇文章 <https://zhuanlan.zhihu.com/p/19902938>`_ 。这可以帮您省掉花一小时安装和配置各种开发工具,以及切换机器时需要新安装的辛苦。别忘了 PaddlePaddle 更新可能导致需要新的开发工具。更别提简化问题复现带来的好处了。 理解 Docker 并不难,大概花十分钟看一下 `如何使用Docker <https://zhuanlan.zhihu.com/p/19902938>`_ 。这可以帮您省掉花一小时安装和配置各种开发工具,以及切换机器时需要新安装的辛苦。别忘了 PaddlePaddle 更新可能导致需要新的开发工具。更别提简化问题复现带来的好处了。
- 我可以用 IDE 吗? - 我可以用 IDE 吗?
...@@ -123,7 +123,7 @@ PaddlePaddle需要使用Docker环境完成编译,这样可以免去单独安 ...@@ -123,7 +123,7 @@ PaddlePaddle需要使用Docker环境完成编译,这样可以免去单独安
- 可以并行编译吗? - 可以并行编译吗?
是的。我们的 Docker image 运行一个 `Bash脚本 <https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/scripts/docker/build.sh>`_ 。这个脚本调用 `make -j$(nproc)` 来启动和 CPU 核一样多的进程来并行编译。 是的。我们的 Docker image 运行一个 `Paddle编译Bash脚本 <https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/scripts/docker/build.sh>`_ 。这个脚本调用 `make -j$(nproc)` 来启动和 CPU 核一样多的进程来并行编译。
- Docker 需要 sudo - Docker 需要 sudo
...@@ -131,11 +131,11 @@ PaddlePaddle需要使用Docker环境完成编译,这样可以免去单独安 ...@@ -131,11 +131,11 @@ PaddlePaddle需要使用Docker环境完成编译,这样可以免去单独安
- 在 Windows/MacOS 上编译很慢 - 在 Windows/MacOS 上编译很慢
Docker 在 Windows 和 MacOS 都可以运行。不过实际上是运行在一个 Linux 虚拟机上。可能需要注意给这个虚拟机多分配一些 CPU 和内存,以保证编译高效。具体做法请参考 `这个issue <https://github.com/PaddlePaddle/Paddle/issues/627>`_ 。 Docker 在 Windows 和 MacOS 都可以运行。不过实际上是运行在一个 Linux 虚拟机上。可能需要注意给这个虚拟机多分配一些 CPU 和内存,以保证编译高效。具体做法请参考 `如何为Windows/Mac计算机上的Docker增加内存和虚拟机 <https://github.com/PaddlePaddle/Paddle/issues/627>`_ 。
- 磁盘不够 - 磁盘不够
本文中的例子里,`docker run` 命令里都用了 `--rm` 参数,这样保证运行结束之后的 containers 不会保留在磁盘上。可以用 `docker ps -a` 命令看到停止后但是没有删除的 containers。`docker build` 命令有时候会产生一些中间结果,是没有名字的 images,也会占用磁盘。可以参考 `这篇文章 <https://zaiste.net/posts/removing_docker_containers/>`_ 来清理这些内容。 本文中的例子里,`docker run` 命令里都用了 `--rm` 参数,这样保证运行结束之后的 containers 不会保留在磁盘上。可以用 `docker ps -a` 命令看到停止后但是没有删除的 containers。`docker build` 命令有时候会产生一些中间结果,是没有名字的 images,也会占用磁盘。可以参考 `如何删除Docker Container <https://zaiste.net/posts/removing_docker_containers/>`_ 来清理这些内容。
.. _compile_deps: .. _compile_deps:
...@@ -195,7 +195,7 @@ BLAS ...@@ -195,7 +195,7 @@ BLAS
PaddlePaddle支持 `MKL <https://software.intel.com/en-us/intel-mkl>`_ 和 PaddlePaddle支持 `MKL <https://software.intel.com/en-us/intel-mkl>`_ 和
`OpenBlAS <http://www.openblas.net/>`_ 两种BLAS库。默认使用MKL。如果使用MKL并且机器含有AVX2指令集, `OpenBlAS <http://www.openblas.net/>`_ 两种BLAS库。默认使用MKL。如果使用MKL并且机器含有AVX2指令集,
还会下载MKL-DNN数学库,详细参考 `这里 <https://github.com/PaddlePaddle/Paddle/tree/develop/doc/design/mkldnn#cmake>`_ 。 还会下载MKL-DNN数学库,详细参考 `mkldnn设计文档 <https://github.com/PaddlePaddle/Paddle/tree/develop/doc/design/mkldnn#cmake>`_ 。
如果关闭MKL,则会使用OpenBLAS作为BLAS库。 如果关闭MKL,则会使用OpenBLAS作为BLAS库。
......
...@@ -24,31 +24,37 @@ set(ANAKIN_LIBRARY "" CACHE STRING "path of Anakin library") ...@@ -24,31 +24,37 @@ set(ANAKIN_LIBRARY "" CACHE STRING "path of Anakin library")
set(inference_deps paddle_inference_api paddle_fluid_api) set(inference_deps paddle_inference_api paddle_fluid_api)
# if anakin is set enable anakin api implementation # if anakin is set enable anakin api implementation
if(ANAKIN_INCLUDE_DIR AND ANAKIN_LIBRARY) if(ANAKIN_INCLUDE AND ANAKIN_LIBRARY)
set(ANAKIN_FOUND ON) set(ANAKIN_FOUND ON)
else() else()
set(ANAKIN_FOUND OFF) set(ANAKIN_FOUND OFF)
endif() endif()
function(fetch_include_recursively root_dir)
if (IS_DIRECTORY ${root_dir})
include_directories(${root_dir})
endif()
file(GLOB ALL_SUB RELATIVE ${root_dir} ${root_dir}/*)
foreach(sub ${ALL_SUB})
if (IS_DIRECTORY ${root_dir}/${sub})
fetch_include_recursively(${root_dir}/${sub})
endif()
endforeach()
endfunction()
if (ANAKIN_FOUND) if (ANAKIN_FOUND)
# Anakin's code style doesn't follow google c style. # Anakin's code style doesn't follow google c style.
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-error=comment set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-error=unused-variable -Wno-error=format-extra-args -Wno-error=comment -Wno-error=format -Wno-error=switch -Wno-error=return-type -Wno-error=non-virtual-dtor -Wno-reorder -Wno-error=cpp")
-Wno-error=reorder
-Wno-error=format
-Wno-error=switch
-Wno-error=return-type
-Wno-error=non-virtual-dtor
-Wno-error=cpp")
message(STATUS "Anakin for inference is enabled") message(STATUS "Anakin for inference is enabled")
message(STATUS "Anakin is set INCLUDE:${ANAKIN_INCLUDE} LIBRARY:${ANAKIN_LIBRARY}") message(STATUS "Anakin is set INCLUDE:${ANAKIN_INCLUDE} LIBRARY:${ANAKIN_LIBRARY}")
include_directories("${ANAKIN_INCLUDE}") fetch_include_recursively(${ANAKIN_INCLUDE})
# Anakin's source path is a mass, need to set sub-directories trivially.
include_directories("${ANAKIN_INCLUDE}/saber") link_directories(${ANAKIN_LIBRARY})
link_directories("${ANAKIN_LIBRARY}")
nv_library(inference_anakin_api SRCS paddle_inference_api_anakin_engine.cc) nv_library(inference_anakin_api SHARED SRCS paddle_inference_api.cc paddle_inference_api_anakin_engine.cc)
target_link_libraries(inference_anakin_api anakin) target_link_libraries(inference_anakin_api anakin anakin_saber_common)
list(APPEND inference_deps inference_anakin_api) list(APPEND inference_deps inference_anakin_api)
endif() endif()
...@@ -84,8 +90,8 @@ inference_api_test(test_paddle_inference_api_impl ...@@ -84,8 +90,8 @@ inference_api_test(test_paddle_inference_api_impl
ARGS test_word2vec test_image_classification) ARGS test_word2vec test_image_classification)
if (ANAKIN_FOUND) if (ANAKIN_FOUND)
nv_test(inference_anakin_test SRCS paddle_inference_api_anakin_engine_tester.cc cc_test(inference_anakin_test SRCS paddle_inference_api_anakin_engine_tester.cc
DEPS ${inference_deps} protobuf) DEPS ${inference_deps})
endif() endif()
if(WITH_TESTING) if(WITH_TESTING)
......
...@@ -19,8 +19,8 @@ limitations under the License. */ ...@@ -19,8 +19,8 @@ limitations under the License. */
#include <glog/logging.h> #include <glog/logging.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <memory> #include <memory>
#include <thread>
#include "paddle/contrib/inference/paddle_inference_api.h" #include "paddle/contrib/inference/paddle_inference_api.h"
namespace paddle { namespace paddle {
namespace demo { namespace demo {
...@@ -61,13 +61,67 @@ void Main(bool use_gpu) { ...@@ -61,13 +61,67 @@ void Main(bool use_gpu) {
for (size_t i = 0; i < std::min(5UL, num_elements); i++) { for (size_t i = 0; i < std::min(5UL, num_elements); i++) {
LOG(INFO) << static_cast<float*>(outputs.front().data.data)[i]; LOG(INFO) << static_cast<float*>(outputs.front().data.data)[i];
} }
// TODO(Superjomn): this is should be free automatically
free(outputs[0].data.data);
}
}
void MainThreads(int num_threads, bool use_gpu) {
// Multi-threads only support on CPU
// 0. Create PaddlePredictor with a config.
NativeConfig config;
config.model_dir = FLAGS_dirname + "word2vec.inference.model";
config.use_gpu = use_gpu;
config.fraction_of_gpu_memory = 0.15;
config.device = 0;
auto main_predictor =
CreatePaddlePredictor<NativeConfig, PaddleEngineKind::kNative>(config);
std::vector<std::thread> threads;
for (int tid = 0; tid < num_threads; ++tid) {
threads.emplace_back([&, tid]() {
// 1. clone a predictor which shares the same parameters
auto predictor = main_predictor->Clone();
constexpr int num_batches = 3;
for (int batch_id = 0; batch_id < num_batches; ++batch_id) {
// 2. Dummy Input Data
int64_t data[4] = {1, 2, 3, 4};
PaddleBuf buf{.data = data, .length = sizeof(data)};
PaddleTensor tensor{.name = "",
.shape = std::vector<int>({4, 1}),
.data = buf,
.dtype = PaddleDType::INT64};
std::vector<PaddleTensor> inputs(4, tensor);
std::vector<PaddleTensor> outputs;
// 3. Run
CHECK(predictor->Run(inputs, &outputs));
// 4. Get output.
ASSERT_EQ(outputs.size(), 1UL);
LOG(INFO) << "TID: " << tid << ", "
<< "output buffer size: " << outputs.front().data.length;
const size_t num_elements = outputs.front().data.length / sizeof(float);
// The outputs' buffers are in CPU memory.
for (size_t i = 0; i < std::min(5UL, num_elements); i++) {
LOG(INFO) << static_cast<float*>(outputs.front().data.data)[i];
}
free(outputs[0].data.data);
}
});
}
for (int i = 0; i < num_threads; ++i) {
threads[i].join();
} }
} }
TEST(demo, word2vec_cpu) { Main(false /*use_gpu*/); } TEST(demo, word2vec_cpu) { Main(false /*use_gpu*/); }
TEST(demo_multi_threads, word2vec_cpu_1) { MainThreads(1, false /*use_gpu*/); }
TEST(demo_multi_threads, word2vec_cpu_4) { MainThreads(4, false /*use_gpu*/); }
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
TEST(demo, word2vec_gpu) { Main(true /*use_gpu*/); } TEST(demo, word2vec_gpu) { Main(true /*use_gpu*/); }
TEST(demo_multi_threads, word2vec_gpu_1) { MainThreads(1, true /*use_gpu*/); }
TEST(demo_multi_threads, word2vec_gpu_4) { MainThreads(4, true /*use_gpu*/); }
#endif #endif
} // namespace demo } // namespace demo
......
...@@ -113,5 +113,4 @@ struct AnakinConfig : public PaddlePredictor::Config { ...@@ -113,5 +113,4 @@ struct AnakinConfig : public PaddlePredictor::Config {
// Similarly, each engine kind should map to a unique predictor implementation. // Similarly, each engine kind should map to a unique predictor implementation.
template <typename ConfigT, PaddleEngineKind engine = PaddleEngineKind::kNative> template <typename ConfigT, PaddleEngineKind engine = PaddleEngineKind::kNative>
std::unique_ptr<PaddlePredictor> CreatePaddlePredictor(const ConfigT& config); std::unique_ptr<PaddlePredictor> CreatePaddlePredictor(const ConfigT& config);
} // namespace paddle } // namespace paddle
...@@ -24,8 +24,16 @@ PaddleInferenceAnakinPredictor::PaddleInferenceAnakinPredictor( ...@@ -24,8 +24,16 @@ PaddleInferenceAnakinPredictor::PaddleInferenceAnakinPredictor(
} }
bool PaddleInferenceAnakinPredictor::Init(const AnakinConfig &config) { bool PaddleInferenceAnakinPredictor::Init(const AnakinConfig &config) {
// TODO(Superjomn) Tell anakin to support return code. if (!(graph_.load(config.model_file))) {
engine_.Build(config.model_file, config.max_batch_size); return false;
}
graph_.ResetBatchSize("input_0", config.max_batch_size);
// optimization for graph
if (!(graph_.Optimize())) {
return false;
}
// construct executer
executor_.init(graph_);
return true; return true;
} }
...@@ -38,24 +46,30 @@ bool PaddleInferenceAnakinPredictor::Run( ...@@ -38,24 +46,30 @@ bool PaddleInferenceAnakinPredictor::Run(
<< "'s type is not float"; << "'s type is not float";
return false; return false;
} }
engine_.SetInputFromCPU( auto d_tensor_in_p = executor_.get_in(input.name);
input.name, static_cast<float *>(input.data.data), input.data.length); float *d_data_p = d_tensor_in_p->mutable_data();
if (cudaMemcpy(d_data_p,
static_cast<float *>(input.data.data),
d_tensor_in_p->valid_size() * sizeof(float),
cudaMemcpyHostToDevice) != 0) {
LOG(ERROR) << "copy data from CPU to GPU error";
return false;
}
} }
// TODO(Superjomn) Tell anakin to support return code. executor_.prediction();
engine_.Execute();
if (output_data->empty()) { if (output_data->empty()) {
LOG(ERROR) << "At least one output should be set with tensors' names."; LOG(ERROR) << "At least one output should be set with tensors' names.";
return false; return false;
} }
for (auto &output : *output_data) { for (auto &output : *output_data) {
auto *tensor = engine_.GetOutputInGPU(output.name); auto *tensor = executor_.get_out(output.name);
output.shape = tensor->shape(); output.shape = tensor->shape();
// Copy data from GPU -> CPU // Copy data from GPU -> CPU
if (cudaMemcpy(output.data.data, if (cudaMemcpy(output.data.data,
tensor->data(), tensor->mutable_data(),
tensor->size(), tensor->valid_size() * sizeof(float),
cudaMemcpyDeviceToHost) != 0) { cudaMemcpyDeviceToHost) != 0) {
LOG(ERROR) << "copy data from GPU to CPU error"; LOG(ERROR) << "copy data from GPU to CPU error";
return false; return false;
...@@ -64,9 +78,26 @@ bool PaddleInferenceAnakinPredictor::Run( ...@@ -64,9 +78,26 @@ bool PaddleInferenceAnakinPredictor::Run(
return true; return true;
} }
// TODO(Superjomn) To implement latter. anakin::Net<anakin::NV, anakin::saber::AK_FLOAT, anakin::Precision::FP32>
&PaddleInferenceAnakinPredictor::get_executer() {
return executor_;
}
// the cloned new Predictor of anakin share the same net weights from original
// Predictor
std::unique_ptr<PaddlePredictor> PaddleInferenceAnakinPredictor::Clone() { std::unique_ptr<PaddlePredictor> PaddleInferenceAnakinPredictor::Clone() {
VLOG(3) << "Anakin Predictor::clone";
std::unique_ptr<PaddlePredictor> cls(new PaddleInferenceAnakinPredictor());
// construct executer from other graph
auto anakin_predictor_p =
dynamic_cast<PaddleInferenceAnakinPredictor *>(cls.get());
if (!anakin_predictor_p) {
LOG(ERROR) << "fail to call Init";
return nullptr; return nullptr;
}
anakin_predictor_p->get_executer().init(graph_);
return std::move(cls);
} }
// A factory to help create difference predictor. // A factory to help create difference predictor.
...@@ -74,6 +105,7 @@ template <> ...@@ -74,6 +105,7 @@ template <>
std::unique_ptr<PaddlePredictor> std::unique_ptr<PaddlePredictor>
CreatePaddlePredictor<AnakinConfig, PaddleEngineKind::kAnakin>( CreatePaddlePredictor<AnakinConfig, PaddleEngineKind::kAnakin>(
const AnakinConfig &config) { const AnakinConfig &config) {
VLOG(3) << "Anakin Predictor create.";
std::unique_ptr<PaddlePredictor> x( std::unique_ptr<PaddlePredictor> x(
new PaddleInferenceAnakinPredictor(config)); new PaddleInferenceAnakinPredictor(config));
return x; return x;
......
...@@ -20,32 +20,42 @@ limitations under the License. */ ...@@ -20,32 +20,42 @@ limitations under the License. */
#pragma once #pragma once
// NOTE This header file do not have namespace. // NOTE This header file do not have namespace.
// TODO(Superjomn) Tell Anakin to provide better APIs. //#include <test/framework/net/paddle_api.h>
#include <test/framework/net/paddle_api.h>
#include "paddle/contrib/inference/paddle_inference_api.h" #include "paddle/contrib/inference/paddle_inference_api.h"
#include "framework/core/net/net.h"
#include "saber/saber_types.h"
namespace paddle { namespace paddle {
class PaddleInferenceAnakinPredictor : public PaddlePredictor { class PaddleInferenceAnakinPredictor : public PaddlePredictor {
public: public:
PaddleInferenceAnakinPredictor() {}
PaddleInferenceAnakinPredictor(const AnakinConfig& config); PaddleInferenceAnakinPredictor(const AnakinConfig& config);
// NOTE Unlike the native engine, the buffers of anakin engine's output_data // NOTE Unlike the native engine, the buffers of anakin engine's output_data
// should be allocated first. // should be allocated first.
// TODO(Superjomn) should unify all the behaviors of output_data accross all
// the engines.
bool Run(const std::vector<PaddleTensor>& inputs, bool Run(const std::vector<PaddleTensor>& inputs,
std::vector<PaddleTensor>* output_data) override; std::vector<PaddleTensor>* output_data) override;
std::unique_ptr<PaddlePredictor> Clone() override; std::unique_ptr<PaddlePredictor> Clone() override;
anakin::Net<anakin::NV, anakin::saber::AK_FLOAT, anakin::Precision::FP32>&
get_executer();
~PaddleInferenceAnakinPredictor() override{};
private: private:
bool Init(const AnakinConfig& config); bool Init(const AnakinConfig& config);
anakin::AnakinEngine<anakin::NV, anakin::graph::Graph<anakin::NV,
anakin::saber::AK_FLOAT, anakin::saber::AK_FLOAT,
anakin::Precision::FP32> anakin::Precision::FP32>
engine_; graph_;
anakin::Net<anakin::NV, anakin::saber::AK_FLOAT, anakin::Precision::FP32>
executor_;
AnakinConfig config_;
}; };
} // namespace paddle } // namespace paddle
...@@ -12,16 +12,54 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,16 +12,54 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/contrib/inference/paddle_inference_api.h" #include <glog/logging.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "gflags/gflags.h"
#include "paddle/contrib/inference/paddle_inference_api.h"
namespace paddle { namespace paddle {
TEST(inference, anakin) { AnakinConfig GetConfig() {
AnakinConfig config; AnakinConfig config;
config.model_file = "./mobilenet_v2.anakin.bin";
config.device = 0;
config.max_batch_size = 1;
return config;
}
auto engine = TEST(inference, anakin) {
AnakinConfig config = GetConfig();
auto predictor =
CreatePaddlePredictor<AnakinConfig, PaddleEngineKind::kAnakin>(config); CreatePaddlePredictor<AnakinConfig, PaddleEngineKind::kAnakin>(config);
float data[1 * 3 * 224 * 224] = {1.0f};
PaddleBuf buf{.data = data, .length = sizeof(data)};
PaddleTensor tensor{.name = "input_0",
.shape = std::vector<int>({1, 3, 224, 224}),
.data = buf,
.dtype = PaddleDType::FLOAT32};
// For simplicity, we set all the slots with the same data.
std::vector<PaddleTensor> paddle_tensor_feeds(1, tensor);
float data_out[1000];
PaddleBuf buf_out{.data = data_out, .length = sizeof(data)};
PaddleTensor tensor_out{.name = "prob_out",
.shape = std::vector<int>({1000, 1}),
.data = buf_out,
.dtype = PaddleDType::FLOAT32};
std::vector<PaddleTensor> outputs(1, tensor_out);
ASSERT_TRUE(predictor->Run(paddle_tensor_feeds, &outputs));
float* data_o = static_cast<float*>(outputs[0].data.data);
for (size_t j = 0; j < 1000; ++j) {
LOG(INFO) << "output[" << j << "]: " << data_o[j];
}
} }
} // namespace paddle } // namespace paddle
...@@ -15,6 +15,8 @@ limitations under the License. */ ...@@ -15,6 +15,8 @@ limitations under the License. */
#include <glog/logging.h> #include <glog/logging.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <thread>
#include "gflags/gflags.h" #include "gflags/gflags.h"
#include "paddle/contrib/inference/paddle_inference_api_impl.h" #include "paddle/contrib/inference/paddle_inference_api_impl.h"
#include "paddle/fluid/inference/tests/test_helper.h" #include "paddle/fluid/inference/tests/test_helper.h"
...@@ -45,14 +47,19 @@ NativeConfig GetConfig() { ...@@ -45,14 +47,19 @@ NativeConfig GetConfig() {
config.model_dir = FLAGS_dirname + "word2vec.inference.model"; config.model_dir = FLAGS_dirname + "word2vec.inference.model";
LOG(INFO) << "dirname " << config.model_dir; LOG(INFO) << "dirname " << config.model_dir;
config.fraction_of_gpu_memory = 0.15; config.fraction_of_gpu_memory = 0.15;
#ifdef PADDLE_WITH_CUDA
config.use_gpu = true; config.use_gpu = true;
#else
config.use_gpu = false;
#endif
config.device = 0; config.device = 0;
return config; return config;
} }
TEST(paddle_inference_api_impl, word2vec) { void MainWord2Vec(bool use_gpu) {
NativeConfig config = GetConfig(); NativeConfig config = GetConfig();
auto predictor = CreatePaddlePredictor<NativeConfig>(config); auto predictor = CreatePaddlePredictor<NativeConfig>(config);
config.use_gpu = use_gpu;
framework::LoDTensor first_word, second_word, third_word, fourth_word; framework::LoDTensor first_word, second_word, third_word, fourth_word;
framework::LoD lod{{0, 1}}; framework::LoD lod{{0, 1}};
...@@ -100,11 +107,12 @@ TEST(paddle_inference_api_impl, word2vec) { ...@@ -100,11 +107,12 @@ TEST(paddle_inference_api_impl, word2vec) {
free(outputs[0].data.data); free(outputs[0].data.data);
} }
TEST(paddle_inference_api_impl, image_classification) { void MainImageClassification(bool use_gpu) {
int batch_size = 2; int batch_size = 2;
bool use_mkldnn = false; bool use_mkldnn = false;
bool repeat = false; bool repeat = false;
NativeConfig config = GetConfig(); NativeConfig config = GetConfig();
config.use_gpu = use_gpu;
config.model_dir = config.model_dir =
FLAGS_dirname + "image_classification_resnet.inference.model"; FLAGS_dirname + "image_classification_resnet.inference.model";
...@@ -149,4 +157,143 @@ TEST(paddle_inference_api_impl, image_classification) { ...@@ -149,4 +157,143 @@ TEST(paddle_inference_api_impl, image_classification) {
free(data); free(data);
} }
void MainThreadsWord2Vec(bool use_gpu) {
NativeConfig config = GetConfig();
config.use_gpu = use_gpu;
auto main_predictor = CreatePaddlePredictor<NativeConfig>(config);
// prepare inputs data and reference results
constexpr int num_jobs = 3;
std::vector<std::vector<framework::LoDTensor>> jobs(num_jobs);
std::vector<std::vector<PaddleTensor>> paddle_tensor_feeds(num_jobs);
std::vector<framework::LoDTensor> refs(num_jobs);
for (size_t i = 0; i < jobs.size(); ++i) {
// each job has 4 words
jobs[i].resize(4);
for (size_t j = 0; j < 4; ++j) {
framework::LoD lod{{0, 1}};
int64_t dict_size = 2073; // The size of dictionary
SetupLoDTensor(&jobs[i][j], lod, static_cast<int64_t>(0), dict_size - 1);
paddle_tensor_feeds[i].push_back(LodTensorToPaddleTensor(&jobs[i][j]));
}
// get reference result of each job
std::vector<paddle::framework::LoDTensor*> ref_feeds;
std::vector<paddle::framework::LoDTensor*> ref_fetches(1, &refs[i]);
for (auto& word : jobs[i]) {
ref_feeds.push_back(&word);
}
TestInference<platform::CPUPlace>(config.model_dir, ref_feeds, ref_fetches);
}
// create threads and each thread run 1 job
std::vector<std::thread> threads;
for (int tid = 0; tid < num_jobs; ++tid) {
threads.emplace_back([&, tid]() {
auto predictor = main_predictor->Clone();
auto& local_inputs = paddle_tensor_feeds[tid];
std::vector<PaddleTensor> local_outputs;
ASSERT_TRUE(predictor->Run(local_inputs, &local_outputs));
// check outputs range
ASSERT_EQ(local_outputs.size(), 1UL);
const size_t len = local_outputs[0].data.length;
float* data = static_cast<float*>(local_outputs[0].data.data);
for (size_t j = 0; j < len / sizeof(float); ++j) {
ASSERT_LT(data[j], 1.0);
ASSERT_GT(data[j], -1.0);
}
// check outputs correctness
float* ref_data = refs[tid].data<float>();
EXPECT_EQ(refs[tid].numel(), static_cast<int64_t>(len / sizeof(float)));
for (int i = 0; i < refs[tid].numel(); ++i) {
EXPECT_NEAR(ref_data[i], data[i], 1e-3);
}
free(data);
});
}
for (int i = 0; i < num_jobs; ++i) {
threads[i].join();
}
}
void MainThreadsImageClassification(bool use_gpu) {
constexpr int num_jobs = 4; // each job run 1 batch
constexpr int batch_size = 1;
NativeConfig config = GetConfig();
config.use_gpu = use_gpu;
config.model_dir =
FLAGS_dirname + "image_classification_resnet.inference.model";
auto main_predictor = CreatePaddlePredictor<NativeConfig>(config);
std::vector<framework::LoDTensor> jobs(num_jobs);
std::vector<std::vector<PaddleTensor>> paddle_tensor_feeds(num_jobs);
std::vector<framework::LoDTensor> refs(num_jobs);
for (size_t i = 0; i < jobs.size(); ++i) {
// prepare inputs
std::vector<std::vector<int64_t>> feed_target_shapes =
GetFeedTargetShapes(config.model_dir, /*is_combined*/ false);
feed_target_shapes[0][0] = batch_size;
framework::DDim input_dims = framework::make_ddim(feed_target_shapes[0]);
SetupTensor<float>(&jobs[i], input_dims, 0.f, 1.f);
paddle_tensor_feeds[i].push_back(LodTensorToPaddleTensor(&jobs[i]));
// get reference result of each job
std::vector<framework::LoDTensor*> ref_feeds(1, &jobs[i]);
std::vector<framework::LoDTensor*> ref_fetches(1, &refs[i]);
TestInference<platform::CPUPlace>(config.model_dir, ref_feeds, ref_fetches);
}
// create threads and each thread run 1 job
std::vector<std::thread> threads;
for (int tid = 0; tid < num_jobs; ++tid) {
threads.emplace_back([&, tid]() {
auto predictor = main_predictor->Clone();
auto& local_inputs = paddle_tensor_feeds[tid];
std::vector<PaddleTensor> local_outputs;
ASSERT_TRUE(predictor->Run(local_inputs, &local_outputs));
// check outputs correctness
ASSERT_EQ(local_outputs.size(), 1UL);
const size_t len = local_outputs[0].data.length;
float* data = static_cast<float*>(local_outputs[0].data.data);
float* ref_data = refs[tid].data<float>();
EXPECT_EQ(refs[tid].numel(), len / sizeof(float));
for (int i = 0; i < refs[tid].numel(); ++i) {
EXPECT_NEAR(ref_data[i], data[i], 1e-3);
}
free(data);
});
}
for (int i = 0; i < num_jobs; ++i) {
threads[i].join();
}
}
TEST(inference_api_native, word2vec_cpu) { MainWord2Vec(false /*use_gpu*/); }
TEST(inference_api_native, word2vec_cpu_threads) {
MainThreadsWord2Vec(false /*use_gpu*/);
}
TEST(inference_api_native, image_classification_cpu) {
MainThreadsImageClassification(false /*use_gpu*/);
}
TEST(inference_api_native, image_classification_cpu_threads) {
MainThreadsImageClassification(false /*use_gpu*/);
}
#ifdef PADDLE_WITH_CUDA
TEST(inference_api_native, word2vec_gpu) { MainWord2Vec(true /*use_gpu*/); }
TEST(inference_api_native, word2vec_gpu_threads) {
MainThreadsWord2Vec(true /*use_gpu*/);
}
TEST(inference_api_native, image_classification_gpu) {
MainThreadsImageClassification(true /*use_gpu*/);
}
TEST(inference_api_native, image_classification_gpu_threads) {
MainThreadsImageClassification(true /*use_gpu*/);
}
#endif
} // namespace paddle } // namespace paddle
...@@ -87,7 +87,7 @@ cc_library(executor SRCS executor.cc DEPS op_registry device_context scope ...@@ -87,7 +87,7 @@ cc_library(executor SRCS executor.cc DEPS op_registry device_context scope
framework_proto glog lod_rank_table feed_fetch_method) framework_proto glog lod_rank_table feed_fetch_method)
cc_library(parallel_executor SRCS parallel_executor.cc DEPS multi_devices_graph_builder threaded_ssa_graph_executor scope_buffered_ssa_graph_executor) cc_library(parallel_executor SRCS parallel_executor.cc DEPS graph_builder_factory threaded_ssa_graph_executor scope_buffered_ssa_graph_executor)
cc_library(prune SRCS prune.cc DEPS framework_proto) cc_library(prune SRCS prune.cc DEPS framework_proto)
cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context) cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context)
......
...@@ -27,6 +27,7 @@ enum class DataLayout { ...@@ -27,6 +27,7 @@ enum class DataLayout {
kNHWC = 0, kNHWC = 0,
kNCHW = 1, kNCHW = 1,
kAnyLayout = 2, kAnyLayout = 2,
kMKLDNN = 3, // all layouts supported by MKLDNN internally
}; };
inline DataLayout StringToDataLayout(const std::string& str) { inline DataLayout StringToDataLayout(const std::string& str) {
...@@ -41,6 +42,8 @@ inline DataLayout StringToDataLayout(const std::string& str) { ...@@ -41,6 +42,8 @@ inline DataLayout StringToDataLayout(const std::string& str) {
return DataLayout::kNCHW; return DataLayout::kNCHW;
} else if (s == "ANYLAYOUT") { } else if (s == "ANYLAYOUT") {
return DataLayout::kAnyLayout; return DataLayout::kAnyLayout;
} else if (s == "MKLDNNLAYOUT") {
return DataLayout::kMKLDNN;
} else { } else {
PADDLE_THROW("Unknown storage order string: %s", s); PADDLE_THROW("Unknown storage order string: %s", s);
} }
...@@ -54,8 +57,10 @@ inline std::string DataLayoutToString(const DataLayout& data_layout) { ...@@ -54,8 +57,10 @@ inline std::string DataLayoutToString(const DataLayout& data_layout) {
return "NCHW"; return "NCHW";
case DataLayout::kAnyLayout: case DataLayout::kAnyLayout:
return "ANY_LAYOUT"; return "ANY_LAYOUT";
case DataLayout::kMKLDNN:
return "MKLDNNLAYOUT";
default: default:
PADDLE_THROW("unknown DataLayou %d", data_layout); PADDLE_THROW("unknown DataLayout %d", data_layout);
} }
} }
......
...@@ -16,6 +16,9 @@ ...@@ -16,6 +16,9 @@
#include <vector> #include <vector>
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -88,5 +91,85 @@ void TransDataLayout(const OpKernelType& kernel_type_for_var, ...@@ -88,5 +91,85 @@ void TransDataLayout(const OpKernelType& kernel_type_for_var,
out->set_layout(expected_kernel_type.data_layout_); out->set_layout(expected_kernel_type.data_layout_);
} }
#ifdef PADDLE_WITH_MKLDNN
using mkldnn::memory;
using mkldnn::primitive;
using mkldnn::reorder;
void* GetDataFromTensor(const Tensor& tensor, mkldnn::memory::data_type type) {
switch (type) {
case mkldnn::memory::data_type::f32:
return platform::to_void_cast(tensor.data<float>());
case mkldnn::memory::data_type::s8:
return platform::to_void_cast(tensor.data<char>());
case mkldnn::memory::data_type::u8:
return platform::to_void_cast(tensor.data<unsigned char>());
case mkldnn::memory::data_type::s16:
return platform::to_void_cast(tensor.data<int16_t>());
case mkldnn::memory::data_type::s32:
return platform::to_void_cast(tensor.data<int32_t>());
default:
PADDLE_THROW("wrong mkldnn type provided");
}
}
#endif
void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var,
const OpKernelType& expected_kernel_type,
const Tensor& in, Tensor* out) {
auto in_layout = kernel_type_for_var.data_layout_;
auto out_layout = expected_kernel_type.data_layout_;
PADDLE_ENFORCE(
in_layout == DataLayout::kMKLDNN && out_layout != DataLayout::kMKLDNN,
"TransDataLayoutFromMKLDNN only supports transform from MKLDNN to "
"non-MKLDNN");
#ifdef PADDLE_WITH_MKLDNN
PADDLE_ENFORCE(in.format() != memory::format::format_undef &&
in.format() != memory::format::any,
"Input tensor should have specified memory format");
// Set default as NCHW in case not specified
out_layout =
out_layout == DataLayout::kAnyLayout ? DataLayout::kNCHW : out_layout;
auto& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = dynamic_cast<platform::MKLDNNDeviceContext*>(
pool.Get(expected_kernel_type.place_));
auto& cpu_engine = dev_ctx->GetEngine();
std::vector<int> in_tz = paddle::framework::vectorize2int(in.dims());
std::vector<int> out_tz = in_tz;
memory::data_type in_type = ToMKLDNNDataType(in.type());
PADDLE_ENFORCE(in_type != memory::data_type::data_undef,
"Input tensor type is not supported: ", in.type().name());
memory::data_type out_type = in_type;
memory::format in_format =
in_tz.size() == 2 ? memory::format::nc : in.format();
memory::format out_format =
out_tz.size() == 2 ? memory::format::nc : ToMKLDNNFormat(out_layout);
void* in_data = GetDataFromTensor(in, in_type);
// output tensor has the same dims as input. Reorder don't change dims
out->Resize(in.dims());
auto out_data = out->mutable_data(expected_kernel_type.place_, in.type());
auto in_memory = memory({{{in_tz}, in_type, in_format}, cpu_engine}, in_data);
auto out_memory =
memory({{{out_tz}, out_type, out_format}, cpu_engine}, out_data);
platform::Reorder(in_memory, out_memory);
out->set_layout(out_layout);
// reset format since the out tensor will be feed to non-MKLDNN OPkernel
out->set_format(memory::format::format_undef);
#endif
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#pragma once #pragma once
#include <map>
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_kernel_type.h" #include "paddle/fluid/framework/op_kernel_type.h"
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
...@@ -22,6 +23,50 @@ ...@@ -22,6 +23,50 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
#ifdef PADDLE_WITH_MKLDNN
using MKLDNNFormat = mkldnn::memory::format;
using MKLDNNDataType = mkldnn::memory::data_type;
inline MKLDNNFormat ToMKLDNNFormat(const DataLayout& layout) {
switch (layout) {
case DataLayout::kNHWC:
return MKLDNNFormat::nhwc;
case DataLayout::kNCHW:
return MKLDNNFormat::nchw;
default:
PADDLE_THROW("Fail to convert layout %s to MKLDNN format",
DataLayoutToString(layout));
}
}
inline DataLayout ToPaddleLayout(const MKLDNNFormat& format) {
switch (format) {
case MKLDNNFormat::nhwc:
return DataLayout::kNHWC;
case MKLDNNFormat::nchw:
return DataLayout::kNCHW;
default:
PADDLE_THROW("Fail to convert MKLDNN format to paddle layout");
}
}
inline MKLDNNDataType ToMKLDNNDataType(const std::type_index type) {
static const std::map<std::type_index, MKLDNNDataType> dict{
{std::type_index(typeid(float)), MKLDNNDataType::f32}, // NOLINT
{std::type_index(typeid(char)), MKLDNNDataType::s8}, // NOLINT
{std::type_index(typeid(unsigned char)), MKLDNNDataType::u8},
{std::type_index(typeid(int16_t)), MKLDNNDataType::s16},
{std::type_index(typeid(int32_t)), MKLDNNDataType::s32}};
auto iter = dict.find(type);
if (iter != dict.end()) return iter->second;
return MKLDNNDataType::data_undef;
}
#endif
void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var,
const OpKernelType& expected_kernel_type,
const Tensor& in, Tensor* out);
std::vector<int> GetAxis(const DataLayout& from, const DataLayout& to); std::vector<int> GetAxis(const DataLayout& from, const DataLayout& to);
void TransDataLayout(const OpKernelType& kernel_type_for_var, void TransDataLayout(const OpKernelType& kernel_type_for_var,
......
...@@ -33,11 +33,34 @@ void DataTransform(const OpKernelType& expected_kernel_type, ...@@ -33,11 +33,34 @@ void DataTransform(const OpKernelType& expected_kernel_type,
Tensor in; Tensor in;
in.ShareDataWith(input_tensor); in.ShareDataWith(input_tensor);
Tensor out; Tensor out;
DataLayout lin = kernel_type_for_var.data_layout_;
DataLayout lout = expected_kernel_type.data_layout_;
// do layout transform // do layout transform
if (NeedTransformLayout(expected_kernel_type.data_layout_, if (NeedTransformLayout(lout, lin)) {
kernel_type_for_var.data_layout_)) { if (lin == DataLayout::kMKLDNN || lout == DataLayout::kMKLDNN) {
PADDLE_ENFORCE(
!(lin == DataLayout::kMKLDNN && lout == DataLayout::kMKLDNN),
"No layout transform needed between two MKLDNN OPKernels");
if (lin != DataLayout::kMKLDNN && lout == DataLayout::kMKLDNN) {
#ifdef PADDLE_WITH_MKLDNN
// Case1 - transform from Non-MKLDNN OPKernel to MKLDNN OPKernel
// Just set layout/format. No real transform occur
out.ShareDataWith(input_tensor);
out.set_layout(DataLayout::kMKLDNN);
out.set_format(ToMKLDNNFormat(lin));
#endif
} else {
// Case2 - transfrom from MKLDNN OPKernel to Non-MKLDNN OPKernel
// Do transform via MKLDNN lib
TransDataLayoutFromMKLDNN(kernel_type_for_var, expected_kernel_type, in,
&out);
}
} else {
// Case3 - transfrom between Non-MKLDNN OPKernels
TransDataLayout(kernel_type_for_var, expected_kernel_type, in, &out); TransDataLayout(kernel_type_for_var, expected_kernel_type, in, &out);
}
transformed = true; transformed = true;
PassTensorData(&out, &in); PassTensorData(&out, &in);
} }
......
...@@ -7,12 +7,13 @@ cc_library(rpc_op_handle SRCS rpc_op_handle.cc DEPS framework_proto scope place ...@@ -7,12 +7,13 @@ cc_library(rpc_op_handle SRCS rpc_op_handle.cc DEPS framework_proto scope place
cc_library(ssa_graph SRCS ssa_graph.cc DEPS var_handle op_handle_base) cc_library(ssa_graph SRCS ssa_graph.cc DEPS var_handle op_handle_base)
cc_library(ssa_graph_builder SRCS ssa_graph_builder.cc DEPS ssa_graph) cc_library(ssa_graph_builder SRCS ssa_graph_builder.cc DEPS ssa_graph)
cc_library(ssa_graph_printer SRCS ssa_graph_printer.cc DEPS ssa_graph_builder)
cc_library(variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows) cc_library(variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows)
if(WITH_GPU) if(WITH_GPU)
nv_library(nccl_all_reduce_op_handle SRCS nccl_all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory nv_library(nccl_all_reduce_op_handle SRCS nccl_all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
dynload_cuda) dynload_cuda variable_visitor)
set(multi_devices_graph_builder_deps nccl_all_reduce_op_handle) set(multi_devices_graph_builder_deps nccl_all_reduce_op_handle)
nv_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope ddim dynload_cuda) nv_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope ddim dynload_cuda)
nv_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor dynload_cuda) nv_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor dynload_cuda)
...@@ -24,10 +25,14 @@ else() ...@@ -24,10 +25,14 @@ else()
endif() endif()
cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor) cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor)
cc_library(fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base scope)
cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle
scale_loss_grad_op_handle rpc_op_handle ${multi_devices_graph_builder_deps} reduce_op_handle broadcast_op_handle) scale_loss_grad_op_handle rpc_op_handle ${multi_devices_graph_builder_deps} reduce_op_handle broadcast_op_handle)
cc_library(graph_builder_factory SRCS graph_builder_factory.cc DEPS multi_devices_graph_builder ssa_graph_printer)
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph framework_proto) cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph framework_proto)
cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
simple_threadpool device_context) simple_threadpool device_context)
......
...@@ -59,8 +59,8 @@ struct BroadcastOpHandle : public OpHandleBase { ...@@ -59,8 +59,8 @@ struct BroadcastOpHandle : public OpHandleBase {
void RunImpl() override; void RunImpl() override;
private: private:
const std::vector<Scope *> &local_scopes_; std::vector<Scope *> local_scopes_;
const std::vector<platform::Place> &places_; std::vector<platform::Place> places_;
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
const platform::NCCLContextMap *nccl_ctxs_; const platform::NCCLContextMap *nccl_ctxs_;
#endif #endif
......
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
#pragma once #pragma once
#include <string>
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
...@@ -29,6 +31,8 @@ struct BuildStrategy { ...@@ -29,6 +31,8 @@ struct BuildStrategy {
ReduceStrategy reduce_{ReduceStrategy::kAllReduce}; ReduceStrategy reduce_{ReduceStrategy::kAllReduce};
GradientScaleStrategy gradient_scale_{GradientScaleStrategy::kCoeffNumDevice}; GradientScaleStrategy gradient_scale_{GradientScaleStrategy::kCoeffNumDevice};
std::string debug_graphviz_path_{""};
}; };
} // namespace details } // namespace details
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/fuse_vars_op_handle.h"
namespace paddle {
namespace framework {
namespace details {
void FuseVarsOpHandle::RunImpl() {
WaitInputVarGenerated(place_);
auto in_var_handles = DynamicCast<VarHandle>(this->Inputs());
auto out_var_handles = DynamicCast<VarHandle>(this->Outputs());
PADDLE_ENFORCE_EQ(in_var_handles.size(), 0);
PADDLE_ENFORCE_EQ(out_var_handles.size() - 1, inputs_numel_.size(), "");
auto scope = local_scope_->FindVar(kLocalExecScopeName)->Get<Scope *>();
auto out_var_handle = out_var_handles[0];
auto out_var = scope->Var(out_var_handle->name_);
auto out_tensor = out_var->GetMutable<LoDTensor>();
out_tensor->Resize({total_numel_}).mutable_data(this->place_, type_);
int64_t s = 0;
for (size_t i = 1; i < out_var_handles.size(); ++i) {
auto out_name = out_var_handles[i]->name_;
auto out_t = scope->Var(out_name)->GetMutable<LoDTensor>();
auto numel = this->inputs_numel_.at(out_name);
out_t->ShareDataWith(out_tensor->Slice(s, s + numel));
s += numel;
}
this->RunAndRecordEvent([] {});
}
std::string FuseVarsOpHandle::Name() const { return "fuse vars"; }
} // namespace details
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <map>
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/container_cast.h"
#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/device_context.h"
namespace paddle {
namespace framework {
namespace details {
struct FuseVarsOpHandle : public OpHandleBase {
public:
FuseVarsOpHandle(Scope *local_scope, const platform::Place &place,
const std::unordered_map<std::string, int64_t> &inputs_numel,
const std::type_index &var_type)
: local_scope_(local_scope),
place_(place),
inputs_numel_(inputs_numel),
type_(var_type) {
total_numel_ = 0;
for (auto in_numel : inputs_numel) {
PADDLE_ENFORCE_GT(in_numel.second, 0);
total_numel_ += in_numel.second;
}
}
std::string Name() const override;
bool IsMultiDeviceTransfer() override { return false; };
protected:
void RunImpl() override;
private:
Scope *local_scope_;
const platform::Place place_;
const std::unordered_map<std::string, int64_t> inputs_numel_;
const std::type_index type_;
int64_t total_numel_;
};
} // namespace details
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/graph_builder_factory.h"
#include <fstream>
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
#include "paddle/fluid/framework/details/ssa_graph_printer.h"
namespace paddle {
namespace framework {
namespace details {
std::unique_ptr<SSAGraphBuilder> SSAGraphBuilderFactory::Create() {
std::unique_ptr<SSAGraphBuilder> res(
#ifdef PADDLE_WITH_CUDA
new MultiDevSSAGraphBuilder(places_, loss_var_name_, param_names_,
local_scopes_, nccl_ctxs_, strategy_)
#else
new MultiDevSSAGraphBuilder(places_, loss_var_name_, param_names_,
local_scopes_, strategy_)
#endif
); // NOLINT
if (!strategy_.debug_graphviz_path_.empty()) {
std::unique_ptr<std::ostream> fout(
new std::ofstream(strategy_.debug_graphviz_path_));
PADDLE_ENFORCE(fout->good());
std::unique_ptr<GraphvizSSAGraphPrinter> graphviz_printer(
new GraphvizSSAGraphPrinter());
res.reset(new SSAGraghBuilderWithPrinter(
std::move(fout), std::move(graphviz_printer), std::move(res)));
}
return res;
}
} // namespace details
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/build_strategy.h"
#include "paddle/fluid/framework/details/ssa_graph_builder.h"
#include "paddle/fluid/platform/place.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/nccl_helper.h"
#endif
namespace paddle {
namespace framework {
class Scope;
namespace details {
class SSAGraphBuilderFactory {
public:
SSAGraphBuilderFactory(const std::vector<platform::Place>& places,
const std::string& loss_var_name,
const std::unordered_set<std::string>& param_names,
const std::vector<Scope*>& local_scopes,
const BuildStrategy& strategy)
: places_(places),
loss_var_name_(loss_var_name),
param_names_(param_names),
local_scopes_(local_scopes),
strategy_(strategy) {}
#ifdef PADDLE_WITH_CUDA
void SetNCCLContextMap(platform::NCCLContextMap* nccl_ctxs) {
nccl_ctxs_ = nccl_ctxs;
}
#endif
std::unique_ptr<SSAGraphBuilder> Create();
private:
std::vector<platform::Place> places_;
std::string loss_var_name_;
std::unordered_set<std::string> param_names_;
std::vector<Scope*> local_scopes_;
BuildStrategy strategy_;
#ifdef PADDLE_WITH_CUDA
platform::NCCLContextMap* nccl_ctxs_;
#endif
};
} // namespace details
} // namespace framework
} // namespace paddle
...@@ -30,10 +30,6 @@ ...@@ -30,10 +30,6 @@
#include "paddle/fluid/framework/details/nccl_all_reduce_op_handle.h" #include "paddle/fluid/framework/details/nccl_all_reduce_op_handle.h"
#endif #endif
DEFINE_string(ssa_graph_path, "/tmp/ssa_graph.dot",
"the ssa graph path only print with GLOG_v=10,"
"default /tmp/graph.dot");
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
...@@ -277,11 +273,6 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -277,11 +273,6 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
*/ */
AddOutputToLeafOps(&result); AddOutputToLeafOps(&result);
if (VLOG_IS_ON(10)) {
std::ofstream fout(FLAGS_ssa_graph_path);
PrintGraphviz(*graph, fout);
}
return std::unique_ptr<SSAGraph>(graph); return std::unique_ptr<SSAGraph>(graph);
} }
...@@ -473,9 +464,8 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(SSAGraph *result, ...@@ -473,9 +464,8 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(SSAGraph *result,
void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result, void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result,
const OpDesc &op) const { const OpDesc &op) const {
auto &p = places_[0]; result->ops_.emplace_back(
auto *s = local_scopes_[0]; new RPCOpHandle(op, local_scopes_[0], op.Type(), places_[0]));
result->ops_.emplace_back(new RPCOpHandle(op, s, p, op.Type()));
if (op.Type() == "send_barrier") { if (op.Type() == "send_barrier") {
ConnectOp(result, result->ops_.back().get(), "send_vars"); ConnectOp(result, result->ops_.back().get(), "send_vars");
......
...@@ -11,10 +11,12 @@ ...@@ -11,10 +11,12 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <algorithm>
#include "paddle/fluid/framework/details/container_cast.h"
#include "paddle/fluid/framework/details/nccl_all_reduce_op_handle.h" #include "paddle/fluid/framework/details/nccl_all_reduce_op_handle.h"
#include <algorithm>
#include "paddle/fluid/framework/details/reduce_and_gather.h" #include "paddle/fluid/framework/details/reduce_and_gather.h"
#include "paddle/fluid/framework/details/variable_visitor.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -30,27 +32,34 @@ NCCLAllReduceOpHandle::NCCLAllReduceOpHandle( ...@@ -30,27 +32,34 @@ NCCLAllReduceOpHandle::NCCLAllReduceOpHandle(
} }
void NCCLAllReduceOpHandle::RunImpl() { void NCCLAllReduceOpHandle::RunImpl() {
if (inputs_.size() == 1) { if (NoDummyInputSize() == 1) {
return; // No need to all reduce when GPU count = 1; return; // No need to all reduce when GPU count = 1;
} else { } else {
// Wait input done // Wait input done
WaitInputVarGenerated(); WaitInputVarGenerated();
auto in_var_handles = DynamicCast<VarHandle>(this->Inputs());
auto &var_name = static_cast<VarHandle *>(this->inputs_[0])->name_; auto out_var_handles = DynamicCast<VarHandle>(this->Outputs());
int dtype = -1; PADDLE_ENFORCE_EQ(
size_t numel = 0; in_var_handles.size(), places_.size(),
"The NoDummyInputSize should be equal to the number of places.");
PADDLE_ENFORCE_EQ(
in_var_handles.size(), out_var_handles.size(),
"The NoDummyInputSize and NoDummyOutputSize should be equal.");
std::vector<const LoDTensor *> lod_tensors; std::vector<const LoDTensor *> lod_tensors;
for (size_t i = 0; i < local_scopes_.size(); ++i) { for (size_t i = 0; i < local_scopes_.size(); ++i) {
auto *s = local_scopes_[i]; auto *s = local_scopes_[i];
auto &local_scope = *s->FindVar(kLocalExecScopeName)->Get<Scope *>(); auto &local_scope = *s->FindVar(kLocalExecScopeName)->Get<Scope *>();
auto &lod_tensor =
auto &lod_tensor = local_scope.FindVar(var_name)->Get<LoDTensor>(); local_scope.FindVar(in_var_handles[i]->name_)->Get<LoDTensor>();
lod_tensors.emplace_back(&lod_tensor); lod_tensors.emplace_back(&lod_tensor);
PADDLE_ENFORCE_EQ(in_var_handles[i]->name_, out_var_handles[i]->name_,
"The name of input and output should be equal.");
} }
if (platform::is_gpu_place(lod_tensors[0]->place())) { if (platform::is_gpu_place(lod_tensors[0]->place())) {
int dtype = -1;
size_t numel = 0;
std::vector<std::function<void()>> all_reduce_calls; std::vector<std::function<void()>> all_reduce_calls;
for (size_t i = 0; i < local_scopes_.size(); ++i) { for (size_t i = 0; i < local_scopes_.size(); ++i) {
auto &p = places_[i]; auto &p = places_[i];
...@@ -96,7 +105,7 @@ void NCCLAllReduceOpHandle::RunImpl() { ...@@ -96,7 +105,7 @@ void NCCLAllReduceOpHandle::RunImpl() {
auto &scope = auto &scope =
*local_scopes_[i]->FindVar(kLocalExecScopeName)->Get<Scope *>(); *local_scopes_[i]->FindVar(kLocalExecScopeName)->Get<Scope *>();
auto &p = places_[i]; auto &p = places_[i];
auto *var = scope.FindVar(var_name); auto *var = scope.FindVar(in_var_handles[i]->name_);
auto *dev_ctx = dev_ctxes_[p]; auto *dev_ctx = dev_ctxes_[p];
RunAndRecordEvent(p, [&trg, var, dev_ctx, p] { RunAndRecordEvent(p, [&trg, var, dev_ctx, p] {
......
...@@ -41,8 +41,8 @@ struct NCCLAllReduceOpHandle : public OpHandleBase { ...@@ -41,8 +41,8 @@ struct NCCLAllReduceOpHandle : public OpHandleBase {
void RunImpl() override; void RunImpl() override;
private: private:
const std::vector<Scope *> &local_scopes_; std::vector<Scope *> local_scopes_;
const std::vector<platform::Place> &places_; std::vector<platform::Place> places_;
const platform::NCCLContextMap &nccl_ctxs_; const platform::NCCLContextMap &nccl_ctxs_;
}; };
......
...@@ -104,6 +104,16 @@ void OpHandleBase::WaitInputVarGenerated(const platform::Place &place) { ...@@ -104,6 +104,16 @@ void OpHandleBase::WaitInputVarGenerated(const platform::Place &place) {
} }
} }
size_t OpHandleBase::NoDummyInputSize() const {
size_t cnt = 0;
for (auto *in : inputs_) {
if (dynamic_cast<DummyVarHandle *>(in) == nullptr) {
++cnt;
}
}
return cnt;
}
bool OpHandleBase::NeedWait(VarHandleBase *in_var) { bool OpHandleBase::NeedWait(VarHandleBase *in_var) {
return in_var && in_var->generated_op_; return in_var && in_var->generated_op_;
} }
......
...@@ -80,6 +80,8 @@ class OpHandleBase { ...@@ -80,6 +80,8 @@ class OpHandleBase {
const std::vector<VarHandleBase *> &Outputs() const { return outputs_; } const std::vector<VarHandleBase *> &Outputs() const { return outputs_; }
size_t NoDummyInputSize() const;
protected: protected:
void RunAndRecordEvent(const std::function<void()> &callback); void RunAndRecordEvent(const std::function<void()> &callback);
......
...@@ -32,8 +32,8 @@ namespace framework { ...@@ -32,8 +32,8 @@ namespace framework {
namespace details { namespace details {
struct ReduceOpHandle : public OpHandleBase { struct ReduceOpHandle : public OpHandleBase {
const std::vector<Scope *> &local_scopes_; std::vector<Scope *> local_scopes_;
const std::vector<platform::Place> &places_; std::vector<platform::Place> places_;
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
const platform::NCCLContextMap *nccl_ctxs_; const platform::NCCLContextMap *nccl_ctxs_;
......
...@@ -19,12 +19,12 @@ namespace framework { ...@@ -19,12 +19,12 @@ namespace framework {
namespace details { namespace details {
RPCOpHandle::RPCOpHandle(const framework::OpDesc &op_desc, RPCOpHandle::RPCOpHandle(const framework::OpDesc &op_desc,
const Scope *local_scope, const platform::Place &place, const Scope *local_scope, const std::string &name,
const std::string &name) const platform::Place &place)
: op_(framework::OpRegistry::CreateOp(op_desc)), : op_(framework::OpRegistry::CreateOp(op_desc)),
local_scope_(local_scope), local_scope_(local_scope),
place_(place), name_(name),
name_(name) {} place_(place) {}
void RPCOpHandle::RunImpl() { void RPCOpHandle::RunImpl() {
// TODO(wuyi): need further analysis whether wait VarDummyHandle. // TODO(wuyi): need further analysis whether wait VarDummyHandle.
......
...@@ -29,7 +29,7 @@ namespace details { ...@@ -29,7 +29,7 @@ namespace details {
struct RPCOpHandle : public OpHandleBase { struct RPCOpHandle : public OpHandleBase {
RPCOpHandle(const framework::OpDesc& op_desc, const Scope* local_scope, RPCOpHandle(const framework::OpDesc& op_desc, const Scope* local_scope,
const platform::Place& place, const std::string& name); const std::string& name, const platform::Place& place);
std::string Name() const override; std::string Name() const override;
...@@ -43,8 +43,8 @@ struct RPCOpHandle : public OpHandleBase { ...@@ -43,8 +43,8 @@ struct RPCOpHandle : public OpHandleBase {
private: private:
std::unique_ptr<OperatorBase> op_; std::unique_ptr<OperatorBase> op_;
const Scope* local_scope_; const Scope* local_scope_;
const platform::Place& place_;
const std::string name_; const std::string name_;
platform::Place place_;
}; };
} // namespace details } // namespace details
......
...@@ -73,64 +73,6 @@ void SSAGraphBuilder::CreateOpOutput(SSAGraph *graph, OpHandleBase *op_handle, ...@@ -73,64 +73,6 @@ void SSAGraphBuilder::CreateOpOutput(SSAGraph *graph, OpHandleBase *op_handle,
op_handle->AddOutput(var); op_handle->AddOutput(var);
} }
template <typename Callback>
void IterAllVar(const SSAGraph &graph, Callback callback) {
for (auto &each : graph.vars_) {
for (auto &pair1 : each) {
for (auto &pair2 : pair1.second) {
callback(*pair2);
}
}
}
for (auto &var : graph.dep_vars_) {
callback(*var);
}
}
void SSAGraphBuilder::PrintGraphviz(const SSAGraph &graph, std::ostream &sout) {
size_t var_id = 0;
std::unordered_map<const VarHandleBase *, size_t> vars;
sout << "digraph G {\n";
IterAllVar(graph, [&](const VarHandleBase &var) {
auto *var_ptr = &var;
auto *var_handle_ptr = dynamic_cast<const VarHandle *>(var_ptr);
auto *dummy_ptr = dynamic_cast<const DummyVarHandle *>(var_ptr);
size_t cur_var_id = var_id++;
vars[var_ptr] = cur_var_id;
if (var_handle_ptr) {
sout << "var_" << cur_var_id << " [label=\"" << var_handle_ptr->name_
<< "\\n"
<< var_handle_ptr->place_ << "\\n"
<< var_handle_ptr->version_ << "\"]" << std::endl;
} else if (dummy_ptr) {
sout << "var_" << cur_var_id << " [label=\"dummy\"]" << std::endl;
}
});
size_t op_id = 0;
for (auto &op : graph.ops_) {
std::string op_name = "op_" + std::to_string(op_id++);
sout << op_name << " [label=\"" << op->Name() << "\", shape=rect]"
<< std::endl;
for (auto in : op->Inputs()) {
std::string var_name = "var_" + std::to_string(vars[in]);
sout << var_name << " -> " << op_name << std::endl;
}
for (auto out : op->Outputs()) {
std::string var_name = "var_" + std::to_string(vars[out]);
sout << op_name << " -> " << var_name << std::endl;
}
}
sout << "}\n";
}
void SSAGraphBuilder::AddOutputToLeafOps(SSAGraph *graph) { void SSAGraphBuilder::AddOutputToLeafOps(SSAGraph *graph) {
for (auto &op : graph->ops_) { for (auto &op : graph->ops_) {
if (!op->Outputs().empty()) { if (!op->Outputs().empty()) {
......
...@@ -55,8 +55,6 @@ class SSAGraphBuilder { ...@@ -55,8 +55,6 @@ class SSAGraphBuilder {
const platform::Place &place, size_t place_offset); const platform::Place &place, size_t place_offset);
static void AddOutputToLeafOps(SSAGraph *graph); static void AddOutputToLeafOps(SSAGraph *graph);
static void PrintGraphviz(const SSAGraph &graph, std::ostream &sout);
}; };
} // namespace details } // namespace details
} // namespace framework } // namespace framework
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/ssa_graph_printer.h"
#include <string>
#include "paddle/fluid/framework/details/ssa_graph.h"
namespace paddle {
namespace framework {
namespace details {
template <typename Callback>
static inline void IterAllVar(const SSAGraph &graph, Callback callback) {
for (auto &each : graph.vars_) {
for (auto &pair1 : each) {
for (auto &pair2 : pair1.second) {
callback(*pair2);
}
}
}
for (auto &var : graph.dep_vars_) {
callback(*var);
}
}
void GraphvizSSAGraphPrinter::Print(const SSAGraph &graph,
std::ostream &sout) const {
size_t var_id = 0;
std::unordered_map<const VarHandleBase *, size_t> vars;
sout << "digraph G {\n";
IterAllVar(graph, [&](const VarHandleBase &var) {
auto *var_ptr = &var;
auto *var_handle_ptr = dynamic_cast<const VarHandle *>(var_ptr);
auto *dummy_ptr = dynamic_cast<const DummyVarHandle *>(var_ptr);
size_t cur_var_id = var_id++;
vars[var_ptr] = cur_var_id;
if (var_handle_ptr) {
sout << "var_" << cur_var_id << " [label=\"" << var_handle_ptr->name_
<< "\\n"
<< var_handle_ptr->place_ << "\\n"
<< var_handle_ptr->version_ << "\"]" << std::endl;
} else if (dummy_ptr) {
sout << "var_" << cur_var_id << " [label=\"dummy\"]" << std::endl;
}
});
size_t op_id = 0;
for (auto &op : graph.ops_) {
std::string op_name = "op_" + std::to_string(op_id++);
sout << op_name << " [label=\"" << op->Name() << "\", shape=rect]"
<< std::endl;
for (auto in : op->Inputs()) {
std::string var_name = "var_" + std::to_string(vars[in]);
sout << var_name << " -> " << op_name << std::endl;
}
for (auto out : op->Outputs()) {
std::string var_name = "var_" + std::to_string(vars[out]);
sout << op_name << " -> " << var_name << std::endl;
}
}
sout << "}\n";
}
} // namespace details
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <iosfwd>
#include "paddle/fluid/framework/details/ssa_graph_builder.h"
namespace paddle {
namespace framework {
namespace details {
struct SSAGraph;
class SSAGraphPrinter {
public:
virtual ~SSAGraphPrinter() {}
virtual void Print(const SSAGraph& graph, std::ostream& sout) const = 0;
};
class GraphvizSSAGraphPrinter : public SSAGraphPrinter {
public:
void Print(const SSAGraph& graph, std::ostream& sout) const override;
};
class SSAGraghBuilderWithPrinter : public SSAGraphBuilder {
public:
SSAGraghBuilderWithPrinter(std::ostream& sout,
std::unique_ptr<SSAGraphPrinter>&& printer,
std::unique_ptr<SSAGraphBuilder>&& builder)
: printer_(std::move(printer)),
builder_(std::move(builder)),
stream_ref_(sout) {}
SSAGraghBuilderWithPrinter(std::unique_ptr<std::ostream>&& sout,
std::unique_ptr<SSAGraphPrinter>&& printer,
std::unique_ptr<SSAGraphBuilder>&& builder)
: printer_(std::move(printer)),
builder_(std::move(builder)),
stream_ptr_(std::move(sout)),
stream_ref_(*stream_ptr_) {}
std::unique_ptr<SSAGraph> Build(const ProgramDesc& program) const override {
auto graph = builder_->Build(program);
printer_->Print(*graph, stream_ref_);
return graph;
}
private:
std::unique_ptr<SSAGraphPrinter> printer_;
std::unique_ptr<SSAGraphBuilder> builder_;
std::unique_ptr<std::ostream> stream_ptr_;
std::ostream& stream_ref_;
};
} // namespace details
} // namespace framework
} // namespace paddle
...@@ -87,7 +87,14 @@ inline std::string KernelTypeToString(const OpKernelType& kernel_key) { ...@@ -87,7 +87,14 @@ inline std::string KernelTypeToString(const OpKernelType& kernel_key) {
} }
inline bool NeedTransformLayout(const DataLayout& l, const DataLayout& r) { inline bool NeedTransformLayout(const DataLayout& l, const DataLayout& r) {
return l != DataLayout::kAnyLayout && r != DataLayout::kAnyLayout && l != r; bool ret =
(l != DataLayout::kAnyLayout && r != DataLayout::kAnyLayout && l != r);
#ifdef PADDLE_WITH_MKLDNN
// Layout transform needed for either non-MKLDNN to MKLDNN or vice versa
ret |= (l != DataLayout::kMKLDNN && r == DataLayout::kMKLDNN);
ret |= (l == DataLayout::kMKLDNN && r != DataLayout::kMKLDNN);
#endif
return ret;
} }
inline bool TransFromNeeded(const OpKernelType& l, const OpKernelType& r) { inline bool TransFromNeeded(const OpKernelType& l, const OpKernelType& r) {
......
...@@ -83,8 +83,14 @@ struct OpKernelRegistrarFunctor<PlaceType, false, I, KernelTypes...> { ...@@ -83,8 +83,14 @@ struct OpKernelRegistrarFunctor<PlaceType, false, I, KernelTypes...> {
void operator()(const char* op_type, const char* library_type) const { void operator()(const char* op_type, const char* library_type) const {
using T = typename KERNEL_TYPE::ELEMENT_TYPE; using T = typename KERNEL_TYPE::ELEMENT_TYPE;
std::string library(library_type);
std::string data_layout = "ANYLAYOUT";
if (library == "MKLDNN") {
data_layout = "MKLDNNLAYOUT";
}
OpKernelType key(ToDataType(std::type_index(typeid(T))), PlaceType(), OpKernelType key(ToDataType(std::type_index(typeid(T))), PlaceType(),
DataLayout::kAnyLayout, StringToLibraryType(library_type)); StringToDataLayout(data_layout),
StringToLibraryType(library_type));
OperatorWithKernel::AllOpKernels()[op_type][key].reset(new KERNEL_TYPE); OperatorWithKernel::AllOpKernels()[op_type][key].reset(new KERNEL_TYPE);
constexpr auto size = std::tuple_size<std::tuple<KernelTypes...>>::value; constexpr auto size = std::tuple_size<std::tuple<KernelTypes...>>::value;
...@@ -99,7 +105,8 @@ struct OpKernelRegistrarFunctor<PlaceType, true, I, KernelType...> { ...@@ -99,7 +105,8 @@ struct OpKernelRegistrarFunctor<PlaceType, true, I, KernelType...> {
void operator()(const char* op_type, const char* library_type) const {} void operator()(const char* op_type, const char* library_type) const {}
}; };
// User can register many kernel in one place. The data type could be different. // User can register many kernel in one place. The data type could be
// different.
template <typename PlaceType, typename... KernelType> template <typename PlaceType, typename... KernelType>
class OpKernelRegistrar : public Registrar { class OpKernelRegistrar : public Registrar {
public: public:
...@@ -149,15 +156,15 @@ class OpKernelRegistrar : public Registrar { ...@@ -149,15 +156,15 @@ class OpKernelRegistrar : public Registrar {
/** /**
* Macro to register OperatorKernel. * Macro to register OperatorKernel.
*/ */
#define REGISTER_OP_KERNEL(op_type, LIBRARY_TYPE, place_class, ...) \ #define REGISTER_OP_KERNEL(op_type, library_type, place_class, ...) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \ STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_op_kernel_##op_type##_##LIBRARY_TYPE##__, \ __reg_op_kernel_##op_type##_##library_type##__, \
"REGISTER_OP_KERNEL must be called in global namespace"); \ "REGISTER_OP_KERNEL must be called in global namespace"); \
static ::paddle::framework::OpKernelRegistrar<place_class, __VA_ARGS__> \ static ::paddle::framework::OpKernelRegistrar<place_class, __VA_ARGS__> \
__op_kernel_registrar_##op_type##_##LIBRARY_TYPE##__(#op_type, \ __op_kernel_registrar_##op_type##_##library_type##__(#op_type, \
#LIBRARY_TYPE); \ #library_type); \
int TouchOpKernelRegistrar_##op_type##_##LIBRARY_TYPE() { \ int TouchOpKernelRegistrar_##op_type##_##library_type() { \
__op_kernel_registrar_##op_type##_##LIBRARY_TYPE##__.Touch(); \ __op_kernel_registrar_##op_type##_##library_type##__.Touch(); \
return 0; \ return 0; \
} }
......
...@@ -293,6 +293,38 @@ static Tensor* GetMutableTensorFromVar(Variable* var) { ...@@ -293,6 +293,38 @@ static Tensor* GetMutableTensorFromVar(Variable* var) {
} }
} }
bool ExecutionContext::HasInput(const std::string& name) const {
if (!op_.HasInputs(name)) {
return false;
}
auto& ins = Inputs(name);
size_t length = ins.size();
if (length == 0) {
return false;
}
PADDLE_ENFORCE_EQ(length, 1UL,
"Input %s should not have more than one inputs", name);
auto arg = ins[0];
auto* var = arg == kEmptyVarName ? nullptr : scope_.FindVar(arg);
return var != nullptr;
}
bool ExecutionContext::HasOutput(const std::string& name) const {
if (!op_.HasOutputs(name)) {
return false;
}
auto& outs = Outputs(name);
size_t length = outs.size();
if (length == 0) {
return false;
}
PADDLE_ENFORCE_EQ(length, 1UL,
"Output %s should not have more than one inputs", name);
auto arg = outs[0];
auto* var = arg == kEmptyVarName ? nullptr : scope_.FindVar(arg);
return var != nullptr;
}
template <> template <>
const Tensor* ExecutionContext::Input<Tensor>(const std::string& name) const { const Tensor* ExecutionContext::Input<Tensor>(const std::string& name) const {
auto* var = InputVar(name); auto* var = InputVar(name);
...@@ -444,9 +476,24 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -444,9 +476,24 @@ class RuntimeInferShapeContext : public InferShapeContext {
auto* out_tensor = out_var->GetMutable<LoDTensor>(); auto* out_tensor = out_var->GetMutable<LoDTensor>();
out_tensor->set_lod(in_tensor.lod()); out_tensor->set_lod(in_tensor.lod());
// TODO(dzhwinter) : reuse ShareLoD in most operators. // TODO(dzhwinter) : reuse ShareLoD in most operators.
// Need to call ShareLayout explicitly in sequence related ops. // Need to call ShareLayout explicitly in sequence related ops.
// Shall we have a better method to shared info between in/out Tensor? // Shall we have a better method to shared info between in/out Tensor?
#ifdef PADDLE_WITH_MKLDNN
// Fix me: ugly workaround below
// Correct solution:
// set_layout() should NOT be called here (i.e. ShareLoD). Instead,
// layout of output tensor should be set "manually" in Compute()
// of each OPKernel. The reason layout should NOT be shared between
// input and output "automatically" (now by InferShape()->ShareLoD())
// is that layout transform may occur after InferShape().
// Workaround:
// Skip set_layout() when input layout is kMKLDNN
// This is to avoid kMKLDNN is populated wrongly into a non-MKLDNN
// OPKernel. In all MKLDNN OPkernel, set_layout(kMKLDNN) should be called
// in Compute()
if (in_tensor.layout() != DataLayout::kMKLDNN)
#endif
out_tensor->set_layout(in_tensor.layout()); out_tensor->set_layout(in_tensor.layout());
} }
...@@ -646,8 +693,10 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType( ...@@ -646,8 +693,10 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType(
} }
if (t != nullptr) { if (t != nullptr) {
int tmp = static_cast<int>(ToDataType(t->type())); int tmp = static_cast<int>(ToDataType(t->type()));
PADDLE_ENFORCE(tmp == data_type || data_type == -1, PADDLE_ENFORCE(
"DataType of Paddle Op %s must be the same.", Type()); tmp == data_type || data_type == -1,
"DataType of Paddle Op %s must be the same. Get %d != %d", Type(),
data_type, tmp);
data_type = tmp; data_type = tmp;
} }
} }
...@@ -665,7 +714,8 @@ OpKernelType OperatorWithKernel::GetExpectedKernelType( ...@@ -665,7 +714,8 @@ OpKernelType OperatorWithKernel::GetExpectedKernelType(
OpKernelType OperatorWithKernel::GetKernelTypeForVar( OpKernelType OperatorWithKernel::GetKernelTypeForVar(
const std::string& var_name, const Tensor& tensor, const std::string& var_name, const Tensor& tensor,
const OpKernelType& expected_kernel_type) const { const OpKernelType& expected_kernel_type) const {
return OpKernelType(expected_kernel_type.data_type_, tensor.place()); return OpKernelType(expected_kernel_type.data_type_, tensor.place(),
tensor.layout());
} }
} // namespace framework } // namespace framework
......
...@@ -191,9 +191,9 @@ class ExecutionContext { ...@@ -191,9 +191,9 @@ class ExecutionContext {
return op_.Attr<T>(name); return op_.Attr<T>(name);
} }
bool HasInput(const std::string& name) const { return op_.HasInputs(name); } bool HasInput(const std::string& name) const;
bool HasOutput(const std::string& name) const { return op_.HasOutputs(name); } bool HasOutput(const std::string& name) const;
size_t InputSize(const std::string& name) const { size_t InputSize(const std::string& name) const {
return op_.Inputs(name).size(); return op_.Inputs(name).size();
......
...@@ -22,7 +22,7 @@ limitations under the License. */ ...@@ -22,7 +22,7 @@ limitations under the License. */
#include "paddle/fluid/platform/nccl_helper.h" #include "paddle/fluid/platform/nccl_helper.h"
#endif #endif
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h" #include "paddle/fluid/framework/details/graph_builder_factory.h"
#include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h" #include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h" #include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
...@@ -102,22 +102,19 @@ ParallelExecutor::ParallelExecutor( ...@@ -102,22 +102,19 @@ ParallelExecutor::ParallelExecutor(
var_infos.back().persistable_ = var->Persistable(); var_infos.back().persistable_ = var->Persistable();
} }
// Step 3. Convert main_program to SSA form and dependency graph. Also, insert // Step 3. Convert main_program to SSA form and dependency graph. Also, insert
// ncclOp // ncclOp
#ifdef PADDLE_WITH_CUDA
details::MultiDevSSAGraphBuilder builder( details::SSAGraphBuilderFactory builder_factory(
member_->places_, loss_var_name, params, member_->local_scopes_, member_->places_, loss_var_name, params, member_->local_scopes_,
member_->nccl_ctxs_.get(), build_strategy);
#else
details::MultiDevSSAGraphBuilder builder(member_->places_, loss_var_name,
params, member_->local_scopes_,
build_strategy); build_strategy);
#ifdef PADDLE_WITH_CUDA
builder_factory.SetNCCLContextMap(member_->nccl_ctxs_.get());
#endif #endif
auto graph = builder.Build(main_program);
member_->executor_.reset(new details::ThreadedSSAGraphExecutor( member_->executor_.reset(new details::ThreadedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, places, std::move(graph))); exec_strategy, member_->local_scopes_, places,
builder_factory.Create()->Build(main_program)));
member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor( member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, std::move(var_infos), exec_strategy, member_->local_scopes_, std::move(var_infos),
......
...@@ -34,13 +34,7 @@ DEFINE_bool( ...@@ -34,13 +34,7 @@ DEFINE_bool(
namespace paddle { namespace paddle {
namespace framework { namespace framework {
Scope::~Scope() { Scope::~Scope() { DropKids(); }
DropKids();
for (auto& kv : vars_) {
VLOG(3) << "Destroy variable " << kv.first;
delete kv.second;
}
}
Scope& Scope::NewScope() const { Scope& Scope::NewScope() const {
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
...@@ -49,10 +43,13 @@ Scope& Scope::NewScope() const { ...@@ -49,10 +43,13 @@ Scope& Scope::NewScope() const {
} }
Variable* Scope::Var(const std::string& name) { Variable* Scope::Var(const std::string& name) {
// acquire the lock when new var under this scope
std::unique_lock<std::mutex> lock(mutex_);
auto* v = FindVarLocally(name); auto* v = FindVarLocally(name);
if (v != nullptr) return v; if (v != nullptr) return v;
v = new Variable(); v = new Variable();
vars_[name] = v; vars_[name].reset(v);
VLOG(3) << "Create variable " << name; VLOG(3) << "Create variable " << name;
v->name_ = &(vars_.find(name)->first); v->name_ = &(vars_.find(name)->first);
return v; return v;
...@@ -67,22 +64,29 @@ Variable* Scope::Var(std::string* name) { ...@@ -67,22 +64,29 @@ Variable* Scope::Var(std::string* name) {
} }
Variable* Scope::FindVar(const std::string& name) const { Variable* Scope::FindVar(const std::string& name) const {
// acquire the lock when find var
std::unique_lock<std::mutex> lock(mutex_);
return FindVarInternal(name);
}
Variable* Scope::FindVarInternal(const std::string& name) const {
auto var = FindVarLocally(name); auto var = FindVarLocally(name);
if (var != nullptr) { if (var != nullptr) {
return var; return var;
} }
return (parent_ == nullptr) ? nullptr : parent_->FindVar(name); return (parent_ == nullptr) ? nullptr : parent_->FindVarInternal(name);
} }
const Scope* Scope::FindScope(const Variable* var) const { const Scope* Scope::FindScope(const Variable* var) const {
for (auto& kv : vars_) { for (auto& kv : vars_) {
if (kv.second == var) { if (kv.second.get() == var) {
return this; return this;
} }
} }
return (parent_ == nullptr) ? nullptr : parent_->FindScope(var); return (parent_ == nullptr) ? nullptr : parent_->FindScope(var);
} }
void Scope::DropKids() { void Scope::DropKids() {
std::unique_lock<std::mutex> lock(mutex_);
for (Scope* s : kids_) delete s; for (Scope* s : kids_) delete s;
kids_.clear(); kids_.clear();
} }
...@@ -110,10 +114,10 @@ void Scope::DeleteScope(Scope* scope) const { ...@@ -110,10 +114,10 @@ void Scope::DeleteScope(Scope* scope) const {
} }
void Scope::EraseVars(const std::vector<std::string>& var_names) { void Scope::EraseVars(const std::vector<std::string>& var_names) {
std::unique_lock<std::mutex> lock(mutex_);
std::set<std::string> var_set(var_names.begin(), var_names.end()); std::set<std::string> var_set(var_names.begin(), var_names.end());
for (auto it = vars_.begin(); it != vars_.end();) { for (auto it = vars_.begin(); it != vars_.end();) {
if (var_set.find(it->first) != var_set.end()) { if (var_set.find(it->first) != var_set.end()) {
delete it->second;
it = vars_.erase(it); it = vars_.erase(it);
} else { } else {
++it; ++it;
...@@ -129,7 +133,7 @@ void Scope::Rename(const std::string& origin_name, ...@@ -129,7 +133,7 @@ void Scope::Rename(const std::string& origin_name,
auto new_it = vars_.find(new_name); auto new_it = vars_.find(new_name);
PADDLE_ENFORCE(new_it == vars_.end(), PADDLE_ENFORCE(new_it == vars_.end(),
"The variable with name %s is already in the scope", new_name); "The variable with name %s is already in the scope", new_name);
vars_[new_name] = origin_it->second; vars_[new_name].reset(origin_it->second.release());
vars_.erase(origin_it); vars_.erase(origin_it);
} }
...@@ -141,7 +145,7 @@ std::string Scope::Rename(const std::string& origin_name) const { ...@@ -141,7 +145,7 @@ std::string Scope::Rename(const std::string& origin_name) const {
Variable* Scope::FindVarLocally(const std::string& name) const { Variable* Scope::FindVarLocally(const std::string& name) const {
auto it = vars_.find(name); auto it = vars_.find(name);
if (it != vars_.end()) return it->second; if (it != vars_.end()) return it->second.get();
return nullptr; return nullptr;
} }
......
...@@ -47,15 +47,18 @@ class Scope { ...@@ -47,15 +47,18 @@ class Scope {
Scope& NewScope() const; Scope& NewScope() const;
/// Create a variable with given name if it doesn't exist. /// Create a variable with given name if it doesn't exist.
/// Caller doesn't own the returned Variable.
Variable* Var(const std::string& name); Variable* Var(const std::string& name);
/// Create a variable with a scope-unique name. /// Create a variable with a scope-unique name.
/// Caller doesn't own the returned Variable.
Variable* Var(std::string* name = nullptr); Variable* Var(std::string* name = nullptr);
void EraseVars(const std::vector<std::string>& var_names); void EraseVars(const std::vector<std::string>& var_names);
/// Find a variable in the scope or any of its ancestors. Returns /// Find a variable in the scope or any of its ancestors. Returns
/// nullptr if cannot find. /// nullptr if cannot find.
/// Caller doesn't own the returned Variable.
Variable* FindVar(const std::string& name) const; Variable* FindVar(const std::string& name) const;
const Scope* parent() const { return parent_; } const Scope* parent() const { return parent_; }
...@@ -78,13 +81,21 @@ class Scope { ...@@ -78,13 +81,21 @@ class Scope {
// Rename variable to a new name and return the new name // Rename variable to a new name and return the new name
std::string Rename(const std::string& origin_name) const; std::string Rename(const std::string& origin_name) const;
Variable* FindVarLocally(const std::string& name) const;
private: private:
// Call Scope::NewScope for a sub-scope. // Call Scope::NewScope for a sub-scope.
explicit Scope(Scope const* parent) : parent_(parent) {} explicit Scope(Scope const* parent) : parent_(parent) {}
mutable std::unordered_map<std::string, Variable*> vars_; // Called by FindVar recursively.
// Caller doesn't own the returned Variable.
Variable* FindVarInternal(const std::string& name) const;
// Called by FindVarInternal and Var.
// Caller doesn't own the returned Variable.
Variable* FindVarLocally(const std::string& name) const;
mutable std::unordered_map<std::string, std::unique_ptr<Variable>> vars_;
// Scope in `kids_` are owned by this class.
mutable std::list<Scope*> kids_; mutable std::list<Scope*> kids_;
Scope const* parent_{nullptr}; Scope const* parent_{nullptr};
......
...@@ -15,5 +15,102 @@ limitations under the License. */ ...@@ -15,5 +15,102 @@ limitations under the License. */
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
namespace paddle { namespace paddle {
namespace framework {} namespace framework {
extern size_t SizeOfType(std::type_index type);
void Tensor::check_memory_size() const {
PADDLE_ENFORCE_NOT_NULL(
holder_, "Tensor holds no memory. Call Tensor::mutable_data first.");
PADDLE_ENFORCE_LE(
numel() * SizeOfType(type()), memory_size(),
"Tensor's dims_ is out of bound. Call Tensor::mutable_data "
"first to re-allocate memory.\n"
"or maybe the required data-type mismatches the data already stored.");
}
size_t Tensor::memory_size() const {
return holder_ == nullptr ? 0UL : holder_->size() - offset_;
}
void* Tensor::mutable_data(platform::Place place, std::type_index type) {
if (holder_ != nullptr) {
holder_->set_type(type);
}
PADDLE_ENFORCE_GE(numel(), 0,
"When calling this method, the Tensor's numel must be "
"equal or larger than zero. "
"Please check Tensor::Resize has been called first.");
int64_t size = numel() * SizeOfType(type);
/* some versions of boost::variant don't have operator!= */
if (holder_ == nullptr || !(holder_->place() == place) ||
holder_->size() < size + offset_) {
if (platform::is_cpu_place(place)) {
holder_.reset(new PlaceholderImpl<platform::CPUPlace>(
boost::get<platform::CPUPlace>(place), size, type));
} else if (platform::is_gpu_place(place) ||
platform::is_cuda_pinned_place(place)) {
#ifndef PADDLE_WITH_CUDA
PADDLE_THROW(
"CUDAPlace or CUDAPinnedPlace is not supported in CPU-only mode.");
}
#else
if (platform::is_gpu_place(place)) {
holder_.reset(new PlaceholderImpl<platform::CUDAPlace>(
boost::get<platform::CUDAPlace>(place), size, type));
} else if (platform::is_cuda_pinned_place(place)) {
holder_.reset(new PlaceholderImpl<platform::CUDAPinnedPlace>(
boost::get<platform::CUDAPinnedPlace>(place), size, type));
}
}
#endif
offset_ = 0;
}
return reinterpret_cast<void*>(reinterpret_cast<uintptr_t>(holder_->ptr()) +
offset_);
}
void* Tensor::mutable_data(platform::Place place) {
PADDLE_ENFORCE(this->holder_ != nullptr,
"Cannot invoke mutable data if current hold nothing.");
return mutable_data(place, holder_->type());
}
Tensor& Tensor::ShareDataWith(const Tensor& src) {
src.check_memory_size();
*this = src;
return *this;
}
Tensor Tensor::Slice(int begin_idx, int end_idx) const {
check_memory_size();
PADDLE_ENFORCE_GE(begin_idx, 0,
"The start row index must be greater than 0.");
PADDLE_ENFORCE_LE(end_idx, dims_[0], "The end row index is out of bound.");
PADDLE_ENFORCE_LT(
begin_idx, end_idx,
"The start row index must be lesser than the end row index.");
if (dims_[0] == 1) {
return *this;
} else {
size_t base = numel() / dims_[0];
Tensor dst;
dst.holder_ = holder_;
dst.set_layout(layout_);
DDim dst_dims = dims_;
dst_dims[0] = end_idx - begin_idx;
dst.Resize(dst_dims);
dst.offset_ = offset_ + begin_idx * base * SizeOfType(type());
return dst;
}
}
Tensor& Tensor::Resize(const DDim& dims) {
dims_ = dims;
return *this;
}
const DDim& Tensor::dims() const { return dims_; }
int64_t Tensor::numel() const { return product(dims_); }
} // namespace framework
} // namespace paddle } // namespace paddle
...@@ -34,6 +34,28 @@ namespace framework { ...@@ -34,6 +34,28 @@ namespace framework {
class LoDTensor; class LoDTensor;
class Tensor { class Tensor {
#ifdef PADDLE_WITH_MKLDNN
public:
inline mkldnn::memory::format format() const { return format_; }
inline void set_format(const mkldnn::memory::format format) {
format_ = format;
}
protected:
/**
* @brief the detail format of memory block which have layout as kMKLDNN
*
* @note MKLDNN lib support various memory format like nchw, nhwc, nChw8C,
* nChw16c, etc. For a MKLDNN memory block, layout will be set as
* DataLayout::kMKLDNN meanwhile detail memory format will be kept in
* this field.
*/
mkldnn::memory::format format_ = mkldnn::memory::format::format_undef;
#endif
public: public:
template <typename T, size_t D, int MajorType, typename IndexType> template <typename T, size_t D, int MajorType, typename IndexType>
friend struct EigenTensor; friend struct EigenTensor;
...@@ -54,26 +76,24 @@ class Tensor { ...@@ -54,26 +76,24 @@ class Tensor {
/*! Return a pointer to mutable memory block. */ /*! Return a pointer to mutable memory block. */
template <typename T> template <typename T>
inline T* data(); T* data();
/*! Return a pointer to constant memory block. */ /*! Return a pointer to constant memory block. */
template <typename T> template <typename T>
inline const T* data() const; const T* data() const;
inline bool IsInitialized() const; bool IsInitialized() const;
inline void switch_place(platform::Place new_place);
/** /**
* @brief Return a pointer to mutable memory block. * @brief Return a pointer to mutable memory block.
* @note If not exist, then allocation. * @note If not exist, then allocation.
*/ */
template <typename T> template <typename T>
inline T* mutable_data(platform::Place place); T* mutable_data(platform::Place place);
inline void* mutable_data(platform::Place place, std::type_index type); void* mutable_data(platform::Place place, std::type_index type);
inline void* mutable_data(platform::Place place); void* mutable_data(platform::Place place);
/** /**
* @brief Return a pointer to mutable memory block. * @brief Return a pointer to mutable memory block.
...@@ -84,19 +104,19 @@ class Tensor { ...@@ -84,19 +104,19 @@ class Tensor {
* @note If not exist, then allocation. * @note If not exist, then allocation.
*/ */
template <typename T> template <typename T>
inline T* mutable_data(DDim dims, platform::Place place); T* mutable_data(DDim dims, platform::Place place);
/*! Return the dimensions of the memory block. */ /*! Return the dimensions of the memory block. */
inline const DDim& dims() const; const DDim& dims() const;
/*! Return the numel of the memory block. */ /*! Return the numel of the memory block. */
inline int64_t numel() const; int64_t numel() const;
/*! Resize the dimensions of the memory block. */ /*! Resize the dimensions of the memory block. */
inline Tensor& Resize(const DDim& dims); Tensor& Resize(const DDim& dims);
/*! The internal of two tensors share the same memory block. */ /*! The internal of two tensors share the same memory block. */
inline Tensor& ShareDataWith(const Tensor& src); Tensor& ShareDataWith(const Tensor& src);
/** /**
* @brief Return a sub-tensor of the given tensor. * @brief Return a sub-tensor of the given tensor.
...@@ -106,7 +126,7 @@ class Tensor { ...@@ -106,7 +126,7 @@ class Tensor {
* @param[in] end_idx The index of the end row(exclusive) to slice. * @param[in] end_idx The index of the end row(exclusive) to slice.
* The index number begins from 0. * The index number begins from 0.
*/ */
inline Tensor Slice(int begin_idx, int end_idx) const; Tensor Slice(int begin_idx, int end_idx) const;
platform::Place place() const { platform::Place place() const {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
...@@ -123,11 +143,11 @@ class Tensor { ...@@ -123,11 +143,11 @@ class Tensor {
// memory size returns the holding memory size in byte. // memory size returns the holding memory size in byte.
size_t memory_size() const; size_t memory_size() const;
inline void check_memory_size() const; void check_memory_size() const;
inline DataLayout layout() const { return layout_; } DataLayout layout() const { return layout_; }
inline void set_layout(const DataLayout layout) { layout_ = layout; } void set_layout(const DataLayout layout) { layout_ = layout; }
private: private:
/** /**
...@@ -197,8 +217,10 @@ class Tensor { ...@@ -197,8 +217,10 @@ class Tensor {
* N,C,H,W for respectively the batch size, the number of * N,C,H,W for respectively the batch size, the number of
* feature maps, the height. * feature maps, the height.
*/ */
// Fix me: here just change the default layout to kNCHW
DataLayout layout_ = DataLayout::kNHWC; // it doesn't fix the real issue, i.e. feeder should set up tensor layout
// according to actual input data
DataLayout layout_ = DataLayout::kNCHW;
/** /**
* @brief A PlaceHolder may be shared by more than one tensor. * @brief A PlaceHolder may be shared by more than one tensor.
...@@ -210,15 +232,6 @@ class Tensor { ...@@ -210,15 +232,6 @@ class Tensor {
size_t offset_; size_t offset_;
}; };
inline void Tensor::switch_place(platform::Place new_place) {
if (holder_->place() == new_place) {
return;
}
// TODO(tonyyang-svail): do memcpy here.
PADDLE_THROW("Not Implemented");
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
......
...@@ -20,21 +20,6 @@ limitations under the License. */ ...@@ -20,21 +20,6 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
extern size_t SizeOfType(std::type_index type);
inline void Tensor::check_memory_size() const {
PADDLE_ENFORCE_NOT_NULL(
holder_, "Tensor holds no memory. Call Tensor::mutable_data first.");
PADDLE_ENFORCE_LE(
numel() * SizeOfType(type()), memory_size(),
"Tensor's dims_ is out of bound. Call Tensor::mutable_data "
"first to re-allocate memory.\n"
"or maybe the required data-type mismatches the data already stored.");
}
inline size_t Tensor::memory_size() const {
return holder_ == nullptr ? 0UL : holder_->size() - offset_;
}
template <typename T> template <typename T>
inline const T* Tensor::data() const { inline const T* Tensor::data() const {
check_memory_size(); check_memory_size();
...@@ -73,88 +58,6 @@ inline T* Tensor::mutable_data(platform::Place place) { ...@@ -73,88 +58,6 @@ inline T* Tensor::mutable_data(platform::Place place) {
return reinterpret_cast<T*>(mutable_data(place, typeid(T))); return reinterpret_cast<T*>(mutable_data(place, typeid(T)));
} }
inline void* Tensor::mutable_data(platform::Place place, std::type_index type) {
if (holder_ != nullptr) {
holder_->set_type(type);
}
PADDLE_ENFORCE_GE(numel(), 0,
"When calling this method, the Tensor's numel must be "
"equal or larger than zero. "
"Please check Tensor::Resize has been called first.");
int64_t size = numel() * SizeOfType(type);
/* some versions of boost::variant don't have operator!= */
if (holder_ == nullptr || !(holder_->place() == place) ||
holder_->size() < size + offset_) {
if (platform::is_cpu_place(place)) {
holder_.reset(new PlaceholderImpl<platform::CPUPlace>(
boost::get<platform::CPUPlace>(place), size, type));
} else if (platform::is_gpu_place(place) ||
platform::is_cuda_pinned_place(place)) {
#ifndef PADDLE_WITH_CUDA
PADDLE_THROW(
"CUDAPlace or CUDAPinnedPlace is not supported in CPU-only mode.");
}
#else
if (platform::is_gpu_place(place)) {
holder_.reset(new PlaceholderImpl<platform::CUDAPlace>(
boost::get<platform::CUDAPlace>(place), size, type));
} else if (platform::is_cuda_pinned_place(place)) {
holder_.reset(new PlaceholderImpl<platform::CUDAPinnedPlace>(
boost::get<platform::CUDAPinnedPlace>(place), size, type));
}
}
#endif
offset_ = 0;
}
return reinterpret_cast<void*>(reinterpret_cast<uintptr_t>(holder_->ptr()) +
offset_);
}
inline void* Tensor::mutable_data(platform::Place place) {
PADDLE_ENFORCE(this->holder_ != nullptr,
"Cannot invoke mutable data if current hold nothing.");
return mutable_data(place, holder_->type());
}
inline Tensor& Tensor::ShareDataWith(const Tensor& src) {
src.check_memory_size();
*this = src;
return *this;
}
inline Tensor Tensor::Slice(int begin_idx, int end_idx) const {
check_memory_size();
PADDLE_ENFORCE_GE(begin_idx, 0,
"The start row index must be greater than 0.");
PADDLE_ENFORCE_LE(end_idx, dims_[0], "The end row index is out of bound.");
PADDLE_ENFORCE_LT(
begin_idx, end_idx,
"The start row index must be lesser than the end row index.");
if (dims_[0] == 1) {
return *this;
} else {
size_t base = numel() / dims_[0];
Tensor dst;
dst.holder_ = holder_;
dst.set_layout(layout_);
DDim dst_dims = dims_;
dst_dims[0] = end_idx - begin_idx;
dst.Resize(dst_dims);
dst.offset_ = offset_ + begin_idx * base * SizeOfType(type());
return dst;
}
}
inline Tensor& Tensor::Resize(const DDim& dims) {
dims_ = dims;
return *this;
}
inline const DDim& Tensor::dims() const { return dims_; }
inline int64_t Tensor::numel() const { return product(dims_); }
inline Tensor ReshapeToMatrix(const Tensor& src, int num_col_dims) { inline Tensor ReshapeToMatrix(const Tensor& src, int num_col_dims) {
Tensor res; Tensor res;
res.ShareDataWith(src); res.ShareDataWith(src);
......
...@@ -209,7 +209,7 @@ TEST(Tensor, ReshapeToMatrix) { ...@@ -209,7 +209,7 @@ TEST(Tensor, ReshapeToMatrix) {
TEST(Tensor, Layout) { TEST(Tensor, Layout) {
framework::Tensor src; framework::Tensor src;
ASSERT_EQ(src.layout(), framework::DataLayout::kNHWC); ASSERT_EQ(src.layout(), framework::DataLayout::kNCHW);
src.set_layout(framework::DataLayout::kAnyLayout); src.set_layout(framework::DataLayout::kAnyLayout);
ASSERT_EQ(src.layout(), framework::DataLayout::kAnyLayout); ASSERT_EQ(src.layout(), framework::DataLayout::kAnyLayout);
} }
...@@ -18,6 +18,8 @@ limitations under the License. */ ...@@ -18,6 +18,8 @@ limitations under the License. */
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
namespace paddle { namespace paddle {
...@@ -107,6 +109,13 @@ class OrderedRegistry { ...@@ -107,6 +109,13 @@ class OrderedRegistry {
std::vector<std::unique_ptr<T>> data_; std::vector<std::unique_ptr<T>> data_;
}; };
template <typename T>
T &GetFromScope(const framework::Scope &scope, const std::string &name) {
framework::Variable *var = scope.FindVar(name);
PADDLE_ENFORCE(var != nullptr);
return *var->GetMutable<T>();
}
} // namespace analysis } // namespace analysis
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
......
# Add TRT tests # Add TRT tests
nv_test(test_op_converter SRCS test_op_converter.cc mul_op.cc conv2d_op.cc DEPS ${FLUID_CORE_MODULES} tensorrt_engine) nv_library(tensorrt_converter
# This test is not stable SRCS mul_op.cc conv2d_op.cc fc_op.cc
# See https://paddleci.ngrok.io/viewLog.html?tab=buildLog&buildTypeId=Paddle_PrCi2&buildId=36834&_focus=8828 DEPS tensorrt_engine mul_op)
#nv_test(test_trt_activation_op SRCS test_activation_op.cc activation_op.cc io_converter.cc
# DEPS ${FLUID_CORE_MODULES} activation_op tensorrt_engine nv_test(test_op_converter SRCS test_op_converter.cc DEPS
# SERIAL) ${FLUID_CORE_MODULES} tensorrt_engine tensorrt_converter)
nv_test(test_io_converter SRCS test_io_converter.cc io_converter.cc DEPS dynload_cuda dynamic_loader lod_tensor) nv_test(test_io_converter SRCS test_io_converter.cc io_converter.cc DEPS dynload_cuda dynamic_loader lod_tensor)
nv_test(test_trt_mul_op SRCS test_mul_op.cc mul_op.cc nv_test(test_trt_mul_op SRCS test_mul_op.cc mul_op.cc
DEPS ${FLUID_CORE_MODULES} tensorrt_engine mul_op SERIAL) DEPS ${FLUID_CORE_MODULES} tensorrt_engine mul_op SERIAL)
nv_test(test_trt_fc_op SRCS test_fc_op.cc fc_op.cc nv_test(test_trt_fc_op SRCS test_fc_op.cc fc_op.cc
DEPS ${FLUID_CORE_MODULES} tensorrt_engine mul_op SERIAL) DEPS ${FLUID_CORE_MODULES} tensorrt_engine mul_op SERIAL)
nv_test(test_trt_activation_op SRCS test_activation_op.cc activation_op.cc
DEPS ${FLUID_CORE_MODULES} tensorrt_engine activation_op SERIAL)
...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" #include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
namespace paddle { namespace paddle {
...@@ -21,7 +22,8 @@ namespace tensorrt { ...@@ -21,7 +22,8 @@ namespace tensorrt {
class ReluOpConverter : public OpConverter { class ReluOpConverter : public OpConverter {
public: public:
ReluOpConverter() {} ReluOpConverter() {}
void operator()(const framework::proto::OpDesc& op) override { void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override {
// Here the two nullptr looks strange, that's because the // Here the two nullptr looks strange, that's because the
// framework::OpDesc's constructor is strange. // framework::OpDesc's constructor is strange.
framework::OpDesc op_desc(op, nullptr); framework::OpDesc op_desc(op, nullptr);
...@@ -32,12 +34,17 @@ class ReluOpConverter : public OpConverter { ...@@ -32,12 +34,17 @@ class ReluOpConverter : public OpConverter {
nvinfer1::IActivationLayer* layer = TRT_ENGINE_ADD_LAYER( nvinfer1::IActivationLayer* layer = TRT_ENGINE_ADD_LAYER(
engine_, Activation, *const_cast<nvinfer1::ITensor*>(input_tensor), engine_, Activation, *const_cast<nvinfer1::ITensor*>(input_tensor),
nvinfer1::ActivationType::kRELU); nvinfer1::ActivationType::kRELU);
engine_->SetITensor(op_desc.Output("Out")[0], layer->getOutput(0)); auto output_name = op_desc.Output("Out")[0];
engine_->SetITensor(output_name, layer->getOutput(0));
if (test_mode) { // the test framework can not determine which is the
// output, so place the declaration inside.
engine_->DeclareOutput(output_name);
}
} }
}; };
REGISTER_TRT_OP_CONVERTER(relu, ReluOpConverter);
} // namespace tensorrt } // namespace tensorrt
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
REGISTER_TRT_OP_CONVERTER(relu, ReluOpConverter);
...@@ -22,14 +22,14 @@ class Conv2dOpConverter : public OpConverter { ...@@ -22,14 +22,14 @@ class Conv2dOpConverter : public OpConverter {
public: public:
Conv2dOpConverter() {} Conv2dOpConverter() {}
void operator()(const framework::proto::OpDesc& op, void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope) override { const framework::Scope& scope, bool test_mode) override {
LOG(INFO) LOG(INFO)
<< "convert a fluid conv2d op to tensorrt conv layer without bias"; << "convert a fluid conv2d op to tensorrt conv layer without bias";
} }
}; };
REGISTER_TRT_OP_CONVERTER(conv2d, Conv2dOpConverter);
} // namespace tensorrt } // namespace tensorrt
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
REGISTER_TRT_OP_CONVERTER(conv2d, Conv2dOpConverter);
...@@ -56,7 +56,7 @@ void ReorderCKtoKC(TensorRTEngine::Weight& iweights, ...@@ -56,7 +56,7 @@ void ReorderCKtoKC(TensorRTEngine::Weight& iweights,
class FcOpConverter : public OpConverter { class FcOpConverter : public OpConverter {
public: public:
void operator()(const framework::proto::OpDesc& op, void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope) override { const framework::Scope& scope, bool test_mode) override {
VLOG(4) << "convert a fluid fc op to tensorrt fc layer without bias"; VLOG(4) << "convert a fluid fc op to tensorrt fc layer without bias";
framework::OpDesc op_desc(op, nullptr); framework::OpDesc op_desc(op, nullptr);
...@@ -106,14 +106,16 @@ class FcOpConverter : public OpConverter { ...@@ -106,14 +106,16 @@ class FcOpConverter : public OpConverter {
n_output, weight.get(), bias.get()); n_output, weight.get(), bias.get());
auto output_name = op_desc.Output("Out").front(); auto output_name = op_desc.Output("Out").front();
engine_->DeclareOutput(layer, 0, output_name); engine_->SetITensor(output_name, layer->getOutput(0));
if (test_mode) {
engine_->DeclareOutput(output_name);
}
} }
}; };
REGISTER_TRT_OP_CONVERTER(fc, FcOpConverter);
} // namespace tensorrt } // namespace tensorrt
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
REGISTER_TRT_OP_CONVERTER(fc, FcOpConverter);
USE_OP(mul); USE_OP(mul);
...@@ -23,9 +23,8 @@ namespace tensorrt { ...@@ -23,9 +23,8 @@ namespace tensorrt {
*/ */
class MulOpConverter : public OpConverter { class MulOpConverter : public OpConverter {
public: public:
MulOpConverter() {}
void operator()(const framework::proto::OpDesc& op, void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope) override { const framework::Scope& scope, bool test_mode) override {
VLOG(4) << "convert a fluid mul op to tensorrt mul layer without bias"; VLOG(4) << "convert a fluid mul op to tensorrt mul layer without bias";
framework::OpDesc op_desc(op, nullptr); framework::OpDesc op_desc(op, nullptr);
...@@ -37,12 +36,18 @@ class MulOpConverter : public OpConverter { ...@@ -37,12 +36,18 @@ class MulOpConverter : public OpConverter {
engine_, MatrixMultiply, *const_cast<nvinfer1::ITensor*>(input1), false, engine_, MatrixMultiply, *const_cast<nvinfer1::ITensor*>(input1), false,
*const_cast<nvinfer1::ITensor*>(input2), false); *const_cast<nvinfer1::ITensor*>(input2), false);
engine_->DeclareOutput(layer, 0, op_desc.Output("Out")[0]); auto output_name = op_desc.Output("Out")[0];
engine_->SetITensor(output_name, layer->getOutput(0));
if (test_mode) { // the test framework can not determine which is the
// output, so place the declaration inside.
engine_->DeclareOutput(output_name);
}
} }
}; };
REGISTER_TRT_OP_CONVERTER(mul, MulOpConverter);
} // namespace tensorrt } // namespace tensorrt
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
USE_OP(mul);
REGISTER_TRT_OP_CONVERTER(mul, MulOpConverter);
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include "paddle/fluid/framework/block_desc.h" #include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/tensorrt/engine.h" #include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/utils/singleton.h" #include "paddle/fluid/inference/utils/singleton.h"
...@@ -34,12 +35,15 @@ class OpConverter { ...@@ -34,12 +35,15 @@ class OpConverter {
// Converter logic for an op. // Converter logic for an op.
virtual void operator()(const framework::proto::OpDesc& op, virtual void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope) {} const framework::Scope& scope,
bool test_mode = false) {}
// Convert a single fluid operaotr and add the corresponding layer to TRT. // Convert a single fluid operator and add the corresponding layer to TRT.
// test_mode: whether the instance executes in an unit test.
void ConvertOp(const framework::proto::OpDesc& op, void ConvertOp(const framework::proto::OpDesc& op,
const std::unordered_set<std::string>& parameters, const std::unordered_set<std::string>& parameters,
const framework::Scope& scope, TensorRTEngine* engine) { const framework::Scope& scope, TensorRTEngine* engine,
bool test_mode = false) {
framework::OpDesc op_desc(op, nullptr); framework::OpDesc op_desc(op, nullptr);
OpConverter* it{nullptr}; OpConverter* it{nullptr};
...@@ -57,7 +61,7 @@ class OpConverter { ...@@ -57,7 +61,7 @@ class OpConverter {
PADDLE_ENFORCE_NOT_NULL(it, "no OpConverter for optype [%s]", PADDLE_ENFORCE_NOT_NULL(it, "no OpConverter for optype [%s]",
op_desc.Type()); op_desc.Type());
it->SetEngine(engine); it->SetEngine(engine);
(*it)(op, scope); (*it)(op, scope, test_mode);
} }
// convert fluid block to tensorrt network // convert fluid block to tensorrt network
...@@ -77,6 +81,9 @@ class OpConverter { ...@@ -77,6 +81,9 @@ class OpConverter {
// TensorRT engine // TensorRT engine
TensorRTEngine* engine_{nullptr}; TensorRTEngine* engine_{nullptr};
protected:
bool test_mode_;
private: private:
// registered op converter map, whose key is the fluid op type, and value is // registered op converter map, whose key is the fluid op type, and value is
// the pointer position of corresponding OpConverter class. // the pointer position of corresponding OpConverter class.
...@@ -86,12 +93,23 @@ class OpConverter { ...@@ -86,12 +93,23 @@ class OpConverter {
}; };
#define REGISTER_TRT_OP_CONVERTER(op_type__, Converter__) \ #define REGISTER_TRT_OP_CONVERTER(op_type__, Converter__) \
struct trt_##op_type__##_converter { \ struct trt_##op_type__##_converter : public ::paddle::framework::Registrar { \
trt_##op_type__##_converter() { \ trt_##op_type__##_converter() { \
Registry<OpConverter>::Register<Converter__>(#op_type__); \ ::paddle::inference:: \
Registry<paddle::inference::tensorrt::OpConverter>::Register< \
::paddle::inference::tensorrt::Converter__>(#op_type__); \
} \ } \
}; \ }; \
trt_##op_type__##_converter trt_##op_type__##_converter__; trt_##op_type__##_converter trt_##op_type__##_converter__; \
int TouchConverterRegister_##op_type__() { \
trt_##op_type__##_converter__.Touch(); \
return 0; \
}
#define USE_TRT_CONVERTER(op_type__) \
extern int TouchConverterRegister_##op_type__(); \
static int use_op_converter_trt_##op_type__ __attribute__((unused)) = \
TouchConverterRegister_##op_type__();
} // namespace tensorrt } // namespace tensorrt
} // namespace inference } // namespace inference
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/inference/tensorrt/convert/ut_helper.h"
#include "paddle/fluid/inference/tensorrt/convert/io_converter.h"
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/place.h"
USE_OP(relu);
namespace paddle { namespace paddle {
namespace inference { namespace inference {
namespace tensorrt { namespace tensorrt {
void Compare(const std::string op_type, float input, float expect) { TEST(ReluOpConverter, main) {
framework::Scope scope; framework::Scope scope;
platform::CUDAPlace place; std::unordered_set<std::string> parameters;
platform::CUDADeviceContext ctx(place); TRTConvertValidation validator(10, parameters, scope, 1000);
validator.DeclInputVar("relu-X", nvinfer1::Dims2(10, 6));
// init fluid op and variable validator.DeclOutputVar("relu-Out", nvinfer1::Dims2(10, 6));
auto x_var = scope.Var("X");
auto x_tensor = x_var->GetMutable<framework::LoDTensor>(); // Prepare Op description
x_tensor->Resize({1, 1}); framework::OpDesc desc;
x_tensor->mutable_data<float>(place); desc.SetType("relu");
std::vector<float> init; desc.SetInput("X", {"relu-X"});
init.push_back(input); desc.SetOutput("Out", {"relu-Out"});
framework::TensorFromVector(init, ctx, x_tensor);
LOG(INFO) << "set OP";
auto out_var = scope.Var("Out"); validator.SetOp(*desc.Proto());
auto out_tensor = out_var->GetMutable<framework::LoDTensor>(); LOG(INFO) << "execute";
out_tensor->Resize({1, 1});
out_tensor->mutable_data<float>(place); validator.Execute(10);
framework::OpDesc op_desc;
op_desc.SetType(op_type);
op_desc.SetInput("X", {"X"});
op_desc.SetOutput("Out", {"Out"});
auto op = framework::OpRegistry::CreateOp(*op_desc.Proto());
// run fluid op
op->Run(scope, place);
// get fluid output
std::vector<float> out1;
framework::TensorToVector(*out_tensor, ctx, &out1);
// init tensorrt op
cudaStream_t stream;
ASSERT_EQ(0, cudaStreamCreate(&stream));
TensorRTEngine* engine = new TensorRTEngine(1, 1 << 10, &stream);
engine->InitNetwork();
engine->DeclareInput("X", nvinfer1::DataType::kFLOAT,
nvinfer1::DimsCHW{1, 1, 1});
// convert op
OpConverter op_converter;
op_converter.ConvertOp(*op_desc.Proto(), engine);
engine->DeclareOutput("Out");
engine->FreezeNetwork();
// convert LoDTensor to ITensor
size_t size = x_tensor->memory_size();
EngineIOConverter::ConvertInput(op_type, *x_tensor,
engine->buffer("X").buffer, size, &stream);
// run tensorrt Outp
engine->Execute(1);
// convert ITensor to LoDTensor
EngineIOConverter::ConvertOutput(op_type, engine->buffer("Out").buffer,
out_tensor, size, &stream);
// get tensorrt output
std::vector<float> out2;
framework::TensorToVector(*out_tensor, ctx, &out2);
// compare
ASSERT_EQ(out1[0], out2[0]);
ASSERT_EQ(out1[0], expect);
delete engine;
cudaStreamDestroy(stream);
}
TEST(OpConverter, ConvertRelu) {
Compare("relu", 1, 1); // relu(1) = 1
Compare("relu", -5, 0); // relu(-5) = 0
} }
} // namespace tensorrt } // namespace tensorrt
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
USE_OP(activation); USE_OP(relu);
...@@ -36,3 +36,5 @@ TEST(OpConverter, ConvertBlock) { ...@@ -36,3 +36,5 @@ TEST(OpConverter, ConvertBlock) {
} // namespace tensorrt } // namespace tensorrt
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
USE_TRT_CONVERTER(conv2d)
...@@ -27,6 +27,7 @@ limitations under the License. */ ...@@ -27,6 +27,7 @@ limitations under the License. */
#include "paddle/fluid/inference/analysis/helper.h" #include "paddle/fluid/inference/analysis/helper.h"
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" #include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/engine.h" #include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/utils/singleton.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
...@@ -104,8 +105,8 @@ class TRTConvertValidation { ...@@ -104,8 +105,8 @@ class TRTConvertValidation {
void SetOp(const framework::proto::OpDesc& desc) { void SetOp(const framework::proto::OpDesc& desc) {
op_ = framework::OpRegistry::CreateOp(desc); op_ = framework::OpRegistry::CreateOp(desc);
OpConverter op_converter; Singleton<OpConverter>::Global().ConvertOp(
op_converter.ConvertOp(desc, parameters_, scope_, engine_.get()); desc, parameters_, scope_, engine_.get(), true /*test_mode*/);
engine_->FreezeNetwork(); engine_->FreezeNetwork();
...@@ -150,7 +151,8 @@ class TRTConvertValidation { ...@@ -150,7 +151,8 @@ class TRTConvertValidation {
// Compare two output // Compare two output
ASSERT_FALSE(fluid_out.empty()); ASSERT_FALSE(fluid_out.empty());
for (size_t i = 0; i < fluid_out.size(); i++) { for (size_t i = 0; i < fluid_out.size(); i++) {
EXPECT_LT(std::abs(fluid_out[i] - trt_out[i]), 1e-6); // Loose the threshold for CI in different machine model.
EXPECT_LT(std::abs(fluid_out[i] - trt_out[i]), 2e-5);
} }
} }
} }
......
...@@ -43,9 +43,10 @@ void TensorRTEngine::Execute(int batch_size) { ...@@ -43,9 +43,10 @@ void TensorRTEngine::Execute(int batch_size) {
} }
TensorRTEngine::~TensorRTEngine() { TensorRTEngine::~TensorRTEngine() {
cudaStreamSynchronize(*stream_);
// clean buffer // clean buffer
for (auto& buf : buffers_) { for (auto& buf : buffers_) {
if (buf.buffer != nullptr) { if (buf.device == DeviceType::GPU && buf.buffer != nullptr) {
PADDLE_ENFORCE_EQ(0, cudaFree(buf.buffer)); PADDLE_ENFORCE_EQ(0, cudaFree(buf.buffer));
buf.buffer = nullptr; buf.buffer = nullptr;
buf.max_size = 0; buf.max_size = 0;
...@@ -80,6 +81,8 @@ void TensorRTEngine::FreezeNetwork() { ...@@ -80,6 +81,8 @@ void TensorRTEngine::FreezeNetwork() {
auto& buf = buffer(item.first); auto& buf = buffer(item.first);
CHECK(buf.buffer == nullptr); // buffer should be allocated only once. CHECK(buf.buffer == nullptr); // buffer should be allocated only once.
PADDLE_ENFORCE_EQ(0, cudaMalloc(&buf.buffer, item.second)); PADDLE_ENFORCE_EQ(0, cudaMalloc(&buf.buffer, item.second));
VLOG(4) << "buffer malloc " << item.first << " " << item.second << " "
<< buf.buffer;
buf.size = buf.max_size = item.second; buf.size = buf.max_size = item.second;
buf.device = DeviceType::GPU; buf.device = DeviceType::GPU;
} }
...@@ -96,6 +99,7 @@ nvinfer1::ITensor* TensorRTEngine::DeclareInput(const std::string& name, ...@@ -96,6 +99,7 @@ nvinfer1::ITensor* TensorRTEngine::DeclareInput(const std::string& name,
PADDLE_ENFORCE(input, "infer network add input %s failed", name); PADDLE_ENFORCE(input, "infer network add input %s failed", name);
buffer_sizes_[name] = kDataTypeSize[static_cast<int>(dtype)] * buffer_sizes_[name] = kDataTypeSize[static_cast<int>(dtype)] *
analysis::AccuDims(dims.d, dims.nbDims); analysis::AccuDims(dims.d, dims.nbDims);
PADDLE_ENFORCE(input->isNetworkInput());
TensorRTEngine::SetITensor(name, input); TensorRTEngine::SetITensor(name, input);
return input; return input;
} }
...@@ -109,7 +113,9 @@ void TensorRTEngine::DeclareOutput(const nvinfer1::ILayer* layer, int offset, ...@@ -109,7 +113,9 @@ void TensorRTEngine::DeclareOutput(const nvinfer1::ILayer* layer, int offset,
SetITensor(name, output); SetITensor(name, output);
PADDLE_ENFORCE(output != nullptr); PADDLE_ENFORCE(output != nullptr);
output->setName(name.c_str()); output->setName(name.c_str());
PADDLE_ENFORCE(!output->isNetworkInput());
infer_network_->markOutput(*output); infer_network_->markOutput(*output);
PADDLE_ENFORCE(output->isNetworkOutput());
// output buffers' size can only be decided latter, set zero here to mark this // output buffers' size can only be decided latter, set zero here to mark this
// and will reset latter. // and will reset latter.
buffer_sizes_[name] = 0; buffer_sizes_[name] = 0;
...@@ -122,6 +128,7 @@ void TensorRTEngine::DeclareOutput(const std::string& name) { ...@@ -122,6 +128,7 @@ void TensorRTEngine::DeclareOutput(const std::string& name) {
auto* output = TensorRTEngine::GetITensor(name); auto* output = TensorRTEngine::GetITensor(name);
PADDLE_ENFORCE(output != nullptr); PADDLE_ENFORCE(output != nullptr);
output->setName(name.c_str()); output->setName(name.c_str());
PADDLE_ENFORCE(!output->isNetworkInput());
infer_network_->markOutput(*output); infer_network_->markOutput(*output);
// output buffers' size can only be decided latter, set zero here to mark this // output buffers' size can only be decided latter, set zero here to mark this
// and will reset latter. // and will reset latter.
......
...@@ -21,6 +21,7 @@ limitations under the License. */ ...@@ -21,6 +21,7 @@ limitations under the License. */
#include <vector> #include <vector>
#include "paddle/fluid/inference/engine.h" #include "paddle/fluid/inference/engine.h"
#include "paddle/fluid/inference/tensorrt/helper.h" #include "paddle/fluid/inference/tensorrt/helper.h"
#include "paddle/fluid/inference/utils/singleton.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
...@@ -131,7 +132,11 @@ class TensorRTEngine : public EngineBase { ...@@ -131,7 +132,11 @@ class TensorRTEngine : public EngineBase {
// TensorRT related internal members // TensorRT related internal members
template <typename T> template <typename T>
struct Destroyer { struct Destroyer {
void operator()(T* x) { x->destroy(); } void operator()(T* x) {
if (x) {
x->destroy();
}
}
}; };
template <typename T> template <typename T>
using infer_ptr = std::unique_ptr<T, Destroyer<T>>; using infer_ptr = std::unique_ptr<T, Destroyer<T>>;
...@@ -155,6 +160,27 @@ class TensorRTEngine : public EngineBase { ...@@ -155,6 +160,27 @@ class TensorRTEngine : public EngineBase {
#define TRT_ENGINE_ADD_LAYER(engine__, layer__, ARGS...) \ #define TRT_ENGINE_ADD_LAYER(engine__, layer__, ARGS...) \
engine__->network()->add##layer__(ARGS); engine__->network()->add##layer__(ARGS);
/*
* Helper to control the TensorRT engine's creation and deletion.
*/
class TRT_EngineManager {
public:
TensorRTEngine* Create(int max_batch, int max_workspace,
cudaStream_t* stream) {
engines_.emplace_back(new TensorRTEngine(max_batch, max_workspace, stream));
return engines_.back().get();
}
void DeleteALl() {
for (auto& ptr : engines_) {
ptr.reset(nullptr);
}
}
private:
std::vector<std::unique_ptr<TensorRTEngine>> engines_;
};
} // namespace tensorrt } // namespace tensorrt
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
...@@ -101,23 +101,22 @@ void SplitData( ...@@ -101,23 +101,22 @@ void SplitData(
} }
void ThreadRunInfer( void ThreadRunInfer(
const int tid, paddle::framework::Executor* executor, const int tid, paddle::framework::Scope* scope,
paddle::framework::Scope* scope,
const std::unique_ptr<paddle::framework::ProgramDesc>& inference_program,
const std::vector<std::vector<const paddle::framework::LoDTensor*>>& jobs) { const std::vector<std::vector<const paddle::framework::LoDTensor*>>& jobs) {
auto copy_program = std::unique_ptr<paddle::framework::ProgramDesc>( // maybe framework:ProgramDesc is not thread-safe
new paddle::framework::ProgramDesc(*inference_program));
auto& sub_scope = scope->NewScope(); auto& sub_scope = scope->NewScope();
auto place = paddle::platform::CPUPlace();
auto executor = paddle::framework::Executor(place);
auto inference_program =
paddle::inference::Load(&executor, scope, FLAGS_model_path);
std::string feed_holder_name = "feed_" + paddle::string::to_string(tid); auto ctx = executor.Prepare(*inference_program, /*block_id*/ 0);
std::string fetch_holder_name = "fetch_" + paddle::string::to_string(tid); executor.CreateVariables(*inference_program, &sub_scope, /*block_id*/ 0);
copy_program->SetFeedHolderName(feed_holder_name);
copy_program->SetFetchHolderName(fetch_holder_name);
const std::vector<std::string>& feed_target_names = const std::vector<std::string>& feed_target_names =
copy_program->GetFeedTargetNames(); inference_program->GetFeedTargetNames();
const std::vector<std::string>& fetch_target_names = const std::vector<std::string>& fetch_target_names =
copy_program->GetFetchTargetNames(); inference_program->GetFetchTargetNames();
PADDLE_ENFORCE_EQ(fetch_target_names.size(), 1UL); PADDLE_ENFORCE_EQ(fetch_target_names.size(), 1UL);
std::map<std::string, paddle::framework::LoDTensor*> fetch_targets; std::map<std::string, paddle::framework::LoDTensor*> fetch_targets;
...@@ -131,9 +130,8 @@ void ThreadRunInfer( ...@@ -131,9 +130,8 @@ void ThreadRunInfer(
auto start_ms = GetCurrentMs(); auto start_ms = GetCurrentMs();
for (size_t i = 0; i < inputs.size(); ++i) { for (size_t i = 0; i < inputs.size(); ++i) {
feed_targets[feed_target_names[0]] = inputs[i]; feed_targets[feed_target_names[0]] = inputs[i];
executor->Run(*copy_program, &sub_scope, &feed_targets, &fetch_targets, executor.RunPreparedContext(ctx.get(), &sub_scope, &feed_targets,
true /*create_local_scope*/, true /*create_vars*/, &fetch_targets, false /*create_local_scope*/);
feed_holder_name, fetch_holder_name);
} }
auto stop_ms = GetCurrentMs(); auto stop_ms = GetCurrentMs();
scope->DeleteScope(&sub_scope); scope->DeleteScope(&sub_scope);
...@@ -158,22 +156,10 @@ TEST(inference, nlp) { ...@@ -158,22 +156,10 @@ TEST(inference, nlp) {
LOG(INFO) << "Number of samples (seq_len<1024): " << datasets.size(); LOG(INFO) << "Number of samples (seq_len<1024): " << datasets.size();
LOG(INFO) << "Total number of words: " << num_total_words; LOG(INFO) << "Total number of words: " << num_total_words;
const bool model_combined = false;
// 0. Call `paddle::framework::InitDevices()` initialize all the devices // 0. Call `paddle::framework::InitDevices()` initialize all the devices
// 1. Define place, executor, scope
auto place = paddle::platform::CPUPlace();
auto executor = paddle::framework::Executor(place);
std::unique_ptr<paddle::framework::Scope> scope( std::unique_ptr<paddle::framework::Scope> scope(
new paddle::framework::Scope()); new paddle::framework::Scope());
// 2. Initialize the inference_program and load parameters
std::unique_ptr<paddle::framework::ProgramDesc> inference_program;
inference_program =
InitProgram(&executor, scope.get(), FLAGS_model_path, model_combined);
if (FLAGS_use_mkldnn) {
EnableMKLDNN(inference_program);
}
#ifdef PADDLE_WITH_MKLML #ifdef PADDLE_WITH_MKLML
// only use 1 thread number per std::thread // only use 1 thread number per std::thread
omp_set_dynamic(0); omp_set_dynamic(0);
...@@ -189,21 +175,30 @@ TEST(inference, nlp) { ...@@ -189,21 +175,30 @@ TEST(inference, nlp) {
start_ms = GetCurrentMs(); start_ms = GetCurrentMs();
for (int i = 0; i < FLAGS_num_threads; ++i) { for (int i = 0; i < FLAGS_num_threads; ++i) {
threads.emplace_back( threads.emplace_back(
new std::thread(ThreadRunInfer, i, &executor, scope.get(), new std::thread(ThreadRunInfer, i, scope.get(), std::ref(jobs)));
std::ref(inference_program), std::ref(jobs)));
} }
for (int i = 0; i < FLAGS_num_threads; ++i) { for (int i = 0; i < FLAGS_num_threads; ++i) {
threads[i]->join(); threads[i]->join();
} }
stop_ms = GetCurrentMs(); stop_ms = GetCurrentMs();
} else { } else {
if (FLAGS_prepare_vars) { // 1. Define place, executor, scope
executor.CreateVariables(*inference_program, scope.get(), 0); auto place = paddle::platform::CPUPlace();
auto executor = paddle::framework::Executor(place);
// 2. Initialize the inference_program and load parameters
std::unique_ptr<paddle::framework::ProgramDesc> inference_program;
inference_program = InitProgram(&executor, scope.get(), FLAGS_model_path,
/*model combined*/ false);
if (FLAGS_use_mkldnn) {
EnableMKLDNN(inference_program);
} }
// always prepare context // always prepare context
std::unique_ptr<paddle::framework::ExecutorPrepareContext> ctx; std::unique_ptr<paddle::framework::ExecutorPrepareContext> ctx;
ctx = executor.Prepare(*inference_program, 0); ctx = executor.Prepare(*inference_program, 0);
if (FLAGS_prepare_vars) {
executor.CreateVariables(*inference_program, scope.get(), 0);
}
// preapre fetch // preapre fetch
const std::vector<std::string>& fetch_target_names = const std::vector<std::string>& fetch_target_names =
inference_program->GetFetchTargetNames(); inference_program->GetFetchTargetNames();
......
...@@ -166,8 +166,6 @@ function(op_library TARGET) ...@@ -166,8 +166,6 @@ function(op_library TARGET)
# NOTE(*): activation use macro to regist the kernels, set use_op manually. # NOTE(*): activation use macro to regist the kernels, set use_op manually.
if(${TARGET} STREQUAL "activation") if(${TARGET} STREQUAL "activation")
file(APPEND ${pybind_file} "USE_OP(relu);\n") file(APPEND ${pybind_file} "USE_OP(relu);\n")
elseif(${TARGET} STREQUAL "reduce")
file(APPEND ${pybind_file} "USE_OP(reduce_sum);\n")
elseif(${TARGET} STREQUAL "fake_dequantize") elseif(${TARGET} STREQUAL "fake_dequantize")
file(APPEND ${pybind_file} "USE_OP(fake_dequantize_max_abs);\n") file(APPEND ${pybind_file} "USE_OP(fake_dequantize_max_abs);\n")
else() else()
...@@ -227,6 +225,8 @@ op_library(softmax_op DEPS softmax) ...@@ -227,6 +225,8 @@ op_library(softmax_op DEPS softmax)
op_library(sequence_softmax_op DEPS softmax) op_library(sequence_softmax_op DEPS softmax)
if (WITH_GPU AND TENSORRT_FOUND) if (WITH_GPU AND TENSORRT_FOUND)
op_library(tensorrt_engine_op DEPS tensorrt_engine) op_library(tensorrt_engine_op DEPS tensorrt_engine)
nv_test(test_tensorrt_engine_op SRCS tensorrt_engine_op_test.cc
DEPS tensorrt_engine_op tensorrt_engine tensorrt_converter)
else() else()
set(DEPS_OPS ${DEPS_OPS} tensorrt_engine_op) set(DEPS_OPS ${DEPS_OPS} tensorrt_engine_op)
endif() endif()
......
...@@ -24,12 +24,12 @@ namespace operators { ...@@ -24,12 +24,12 @@ namespace operators {
: public ::paddle::framework::OpProtoAndCheckerMaker { \ : public ::paddle::framework::OpProtoAndCheckerMaker { \
public: \ public: \
void Make() override { \ void Make() override { \
AddInput("X", "Input of " #OP_NAME "operator"); \ AddInput("X", "Input of " #OP_NAME " operator"); \
AddOutput("Out", "Output of" #OP_NAME "operator"); \ AddOutput("Out", "Output of " #OP_NAME " operator"); \
AddAttr<bool>("use_mkldnn", \ AddAttr<bool>("use_mkldnn", \
"(bool, default false) Only used in mkldnn kernel") \ "(bool, default false) Only used in mkldnn kernel") \
.SetDefault(false); \ .SetDefault(false); \
AddComment(#OP_COMMENT); \ AddComment(OP_COMMENT); \
} \ } \
} }
...@@ -58,14 +58,16 @@ framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx, ...@@ -58,14 +58,16 @@ framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx,
const framework::OperatorWithKernel& oper, const framework::OperatorWithKernel& oper,
const std::string& name) { const std::string& name) {
framework::LibraryType library{framework::LibraryType::kPlain}; framework::LibraryType library{framework::LibraryType::kPlain};
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
auto it = oper.Attrs().find("use_mkldnn"); auto it = oper.Attrs().find("use_mkldnn");
if (library == framework::LibraryType::kPlain && it != oper.Attrs().end() && if (library == framework::LibraryType::kPlain && it != oper.Attrs().end() &&
platform::CanMKLDNNBeUsed(ctx)) { platform::CanMKLDNNBeUsed(ctx)) {
library = framework::LibraryType::kMKLDNN; library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
} }
#endif #endif
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>(name)->type()), framework::ToDataType(ctx.Input<framework::Tensor>(name)->type()),
ctx.GetPlace(), layout, library); ctx.GetPlace(), layout, library);
......
...@@ -111,14 +111,16 @@ class BatchNormOp : public framework::OperatorWithKernel { ...@@ -111,14 +111,16 @@ class BatchNormOp : public framework::OperatorWithKernel {
"Variance input should be of float type"); "Variance input should be of float type");
framework::LibraryType library_{framework::LibraryType::kPlain}; framework::LibraryType library_{framework::LibraryType::kPlain};
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain && if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) { platform::CanMKLDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kMKLDNN; library_ = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
} }
#endif #endif
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout, return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout,
library_); library_);
} }
...@@ -367,17 +369,18 @@ class BatchNormGradOp : public framework::OperatorWithKernel { ...@@ -367,17 +369,18 @@ class BatchNormGradOp : public framework::OperatorWithKernel {
} }
framework::LibraryType library_{framework::LibraryType::kPlain}; framework::LibraryType library_{framework::LibraryType::kPlain};
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
framework::DataLayout layout_ = framework::DataLayout::kAnyLayout;
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain && if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) { platform::CanMKLDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kMKLDNN; library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN;
} }
#endif #endif
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(), framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(),
layout, library_); layout_, library_);
} }
}; };
......
...@@ -75,6 +75,11 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -75,6 +75,11 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
framework::OpKernelType ConvOp::GetExpectedKernelType( framework::OpKernelType ConvOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
framework::LibraryType library{framework::LibraryType::kPlain}; framework::LibraryType library{framework::LibraryType::kPlain};
std::string data_format = ctx.Attr<std::string>("data_format");
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
framework::DataLayout layout = framework::StringToDataLayout(data_format);
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
if (platform::CanCUDNNBeUsed(ctx)) { if (platform::CanCUDNNBeUsed(ctx)) {
library = framework::LibraryType::kCUDNN; library = framework::LibraryType::kCUDNN;
...@@ -84,6 +89,7 @@ framework::OpKernelType ConvOp::GetExpectedKernelType( ...@@ -84,6 +89,7 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
if (library == framework::LibraryType::kPlain && if (library == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) { platform::CanMKLDNNBeUsed(ctx)) {
library = framework::LibraryType::kMKLDNN; library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
} }
#endif #endif
...@@ -99,9 +105,6 @@ framework::OpKernelType ConvOp::GetExpectedKernelType( ...@@ -99,9 +105,6 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
"float16 can only be used when CUDNN is used"); "float16 can only be used when CUDNN is used");
} }
std::string data_format = ctx.Attr<std::string>("data_format");
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
framework::DataLayout layout = framework::StringToDataLayout(data_format);
return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout, return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout,
library); library);
} }
...@@ -309,6 +312,10 @@ void ConvOpGrad::InferShape(framework::InferShapeContext* ctx) const { ...@@ -309,6 +312,10 @@ void ConvOpGrad::InferShape(framework::InferShapeContext* ctx) const {
framework::OpKernelType ConvOpGrad::GetExpectedKernelType( framework::OpKernelType ConvOpGrad::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
framework::LibraryType library_{framework::LibraryType::kPlain}; framework::LibraryType library_{framework::LibraryType::kPlain};
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
std::string data_format = ctx.Attr<std::string>("data_format");
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
if (platform::CanCUDNNBeUsed(ctx)) { if (platform::CanCUDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kCUDNN; library_ = framework::LibraryType::kCUDNN;
...@@ -318,12 +325,10 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType( ...@@ -318,12 +325,10 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType(
if (library_ == framework::LibraryType::kPlain && if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) { platform::CanMKLDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kMKLDNN; library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN;
} }
#endif #endif
std::string data_format = ctx.Attr<std::string>("data_format");
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Input")->type()), ctx.GetPlace(), framework::ToDataType(ctx.Input<Tensor>("Input")->type()), ctx.GetPlace(),
layout_, library_); layout_, library_);
......
...@@ -48,6 +48,13 @@ class CropOp : public framework::OperatorWithKernel { ...@@ -48,6 +48,13 @@ class CropOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("Out", y_dim); ctx->SetOutputDim("Out", y_dim);
} }
} }
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
ctx.device_context());
}
}; };
class CropOpMaker : public framework::OpProtoAndCheckerMaker { class CropOpMaker : public framework::OpProtoAndCheckerMaker {
...@@ -60,13 +67,19 @@ class CropOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -60,13 +67,19 @@ class CropOpMaker : public framework::OpProtoAndCheckerMaker {
"The input used as reference for cropping, " "The input used as reference for cropping, "
"which is of the same dimensions as X.") "which is of the same dimensions as X.")
.AsDispensable(); .AsDispensable();
AddInput("Offsets",
"The input used to describe offsets in runtime, which is a "
"1-D vector whose size equals to the rank of input 'X'. The "
"elements data type must be int.")
.AsDispensable();
AddOutput("Out", AddOutput("Out",
"The output of crop op, " "The output of crop op, "
"which is of the same dimensions as X."); "which is of the same dimensions as X.");
AddAttr<std::vector<int>>("offsets", AddAttr<std::vector<int>>("offsets",
"A list<int> describing offsets to be cropped. " "A list<int> describing offsets to be cropped. "
"The size of offsets list should be the same as " "The size of offsets list should be the same as "
"the dimension size of input X."); "the dimension size of input X.")
.SetDefault(std::vector<int>());
AddAttr<std::vector<int>>("shape", AddAttr<std::vector<int>>("shape",
"A list<int> describing the shape of output. " "A list<int> describing the shape of output. "
"The size of shape list should be the same as " "The size of shape list should be the same as "
...@@ -77,6 +90,17 @@ Crop Operator. ...@@ -77,6 +90,17 @@ Crop Operator.
Crop input into output, as specified by offsets and shape. Crop input into output, as specified by offsets and shape.
There are two ways to set the offsets:
1. In runtime: Using the input 'Offsets', which is a Vairbale and can be
output of other operators. This way is suitable for
dynamic offsets.
2. In network configuration: Using the attribute 'offsets', which will be
set in Python configure script. This way is
suitable for fixed offsets.
You CANNOT use these two ways at the same time. An exception will be raised
if input 'Offset' is configured and meanwhile the attribute 'offsets' is
not empty.
There are two ways to set shape: There are two ways to set shape:
1. reference input: crop input X into the same shape as reference input. 1. reference input: crop input X into the same shape as reference input.
The dimension of reference input should The dimension of reference input should
...@@ -146,6 +170,15 @@ class CropOpGrad : public framework::OperatorWithKernel { ...@@ -146,6 +170,15 @@ class CropOpGrad : public framework::OperatorWithKernel {
ctx->SetOutputDim(x_grad_name, x_dims); ctx->SetOutputDim(x_grad_name, x_dims);
} }
} }
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"))
->type()),
ctx.device_context());
}
}; };
} // namespace operators } // namespace operators
......
...@@ -27,6 +27,37 @@ template <typename T, size_t D, int MajorType = Eigen::RowMajor, ...@@ -27,6 +27,37 @@ template <typename T, size_t D, int MajorType = Eigen::RowMajor,
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>; using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
using framework::Tensor; using framework::Tensor;
static std::vector<int> GetOffsets(const framework::ExecutionContext& ctx) {
std::vector<int> res;
int rank = ctx.Input<Tensor>("X")->dims().size();
if (ctx.HasInput("Offsets")) {
PADDLE_ENFORCE(ctx.Attr<std::vector<int>>("offsets").empty(),
"Input 'Offsets' and attribute 'offsets' should not be used "
"at the same time.");
const auto* offsets_tensor = ctx.Input<Tensor>("Offsets");
PADDLE_ENFORCE_EQ(offsets_tensor->dims().size(), 1);
PADDLE_ENFORCE_EQ(
rank, offsets_tensor->dims()[0],
"Offsets size should be equal to dimension size of input tensor.");
const int* offsets_data;
framework::Tensor cpu_tmp_tensor;
if (platform::is_cpu_place(offsets_tensor->place())) {
offsets_data = offsets_tensor->data<int>();
} else {
framework::TensorCopySync(*offsets_tensor, platform::CPUPlace(),
&cpu_tmp_tensor);
offsets_data = cpu_tmp_tensor.data<int>();
}
res = std::vector<int>(offsets_data, offsets_data + rank);
} else {
res = ctx.Attr<std::vector<int>>("offsets");
PADDLE_ENFORCE_EQ(
rank, res.size(),
"Offsets size should be equal to dimension size of input tensor.");
}
return res;
}
template <typename T> template <typename T>
class CropKernel : public framework::OpKernel<T> { class CropKernel : public framework::OpKernel<T> {
public: public:
...@@ -37,10 +68,7 @@ class CropKernel : public framework::OpKernel<T> { ...@@ -37,10 +68,7 @@ class CropKernel : public framework::OpKernel<T> {
T* out_data = out->mutable_data<T>(context.GetPlace()); T* out_data = out->mutable_data<T>(context.GetPlace());
auto x_stride = framework::stride(x->dims()); auto x_stride = framework::stride(x->dims());
auto out_stride = framework::stride(out->dims()); auto out_stride = framework::stride(out->dims());
auto offsets = context.Attr<std::vector<int>>("offsets"); auto offsets = GetOffsets(context);
PADDLE_ENFORCE_EQ(
x->dims().size(), static_cast<int64_t>(offsets.size()),
"Offsets size should be equal to dimension size of input tensor.");
int64_t offset = 0; int64_t offset = 0;
for (size_t i = 0; i < offsets.size(); ++i) { for (size_t i = 0; i < offsets.size(); ++i) {
offset += (x_stride[i] * offsets[i]); offset += (x_stride[i] * offsets[i]);
...@@ -56,7 +84,7 @@ void CropGradFunction(const framework::ExecutionContext& context) { ...@@ -56,7 +84,7 @@ void CropGradFunction(const framework::ExecutionContext& context) {
if (d_x != nullptr) { if (d_x != nullptr) {
auto* d_out = context.Input<Tensor>(framework::GradVarName("Out")); auto* d_out = context.Input<Tensor>(framework::GradVarName("Out"));
d_x->mutable_data<T>(context.GetPlace()); d_x->mutable_data<T>(context.GetPlace());
auto offsets = context.Attr<std::vector<int>>("offsets"); auto offsets = GetOffsets(context);
Eigen::array<std::pair<int, int>, D> paddings; Eigen::array<std::pair<int, int>, D> paddings;
for (size_t i = 0; i < D; ++i) { for (size_t i = 0; i < D; ++i) {
paddings[i].first = offsets[i]; paddings[i].first = offsets[i];
......
if(WITH_DISTRIBUTE) if(WITH_DISTRIBUTE)
grpc_library(sendrecvop_grpc SRCS bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc grpc_library(sendrecvop_grpc SRCS bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc
request_handler_impl.cc rpc_server.cc grpc_server.cc variable_response.cc PROTO send_recv.proto DEPS lod_tensor request_handler_impl.cc rpc_client.cc rpc_server.cc grpc_server.cc variable_response.cc PROTO send_recv.proto DEPS lod_tensor
selected_rows memory) selected_rows memory)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set_source_files_properties(serde_test.cc grpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(serde_test.cc grpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
......
...@@ -25,29 +25,15 @@ namespace paddle { ...@@ -25,29 +25,15 @@ namespace paddle {
namespace operators { namespace operators {
namespace detail { namespace detail {
std::once_flag RPCClient::init_flag_; void GRPCClient::InitImpl() { InitEventLoop(); }
std::unique_ptr<RPCClient> RPCClient::rpc_client_(nullptr); void GRPCClient::InitEventLoop() {
RPCClient* RPCClient::GetInstance() {
std::call_once(init_flag_, &RPCClient::Init);
return rpc_client_.get();
}
void RPCClient::Init() {
if (rpc_client_.get() == nullptr) {
rpc_client_.reset(new RPCClient());
}
rpc_client_->InitEventLoop();
}
void RPCClient::InitEventLoop() {
// start the client process thread // start the client process thread
// TODO(wuyi): can make this in a threadpool // TODO(wuyi): can make this in a threadpool
client_thread_.reset(new std::thread(std::bind(&RPCClient::Proceed, this))); client_thread_.reset(new std::thread(std::bind(&GRPCClient::Proceed, this)));
} }
RPCClient::~RPCClient() { GRPCClient::~GRPCClient() {
Wait(); Wait();
cq_.Shutdown(); cq_.Shutdown();
{ {
...@@ -59,11 +45,10 @@ RPCClient::~RPCClient() { ...@@ -59,11 +45,10 @@ RPCClient::~RPCClient() {
client_thread_->join(); client_thread_->join();
} }
bool RPCClient::AsyncSendVariable(const std::string& ep, bool GRPCClient::AsyncSendVar(const std::string& ep,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
const framework::Scope& scope, const framework::Scope& scope,
const std::string& var_name, const std::string& var_name, int64_t time_out) {
int64_t time_out) {
const platform::DeviceContext* p_ctx = &ctx; const platform::DeviceContext* p_ctx = &ctx;
const std::string ep_val = ep; const std::string ep_val = ep;
const std::string var_name_val = var_name; const std::string var_name_val = var_name;
...@@ -113,11 +98,10 @@ void RequestToByteBuffer(const T& proto, ::grpc::ByteBuffer* result) { ...@@ -113,11 +98,10 @@ void RequestToByteBuffer(const T& proto, ::grpc::ByteBuffer* result) {
result->Swap(&tmp); result->Swap(&tmp);
} }
bool RPCClient::AsyncGetVariable(const std::string& ep, bool GRPCClient::AsyncGetVar(const std::string& ep,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
const framework::Scope& scope, const framework::Scope& scope,
const std::string& var_name, const std::string& var_name, int64_t time_out) {
int64_t time_out) {
const platform::DeviceContext* p_ctx = &ctx; const platform::DeviceContext* p_ctx = &ctx;
const std::string ep_val = ep; const std::string ep_val = ep;
const std::string var_name_val = var_name; const std::string var_name_val = var_name;
...@@ -155,7 +139,7 @@ bool RPCClient::AsyncGetVariable(const std::string& ep, ...@@ -155,7 +139,7 @@ bool RPCClient::AsyncGetVariable(const std::string& ep,
return true; return true;
} }
bool RPCClient::AsyncPrefetchVariable(const std::string& ep, bool GRPCClient::AsyncPrefetchVar(const std::string& ep,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
const framework::Scope& scope, const framework::Scope& scope,
const std::string& in_var_name, const std::string& in_var_name,
...@@ -198,7 +182,8 @@ bool RPCClient::AsyncPrefetchVariable(const std::string& ep, ...@@ -198,7 +182,8 @@ bool RPCClient::AsyncPrefetchVariable(const std::string& ep,
return true; return true;
} }
void RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) { void GRPCClient::AsyncSendBatchBarrier(const std::string& ep,
int64_t time_out) {
const auto ch = GetChannel(ep); const auto ch = GetChannel(ep);
BatchBarrierProcessor* s = new BatchBarrierProcessor(ch); BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
...@@ -211,7 +196,8 @@ void RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) { ...@@ -211,7 +196,8 @@ void RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) {
req_count_++; req_count_++;
} }
void RPCClient::AsyncSendFetchBarrier(const std::string& ep, int64_t time_out) { void GRPCClient::AsyncSendFetchBarrier(const std::string& ep,
int64_t time_out) {
const auto ch = GetChannel(ep); const auto ch = GetChannel(ep);
FetchBarrierProcessor* s = new FetchBarrierProcessor(ch); FetchBarrierProcessor* s = new FetchBarrierProcessor(ch);
s->Prepare(time_out); s->Prepare(time_out);
...@@ -223,12 +209,12 @@ void RPCClient::AsyncSendFetchBarrier(const std::string& ep, int64_t time_out) { ...@@ -223,12 +209,12 @@ void RPCClient::AsyncSendFetchBarrier(const std::string& ep, int64_t time_out) {
req_count_++; req_count_++;
} }
void RPCClient::Wait() { void GRPCClient::Wait() {
std::unique_lock<std::mutex> lk(sync_mutex_); std::unique_lock<std::mutex> lk(sync_mutex_);
sync_cond_.wait(lk, [this] { return req_count_ == 0; }); sync_cond_.wait(lk, [this] { return req_count_ == 0; });
} }
void RPCClient::Proceed() { void GRPCClient::Proceed() {
void* tag = nullptr; void* tag = nullptr;
bool ok = false; bool ok = false;
...@@ -251,7 +237,7 @@ void RPCClient::Proceed() { ...@@ -251,7 +237,7 @@ void RPCClient::Proceed() {
} }
} }
std::shared_ptr<grpc::Channel> RPCClient::GetChannel(const std::string& ep) { std::shared_ptr<grpc::Channel> GRPCClient::GetChannel(const std::string& ep) {
// TODO(Yancey1989): make grpc client completely thread-safe // TODO(Yancey1989): make grpc client completely thread-safe
std::lock_guard<std::mutex> guard(chan_mutex_); std::lock_guard<std::mutex> guard(chan_mutex_);
auto it = channels_.find(ep); auto it = channels_.find(ep);
......
...@@ -38,6 +38,7 @@ limitations under the License. */ ...@@ -38,6 +38,7 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/detail/rpc_client.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h" #include "paddle/fluid/operators/detail/sendrecvop_utils.h"
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN #include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
...@@ -164,47 +165,46 @@ class FetchBarrierProcessor : public BaseProcessor { ...@@ -164,47 +165,46 @@ class FetchBarrierProcessor : public BaseProcessor {
std::unique_ptr<sendrecv::SendRecvService::Stub> stub_; std::unique_ptr<sendrecv::SendRecvService::Stub> stub_;
}; };
class RPCClient { class GRPCClient : public RPCClient {
public: public:
RPCClient() {} GRPCClient() {}
~RPCClient(); virtual ~GRPCClient();
static RPCClient* GetInstance(); bool AsyncSendVar(const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& var_name,
int64_t time_out = RPCClient::rpc_time_out) override;
bool AsyncSendVariable(const std::string& ep, bool AsyncGetVar(const std::string& ep, const platform::DeviceContext& ctx,
const platform::DeviceContext& ctx, const framework::Scope& scope, const std::string& var_name,
const framework::Scope& scope, int64_t time_out = RPCClient::rpc_time_out) override;
const std::string& var_name,
int64_t time_out = 600 * 1000);
bool AsyncGetVariable(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name,
int64_t time_out = 600 * 1000);
bool AsyncPrefetchVariable(const std::string& ep, bool AsyncPrefetchVar(const std::string& ep,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
const framework::Scope& scope, const framework::Scope& scope,
const std::string& in_var_name, const std::string& in_var_name,
const std::string& out_var_name, const std::string& out_var_name,
int64_t time_out = 600 * 1000); int64_t time_out = RPCClient::rpc_time_out) override;
void AsyncSendBatchBarrier(const std::string& ep, void AsyncSendBatchBarrier(
int64_t time_out = 600 * 1000); const std::string& ep,
int64_t time_out = RPCClient::rpc_time_out) override;
void AsyncSendFetchBarrier(const std::string& ep, void AsyncSendFetchBarrier(
int64_t time_out = 600 * 1000); const std::string& ep,
int64_t time_out = RPCClient::rpc_time_out) override;
void Wait(); void Wait() override;
protected:
void InitImpl() override;
private:
// InitEventLoop should only be called by Init() // InitEventLoop should only be called by Init()
void InitEventLoop(); void InitEventLoop();
private:
void Proceed(); void Proceed();
std::shared_ptr<grpc::Channel> GetChannel(const std::string& ep); std::shared_ptr<grpc::Channel> GetChannel(const std::string& ep);
// Init is called by GetInstance.
static void Init();
private: private:
grpc::CompletionQueue cq_; grpc::CompletionQueue cq_;
...@@ -218,9 +218,7 @@ class RPCClient { ...@@ -218,9 +218,7 @@ class RPCClient {
// mutex for GetChannel thread safety // mutex for GetChannel thread safety
std::mutex chan_mutex_; std::mutex chan_mutex_;
static std::unique_ptr<RPCClient> rpc_client_; DISABLE_COPY_AND_ASSIGN(GRPCClient);
static std::once_flag init_flag_;
DISABLE_COPY_AND_ASSIGN(RPCClient);
}; };
} // namespace detail } // namespace detail
......
...@@ -19,6 +19,7 @@ limitations under the License. */ ...@@ -19,6 +19,7 @@ limitations under the License. */
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/operators/detail/grpc_client.h" #include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/operators/detail/grpc_server.h" #include "paddle/fluid/operators/detail/grpc_server.h"
#include "paddle/fluid/operators/detail/rpc_client.h"
#include "paddle/fluid/framework/block_desc.h" #include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
...@@ -123,7 +124,8 @@ TEST(PREFETCH, CPU) { ...@@ -123,7 +124,8 @@ TEST(PREFETCH, CPU) {
std::thread server_thread(StartServer); std::thread server_thread(StartServer);
g_rpc_service->WaitServerReady(); g_rpc_service->WaitServerReady();
detail::RPCClient* client = detail::RPCClient::GetInstance(); detail::RPCClient* client =
detail::RPCClient::GetInstance<detail::GRPCClient>();
int port = g_rpc_service->GetSelectedPort(); int port = g_rpc_service->GetSelectedPort();
std::string ep = paddle::string::Sprintf("127.0.0.1:%d", port); std::string ep = paddle::string::Sprintf("127.0.0.1:%d", port);
...@@ -137,7 +139,7 @@ TEST(PREFETCH, CPU) { ...@@ -137,7 +139,7 @@ TEST(PREFETCH, CPU) {
std::string in_var_name("ids"); std::string in_var_name("ids");
std::string out_var_name("out"); std::string out_var_name("out");
client->AsyncPrefetchVariable(ep, ctx, scope, in_var_name, out_var_name); client->AsyncPrefetchVar(ep, ctx, scope, in_var_name, out_var_name);
client->Wait(); client->Wait();
auto var = scope.Var(out_var_name); auto var = scope.Var(out_var_name);
auto value = var->GetMutable<framework::SelectedRows>()->value(); auto value = var->GetMutable<framework::SelectedRows>()->value();
......
...@@ -80,7 +80,6 @@ class RequestHandler { ...@@ -80,7 +80,6 @@ class RequestHandler {
} }
framework::ProgramDesc* program() { return program_; } framework::ProgramDesc* program() { return program_; }
framework::Executor* executor() { return executor_; } framework::Executor* executor() { return executor_; }
std::vector<framework::Variable*>& sparse_vars() { return sparse_vars_; }
// This function processes user's rpc request. // This function processes user's rpc request.
// The implemention is in request_handler_impl. // The implemention is in request_handler_impl.
...@@ -113,13 +112,7 @@ class RequestHandler { ...@@ -113,13 +112,7 @@ class RequestHandler {
std::unordered_map<std::string, std::unordered_map<std::string,
std::shared_ptr<framework::ExecutorPrepareContext>>* std::shared_ptr<framework::ExecutorPrepareContext>>*
grad_to_prepared_ctx_; grad_to_prepared_ctx_;
// Record received sparse variables, so that
// we could reset those after execute optimize program
std::vector<framework::Variable*> sparse_vars_;
RPCServer* rpc_server_; RPCServer* rpc_server_;
std::mutex sparse_var_mutex_;
}; };
} // namespace detail } // namespace detail
......
...@@ -63,16 +63,22 @@ bool RequestSendHandler::Handle(const std::string& varname, ...@@ -63,16 +63,22 @@ bool RequestSendHandler::Handle(const std::string& varname,
PADDLE_THROW("sync: Can not find server side var"); PADDLE_THROW("sync: Can not find server side var");
return false; return false;
} }
if (invar->IsType<framework::SelectedRows>()) { if (invar->IsType<framework::SelectedRows>()) {
std::unique_lock<std::mutex> lock(sparse_var_mutex_); std::unique_lock<std::mutex> lock(mutex_sparse_vars_);
sparse_vars_.push_back(invar); sparse_vars_.push_back(invar);
} }
} }
return true; return true;
} }
void RequestSendHandler::ResetSparseVarRecorder() {
std::unique_lock<std::mutex> lock(mutex_sparse_vars_);
for (auto* var : sparse_vars_) {
var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear();
}
sparse_vars_.clear();
}
bool RequestGetHandler::Handle(const std::string& varname, bool RequestGetHandler::Handle(const std::string& varname,
framework::Scope* scope, framework::Scope* scope,
framework::Variable* invar, framework::Variable* invar,
......
...@@ -41,6 +41,11 @@ class RequestSendHandler final : public RequestHandler { ...@@ -41,6 +41,11 @@ class RequestSendHandler final : public RequestHandler {
virtual ~RequestSendHandler() {} virtual ~RequestSendHandler() {}
bool Handle(const std::string& varname, framework::Scope* scope, bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar) override; framework::Variable* var, framework::Variable** outvar) override;
void ResetSparseVarRecorder();
private:
std::mutex mutex_sparse_vars_;
std::vector<framework::Variable*> sparse_vars_;
}; };
class RequestGetHandler final : public RequestHandler { class RequestGetHandler final : public RequestHandler {
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/detail/rpc_client.h"
namespace paddle {
namespace operators {
namespace detail {
std::once_flag RPCClient::init_flag_;
std::unique_ptr<RPCClient> RPCClient::rpc_client_(nullptr);
} // namespace detail
} // namespace operators
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
namespace paddle {
namespace operators {
namespace detail {
class RPCClient {
public:
virtual bool AsyncSendVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name,
int64_t time_out = rpc_time_out) = 0;
virtual bool AsyncGetVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name,
int64_t time_out = rpc_time_out) = 0;
virtual bool AsyncPrefetchVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& in_var_name,
const std::string& out_var_name,
int64_t time_out = rpc_time_out) = 0;
virtual void AsyncSendBatchBarrier(const std::string& ep,
int64_t time_out = rpc_time_out) = 0;
virtual void AsyncSendFetchBarrier(const std::string& ep,
int64_t time_out = rpc_time_out) = 0;
virtual void Wait() = 0;
static constexpr int64_t rpc_time_out = 120 * 1000;
template <typename T>
static RPCClient* GetInstance() {
std::call_once(init_flag_, &RPCClient::Init<T>);
return rpc_client_.get();
}
// Init is called by GetInstance.
template <typename T>
static void Init() {
if (rpc_client_.get() == nullptr) {
rpc_client_.reset(new T());
rpc_client_->InitImpl();
}
}
protected:
virtual void InitImpl() {}
private:
static std::once_flag init_flag_;
static std::unique_ptr<RPCClient> rpc_client_;
};
} // namespace detail
} // namespace operators
} // namespace paddle
...@@ -60,6 +60,7 @@ class RPCServer { ...@@ -60,6 +60,7 @@ class RPCServer {
void SetCond(const std::string& rpc_name); void SetCond(const std::string& rpc_name);
void WaitCond(const std::string& rpc_name); void WaitCond(const std::string& rpc_name);
void IncreaseBatchBarrier(const std::string rpc_name); void IncreaseBatchBarrier(const std::string rpc_name);
void ResetBarrierCounter(); void ResetBarrierCounter();
protected: protected:
......
...@@ -43,7 +43,7 @@ void FCOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -43,7 +43,7 @@ void FCOp::InferShape(framework::InferShapeContext* ctx) const {
framework::OpKernelType FCOp::GetExpectedKernelType( framework::OpKernelType FCOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
framework::LibraryType library{framework::LibraryType::kMKLDNN}; framework::LibraryType library{framework::LibraryType::kMKLDNN};
framework::DataLayout layout{framework::DataLayout::kAnyLayout}; framework::DataLayout layout{framework::DataLayout::kMKLDNN};
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Input")->type()), ctx.GetPlace(), framework::ToDataType(ctx.Input<Tensor>("Input")->type()), ctx.GetPlace(),
...@@ -65,7 +65,7 @@ void FCOpGrad::InferShape(framework::InferShapeContext* ctx) const { ...@@ -65,7 +65,7 @@ void FCOpGrad::InferShape(framework::InferShapeContext* ctx) const {
framework::OpKernelType FCOpGrad::GetExpectedKernelType( framework::OpKernelType FCOpGrad::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
framework::LibraryType library{framework::LibraryType::kMKLDNN}; framework::LibraryType library{framework::LibraryType::kMKLDNN};
framework::DataLayout layout{framework::DataLayout::kAnyLayout}; framework::DataLayout layout{framework::DataLayout::kMKLDNN};
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Input")->type()), ctx.GetPlace(), framework::ToDataType(ctx.Input<Tensor>("Input")->type()), ctx.GetPlace(),
......
...@@ -21,6 +21,7 @@ limitations under the License. */ ...@@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/grpc_client.h" #include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/operators/detail/rpc_client.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
namespace paddle { namespace paddle {
...@@ -43,7 +44,8 @@ class FetchBarrierOp : public framework::OperatorBase { ...@@ -43,7 +44,8 @@ class FetchBarrierOp : public framework::OperatorBase {
// For profiling // For profiling
platform::RecordEvent record_event(Type(), &ctx); platform::RecordEvent record_event(Type(), &ctx);
auto rpc_client = detail::RPCClient::GetInstance(); detail::RPCClient* rpc_client =
detail::RPCClient::GetInstance<detail::GRPCClient>();
rpc_client->Wait(); rpc_client->Wait();
......
...@@ -43,7 +43,8 @@ TEST(Gather, GatherData) { ...@@ -43,7 +43,8 @@ TEST(Gather, GatherData) {
auto* cpu_place = new paddle::platform::CPUPlace(); auto* cpu_place = new paddle::platform::CPUPlace();
paddle::platform::CPUDeviceContext ctx(*cpu_place); paddle::platform::CPUDeviceContext ctx(*cpu_place);
paddle::operators::CPUGather<int>(ctx, *src, *index, output); paddle::operators::CPUGather<int>(ctx, *src, *index, output);
delete cpu_place;
cpu_place = NULL;
for (int i = 0; i < 4; ++i) EXPECT_EQ(p_output[i], i + 4); for (int i = 0; i < 4; ++i) EXPECT_EQ(p_output[i], i + 4);
for (int i = 4; i < 8; ++i) EXPECT_EQ(p_output[i], i - 4); for (int i = 4; i < 8; ++i) EXPECT_EQ(p_output[i], i - 4);
......
...@@ -61,12 +61,13 @@ class GenNCCLIdOp : public framework::OperatorBase { ...@@ -61,12 +61,13 @@ class GenNCCLIdOp : public framework::OperatorBase {
std::vector<std::string> endpoint_list = std::vector<std::string> endpoint_list =
Attr<std::vector<std::string>>("endpoint_list"); Attr<std::vector<std::string>>("endpoint_list");
detail::RPCClient client; detail::RPCClient* client =
detail::RPCClient::GetInstance<detail::GRPCClient>();
for (auto& ep : endpoint_list) { for (auto& ep : endpoint_list) {
VLOG(3) << "sending nccl id to " << ep; VLOG(3) << "sending nccl id to " << ep;
client.AsyncSendVariable(ep, dev_ctx, *scope, NCCL_ID_VARNAME); client->AsyncSendVar(ep, dev_ctx, *scope, NCCL_ID_VARNAME);
} }
client.Wait(); client->Wait();
VLOG(3) << "sending completed..."; VLOG(3) << "sending completed...";
} }
......
...@@ -108,9 +108,6 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor, ...@@ -108,9 +108,6 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor,
std::shared_ptr<framework::ExecutorPrepareContext>(nullptr)); std::shared_ptr<framework::ExecutorPrepareContext>(nullptr));
rpc_service_->ResetBarrierCounter(); rpc_service_->ResetBarrierCounter();
// Record received sparse variables, so that
// we could reset those after execute optimize program
std::vector<framework::Variable *> sparse_vars;
while (true) { while (true) {
// Get from multiple trainers, we don't care about the order in which // Get from multiple trainers, we don't care about the order in which
// the gradients arrives, just add suffix 0~n and merge the gradient. // the gradients arrives, just add suffix 0~n and merge the gradient.
...@@ -146,18 +143,12 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor, ...@@ -146,18 +143,12 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor,
recv_scope); recv_scope);
VLOG(2) << "run all blocks spent " << detail::GetTimestamp() - ts << "(ms)"; VLOG(2) << "run all blocks spent " << detail::GetTimestamp() - ts << "(ms)";
// Reset the received sparse variables, the sum operator would not
// sum the input sparse variables which rows is empty at the next
// mini-batch.
// TODO(Yancey1989): move the reset action into an operator, we couldn't
// have any hide logic in the operator.
for (framework::Variable *var : sparse_vars) {
var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear();
}
rpc_service_->SetCond(detail::kRequestGet); rpc_service_->SetCond(detail::kRequestGet);
rpc_service_->WaitBarrier(detail::kRequestGet); rpc_service_->WaitBarrier(detail::kRequestGet);
rpc_service_->ResetBarrierCounter(); rpc_service_->ResetBarrierCounter();
// reset received sparse vars to avoid reuse it in the next mini-batch
dynamic_cast<detail::RequestSendHandler *>(request_send_handler_.get())
->ResetSparseVarRecorder();
} // while(true) } // while(true)
} }
......
...@@ -124,16 +124,17 @@ namespace { ...@@ -124,16 +124,17 @@ namespace {
framework::OpKernelType GetExpectedLRNKernel( framework::OpKernelType GetExpectedLRNKernel(
const framework::ExecutionContext& ctx) { const framework::ExecutionContext& ctx) {
framework::LibraryType library_{framework::LibraryType::kPlain}; framework::LibraryType library_{framework::LibraryType::kPlain};
std::string data_format = ctx.Attr<std::string>("data_format");
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain && if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) { platform::CanMKLDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kMKLDNN; library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN;
} }
#endif #endif
std::string data_format = ctx.Attr<std::string>("data_format");
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(), framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(),
layout_, library_); layout_, library_);
......
...@@ -77,6 +77,8 @@ TEST(math_function, gemm_trans_clbas) { ...@@ -77,6 +77,8 @@ TEST(math_function, gemm_trans_clbas) {
paddle::platform::CPUDeviceContext context(*cpu_place); paddle::platform::CPUDeviceContext context(*cpu_place);
GetBlas<float>(context).GEMM(false, true, m, n, k, 1, input1_ptr, 3, GetBlas<float>(context).GEMM(false, true, m, n, k, 1, input1_ptr, 3,
input2_ptr + 3, 3, 1, input3_ptr + 1, 4); input2_ptr + 3, 3, 1, input3_ptr + 1, 4);
delete cpu_place;
cpu_place = NULL;
EXPECT_EQ(input3_ptr[0], 0); EXPECT_EQ(input3_ptr[0], 0);
EXPECT_EQ(input3_ptr[1], 24); EXPECT_EQ(input3_ptr[1], 24);
......
...@@ -24,10 +24,13 @@ using mkldnn::pooling_backward; ...@@ -24,10 +24,13 @@ using mkldnn::pooling_backward;
// Generate keys for storing/retriving primitives for this operator // Generate keys for storing/retriving primitives for this operator
// TODO(jczaja): Make hashing function more optimial // TODO(jczaja): Make hashing function more optimial
static std::string gethash(memory::dims& input_dims, std::string& pooling_type, static std::string gethash(const memory::dims& input_dims,
std::vector<int>& ksize, std::vector<int>& strides, const std::string& pooling_type,
std::vector<int>& paddings, std::string suffix) { const std::vector<int>& ksize,
auto dims2str = [](memory::dims& operand_dims) { const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::string& suffix) {
auto dims2str = [](const memory::dims& operand_dims) {
std::string dstr = ""; std::string dstr = "";
for (size_t i = 0; i < operand_dims.size(); ++i) { for (size_t i = 0; i < operand_dims.size(); ++i) {
dstr += std::to_string(operand_dims[i]) + "-"; dstr += std::to_string(operand_dims[i]) + "-";
......
...@@ -83,6 +83,9 @@ void PoolOp::InferShape(framework::InferShapeContext *ctx) const { ...@@ -83,6 +83,9 @@ void PoolOp::InferShape(framework::InferShapeContext *ctx) const {
framework::OpKernelType PoolOp::GetExpectedKernelType( framework::OpKernelType PoolOp::GetExpectedKernelType(
const framework::ExecutionContext &ctx) const { const framework::ExecutionContext &ctx) const {
framework::LibraryType library_{framework::LibraryType::kPlain}; framework::LibraryType library_{framework::LibraryType::kPlain};
std::string data_format = ctx.Attr<std::string>("data_format");
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
if (platform::CanCUDNNBeUsed(ctx)) { if (platform::CanCUDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kCUDNN; library_ = framework::LibraryType::kCUDNN;
...@@ -92,11 +95,10 @@ framework::OpKernelType PoolOp::GetExpectedKernelType( ...@@ -92,11 +95,10 @@ framework::OpKernelType PoolOp::GetExpectedKernelType(
if (library_ == framework::LibraryType::kPlain && if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) { platform::CanMKLDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kMKLDNN; library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN;
} }
#endif #endif
std::string data_format = ctx.Attr<std::string>("data_format");
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(), framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(),
layout_, library_); layout_, library_);
...@@ -112,6 +114,9 @@ void PoolOpGrad::InferShape(framework::InferShapeContext *ctx) const { ...@@ -112,6 +114,9 @@ void PoolOpGrad::InferShape(framework::InferShapeContext *ctx) const {
framework::OpKernelType PoolOpGrad::GetExpectedKernelType( framework::OpKernelType PoolOpGrad::GetExpectedKernelType(
const framework::ExecutionContext &ctx) const { const framework::ExecutionContext &ctx) const {
framework::LibraryType library_{framework::LibraryType::kPlain}; framework::LibraryType library_{framework::LibraryType::kPlain};
std::string data_format = ctx.Attr<std::string>("data_format");
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
if (platform::CanCUDNNBeUsed(ctx)) { if (platform::CanCUDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kCUDNN; library_ = framework::LibraryType::kCUDNN;
...@@ -121,6 +126,7 @@ framework::OpKernelType PoolOpGrad::GetExpectedKernelType( ...@@ -121,6 +126,7 @@ framework::OpKernelType PoolOpGrad::GetExpectedKernelType(
if (library_ == framework::LibraryType::kPlain && if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) { platform::CanMKLDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kMKLDNN; library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN;
} }
#endif #endif
...@@ -129,8 +135,6 @@ framework::OpKernelType PoolOpGrad::GetExpectedKernelType( ...@@ -129,8 +135,6 @@ framework::OpKernelType PoolOpGrad::GetExpectedKernelType(
PADDLE_ENFORCE_EQ(library_, framework::LibraryType::kCUDNN, PADDLE_ENFORCE_EQ(library_, framework::LibraryType::kCUDNN,
"float16 can only be used when CUDNN is used"); "float16 can only be used when CUDNN is used");
} }
std::string data_format = ctx.Attr<std::string>("data_format");
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_, return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_,
library_); library_);
} }
......
...@@ -41,14 +41,14 @@ class PrefetchOp : public framework::OperatorBase { ...@@ -41,14 +41,14 @@ class PrefetchOp : public framework::OperatorBase {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place); auto& ctx = *pool.Get(place);
auto rpc_client = detail::RPCClient::GetInstance(); detail::RPCClient* rpc_client =
detail::RPCClient::GetInstance<detail::GRPCClient>();
for (size_t i = 0; i < ins.size(); i++) { for (size_t i = 0; i < ins.size(); i++) {
if (NeedSend(scope, ins[i])) { if (NeedSend(scope, ins[i])) {
VLOG(3) << "sending " << ins[i] << " to " << epmap[i] << " to get " VLOG(3) << "sending " << ins[i] << " to " << epmap[i] << " to get "
<< outs[i] << " back"; << outs[i] << " back";
rpc_client->AsyncPrefetchVariable(epmap[i], ctx, scope, ins[i], rpc_client->AsyncPrefetchVar(epmap[i], ctx, scope, ins[i], outs[i]);
outs[i]);
} else { } else {
VLOG(3) << "don't send no-initialied variable: " << ins[i]; VLOG(3) << "don't send no-initialied variable: " << ins[i];
} }
......
...@@ -20,7 +20,6 @@ class RandomCropOp : public framework::OperatorWithKernel { ...@@ -20,7 +20,6 @@ class RandomCropOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
...@@ -36,7 +35,7 @@ class RandomCropOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -36,7 +35,7 @@ class RandomCropOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("Seed", "The random seed."); AddInput("Seed", "The random seed.");
AddOutput("Out", "The cropped instance batch."); AddOutput("Out", "The cropped instance batch.");
AddOutput("SeedOut", "The random seed after random cropping.") AddOutput("SeedOut", "The random seed after random cropping.")
.AsDispensable(); .AsIntermediate();
AddAttr<std::vector<int>>("shape", "The shape of a cropped instance."); AddAttr<std::vector<int>>("shape", "The shape of a cropped instance.");
AddComment(R"DOC( AddComment(R"DOC(
This operator takes a batch of instance, and do random cropping on each instance. This operator takes a batch of instance, and do random cropping on each instance.
......
...@@ -44,11 +44,12 @@ class RecvOp : public framework::OperatorBase { ...@@ -44,11 +44,12 @@ class RecvOp : public framework::OperatorBase {
// For profiling // For profiling
platform::RecordEvent record_event(Type(), &ctx); platform::RecordEvent record_event(Type(), &ctx);
auto rpc_client = detail::RPCClient::GetInstance(); detail::RPCClient* rpc_client =
detail::RPCClient::GetInstance<detail::GRPCClient>();
for (size_t i = 0; i < outs.size(); i++) { for (size_t i = 0; i < outs.size(); i++) {
VLOG(3) << "getting " << outs[i] << " from " << epmap[i]; VLOG(3) << "getting " << outs[i] << " from " << epmap[i];
rpc_client->AsyncGetVariable(epmap[i], ctx, scope, outs[i]); rpc_client->AsyncGetVar(epmap[i], ctx, scope, outs[i]);
} }
if (sync_mode) { if (sync_mode) {
rpc_client->Wait(); rpc_client->Wait();
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/reduce_min_max_op.h"
REGISTER_REDUCE_OP(reduce_max);
REGISTER_OP_CPU_KERNEL(
reduce_max, ops::ReduceKernel<paddle::platform::CPUDeviceContext, float,
ops::MaxFunctor>,
ops::ReduceKernel<paddle::platform::CPUDeviceContext, double,
ops::MaxFunctor>,
ops::ReduceKernel<paddle::platform::CPUDeviceContext, int, ops::MaxFunctor>,
ops::ReduceKernel<paddle::platform::CPUDeviceContext, int64_t,
ops::MaxFunctor>);
REGISTER_OP_CPU_KERNEL(
reduce_max_grad, ops::ReduceGradKernel<paddle::platform::CPUDeviceContext,
float, ops::MaxOrMinGradFunctor>,
ops::ReduceGradKernel<paddle::platform::CPUDeviceContext, double,
ops::MaxOrMinGradFunctor>,
ops::ReduceGradKernel<paddle::platform::CPUDeviceContext, int,
ops::MaxOrMinGradFunctor>,
ops::ReduceGradKernel<paddle::platform::CPUDeviceContext, int64_t,
ops::MaxOrMinGradFunctor>);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/reduce_min_max_op.h"
REGISTER_OP_CUDA_KERNEL(reduce_max,
ops::ReduceKernel<paddle::platform::CUDADeviceContext,
float, ops::MaxFunctor>,
ops::ReduceKernel<paddle::platform::CUDADeviceContext,
double, ops::MaxFunctor>,
ops::ReduceKernel<paddle::platform::CUDADeviceContext,
int, ops::MaxFunctor>,
ops::ReduceKernel<paddle::platform::CUDADeviceContext,
int64_t, ops::MaxFunctor>);
REGISTER_OP_CUDA_KERNEL(
reduce_max_grad, ops::ReduceGradKernel<paddle::platform::CUDADeviceContext,
float, ops::MaxOrMinGradFunctor>,
ops::ReduceGradKernel<paddle::platform::CUDADeviceContext, double,
ops::MaxOrMinGradFunctor>,
ops::ReduceGradKernel<paddle::platform::CUDADeviceContext, int,
ops::MaxOrMinGradFunctor>,
ops::ReduceGradKernel<paddle::platform::CUDADeviceContext, int64_t,
ops::MaxOrMinGradFunctor>);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/reduce_mean_op.h"
REGISTER_REDUCE_OP(reduce_mean);
REGISTER_OP_CPU_KERNEL(reduce_mean,
ops::ReduceKernel<paddle::platform::CPUDeviceContext,
float, ops::MeanFunctor>,
ops::ReduceKernel<paddle::platform::CPUDeviceContext,
double, ops::MeanFunctor>,
ops::ReduceKernel<paddle::platform::CPUDeviceContext,
int, ops::MeanFunctor>,
ops::ReduceKernel<paddle::platform::CPUDeviceContext,
int64_t, ops::MeanFunctor>);
REGISTER_OP_CPU_KERNEL(reduce_mean_grad,
ops::ReduceGradKernel<paddle::platform::CPUDeviceContext,
float, ops::MeanGradFunctor>,
ops::ReduceGradKernel<paddle::platform::CPUDeviceContext,
double, ops::MeanGradFunctor>,
ops::ReduceGradKernel<paddle::platform::CPUDeviceContext,
int, ops::MeanGradFunctor>,
ops::ReduceGradKernel<paddle::platform::CPUDeviceContext,
int64_t, ops::MeanGradFunctor>);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/reduce_mean_op.h"
REGISTER_OP_CUDA_KERNEL(reduce_mean,
ops::ReduceKernel<paddle::platform::CUDADeviceContext,
float, ops::MeanFunctor>,
ops::ReduceKernel<paddle::platform::CUDADeviceContext,
double, ops::MeanFunctor>,
ops::ReduceKernel<paddle::platform::CUDADeviceContext,
int, ops::MeanFunctor>,
ops::ReduceKernel<paddle::platform::CUDADeviceContext,
int64_t, ops::MeanFunctor>);
REGISTER_OP_CUDA_KERNEL(
reduce_mean_grad, ops::ReduceGradKernel<paddle::platform::CUDADeviceContext,
float, ops::MeanGradFunctor>,
ops::ReduceGradKernel<paddle::platform::CUDADeviceContext, double,
ops::MeanGradFunctor>,
ops::ReduceGradKernel<paddle::platform::CUDADeviceContext, int,
ops::MeanGradFunctor>,
ops::ReduceGradKernel<paddle::platform::CUDADeviceContext, int64_t,
ops::MeanGradFunctor>);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/operators/reduce_op.h"
namespace paddle {
namespace operators {
struct MeanFunctor {
template <typename DeviceContext, typename X, typename Y, typename Dim>
void operator()(const DeviceContext& place, X* x, Y* y, const Dim& dim) {
y->device(place) = x->mean(dim);
}
};
struct MeanGradFunctor {
template <typename DeviceContext, typename X, typename Y, typename DX,
typename DY, typename Dim>
void operator()(const DeviceContext& place, X* x, Y* y, DX* dx, DY* dy,
const Dim& dim, int size) {
dx->device(place) = dy->broadcast(dim) / dx->constant(size);
}
};
} // namespace operators
} // namespace paddle
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册