提交 54002c3b 编写于 作者: Y yuyang18

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into feature/polish_doc

...@@ -55,12 +55,13 @@ option(WITH_FLUID_ONLY "Compile PaddlePaddle fluid only" OFF) ...@@ -55,12 +55,13 @@ option(WITH_FLUID_ONLY "Compile PaddlePaddle fluid only" OFF)
option(WITH_GOLANG "Compile PaddlePaddle with GOLANG" OFF) option(WITH_GOLANG "Compile PaddlePaddle with GOLANG" OFF)
option(GLIDE_INSTALL "Download and install go dependencies " ON) option(GLIDE_INSTALL "Download and install go dependencies " ON)
option(USE_NNPACK "Compile PaddlePaddle with NNPACK library" OFF) option(USE_NNPACK "Compile PaddlePaddle with NNPACK library" OFF)
option(WITH_DISTRIBUTE "Compile with grpc distributed support" OFF) option(WITH_DISTRIBUTE "Compile with distributed support" OFF)
option(USE_EIGEN_FOR_BLAS "Use matrix multiplication in Eigen" OFF) option(USE_EIGEN_FOR_BLAS "Use matrix multiplication in Eigen" OFF)
option(EIGEN_USE_THREADS "Compile with multi-threaded Eigen" OFF) option(EIGEN_USE_THREADS "Compile with multi-threaded Eigen" OFF)
option(WITH_ARM_FP16 "Use half precision support on armv8.2-a cpu" OFF) option(WITH_ARM_FP16 "Use half precision support on armv8.2-a cpu" OFF)
option(WITH_FAST_BUNDLE_TEST "Bundle tests that can be run in a single process together to reduce launch overhead" OFF) option(WITH_FAST_BUNDLE_TEST "Bundle tests that can be run in a single process together to reduce launch overhead" OFF)
option(WITH_CONTRIB "Compile the third-party contributation" OFF) option(WITH_CONTRIB "Compile the third-party contributation" OFF)
option(WITH_GRPC "Use grpc as the default rpc framework" ${WITH_DISTRIBUTE})
# CMAKE_BUILD_TYPE # CMAKE_BUILD_TYPE
if(NOT CMAKE_BUILD_TYPE) if(NOT CMAKE_BUILD_TYPE)
...@@ -147,7 +148,16 @@ include(external/any) # download libn::any ...@@ -147,7 +148,16 @@ include(external/any) # download libn::any
include(external/eigen) # download eigen3 include(external/eigen) # download eigen3
include(external/pybind11) # download pybind11 include(external/pybind11) # download pybind11
include(external/cares) include(external/cares)
include(external/grpc)
if(WITH_DISTRIBUTE)
if(WITH_GRPC)
include(external/grpc)
else()
include(external/leveldb)
include(external/brpc)
endif()
endif()
include(external/snappy) # download snappy include(external/snappy) # download snappy
include(external/snappystream) include(external/snappystream)
include(external/threadpool) include(external/threadpool)
......
...@@ -28,6 +28,8 @@ Currently supported `--model` argument include: ...@@ -28,6 +28,8 @@ Currently supported `--model` argument include:
``` ```
You can choose to use GPU/CPU training. With GPU training, you can specify You can choose to use GPU/CPU training. With GPU training, you can specify
`--gpus <gpu_num>` to run multi GPU training. `--gpus <gpu_num>` to run multi GPU training.
You can set async mode parameter server. With async mode, you can specify
`--async_mode` to train model asynchronous.
* Run distributed training with parameter servers: * Run distributed training with parameter servers:
* see [run_fluid_benchmark.sh](https://github.com/PaddlePaddle/Paddle/blob/develop/benchmark/fluid/run_fluid_benchmark.sh) as an example. * see [run_fluid_benchmark.sh](https://github.com/PaddlePaddle/Paddle/blob/develop/benchmark/fluid/run_fluid_benchmark.sh) as an example.
* start parameter servers: * start parameter servers:
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
__all__ = ['parse_args', ]
BENCHMARK_MODELS = [
"machine_translation", "resnet", "vgg", "mnist", "stacked_dynamic_lstm"
]
def parse_args():
parser = argparse.ArgumentParser('Fluid model benchmarks.')
parser.add_argument(
'--model',
type=str,
choices=BENCHMARK_MODELS,
default='resnet',
help='The model to run benchmark with.')
parser.add_argument(
'--batch_size', type=int, default=32, help='The minibatch size.')
# args related to learning rate
parser.add_argument(
'--learning_rate', type=float, default=0.001, help='The learning rate.')
# TODO(wuyi): add "--use_fake_data" option back.
parser.add_argument(
'--skip_batch_num',
type=int,
default=5,
help='The first num of minibatch num to skip, for better performance test'
)
parser.add_argument(
'--iterations', type=int, default=80, help='The number of minibatches.')
parser.add_argument(
'--pass_num', type=int, default=100, help='The number of passes.')
parser.add_argument(
'--data_format',
type=str,
default='NCHW',
choices=['NCHW', 'NHWC'],
help='The data data_format, now only support NCHW.')
parser.add_argument(
'--device',
type=str,
default='GPU',
choices=['CPU', 'GPU'],
help='The device type.')
parser.add_argument(
'--gpus',
type=int,
default=1,
help='If gpus > 1, will use ParallelExecutor to run, else use Executor.')
# this option is available only for vgg and resnet.
parser.add_argument(
'--cpus',
type=int,
default=1,
help='If cpus > 1, will use ParallelDo to run, else use Executor.')
parser.add_argument(
'--data_set',
type=str,
default='flowers',
choices=['cifar10', 'flowers'],
help='Optional dataset for benchmark.')
parser.add_argument(
'--infer_only', action='store_true', help='If set, run forward only.')
parser.add_argument(
'--use_cprof', action='store_true', help='If set, use cProfile.')
parser.add_argument(
'--use_nvprof',
action='store_true',
help='If set, use nvprof for CUDA.')
parser.add_argument(
'--no_test',
action='store_true',
help='If set, do not test the testset during training.')
parser.add_argument(
'--memory_optimize',
action='store_true',
help='If set, optimize runtime memory before start.')
parser.add_argument(
'--use_fake_data',
action='store_true',
help='If set ommit the actual read data operators.')
parser.add_argument(
'--profile', action='store_true', help='If set, profile a few steps.')
parser.add_argument(
'--update_method',
type=str,
default='local',
choices=['local', 'pserver', 'nccl2'],
help='Choose parameter update method, can be local, pserver, nccl2.')
parser.add_argument(
'--no_split_var',
action='store_true',
default=False,
help='Whether split variables into blocks when update_method is pserver')
parser.add_argument(
'--async_mode',
action='store_true',
default=False,
help='Whether start pserver in async mode to support ASGD')
parser.add_argument(
'--use_reader_op',
action='store_true',
help='Whether to use reader op, and must specify the data path if set this to true.'
)
parser.add_argument(
'--data_path',
type=str,
default="",
help='Directory that contains all the training recordio files.')
args = parser.parse_args()
return args
...@@ -24,108 +24,7 @@ import paddle.fluid.core as core ...@@ -24,108 +24,7 @@ import paddle.fluid.core as core
import paddle.fluid.profiler as profiler import paddle.fluid.profiler as profiler
import paddle.fluid.transpiler.distribute_transpiler as distribute_transpiler import paddle.fluid.transpiler.distribute_transpiler as distribute_transpiler
BENCHMARK_MODELS = [ from args import *
"machine_translation", "resnet", "vgg", "mnist", "stacked_dynamic_lstm"
]
def parse_args():
parser = argparse.ArgumentParser('Fluid model benchmarks.')
parser.add_argument(
'--model',
type=str,
choices=BENCHMARK_MODELS,
default='resnet',
help='The model to run benchmark with.')
parser.add_argument(
'--batch_size',
type=int,
default=32,
help='The batch size on each gpu.')
parser.add_argument(
'--learning_rate', type=float, default=0.001, help='The learning rate.')
parser.add_argument(
'--skip_batch_num',
type=int,
default=5,
help='The first num of minibatch num to skip, for better performance test'
)
parser.add_argument(
'--iterations',
type=int,
default=80,
help='The number of minibatches, set to -1 to run all batches.')
parser.add_argument(
'--pass_num', type=int, default=100, help='The number of passes.')
parser.add_argument(
'--data_format',
type=str,
default='NCHW',
choices=['NCHW', 'NHWC'],
help='The data data_format, now only support NCHW.')
parser.add_argument(
'--device',
type=str,
default='GPU',
choices=['CPU', 'GPU'],
help='The device type.')
parser.add_argument(
'--gpus',
type=int,
default=1,
help='If gpus > 1, will use ParallelExecutor to run, else use Executor.')
# this option is available only for vgg and resnet.
parser.add_argument(
'--cpus',
type=int,
default=1,
help='If cpus > 1, will use ParallelDo to run, else use Executor.')
parser.add_argument(
'--data_set',
type=str,
default='flowers',
choices=['cifar10', 'flowers', 'imagenet'],
help='Optional dataset for benchmark.')
parser.add_argument(
'--infer_only', action='store_true', help='If set, run forward only.')
parser.add_argument(
'--use_cprof', action='store_true', help='If set, use cProfile.')
parser.add_argument(
'--use_nvprof',
action='store_true',
help='If set, use nvprof for CUDA.')
parser.add_argument(
'--no_test',
action='store_true',
help='If set, do not test the testset during training.')
parser.add_argument(
'--memory_optimize',
action='store_true',
help='If set, optimize runtime memory before start.')
parser.add_argument(
'--use_fake_data',
action='store_true',
help='If set ommit the actual read data operators.')
parser.add_argument(
'--profile', action='store_true', help='If set, profile a few steps.')
parser.add_argument(
'--update_method',
type=str,
default='local',
choices=['local', 'pserver', 'nccl2'],
help='Choose parameter update method, can be local, pserver, nccl2.')
parser.add_argument(
'--use_reader_op',
action='store_true',
help='Whether to use reader op, and must specify the data path if set this to true.'
)
parser.add_argument(
'--data_path',
type=str,
default="",
help='Directory that contains all the training recordio files.')
args = parser.parse_args()
return args
def append_nccl2_prepare(trainer_id): def append_nccl2_prepare(trainer_id):
...@@ -160,7 +59,7 @@ def append_nccl2_prepare(trainer_id): ...@@ -160,7 +59,7 @@ def append_nccl2_prepare(trainer_id):
"nccl-based dist train.") "nccl-based dist train.")
def dist_transpile(trainer_id): def dist_transpile(trainer_id, args):
if trainer_id < 0: if trainer_id < 0:
return None, None return None, None
...@@ -182,7 +81,12 @@ def dist_transpile(trainer_id): ...@@ -182,7 +81,12 @@ def dist_transpile(trainer_id):
training_role = os.getenv("PADDLE_TRAINING_ROLE") training_role = os.getenv("PADDLE_TRAINING_ROLE")
t = distribute_transpiler.DistributeTranspiler() t = distribute_transpiler.DistributeTranspiler()
t.transpile(trainer_id, pservers=pserver_endpoints, trainers=trainers) t.transpile(
trainer_id,
pservers=pserver_endpoints,
trainers=trainers,
sync_mode=not args.async_mode,
slice_var_up=not args.no_split_var)
if training_role == "PSERVER": if training_role == "PSERVER":
pserver_program = t.get_pserver_program(current_endpoint) pserver_program = t.get_pserver_program(current_endpoint)
pserver_startup_program = t.get_startup_program(current_endpoint, pserver_startup_program = t.get_startup_program(current_endpoint,
...@@ -276,7 +180,7 @@ def train(avg_loss, infer_prog, optimizer, train_reader, test_reader, batch_acc, ...@@ -276,7 +180,7 @@ def train(avg_loss, infer_prog, optimizer, train_reader, test_reader, batch_acc,
print_train_time(start_time, time.time(), num_samples) print_train_time(start_time, time.time(), num_samples)
print("Pass: %d, Loss: %f" % (pass_id, np.mean(train_losses))), print("Pass: %d, Loss: %f" % (pass_id, np.mean(train_losses))),
# evaluation # evaluation
if not args.no_test and batch_acc: if not args.no_test and batch_acc and not args.use_reader_op:
pass_test_acc = test(exe, infer_prog, test_reader, feeder, pass_test_acc = test(exe, infer_prog, test_reader, feeder,
batch_acc) batch_acc)
print(", Test Accuracy: %f" % pass_test_acc) print(", Test Accuracy: %f" % pass_test_acc)
...@@ -373,11 +277,12 @@ def train_parallel(avg_loss, infer_prog, optimizer, train_reader, test_reader, ...@@ -373,11 +277,12 @@ def train_parallel(avg_loss, infer_prog, optimizer, train_reader, test_reader,
batch_id += 1 batch_id += 1
print_train_time(start_time, time.time(), num_samples) print_train_time(start_time, time.time(), num_samples)
if not args.no_test and batch_acc: if not args.no_test and batch_acc and not args.use_reader_op:
# we have not implement record io for test
# skip test when use args.use_reader_op
test_acc = test(startup_exe, infer_prog, test_reader, feeder, test_acc = test(startup_exe, infer_prog, test_reader, feeder,
batch_acc) batch_acc)
print("Pass: %d, Test Accuracy: %f\n" % (pass_id, test_acc)) print("Pass: %d, Test Accuracy: %f\n" % (pass_id, test_acc))
exit(0)
def print_arguments(args): def print_arguments(args):
...@@ -417,7 +322,7 @@ def main(): ...@@ -417,7 +322,7 @@ def main():
fluid.memory_optimize(fluid.default_main_program()) fluid.memory_optimize(fluid.default_main_program())
if args.update_method == "pserver": if args.update_method == "pserver":
train_prog, startup_prog = dist_transpile(trainer_id) train_prog, startup_prog = dist_transpile(trainer_id, args)
if not train_prog: if not train_prog:
raise Exception( raise Exception(
"Must configure correct environments to run dist train.") "Must configure correct environments to run dist train.")
......
...@@ -199,7 +199,10 @@ def get_model(args): ...@@ -199,7 +199,10 @@ def get_model(args):
batched_train_reader = paddle.batch( batched_train_reader = paddle.batch(
paddle.reader.shuffle( paddle.reader.shuffle(
train_reader, buf_size=5120), train_reader, buf_size=5120),
batch_size=args.batch_size * args.gpus) batch_size=args.batch_size * args.gpus,
batched_test_reader = paddle.batch(train_reader, batch_size=args.batch_size) drop_last=True)
batched_test_reader = paddle.batch(
train_reader, batch_size=args.batch_size, drop_last=True)
return avg_cost, inference_program, optimizer, batched_train_reader, batched_test_reader, batch_acc return avg_cost, inference_program, optimizer, batched_train_reader,\
batched_test_reader, batch_acc
...@@ -104,8 +104,9 @@ def get_model(args): ...@@ -104,8 +104,9 @@ def get_model(args):
loss = fluid.layers.mean(x=loss) loss = fluid.layers.mean(x=loss)
# add acc # add acc
batch_size_tensor = fluid.layers.create_tensor(dtype='int64')
batch_acc = fluid.layers.accuracy(input=logit, label=fluid.layers.data(name='label', \ batch_acc = fluid.layers.accuracy(input=logit, label=fluid.layers.data(name='label', \
shape=[1], dtype='int64')) shape=[1], dtype='int64'), total=batch_size_tensor)
inference_program = fluid.default_main_program().clone() inference_program = fluid.default_main_program().clone()
with fluid.program_guard(inference_program): with fluid.program_guard(inference_program):
......
...@@ -82,7 +82,8 @@ def get_model(args): ...@@ -82,7 +82,8 @@ def get_model(args):
data_file, batch_size=args.batch_size)) data_file, batch_size=args.batch_size))
images, label = fluid.layers.read_file(data_file) images, label = fluid.layers.read_file(data_file)
else: else:
images = fluid.layers.data(name='data', shape=dshape, dtype='float32') images = fluid.layers.data(
name='data', shape=data_shape, dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64') label = fluid.layers.data(name='label', shape=[1], dtype='int64')
# Train program # Train program
......
...@@ -118,6 +118,10 @@ endif() ...@@ -118,6 +118,10 @@ endif()
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${SIMD_FLAG}") set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${SIMD_FLAG}")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${SIMD_FLAG}") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${SIMD_FLAG}")
if(WITH_DISTRIBUTE)
add_definitions(-DPADDLE_WITH_DISTRIBUTE)
endif()
if(WITH_GOLANG) if(WITH_GOLANG)
# we need to symlink Paddle directory into GOPATH. If we # we need to symlink Paddle directory into GOPATH. If we
# don't do it and we have code that depends on Paddle, go # don't do it and we have code that depends on Paddle, go
...@@ -166,3 +170,7 @@ if(WITH_GOLANG) ...@@ -166,3 +170,7 @@ if(WITH_GOLANG)
endif() endif()
endif(WITH_GOLANG) endif(WITH_GOLANG)
if(WITH_GRPC)
add_definitions(-DPADDLE_WITH_GRPC)
endif(WITH_GRPC)
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
INCLUDE(ExternalProject)
SET(BRPC_SOURCES_DIR ${THIRD_PARTY_PATH}/brpc)
SET(BRPC_INSTALL_DIR ${THIRD_PARTY_PATH}/install/brpc)
SET(BRPC_INCLUDE_DIR "${BRPC_INSTALL_DIR}/include" CACHE PATH "brpc include directory." FORCE)
SET(BRPC_LIBRARIES "${BRPC_INSTALL_DIR}/lib/libbrpc.a" CACHE FILEPATH "brpc library." FORCE)
INCLUDE_DIRECTORIES(${BRPC_INCLUDE_DIR})
# Reference https://stackoverflow.com/questions/45414507/pass-a-list-of-prefix-paths-to-externalproject-add-in-cmake-args
set(prefix_path "${THIRD_PARTY_PATH}/install/gflags|${THIRD_PARTY_PATH}/install/leveldb|${THIRD_PARTY_PATH}/install/snappy|${THIRD_PARTY_PATH}/install/gtest|${THIRD_PARTY_PATH}/install/protobuf")
# If minimal .a is need, you can set WITH_DEBUG_SYMBOLS=OFF
ExternalProject_Add(
extern_brpc
${EXTERNAL_PROJECT_LOG_ARGS}
GIT_REPOSITORY "https://github.com/brpc/brpc"
GIT_TAG "6d153dd7ff00f960ae6895c9c5fff0ce9f07aff2"
PREFIX ${BRPC_SOURCES_DIR}
UPDATE_COMMAND ""
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
-DCMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS}
-DCMAKE_C_FLAGS=${CMAKE_C_FLAGS}
-DCMAKE_INSTALL_PREFIX=${BRPC_INSTALL_DIR}
-DCMAKE_INSTALL_LIBDIR=${BRPC_INSTALL_DIR}/lib
-DCMAKE_POSITION_INDEPENDENT_CODE=ON
-DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE}
-DCMAKE_PREFIX_PATH=${prefix_path}
-DBRPC_WITH_GLOG=ON
${EXTERNAL_OPTIONAL_ARGS}
LIST_SEPARATOR |
CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${BRPC_INSTALL_DIR}
-DCMAKE_INSTALL_LIBDIR:PATH=${BRPC_INSTALL_DIR}/lib
-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON
-DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE}
)
ADD_DEPENDENCIES(extern_brpc protobuf leveldb gflags glog gtest snappy)
ADD_LIBRARY(brpc STATIC IMPORTED GLOBAL)
SET_PROPERTY(TARGET brpc PROPERTY IMPORTED_LOCATION ${BRPC_LIBRARIES})
ADD_DEPENDENCIES(brpc extern_brpc)
LIST(APPEND external_project_dependencies brpc)
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
INCLUDE(ExternalProject)
SET(LEVELDB_SOURCES_DIR ${THIRD_PARTY_PATH}/leveldb)
SET(LEVELDB_INSTALL_DIR ${THIRD_PARTY_PATH}/install/leveldb)
SET(LEVELDB_INCLUDE_DIR "${LEVELDB_INSTALL_DIR}/include" CACHE PATH "leveldb include directory." FORCE)
SET(LEVELDB_LIBRARIES "${LEVELDB_INSTALL_DIR}/lib/libleveldb.a" CACHE FILEPATH "leveldb library." FORCE)
INCLUDE_DIRECTORIES(${LEVELDB_INCLUDE_DIR})
ExternalProject_Add(
extern_leveldb
${EXTERNAL_PROJECT_LOG_ARGS}
PREFIX ${LEVELDB_SOURCES_DIR}
URL "https://github.com/google/leveldb/archive/v1.18.tar.gz"
URL_MD5 "73770de34a2a5ab34498d2e05b2b7fa0"
CONFIGURE_COMMAND ""
BUILD_COMMAND CXXFLAGS=-fPIC make -j ${NUM_OF_PROCESSOR} libleveldb.a
INSTALL_COMMAND mkdir -p ${LEVELDB_INSTALL_DIR}/lib/
&& cp ${LEVELDB_SOURCES_DIR}/src/extern_leveldb/libleveldb.a ${LEVELDB_LIBRARIES}
&& cp -r ${LEVELDB_SOURCES_DIR}/src/extern_leveldb/include ${LEVELDB_INSTALL_DIR}/
BUILD_IN_SOURCE 1
)
ADD_DEPENDENCIES(extern_leveldb snappy)
ADD_LIBRARY(leveldb STATIC IMPORTED GLOBAL)
SET_PROPERTY(TARGET leveldb PROPERTY IMPORTED_LOCATION ${LEVELDB_LIBRARIES})
ADD_DEPENDENCIES(leveldb extern_leveldb)
LIST(APPEND external_project_dependencies leveldb)
...@@ -610,3 +610,21 @@ function(grpc_library TARGET_NAME) ...@@ -610,3 +610,21 @@ function(grpc_library TARGET_NAME)
COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
cc_library("${TARGET_NAME}" SRCS "${grpc_library_SRCS}" DEPS "${TARGET_NAME}_grpc" "${TARGET_NAME}_proto" "${grpc_library_DEPS}") cc_library("${TARGET_NAME}" SRCS "${grpc_library_SRCS}" DEPS "${TARGET_NAME}_grpc" "${TARGET_NAME}_proto" "${grpc_library_DEPS}")
endfunction() endfunction()
function(brpc_library TARGET_NAME)
set(oneValueArgs PROTO)
set(multiValueArgs SRCS DEPS)
set(options "")
cmake_parse_arguments(brpc_library "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
message(STATUS "generating brpc ${brpc_library_PROTO}")
get_filename_component(ABS_PROTO ${brpc_library_PROTO} ABSOLUTE)
get_filename_component(PROTO_WE ${brpc_library_PROTO} NAME_WE)
get_filename_component(PROTO_PATH ${ABS_PROTO} PATH)
protobuf_generate_cpp(brpc_proto_srcs brpc_proto_hdrs "${ABS_PROTO}")
cc_library("${TARGET_NAME}_proto" SRCS "${brpc_proto_srcs}")
cc_library("${TARGET_NAME}" SRCS "${brpc_library_SRCS}" DEPS "${TARGET_NAME}_proto" "${brpc_library_DEPS}")
endfunction()
# Automatic Differentiation with the Tape
## Automatic Differentiation
A key challenge in the field of deep learning is to automatically derive the backward pass from the forward pass described algorithmically by researchers. Such a derivation, or a transformation of the forward pass program, has been long studied before the recent prosperity of deep learning in the field known as [automatic differentiation](https://arxiv.org/pdf/1502.05767.pdf).
## The Tape
Given the forward pass program (usually in Python in practices), there are two strategies to derive the backward pass:
1. from the forward pass program itself, or
1. from the execution trace of the forward pass program, which is often known as the *tape*.
This article surveys systems that follow the latter strategy.
## Dynamic Network
When we train a deep learning model, the tape changes every iteration as the input data change, so we have to re-derive the backward pass every iteration. This is known as *dynamic network*.
Deep learning systems that utilize the idea of dynamic network gained their popularities in recent years. This article surveys two representative systems: [PyTorch](https://pytorch.org/) and [DyNet](https://dynet.readthedocs.io/en/latest/).
## An Overview
Both frameworks record a ‘tape’ of the computation and interpreting (or run-time compiling) a transformation of the tape played back in reverse. This tape is a different kind of entity than the original program.[[link]](http://www.bcl.hamilton.ie/~barak/papers/toplas-reverse.pdf)
Consider the following code feedforward model.
```python
x = Variable(randn(20, 1)))
label = Variable(randint(1))
W_1, W_2 = Variable(randn(20, 20)), Variable(randn(10, 20))
h = matmul(W_1, x)
pred = matmul(W_2, x)
loss = softmax(pred, label)
loss.backward()
```
### 1) Dynet uses List to encode the Tape
During the forward execution, a list of operators, in this case `matmul`, `matmul` and `softmax`, are recorded in the tape, along with the necessary information needed to do the backward such as pointers to the inputs and outputs. Then the tape is played in reverse order at `loss.backward()`.
<details>
<summary></summary>
digraph g {
graph [
rankdir = "LR"
];
node [
fontsize = "16"
shape = "ellipse"
];
edge [];
"node0" [
label = "<f0> type: matmul | <f1> input: W_1, x | <f2> output: h"
shape = "record"
];
"node1" [
label = "<f0> type: matmul | <f1> input: W_2, h | <f2> output: pred"
shape = "record"
];
"node2" [
label = "<f0> type: softmax | <f1> input: pred, label | <f2> output: loss"
shape = "record"
];
"node0":f0 -> "node1":f0 [];
"node1":f0 -> "node2":f0 [];
}
</details>
![Alt text](https://g.gravizo.com/svg?digraph%20g%20{%20graph%20[%20rankdir%20=%20%22LR%22%20];%20node%20[%20fontsize%20=%20%2216%22%20shape%20=%20%22ellipse%22%20];%20edge%20[];%20%22node0%22%20[%20label%20=%20%22%3Cf0%3E%20type:%20matmul%20|%20%3Cf1%3E%20input:%20W_1,%20x%20|%20%3Cf2%3E%20output:%20h%22%20shape%20=%20%22record%22%20];%20%22node1%22%20[%20label%20=%20%22%3Cf0%3E%20type:%20matmul%20|%20%3Cf1%3E%20input:%20W_2,%20h%20|%20%3Cf2%3E%20output:%20pred%22%20shape%20=%20%22record%22%20];%20%22node2%22%20[%20label%20=%20%22%3Cf0%3E%20type:%20softmax%20|%20%3Cf1%3E%20input:%20pred,%20label%20|%20%3Cf2%3E%20output:%20loss%22%20shape%20=%20%22record%22%20];%20%22node0%22:f0%20-%3E%20%22node1%22:f0%20[%20id%20=%200%20];%20%22node1%22:f0%20-%3E%20%22node2%22:f0%20[%20id%20=%201%20];%20})
### 2) Pytorch uses Node Graph to encode the Tape
The graph is composed of `Variable`s and `Function`s. During the forward execution, a `Variable` records its creator function, e.g. `h.creator = matmul`. And a Function records its inputs' previous/dependent functions `prev_func` through `creator`, e.g. `matmul.prev_func = matmul1`. At `loss.backward()`, a topological sort is performed on all `prev_func`s. Then the grad op is performed by the sorted order.
<details>
<summary></summary>
digraph g {
graph [
rankdir = "LR"
];
subgraph function {
node [
fontsize = "16"
style = filled
shape = "record"
];
"matmul0" [ label = "<f0> type: matmul | prev_func: None" ];
"matmul1" [ label = "<f0> type: matmul | prev_func: matmul" ];
"softmax" [ label = "<f0> type: softmax | prev_func: matmul" ];
}
subgraph variable {
node [
fontsize = "16"
shape = "Mrecord"
style = filled
fillcolor = white
];
"x" [ label = "<f0> x | <f1> creator: None" ];
"label" [ label = "<f0> label | <f1> creator: None" ];
"W_1" [ label = "<f0> W_1 | <f1> creator: None" ];
"W_2" [ label = "<f0> W_2 | <f1> creator: None" ];
"h" [ label = "<f0> h | <f1> creator: None" ];
"pred" [ label = "<f0> pred | <f1> creator: matmul" ];
"loss" [ label = "<f0> loss | <f1> creator: softmax" ];
}
subgraph data_flow {
"x":f0 -> "matmul0":f0;
"W_1":f0 -> "matmul0":f0;
"matmul0":f0 -> "h":f0;
"h":f0 -> "matmul1":f0;
"W_2":f0 -> "matmul1":f0;
"matmul1":f0 -> "pred":f0;
"pred":f0 -> "softmax":f0;
"label":f0 -> "softmax":f0;
"softmax":f0 -> "loss":f0;
}
subgraph prev_func {
edge [color="red", arrowsize="0.6", penwidth="1", constraint=false];
"matmul1":f1 -> "matmul0":f0;
"softmax":f1 -> "matmul1":f0;
label = "prev_func";
}
}
</details>
![Alt text](https://g.gravizo.com/svg?digraph%20g%20{%20graph%20[%20rankdir%20=%20%22LR%22%20];%20subgraph%20function%20{%20node%20[%20fontsize%20=%20%2216%22%20style%20=%20filled%20shape%20=%20%22record%22%20];%20%22matmul0%22%20[%20label%20=%20%22%3Cf0%3E%20type:%20matmul%20|%20prev_func:%20None%22%20];%20%22matmul1%22%20[%20label%20=%20%22%3Cf0%3E%20type:%20matmul%20|%20prev_func:%20matmul%22%20];%20%22softmax%22%20[%20label%20=%20%22%3Cf0%3E%20type:%20softmax%20|%20prev_func:%20matmul%22%20];%20}%20subgraph%20variable%20{%20node%20[%20fontsize%20=%20%2216%22%20shape%20=%20%22Mrecord%22%20style%20=%20filled%20fillcolor%20=%20white%20];%20%22x%22%20[%20label%20=%20%22%3Cf0%3E%20x%20|%20%3Cf1%3E%20creator:%20None%22%20];%20%22label%22%20[%20label%20=%20%22%3Cf0%3E%20label%20|%20%3Cf1%3E%20creator:%20None%22%20];%20%22W_1%22%20[%20label%20=%20%22%3Cf0%3E%20W_1%20|%20%3Cf1%3E%20creator:%20None%22%20];%20%22W_2%22%20[%20label%20=%20%22%3Cf0%3E%20W_2%20|%20%3Cf1%3E%20creator:%20None%22%20];%20%22h%22%20[%20label%20=%20%22%3Cf0%3E%20h%20|%20%3Cf1%3E%20creator:%20None%22%20];%20%22pred%22%20[%20label%20=%20%22%3Cf0%3E%20pred%20|%20%3Cf1%3E%20creator:%20matmul%22%20];%20%22loss%22%20[%20label%20=%20%22%3Cf0%3E%20loss%20|%20%3Cf1%3E%20creator:%20softmax%22%20];%20}%20subgraph%20data_flow%20{%20%22x%22:f0%20-%3E%20%22matmul0%22:f0;%20%22W_1%22:f0%20-%3E%20%22matmul0%22:f0;%20%22matmul0%22:f0%20-%3E%20%22h%22:f0;%20%22h%22:f0%20-%3E%20%22matmul1%22:f0;%20%22W_2%22:f0%20-%3E%20%22matmul1%22:f0;%20%22matmul1%22:f0%20-%3E%20%22pred%22:f0;%20%22pred%22:f0%20-%3E%20%22softmax%22:f0;%20%22label%22:f0%20-%3E%20%22softmax%22:f0;%20%22softmax%22:f0%20-%3E%20%22loss%22:f0;%20}%20subgraph%20prev_func%20{%20edge%20[color=%22red%22,%20arrowsize=%220.6%22,%20penwidth=%221%22,%20constraint=false];%20%22matmul1%22:f1%20-%3E%20%22matmul0%22:f0;%20%22softmax%22:f1%20-%3E%20%22matmul1%22:f0;%20label%20=%20%22prev_func%22;%20}%20})
Chainer and Autograd uses the similar techniques to record the forward pass. For details please refer to the appendix.
## Design choices
### 1) Dynet's List vs Pytorch's Node Graph
What's good about List:
1. It avoids a topological sort. One only needs to traverse the list of operators in reverse and calling the corresponding backward operator.
1. It promises effient data parallelism implementations. One could count the time of usage of a certain variable during the construction list. Then in the play back, one knows the calculation of a variable has completed. This enables communication and computation overlapping.
What's good about Node Graph:
1. More flexibility. PyTorch users can mix and match independent graphs however they like, in whatever threads they like (without explicit synchronization). An added benefit of structuring graphs this way is that when a portion of the graph becomes dead, it is automatically freed. [[2]](https://openreview.net/pdf?id=BJJsrmfCZ) Consider the following example, Pytorch only does backward on SmallNet while Dynet does both BigNet and SmallNet.
```python
result = BigNet(data)
loss = SmallNet(data)
loss.backward()
```
### 2) Dynet's Lazy evaluation vs Pytorch's Immediate evaluation
Dynet builds the list in a symbolic matter. Consider the following example
```python
for epoch in range(num_epochs):
for in_words, out_label in training_data:
dy.renew_cg()
W = dy.parameter(W_p)
b = dy.parameter(b_p)
score_sym = dy.softmax(W*dy.concatenate([E[in_words[0]],E[in_words[1]]])+b)
loss_sym = dy.pickneglogsoftmax(score_sym, out_label)
loss_val = loss_sym.value()
loss_sym.backward()
```
The computation of `lookup`, `concat`, `matmul` and `softmax` didn't happen until the call of `loss_sym.value()`. This defered execution is useful because it allows some graph-like optimization possible, e.g. kernel fusion.
Pytorch chooses immediate evaluation. It avoids ever materializing a "forward graph"/"tape" (no need to explicitly call `dy.renew_cg()` to reset the list), recording only what is necessary to differentiate the computation, i.e. `creator` and `prev_func`.
## What can fluid learn from them?
TBD
# Appendix
### Overview
| Framework | Has Tape | Core in C++ | First Release Date |
|-----------|----------|-------------|--------------------|
| Autograd | No | No | Mar 5, 2015 |
| Chainer | No | No | Jun 5, 2015 |
| Pytorch | No | Yes | Aug 31, 2016 |
| Dynet | Yes | Yes | Oct 12, 2016 |
### Source Code
#### Autograd
[Backward code](https://github.com/HIPS/autograd/blob/442205dfefe407beffb33550846434baa90c4de7/autograd/core.py#L8-L40). In the forward pass, a graph of VJPNode is constructed.
```python
# User API
def make_grad(fun, x):
start_node = VJPNode.new_root()
end_value, end_node = trace(start_node, fun, x)
return backward_pass(g, end_node), end_value
# trace the forward pass by creating VJPNodes
def trace(start_node, fun, x):
with trace_stack.new_trace() as t:
start_box = new_box(x, t, start_node)
end_box = fun(start_box)
return end_box._value, end_box._node
def backward_pass(g, end_node):
outgrads = {end_node : (g, False)}
for node in toposort(end_node):
outgrad = outgrads.pop(node)
ingrads = node.vjp(outgrad[0])
for parent, ingrad in zip(node.parents, ingrads):
outgrads[parent] = add_outgrads(outgrads.get(parent), ingrad)
return outgrad[0]
# Every VJPNode corresponds to a op_grad
class VJPNode(Node):
__slots__ = ['parents', 'vjp']
def __init__(self, value, fun, args, kwargs, parent_argnums, parents):
self.parents = parents
vjpmaker = primitive_vjps[fun]
self.vjp = vjpmaker(parent_argnums, value, args, kwargs)
```
#### Chainer
Example Code
```python
# (1) Function Set definition, creates FunctionNode
model = FunctionSet(
l1=F.Linear(784, 100),
l2=F.Linear(100, 100),
l3=F.Linear(100, 10)).to_gpu()
# (2) Optimizer Setup
opt = optimizers.SGD()
opt.setup(model)
# (3) Forward computation
def forward(x, t):
h1 = F.relu(model.l1(x))
h2 = F.relu(model.l2(h1))
y = model.l3(h2)
return F.softmax_cross_entropy(y, t)
# (4) Training loop
for epoch in xrange(n_epoch):
for i in xrange(0, N, b_size):
x = Variable(to_gpu(...))
t = Variable(to_gpu(...))
opt.zero_grads()
loss = forward(x, t)
loss.backward()
opt.update()
```
In `forward(x, t)`, a graph of [`VariableNode`](https://github.com/chainer/chainer/blob/master/chainer/variable.py#L110) and [`FunctionNode`](https://github.com/chainer/chainer/blob/a69103a4aa59d5b318f39b01dbcb858d465b89cf/chainer/function_node.py#L19) is constructed. Every output's `VariableNode.creator` is pointed to the `FunctionNode`.
```python
class FunctionNode(object):
...
def apply(self, inputs):
outputs = self.forward(inputs)
ret = tuple([variable.Variable(y, requires_grad=requires_grad)
for y in outputs])
# Topological ordering
self.rank = max([x.rank for x in inputs]) if input_vars else 0
# Add backward edges
for y in ret:
y.creator_node = self
self.inputs = tuple([x.node for x in input_vars])
self.outputs = tuple([y.node for y in ret])
return ret
```
`loss.backward()` will calculate the accumulated gradient of all variables. All the backward of `FunctionNode`s will be called based on the topological order.
```python
class VariableNode(object):
...
def backward(self, retain_grad, loss_scale):
if self.creator_node is None:
return
cand_funcs = []
seen_set = set()
grads = {}
# Initialize error by 1, if this is a loss variable
if self.data.size == 1 and self._grad_var is None:
self.grad = numpy.ones_like(self.data)
grads[self._node] = self._grad_var
def add_cand(cand):
if cand not in seen_set:
# Negate since heapq is min-heap. This is a global variable
heapq.heappush(cand_funcs, (-cand.rank, len(seen_set), cand))
seen_set.add(cand)
add_cand(self.creator_node)
while cand_funcs:
_, _, func = heapq.heappop(cand_funcs)
gxs = func.backward_accumulate(func.inputs, func.outputs, func.outputs.grad)
for x, gx in enumerate(gxs):
if x in grads:
grads[x] += gx
else:
grads[x] = gx
if x.creator_node is not None:
add_cand(x.creator_node)
```
#### PyTorch
Example Code
```python
x = Variable(torch.ones(5, 5))
y = Variable(torch.ones(5, 5) * 4)
z = x ** 2 + x * 2 + x * y + y
z.backward(torch.ones(5, 5))
```
The trace is done by `Variable.creator` and `Function.previous_functions`.
```python
class Variable(object):
def __init__(self, tensor, creator=None, requires_grad=True):
if creator is None:
creator = Leaf(self, requires_grad)
self.data = tensor
self.creator = creator
self._grad = None
def backward(self, gradient=None):
if gradient is None:
if self.data.numel() != 1:
raise RuntimeError('backward should be called only on a scalar (i.e. 1-element tensor) or with gradient w.r.t. the variable')
gradient = self.data.new(1).fill_(1)
self._execution_engine.run_backward(self, gradient)
class Function(obejct):
# ...
def _do_forward(self, *input):
unpacked_input = tuple(arg.data for arg in input)
raw_output = self.forward(*unpacked_input)
# mark output.creator = self for backward trace
output = tuple(Variable(tensor, self) for tensor in raw_output)
self.previous_functions = [(arg.creator, id(arg)) for arg in input]
self.output_ids = {id(var): i for i, var in enumerate(output)}
return output
def _do_backward(self, grad_output):
return self.backwaerd(grad_output)
```
The [backward](https://github.com/pytorch/pytorch/blob/v0.1.1/torch/autograd/engine.py) is similar to Autograd.
#### DyNet
Example code
```python
model = dy.model()
W_p = model.add_parameters((20, 100))
b_p = model.add_parameters(20)
E = model.add_lookup_parameters((20000, 50))
for epoch in range(num_epochs):
for in_words, out_label in training_data:
dy.renew_cg() # init tape
W = dy.parameter(W_p)
b = dy.parameter(b_p)
score_sym = dy.softmax(W*dy.concatenate([E[in_words[0]],E[in_words[1]]])+b)
loss_sym = dy.pickneglogsoftmax(score_sym, out_label)
loss_val = loss_sym.value()
loss_sym.backward()
```
[forward](https://github.com/clab/dynet/blob/740a9626a13a2732544de142e256ad0d0a166658/dynet/exec.cc#L84-L158), [backward](https://github.com/clab/dynet/blob/740a9626a13a2732544de142e256ad0d0a166658/dynet/exec.cc#L166-L284). The trace is done by creating a tape of expressions in every iteration. Backward is done by traverse the tape in the reverse order.
```c++
void SimpleExecutionEngine::backward(VariableIndex from_where, bool full) {
...
for (int i = num_nodes - 1; i >= 0; --i) {
// each node corresponds to an op
node->backward(xs, node_fx, node_dEdfx, ai, node_dEdxai);
}
...
}
```
...@@ -109,7 +109,6 @@ void MainWord2Vec(bool use_gpu) { ...@@ -109,7 +109,6 @@ void MainWord2Vec(bool use_gpu) {
void MainImageClassification(bool use_gpu) { void MainImageClassification(bool use_gpu) {
int batch_size = 2; int batch_size = 2;
bool use_mkldnn = false;
bool repeat = false; bool repeat = false;
NativeConfig config = GetConfig(); NativeConfig config = GetConfig();
config.use_gpu = use_gpu; config.use_gpu = use_gpu;
...@@ -134,12 +133,8 @@ void MainImageClassification(bool use_gpu) { ...@@ -134,12 +133,8 @@ void MainImageClassification(bool use_gpu) {
std::vector<framework::LoDTensor*> cpu_fetchs1; std::vector<framework::LoDTensor*> cpu_fetchs1;
cpu_fetchs1.push_back(&output1); cpu_fetchs1.push_back(&output1);
TestInference<platform::CPUPlace, false, true>(config.model_dir, TestInference<platform::CPUPlace, false, true>(
cpu_feeds, config.model_dir, cpu_feeds, cpu_fetchs1, repeat, is_combined);
cpu_fetchs1,
repeat,
is_combined,
use_mkldnn);
auto predictor = CreatePaddlePredictor(config); auto predictor = CreatePaddlePredictor(config);
std::vector<PaddleTensor> paddle_tensor_feeds; std::vector<PaddleTensor> paddle_tensor_feeds;
......
...@@ -83,11 +83,16 @@ cc_library(lod_rank_table SRCS lod_rank_table.cc DEPS lod_tensor) ...@@ -83,11 +83,16 @@ cc_library(lod_rank_table SRCS lod_rank_table.cc DEPS lod_tensor)
cc_library(feed_fetch_method SRCS feed_fetch_method.cc DEPS lod_tensor scope glog) cc_library(feed_fetch_method SRCS feed_fetch_method.cc DEPS lod_tensor scope glog)
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope if(WITH_DISTRIBUTE)
framework_proto glog lod_rank_table feed_fetch_method) cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method sendrecvop_grpc cares grpc++_unsecure grpc_unsecure gpr)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
else()
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method)
endif()
cc_library(parallel_executor SRCS parallel_executor.cc DEPS graph_builder_factory threaded_ssa_graph_executor scope_buffered_ssa_graph_executor) cc_library(parallel_executor SRCS parallel_executor.cc DEPS ssa_graph_builder_factory threaded_ssa_graph_executor scope_buffered_ssa_graph_executor)
cc_library(prune SRCS prune.cc DEPS framework_proto) cc_library(prune SRCS prune.cc DEPS framework_proto)
cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context) cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context)
......
...@@ -8,18 +8,19 @@ cc_library(rpc_op_handle SRCS rpc_op_handle.cc DEPS framework_proto scope place ...@@ -8,18 +8,19 @@ cc_library(rpc_op_handle SRCS rpc_op_handle.cc DEPS framework_proto scope place
cc_library(ssa_graph SRCS ssa_graph.cc DEPS var_handle op_handle_base) cc_library(ssa_graph SRCS ssa_graph.cc DEPS var_handle op_handle_base)
cc_library(ssa_graph_builder SRCS ssa_graph_builder.cc DEPS ssa_graph) cc_library(ssa_graph_builder SRCS ssa_graph_builder.cc DEPS ssa_graph)
cc_library(ssa_graph_printer SRCS ssa_graph_printer.cc DEPS ssa_graph_builder) cc_library(ssa_graph_printer SRCS ssa_graph_printer.cc DEPS ssa_graph_builder)
cc_library(ssa_graph_checker SRCS ssa_graph_checker.cc DEPS ssa_graph_builder)
cc_library(variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows) cc_library(variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows)
if(WITH_GPU) if(WITH_GPU)
nv_library(nccl_all_reduce_op_handle SRCS nccl_all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory nv_library(all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
dynload_cuda variable_visitor) dynload_cuda variable_visitor)
set(multi_devices_graph_builder_deps nccl_all_reduce_op_handle)
nv_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope ddim dynload_cuda) nv_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope ddim dynload_cuda)
nv_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor dynload_cuda) nv_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor dynload_cuda)
else() else()
set(multi_devices_graph_builder_deps) cc_library(all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
variable_visitor)
cc_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope ddim) cc_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope ddim)
cc_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor) cc_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor)
endif() endif()
...@@ -28,10 +29,10 @@ cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope d ...@@ -28,10 +29,10 @@ cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope d
cc_library(fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base scope) cc_library(fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base scope)
cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle
scale_loss_grad_op_handle rpc_op_handle ${multi_devices_graph_builder_deps} reduce_op_handle broadcast_op_handle) scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle)
cc_library(graph_builder_factory SRCS graph_builder_factory.cc DEPS multi_devices_graph_builder ssa_graph_printer) cc_library(ssa_graph_builder_factory SRCS ssa_graph_builder_factory.cc DEPS multi_devices_graph_builder ssa_graph_printer ssa_graph_checker)
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph framework_proto) cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph framework_proto)
cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
......
...@@ -13,25 +13,33 @@ ...@@ -13,25 +13,33 @@
// limitations under the License. // limitations under the License.
#include <algorithm> #include <algorithm>
#include "paddle/fluid/framework/details/all_reduce_op_handle.h"
#include "paddle/fluid/framework/details/container_cast.h" #include "paddle/fluid/framework/details/container_cast.h"
#include "paddle/fluid/framework/details/nccl_all_reduce_op_handle.h"
#include "paddle/fluid/framework/details/reduce_and_gather.h" #include "paddle/fluid/framework/details/reduce_and_gather.h"
#include "paddle/fluid/framework/details/variable_visitor.h" #include "paddle/fluid/framework/details/variable_visitor.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
NCCLAllReduceOpHandle::NCCLAllReduceOpHandle(
const std::vector<Scope *> &local_scopes, #ifdef PADDLE_WITH_CUDA
AllReduceOpHandle::AllReduceOpHandle(const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
const platform::NCCLContextMap &ctxs) const platform::NCCLContextMap *ctxs)
: local_scopes_(local_scopes), places_(places), nccl_ctxs_(ctxs) { : local_scopes_(local_scopes), places_(places), nccl_ctxs_(ctxs) {
if (nccl_ctxs_) {
for (auto &p : places_) { for (auto &p : places_) {
this->dev_ctxes_[p] = nccl_ctxs_.DevCtx(p); this->dev_ctxes_[p] = nccl_ctxs_->DevCtx(p);
}
} }
} }
#else
AllReduceOpHandle::AllReduceOpHandle(const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places)
: local_scopes_(local_scopes), places_(places) {}
#endif
void NCCLAllReduceOpHandle::RunImpl() { void AllReduceOpHandle::RunImpl() {
if (NoDummyInputSize() == 1) { if (NoDummyInputSize() == 1) {
return; // No need to all reduce when GPU count = 1; return; // No need to all reduce when GPU count = 1;
} else { } else {
...@@ -58,6 +66,8 @@ void NCCLAllReduceOpHandle::RunImpl() { ...@@ -58,6 +66,8 @@ void NCCLAllReduceOpHandle::RunImpl() {
} }
if (platform::is_gpu_place(lod_tensors[0]->place())) { if (platform::is_gpu_place(lod_tensors[0]->place())) {
#ifdef PADDLE_WITH_CUDA
PADDLE_ENFORCE(nccl_ctxs_, "nccl_ctxs should not be nullptr.");
int dtype = -1; int dtype = -1;
size_t numel = 0; size_t numel = 0;
std::vector<std::function<void()>> all_reduce_calls; std::vector<std::function<void()>> all_reduce_calls;
...@@ -75,7 +85,7 @@ void NCCLAllReduceOpHandle::RunImpl() { ...@@ -75,7 +85,7 @@ void NCCLAllReduceOpHandle::RunImpl() {
} }
int dev_id = boost::get<platform::CUDAPlace>(p).device; int dev_id = boost::get<platform::CUDAPlace>(p).device;
auto &nccl_ctx = nccl_ctxs_.at(dev_id); auto &nccl_ctx = nccl_ctxs_->at(dev_id);
auto stream = nccl_ctx.stream(); auto stream = nccl_ctx.stream();
auto comm = nccl_ctx.comm_; auto comm = nccl_ctx.comm_;
all_reduce_calls.emplace_back([=] { all_reduce_calls.emplace_back([=] {
...@@ -90,22 +100,25 @@ void NCCLAllReduceOpHandle::RunImpl() { ...@@ -90,22 +100,25 @@ void NCCLAllReduceOpHandle::RunImpl() {
call(); call();
} }
}); });
#else
PADDLE_THROW("Not compiled with CUDA");
#endif
} else { // Special handle CPU only Operator's gradient. Like CRF } else { // Special handle CPU only Operator's gradient. Like CRF
auto &trg = *this->local_scopes_[0] auto &trg = *this->local_scopes_[0]
->FindVar(kLocalExecScopeName) ->FindVar(kLocalExecScopeName)
->Get<Scope *>() ->Get<Scope *>()
->Var() ->FindVar(out_var_handles[0]->name_)
->GetMutable<framework::LoDTensor>(); ->GetMutable<framework::LoDTensor>();
// Reduce All Tensor to trg in CPU // Reduce All Tensor to trg in CPU
ReduceLoDTensor func(lod_tensors, &trg); ReduceLoDTensor func(lod_tensors, &trg);
VisitDataType(ToDataType(lod_tensors[0]->type()), func); VisitDataType(ToDataType(lod_tensors[0]->type()), func);
for (size_t i = 0; i < local_scopes_.size(); ++i) { for (size_t i = 1; i < local_scopes_.size(); ++i) {
auto &scope = auto &scope =
*local_scopes_[i]->FindVar(kLocalExecScopeName)->Get<Scope *>(); *local_scopes_[i]->FindVar(kLocalExecScopeName)->Get<Scope *>();
auto &p = places_[i]; auto &p = places_[i];
auto *var = scope.FindVar(in_var_handles[i]->name_); auto *var = scope.FindVar(out_var_handles[i]->name_);
auto *dev_ctx = dev_ctxes_[p]; auto *dev_ctx = dev_ctxes_[p];
RunAndRecordEvent(p, [&trg, var, dev_ctx, p] { RunAndRecordEvent(p, [&trg, var, dev_ctx, p] {
...@@ -118,7 +131,7 @@ void NCCLAllReduceOpHandle::RunImpl() { ...@@ -118,7 +131,7 @@ void NCCLAllReduceOpHandle::RunImpl() {
} }
} }
std::string NCCLAllReduceOpHandle::Name() const { return "nccl_all_reduce"; } std::string AllReduceOpHandle::Name() const { return "all_reduce"; }
} // namespace details } // namespace details
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -20,17 +20,23 @@ ...@@ -20,17 +20,23 @@
#include "paddle/fluid/framework/details/op_handle_base.h" #include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/nccl_helper.h" #include "paddle/fluid/platform/nccl_helper.h"
#endif
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
struct NCCLAllReduceOpHandle : public OpHandleBase { struct AllReduceOpHandle : public OpHandleBase {
NCCLAllReduceOpHandle(const std::vector<Scope *> &local_scopes, #ifdef PADDLE_WITH_CUDA
AllReduceOpHandle(const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
const platform::NCCLContextMap &ctxs); const platform::NCCLContextMap *ctxs);
#else
AllReduceOpHandle(const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places);
#endif
std::string Name() const override; std::string Name() const override;
// Delay and buffer nccl_all_reduce together can significantly increase // Delay and buffer nccl_all_reduce together can significantly increase
...@@ -43,7 +49,9 @@ struct NCCLAllReduceOpHandle : public OpHandleBase { ...@@ -43,7 +49,9 @@ struct NCCLAllReduceOpHandle : public OpHandleBase {
private: private:
std::vector<Scope *> local_scopes_; std::vector<Scope *> local_scopes_;
std::vector<platform::Place> places_; std::vector<platform::Place> places_;
const platform::NCCLContextMap &nccl_ctxs_; #ifdef PADDLE_WITH_CUDA
const platform::NCCLContextMap *nccl_ctxs_;
#endif
}; };
} // namespace details } // namespace details
......
...@@ -20,7 +20,7 @@ namespace details { ...@@ -20,7 +20,7 @@ namespace details {
struct ExecutionStrategy { struct ExecutionStrategy {
size_t num_threads_{0}; size_t num_threads_{0};
bool use_event_{true}; bool use_cuda_{true};
bool allow_op_delay_{false}; bool allow_op_delay_{false};
size_t num_iteration_per_drop_scope_{100}; size_t num_iteration_per_drop_scope_{100};
}; };
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/details/all_reduce_op_handle.h"
#include "paddle/fluid/framework/details/broadcast_op_handle.h" #include "paddle/fluid/framework/details/broadcast_op_handle.h"
#include "paddle/fluid/framework/details/computation_op_handle.h" #include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h" #include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
...@@ -26,10 +27,6 @@ ...@@ -26,10 +27,6 @@
#include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/framework/details/nccl_all_reduce_op_handle.h"
#endif
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
...@@ -89,7 +86,7 @@ std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainSendVars( ...@@ -89,7 +86,7 @@ std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainSendVars(
for (auto *op : program.Block(0).AllOps()) { for (auto *op : program.Block(0).AllOps()) {
// TODO(Yancey1989): use a graceful method to find send op, // TODO(Yancey1989): use a graceful method to find send op,
// instead of the the hard code string // instead of the the hard code string
if (op->Type() == "send_vars") { if (op->Type() == "send") {
auto op_vars = op->InputArgumentNames(); auto op_vars = op->InputArgumentNames();
send_vars.reserve(send_vars.size() + send_vars.reserve(send_vars.size() +
std::distance(op_vars.begin(), op_vars.end())); std::distance(op_vars.begin(), op_vars.end()));
...@@ -243,7 +240,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -243,7 +240,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
CreateReduceOp(&result, g_name, 0); CreateReduceOp(&result, g_name, 0);
CreateBroadcastOp(&result, g_name, 0); CreateBroadcastOp(&result, g_name, 0);
} else { } else {
InsertNCCLAllReduceOp(&result, g_name); InsertAllReduceOp(&result, g_name);
} }
break; break;
} }
...@@ -286,6 +283,19 @@ bool MultiDevSSAGraphBuilder::IsSparseGradient( ...@@ -286,6 +283,19 @@ bool MultiDevSSAGraphBuilder::IsSparseGradient(
return false; return false;
} }
void MultiDevSSAGraphBuilder::SetCommunicationContext(
OpHandleBase *op_handle, const platform::Place &p) const {
#ifdef PADDLE_WITH_CUDA
if (nccl_ctxs_ == nullptr) {
op_handle->SetDeviceContext(p,
platform::DeviceContextPool::Instance().Get(p));
}
#else
op_handle->SetDeviceContext(p,
platform::DeviceContextPool::Instance().Get(p));
#endif
}
void MultiDevSSAGraphBuilder::CreateBroadcastOp(SSAGraph *result, void MultiDevSSAGraphBuilder::CreateBroadcastOp(SSAGraph *result,
const std::string &p_name, const std::string &p_name,
size_t src_dev_id) const { size_t src_dev_id) const {
...@@ -300,15 +310,12 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(SSAGraph *result, ...@@ -300,15 +310,12 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(SSAGraph *result,
op_handle->AddInput(in); op_handle->AddInput(in);
for (size_t i = 0; i < places_.size(); ++i) { for (size_t i = 0; i < places_.size(); ++i) {
auto &vars = result->vars_.at(i).at(p_name);
auto &p = places_[i]; auto &p = places_[i];
SetCommunicationContext(op_handle, p);
auto &vars = result->vars_.at(i).at(p_name);
auto *out_var = new VarHandle(vars.size(), i, p_name, p); auto *out_var = new VarHandle(vars.size(), i, p_name, p);
vars.emplace_back(out_var); vars.emplace_back(out_var);
op_handle->AddOutput(out_var); op_handle->AddOutput(out_var);
#ifndef ADDLE_WITH_CUDA
op_handle->SetDeviceContext(p,
platform::DeviceContextPool::Instance().Get(p));
#endif
} }
} }
...@@ -320,15 +327,19 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(SSAGraph *result, ...@@ -320,15 +327,19 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(SSAGraph *result,
CreateOpHandleIOs(result, op, dev_id); CreateOpHandleIOs(result, op, dev_id);
} }
void MultiDevSSAGraphBuilder::InsertNCCLAllReduceOp( void MultiDevSSAGraphBuilder::InsertAllReduceOp(SSAGraph *result,
SSAGraph *result, const std::string &og) const { const std::string &og) const {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
result->ops_.emplace_back( result->ops_.emplace_back(
new NCCLAllReduceOpHandle(local_scopes_, places_, *nccl_ctxs_)); new AllReduceOpHandle(local_scopes_, places_, nccl_ctxs_));
#else
result->ops_.emplace_back(new AllReduceOpHandle(local_scopes_, places_));
#endif
auto *op_handle = result->ops_.back().get(); auto *op_handle = result->ops_.back().get();
for (size_t i = 0; i < places_.size(); ++i) { for (size_t i = 0; i < places_.size(); ++i) {
auto &p = places_[i]; auto &p = places_[i];
SetCommunicationContext(op_handle, p);
auto &vars = result->vars_[i][og]; auto &vars = result->vars_[i][og];
PADDLE_ENFORCE(!vars.empty()); PADDLE_ENFORCE(!vars.empty());
auto &prev_grad = vars.back(); auto &prev_grad = vars.back();
...@@ -338,9 +349,6 @@ void MultiDevSSAGraphBuilder::InsertNCCLAllReduceOp( ...@@ -338,9 +349,6 @@ void MultiDevSSAGraphBuilder::InsertNCCLAllReduceOp(
vars.emplace_back(var); vars.emplace_back(var);
op_handle->AddOutput(var); op_handle->AddOutput(var);
} }
#else
PADDLE_ENFORCE("Not implemented");
#endif
} }
bool MultiDevSSAGraphBuilder::IsParameterGradientOnce( bool MultiDevSSAGraphBuilder::IsParameterGradientOnce(
...@@ -379,7 +387,9 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(SSAGraph *result) const { ...@@ -379,7 +387,9 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(SSAGraph *result) const {
for (size_t i = 0; i < places_.size(); ++i) { for (size_t i = 0; i < places_.size(); ++i) {
// Insert ScaleCost OpHandle // Insert ScaleCost OpHandle
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
auto *communication_dev_ctx = nccl_ctxs_->DevCtx(places_[i]); auto *communication_dev_ctx =
nccl_ctxs_ ? nccl_ctxs_->DevCtx(places_[i])
: platform::DeviceContextPool::Instance().Get(places_[i]);
#else #else
auto *communication_dev_ctx = auto *communication_dev_ctx =
platform::DeviceContextPool::Instance().Get(platform::CPUPlace()); platform::DeviceContextPool::Instance().Get(platform::CPUPlace());
...@@ -424,12 +434,9 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(SSAGraph *result, ...@@ -424,12 +434,9 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(SSAGraph *result,
auto *op_handle = result->ops_.back().get(); auto *op_handle = result->ops_.back().get();
for (size_t i = 0; i < places_.size(); ++i) { for (size_t i = 0; i < places_.size(); ++i) {
auto &vars = result->vars_[i][og];
#ifndef PADDLE_WITH_CUDA
auto &p = places_[i]; auto &p = places_[i];
op_handle->SetDeviceContext(p, SetCommunicationContext(op_handle, p);
platform::DeviceContextPool::Instance().Get(p)); auto &vars = result->vars_[i][og];
#endif
PADDLE_ENFORCE(!vars.empty()); PADDLE_ENFORCE(!vars.empty());
auto &prev_grad = vars.back(); auto &prev_grad = vars.back();
op_handle->AddInput(prev_grad.get()); op_handle->AddInput(prev_grad.get());
...@@ -468,17 +475,17 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result, ...@@ -468,17 +475,17 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result,
new RPCOpHandle(op, local_scopes_[0], op.Type(), places_[0])); new RPCOpHandle(op, local_scopes_[0], op.Type(), places_[0]));
if (op.Type() == "send_barrier") { if (op.Type() == "send_barrier") {
ConnectOp(result, result->ops_.back().get(), "send_vars"); ConnectOp(result, result->ops_.back().get(), "send");
} else if (op.Type() == "recv") { } else if (op.Type() == "recv") {
ConnectOp(result, result->ops_.back().get(), "send_barrier"); ConnectOp(result, result->ops_.back().get(), "send_barrier");
} else if (op.Type() == "fetch_barrier") { } else if (op.Type() == "fetch_barrier") {
ConnectOp(result, result->ops_.back().get(), "recv"); ConnectOp(result, result->ops_.back().get(), "recv");
} else if (op.Type() == "send_vars") { } else if (op.Type() == "send") {
// do nothing // do nothing
} else { } else {
PADDLE_THROW( PADDLE_THROW(
"rpc op should be in [" "rpc op should be in ["
"send_vars, send_barrier. recv, fetch_barrier]"); "send, send_barrier. recv, fetch_barrier]");
} }
// TODO(Yancey1989): schedule rpc op on different place may // TODO(Yancey1989): schedule rpc op on different place may
......
...@@ -100,7 +100,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { ...@@ -100,7 +100,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
const std::vector<std::unordered_set<std::string>> &var_name_on_devices, const std::vector<std::unordered_set<std::string>> &var_name_on_devices,
const OpDesc &op) const; const OpDesc &op) const;
void InsertNCCLAllReduceOp(SSAGraph *result, const std::string &og) const; void InsertAllReduceOp(SSAGraph *result, const std::string &og) const;
void CreateBroadcastOp(SSAGraph *result, const std::string &p_name, void CreateBroadcastOp(SSAGraph *result, const std::string &p_name,
size_t src_dev_id) const; size_t src_dev_id) const;
...@@ -111,6 +111,9 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { ...@@ -111,6 +111,9 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
private: private:
BuildStrategy strategy_; BuildStrategy strategy_;
void SetCommunicationContext(OpHandleBase *op_handle,
const platform::Place &p) const;
}; };
} // namespace details } // namespace details
} // namespace framework } // namespace framework
......
...@@ -39,9 +39,9 @@ OpHandleBase::~OpHandleBase() { ...@@ -39,9 +39,9 @@ OpHandleBase::~OpHandleBase() {
#endif #endif
} }
void OpHandleBase::Run(bool use_event) { void OpHandleBase::Run(bool use_cuda) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
if (events_.empty() && use_event) { if (events_.empty() && use_cuda) {
for (auto &p : dev_ctxes_) { for (auto &p : dev_ctxes_) {
int dev_id = boost::get<platform::CUDAPlace>(p.first).device; int dev_id = boost::get<platform::CUDAPlace>(p.first).device;
PADDLE_ENFORCE(cudaSetDevice(dev_id)); PADDLE_ENFORCE(cudaSetDevice(dev_id));
...@@ -50,7 +50,7 @@ void OpHandleBase::Run(bool use_event) { ...@@ -50,7 +50,7 @@ void OpHandleBase::Run(bool use_event) {
} }
} }
#else #else
PADDLE_ENFORCE(!use_event); PADDLE_ENFORCE(!use_cuda);
#endif #endif
RunImpl(); RunImpl();
......
...@@ -36,7 +36,7 @@ class OpHandleBase { ...@@ -36,7 +36,7 @@ class OpHandleBase {
virtual std::string Name() const = 0; virtual std::string Name() const = 0;
void Run(bool use_event); void Run(bool use_cuda);
virtual void RecordWaitEventOnCtx(platform::DeviceContext *waited_ctx); virtual void RecordWaitEventOnCtx(platform::DeviceContext *waited_ctx);
......
...@@ -37,7 +37,9 @@ struct ReduceLoDTensor { ...@@ -37,7 +37,9 @@ struct ReduceLoDTensor {
PADDLE_ENFORCE_NE(t0.numel(), 0); PADDLE_ENFORCE_NE(t0.numel(), 0);
dst_tensor_.Resize(t0.dims()); dst_tensor_.Resize(t0.dims());
T *dst = dst_tensor_.mutable_data<T>(platform::CPUPlace()); T *dst = dst_tensor_.mutable_data<T>(platform::CPUPlace());
if (dst != t0.data<T>()) {
std::copy(t0.data<T>(), t0.data<T>() + t0.numel(), dst); std::copy(t0.data<T>(), t0.data<T>() + t0.numel(), dst);
}
for (size_t i = 1; i < src_tensors_.size(); ++i) { for (size_t i = 1; i < src_tensors_.size(); ++i) {
auto &t = *src_tensors_[i]; auto &t = *src_tensors_[i];
......
...@@ -11,8 +11,8 @@ ...@@ -11,8 +11,8 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/details/ssa_graph_builder.h" #include "paddle/fluid/framework/details/ssa_graph_builder.h"
#include <utility>
namespace paddle { namespace paddle {
namespace framework { namespace framework {
......
...@@ -12,9 +12,10 @@ ...@@ -12,9 +12,10 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/details/graph_builder_factory.h" #include "paddle/fluid/framework/details/ssa_graph_builder_factory.h"
#include <fstream> #include <fstream>
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h" #include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
#include "paddle/fluid/framework/details/ssa_graph_checker.h"
#include "paddle/fluid/framework/details/ssa_graph_printer.h" #include "paddle/fluid/framework/details/ssa_graph_printer.h"
namespace paddle { namespace paddle {
...@@ -40,6 +41,8 @@ std::unique_ptr<SSAGraphBuilder> SSAGraphBuilderFactory::Create() { ...@@ -40,6 +41,8 @@ std::unique_ptr<SSAGraphBuilder> SSAGraphBuilderFactory::Create() {
res.reset(new SSAGraghBuilderWithPrinter( res.reset(new SSAGraghBuilderWithPrinter(
std::move(fout), std::move(graphviz_printer), std::move(res))); std::move(fout), std::move(graphviz_printer), std::move(res)));
} }
res.reset(new SSAGraghBuilderWithChecker(std::move(res)));
return res; return res;
} }
} // namespace details } // namespace details
......
...@@ -40,7 +40,11 @@ class SSAGraphBuilderFactory { ...@@ -40,7 +40,11 @@ class SSAGraphBuilderFactory {
loss_var_name_(loss_var_name), loss_var_name_(loss_var_name),
param_names_(param_names), param_names_(param_names),
local_scopes_(local_scopes), local_scopes_(local_scopes),
strategy_(strategy) {} strategy_(strategy) {
#ifdef PADDLE_WITH_CUDA
nccl_ctxs_ = nullptr;
#endif
}
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
void SetNCCLContextMap(platform::NCCLContextMap* nccl_ctxs) { void SetNCCLContextMap(platform::NCCLContextMap* nccl_ctxs) {
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/ssa_graph.h"
#include <string>
#include "paddle/fluid/framework/details/ssa_graph_checker.h"
namespace paddle {
namespace framework {
namespace details {
bool SSAGraghBuilderWithChecker::IsValidGraph(const SSAGraph *graph) const {
std::unordered_map<OpHandleBase *, size_t> pending_ops;
std::unordered_set<VarHandleBase *> pending_vars;
std::unordered_set<VarHandleBase *> ready_vars;
std::unordered_set<OpHandleBase *> ready_ops;
auto insert_pending_var = [&](VarHandleBase *var) {
pending_vars.insert(var);
if (var->generated_op_ == nullptr) {
ready_vars.emplace(var);
}
};
for (auto &var_map : graph->vars_) {
for (auto &name_pair : var_map) {
for (auto &version_pair : name_pair.second) {
insert_pending_var(version_pair.get());
}
}
}
for (auto &var : graph->dep_vars_) {
insert_pending_var(var.get());
}
for (auto &op : graph->ops_) {
if (op->Inputs().empty()) {
ready_ops.insert(op.get());
} else {
pending_ops.insert({op.get(), op.get()->NoDupInputSize()});
}
}
auto run_all_ops = [&](std::unordered_set<OpHandleBase *> &set) {
for (auto *op : set) {
for (auto out : op->Outputs()) {
ready_vars.emplace(out);
}
}
set.clear();
};
while (!pending_vars.empty()) {
run_all_ops(ready_ops);
if (ready_vars.empty()) {
return false;
}
for (auto ready_var : ready_vars) {
pending_vars.erase(ready_var);
for (auto *op : ready_var->pending_ops_) {
auto &deps = --pending_ops[op];
if (deps == 0) {
ready_ops.insert(op);
}
}
}
ready_vars.clear();
}
return true;
}
} // namespace details
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/details/ssa_graph_builder.h"
namespace paddle {
namespace framework {
namespace details {
struct SSAGraph;
class SSAGraghBuilderWithChecker : public SSAGraphBuilder {
public:
explicit SSAGraghBuilderWithChecker(
std::unique_ptr<SSAGraphBuilder>&& builder)
: builder_(std::move(builder)) {}
std::unique_ptr<SSAGraph> Build(const ProgramDesc& program) const override {
auto graph = builder_->Build(program);
PADDLE_ENFORCE(IsValidGraph(graph.get()));
return graph;
}
bool IsValidGraph(const SSAGraph* graph) const;
private:
std::unique_ptr<SSAGraphBuilder> builder_;
};
} // namespace details
} // namespace framework
} // namespace paddle
...@@ -185,6 +185,7 @@ void ThreadedSSAGraphExecutor::InsertPendingVar( ...@@ -185,6 +185,7 @@ void ThreadedSSAGraphExecutor::InsertPendingVar(
ready_vars->Push(var); ready_vars->Push(var);
} }
} }
void ThreadedSSAGraphExecutor::RunOp( void ThreadedSSAGraphExecutor::RunOp(
BlockingQueue<VarHandleBase *> *ready_var_q, details::OpHandleBase *op) { BlockingQueue<VarHandleBase *> *ready_var_q, details::OpHandleBase *op) {
auto op_run = [ready_var_q, op, this] { auto op_run = [ready_var_q, op, this] {
...@@ -192,7 +193,7 @@ void ThreadedSSAGraphExecutor::RunOp( ...@@ -192,7 +193,7 @@ void ThreadedSSAGraphExecutor::RunOp(
if (VLOG_IS_ON(10)) { if (VLOG_IS_ON(10)) {
VLOG(10) << op << " " << op->Name() << " : " << op->DebugString(); VLOG(10) << op << " " << op->Name() << " : " << op->DebugString();
} }
op->Run(strategy_.use_event_); op->Run(strategy_.use_cuda_);
VLOG(10) << op << " " << op->Name() << " Done "; VLOG(10) << op << " " << op->Name() << " Done ";
running_ops_--; running_ops_--;
ready_var_q->Extend(op->Outputs()); ready_var_q->Extend(op->Outputs());
......
...@@ -20,10 +20,14 @@ limitations under the License. */ ...@@ -20,10 +20,14 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor_array.h" #include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/reader.h" #include "paddle/fluid/framework/reader.h"
#ifdef PADDLE_WITH_DISTRIBUTE
#include "paddle/fluid/operators/detail/grpc_client.h"
#endif
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
DECLARE_bool(benchmark); DECLARE_bool(benchmark);
DEFINE_bool(use_mkldnn, false, "Use MKLDNN to run");
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -43,6 +47,14 @@ ExecutorPrepareContext::~ExecutorPrepareContext() { ...@@ -43,6 +47,14 @@ ExecutorPrepareContext::~ExecutorPrepareContext() {
Executor::Executor(const platform::Place& place) : place_(place) {} Executor::Executor(const platform::Place& place) : place_(place) {}
#ifdef PADDLE_WITH_DISTRIBUTE
void Executor::Complete() {
::paddle::operators::detail::RPCClient::GetInstance<
::paddle::operators::detail::GRPCClient>()
->SendComplete();
}
#endif
void InitializeVariable(Variable* var, proto::VarType::Type var_type) { void InitializeVariable(Variable* var, proto::VarType::Type var_type) {
if (var_type == proto::VarType::LOD_TENSOR) { if (var_type == proto::VarType::LOD_TENSOR) {
var->GetMutable<LoDTensor>(); var->GetMutable<LoDTensor>();
...@@ -115,6 +127,7 @@ void Executor::CreateVariables(const ProgramDesc& pdesc, Scope* scope, ...@@ -115,6 +127,7 @@ void Executor::CreateVariables(const ProgramDesc& pdesc, Scope* scope,
void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id, void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
bool create_local_scope, bool create_vars) { bool create_local_scope, bool create_vars) {
platform::RecordBlock b(block_id); platform::RecordBlock b(block_id);
if (FLAGS_use_mkldnn) EnableMKLDNN(pdesc);
auto ctx = Prepare(pdesc, block_id); auto ctx = Prepare(pdesc, block_id);
RunPreparedContext(ctx.get(), scope, create_local_scope, create_vars); RunPreparedContext(ctx.get(), scope, create_local_scope, create_vars);
} }
...@@ -214,6 +227,7 @@ void Executor::Run(const ProgramDesc& program, Scope* scope, ...@@ -214,6 +227,7 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
const std::string& feed_holder_name, const std::string& feed_holder_name,
const std::string& fetch_holder_name) { const std::string& fetch_holder_name) {
platform::RecordBlock b(kProgramId); platform::RecordBlock b(kProgramId);
if (FLAGS_use_mkldnn) EnableMKLDNN(program);
bool has_feed_ops = bool has_feed_ops =
has_feed_operators(program.Block(0), *feed_targets, feed_holder_name); has_feed_operators(program.Block(0), *feed_targets, feed_holder_name);
bool has_fetch_ops = bool has_fetch_ops =
...@@ -225,7 +239,6 @@ void Executor::Run(const ProgramDesc& program, Scope* scope, ...@@ -225,7 +239,6 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
unique_ptr_of_copy_program.reset(new ProgramDesc(program)); unique_ptr_of_copy_program.reset(new ProgramDesc(program));
copy_program = unique_ptr_of_copy_program.get(); copy_program = unique_ptr_of_copy_program.get();
} }
auto* global_block = copy_program->MutableBlock(0); auto* global_block = copy_program->MutableBlock(0);
if (!has_feed_ops) { if (!has_feed_ops) {
...@@ -378,5 +391,19 @@ void Executor::RunPreparedContext( ...@@ -378,5 +391,19 @@ void Executor::RunPreparedContext(
} }
} }
void Executor::EnableMKLDNN(const ProgramDesc& program) {
#ifdef PADDLE_WITH_MKLDNN
VLOG(3) << "use_mkldnn=True";
for (size_t bid = 0; bid < program.Size(); ++bid) {
auto* block = const_cast<ProgramDesc&>(program).MutableBlock(bid);
for (auto* op : block->AllOps()) {
if (op->HasAttr("use_mkldnn")) {
op->SetAttr("use_mkldnn", true);
}
}
}
#endif
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -44,6 +44,13 @@ class Executor { ...@@ -44,6 +44,13 @@ class Executor {
explicit Executor(const platform::Place& place); explicit Executor(const platform::Place& place);
#ifdef PADDLE_WITH_DISTRIBUTE
/*
* Sending signal to pserver to mark current trainer stop.
*/
void Complete();
#endif
/* @Brief /* @Brief
* Runtime evaluation of the given ProgramDesc under certain Scope * Runtime evaluation of the given ProgramDesc under certain Scope
* *
...@@ -81,6 +88,8 @@ class Executor { ...@@ -81,6 +88,8 @@ class Executor {
const std::string& feed_holder_name = "feed", const std::string& feed_holder_name = "feed",
const std::string& fetch_holder_name = "fetch"); const std::string& fetch_holder_name = "fetch");
void EnableMKLDNN(const ProgramDesc& program);
private: private:
const platform::Place place_; const platform::Place place_;
}; };
......
...@@ -71,6 +71,7 @@ message OpProto { ...@@ -71,6 +71,7 @@ message OpProto {
optional bool duplicable = 3 [ default = false ]; optional bool duplicable = 3 [ default = false ];
optional bool intermediate = 4 [ default = false ]; optional bool intermediate = 4 [ default = false ];
optional bool dispensable = 5 [ default = false ]; optional bool dispensable = 5 [ default = false ];
optional string reuse = 6;
} }
// AttrProto describes the C++ type Attribute. // AttrProto describes the C++ type Attribute.
......
...@@ -17,12 +17,11 @@ limitations under the License. */ ...@@ -17,12 +17,11 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
static OpInfoMap* g_op_info_map = nullptr; // C++11 removes the need for manual locking. Concurrent execution shall wait if
// a static local variable is already being initialized.
// https://stackoverflow.com/questions/11711920/how-to-implement-multithread-safe-singleton-in-c11-without-using-mutex
OpInfoMap& OpInfoMap::Instance() { OpInfoMap& OpInfoMap::Instance() {
if (g_op_info_map == nullptr) { static OpInfoMap* g_op_info_map = new OpInfoMap();
g_op_info_map = new OpInfoMap();
}
return *g_op_info_map; return *g_op_info_map;
} }
} // namespace framework } // namespace framework
......
...@@ -21,6 +21,7 @@ namespace framework { ...@@ -21,6 +21,7 @@ namespace framework {
void OpProtoAndCheckerMaker::Validate() { void OpProtoAndCheckerMaker::Validate() {
validated_ = true; validated_ = true;
CheckNoDuplicatedInOutAttrs(); CheckNoDuplicatedInOutAttrs();
CheckReuseVars();
} }
OpProtoAndCheckerMaker::VariableBuilder OpProtoAndCheckerMaker::AddInput( OpProtoAndCheckerMaker::VariableBuilder OpProtoAndCheckerMaker::AddInput(
...@@ -56,6 +57,24 @@ void OpProtoAndCheckerMaker::CheckNoDuplicatedInOutAttrs() { ...@@ -56,6 +57,24 @@ void OpProtoAndCheckerMaker::CheckNoDuplicatedInOutAttrs() {
} }
} }
void OpProtoAndCheckerMaker::CheckReuseVars() {
std::unordered_set<std::string> names;
for (auto& input : proto_->inputs()) {
names.insert(input.name());
}
auto checker = [&](const std::string& name, const std::string& reused) {
PADDLE_ENFORCE(
names.count(reused),
"Output [%s] reuse Input [%s], but the input is not registered.", name,
reused);
};
for (auto& output : proto_->outputs()) {
if (output.has_reuse()) {
checker(output.name(), output.reuse());
}
}
}
void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto, void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto,
OpAttrChecker* attr_checker) { OpAttrChecker* attr_checker) {
proto_ = proto; proto_ = proto;
......
...@@ -14,6 +14,8 @@ limitations under the License. */ ...@@ -14,6 +14,8 @@ limitations under the License. */
#pragma once #pragma once
#include <string> #include <string>
#include <unordered_set>
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/fluid/framework/attribute.h" #include "paddle/fluid/framework/attribute.h"
#include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/framework.pb.h"
...@@ -64,6 +66,11 @@ class OpProtoAndCheckerMaker { ...@@ -64,6 +66,11 @@ class OpProtoAndCheckerMaker {
var_->set_dispensable(true); var_->set_dispensable(true);
return *this; return *this;
} }
VariableBuilder &Reuse(const std::string &name) {
var_->set_reuse(name);
return *this;
}
}; };
VariableBuilder AddInput(const std::string &name, const std::string &comment); VariableBuilder AddInput(const std::string &name, const std::string &comment);
...@@ -89,6 +96,8 @@ class OpProtoAndCheckerMaker { ...@@ -89,6 +96,8 @@ class OpProtoAndCheckerMaker {
void CheckNoDuplicatedInOutAttrs(); void CheckNoDuplicatedInOutAttrs();
void Validate(); void Validate();
void CheckReuseVars();
proto::OpProto *proto_; proto::OpProto *proto_;
OpAttrChecker *op_checker_; OpAttrChecker *op_checker_;
bool validated_{false}; bool validated_{false};
......
...@@ -47,3 +47,23 @@ TEST(ProtoMaker, DuplicatedInOut) { ...@@ -47,3 +47,23 @@ TEST(ProtoMaker, DuplicatedInOut) {
ASSERT_THROW(proto_maker(&op_proto, &op_checker), ASSERT_THROW(proto_maker(&op_proto, &op_checker),
paddle::platform::EnforceNotMet); paddle::platform::EnforceNotMet);
} }
class TestInplaceProtoMaker : public paddle::framework::OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "input of test op");
AddOutput("XOut", "output of test op").Reuse("X");
AddOutput("NoOut", "output of test op").Reuse("NotExists");
}
};
TEST(ProtoMaker, InplaceOutput) {
paddle::framework::proto::OpProto op_proto;
paddle::framework::OpAttrChecker op_checker;
TestInplaceProtoMaker proto_maker;
ASSERT_THROW(proto_maker(&op_proto, &op_checker),
paddle::platform::EnforceNotMet);
// proto_maker(&op_proto, &op_checker);
// proto_maker.Make();
// ASSERT_THROW(proto_maker.Validate(), paddle::platform::EnforceNotMet);
}
...@@ -22,8 +22,8 @@ limitations under the License. */ ...@@ -22,8 +22,8 @@ limitations under the License. */
#include "paddle/fluid/platform/nccl_helper.h" #include "paddle/fluid/platform/nccl_helper.h"
#endif #endif
#include "paddle/fluid/framework/details/graph_builder_factory.h"
#include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h" #include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/ssa_graph_builder_factory.h"
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h" #include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
...@@ -43,7 +43,8 @@ class ParallelExecutorPrivate { ...@@ -43,7 +43,8 @@ class ParallelExecutorPrivate {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
std::unique_ptr<platform::NCCLContextMap> nccl_ctxs_; std::unique_ptr<platform::NCCLContextMap> nccl_ctxs_;
#endif #endif
bool own_local_scope; bool own_local_scope_;
bool use_cuda_;
}; };
std::vector<Scope *> &ParallelExecutor::GetLocalScopes() { std::vector<Scope *> &ParallelExecutor::GetLocalScopes() {
...@@ -60,23 +61,25 @@ ParallelExecutor::ParallelExecutor( ...@@ -60,23 +61,25 @@ ParallelExecutor::ParallelExecutor(
size_t num_trainers, size_t trainer_id) size_t num_trainers, size_t trainer_id)
: member_(new ParallelExecutorPrivate(places)) { : member_(new ParallelExecutorPrivate(places)) {
member_->global_scope_ = scope; member_->global_scope_ = scope;
member_->use_cuda_ = exec_strategy.use_cuda_;
// Step 1. Bcast the params to devs. // Step 1. Bcast the params to devs.
// Create local scopes // Create local scopes
if (local_scopes.empty()) { if (local_scopes.empty()) {
member_->own_local_scope = true; member_->own_local_scope_ = true;
member_->local_scopes_.emplace_back(member_->global_scope_); member_->local_scopes_.emplace_back(member_->global_scope_);
for (size_t i = 1; i < member_->places_.size(); ++i) { for (size_t i = 1; i < member_->places_.size(); ++i) {
member_->local_scopes_.emplace_back(&scope->NewScope()); member_->local_scopes_.emplace_back(&scope->NewScope());
} }
} else { } else {
member_->own_local_scope = false; member_->own_local_scope_ = false;
PADDLE_ENFORCE_EQ(member_->places_.size(), local_scopes.size()); PADDLE_ENFORCE_EQ(member_->places_.size(), local_scopes.size());
for (size_t i = 0; i < member_->places_.size(); ++i) { for (size_t i = 0; i < member_->places_.size(); ++i) {
member_->local_scopes_.emplace_back(&local_scopes[i]->NewScope()); member_->local_scopes_.emplace_back(&local_scopes[i]->NewScope());
} }
} }
if (member_->use_cuda_) {
// Bcast Parameters to all GPUs // Bcast Parameters to all GPUs
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
auto *nccl_id_var = scope->FindVar(NCCL_ID_VARNAME); auto *nccl_id_var = scope->FindVar(NCCL_ID_VARNAME);
...@@ -86,9 +89,12 @@ ParallelExecutor::ParallelExecutor( ...@@ -86,9 +89,12 @@ ParallelExecutor::ParallelExecutor(
} }
member_->nccl_ctxs_.reset(new platform::NCCLContextMap( member_->nccl_ctxs_.reset(new platform::NCCLContextMap(
member_->places_, nccl_id, num_trainers, trainer_id)); member_->places_, nccl_id, num_trainers, trainer_id));
#else
PADDLE_THROW("Not compiled with CUDA");
#endif #endif
if (platform::is_gpu_place(places[0]) && member_->local_scopes_.size() != 1 && }
local_scopes.empty()) { // Is CUDA
if (member_->local_scopes_.size() != 1 && local_scopes.empty()) {
BCastParamsToGPUs(bcast_vars); BCastParamsToGPUs(bcast_vars);
} }
// Startup Program has been run. All local scopes has correct parameters. // Startup Program has been run. All local scopes has correct parameters.
...@@ -108,9 +114,13 @@ ParallelExecutor::ParallelExecutor( ...@@ -108,9 +114,13 @@ ParallelExecutor::ParallelExecutor(
details::SSAGraphBuilderFactory builder_factory( details::SSAGraphBuilderFactory builder_factory(
member_->places_, loss_var_name, params, member_->local_scopes_, member_->places_, loss_var_name, params, member_->local_scopes_,
build_strategy); build_strategy);
if (member_->use_cuda_) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
builder_factory.SetNCCLContextMap(member_->nccl_ctxs_.get()); builder_factory.SetNCCLContextMap(member_->nccl_ctxs_.get());
#else
PADDLE_THROW("Not compiled with CUDA");
#endif #endif
}
member_->executor_.reset(new details::ThreadedSSAGraphExecutor( member_->executor_.reset(new details::ThreadedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, places, exec_strategy, member_->local_scopes_, places,
...@@ -123,7 +133,6 @@ ParallelExecutor::ParallelExecutor( ...@@ -123,7 +133,6 @@ ParallelExecutor::ParallelExecutor(
void ParallelExecutor::BCastParamsToGPUs( void ParallelExecutor::BCastParamsToGPUs(
const std::unordered_set<std::string> &vars) const { const std::unordered_set<std::string> &vars) const {
#ifdef PADDLE_WITH_CUDA
auto *main_scope = member_->local_scopes_[0]; auto *main_scope = member_->local_scopes_[0];
for (auto &var : vars) { for (auto &var : vars) {
...@@ -135,6 +144,7 @@ void ParallelExecutor::BCastParamsToGPUs( ...@@ -135,6 +144,7 @@ void ParallelExecutor::BCastParamsToGPUs(
auto &main_tensor = main_var->Get<LoDTensor>(); auto &main_tensor = main_var->Get<LoDTensor>();
auto &dims = main_tensor.dims(); auto &dims = main_tensor.dims();
if (paddle::platform::is_gpu_place(main_tensor.place())) { if (paddle::platform::is_gpu_place(main_tensor.place())) {
#ifdef PADDLE_WITH_CUDA
size_t numel = main_tensor.numel(); size_t numel = main_tensor.numel();
ncclDataType_t data_type = platform::ToNCCLDataType(main_tensor.type()); ncclDataType_t data_type = platform::ToNCCLDataType(main_tensor.type());
platform::NCCLGroupGuard guard; platform::NCCLGroupGuard guard;
...@@ -153,6 +163,10 @@ void ParallelExecutor::BCastParamsToGPUs( ...@@ -153,6 +163,10 @@ void ParallelExecutor::BCastParamsToGPUs(
platform::dynload::ncclBcast(buffer, numel, data_type, 0, platform::dynload::ncclBcast(buffer, numel, data_type, 0,
nccl_ctx.comm_, nccl_ctx.stream()); nccl_ctx.comm_, nccl_ctx.stream());
} }
member_->nccl_ctxs_->WaitAll();
#else
PADDLE_THROW("Not compiled with CUDA");
#endif
} else { } else {
platform::CPUPlace cpu; platform::CPUPlace cpu;
for (size_t i = 1; i < member_->places_.size(); ++i) { for (size_t i = 1; i < member_->places_.size(); ++i) {
...@@ -163,11 +177,7 @@ void ParallelExecutor::BCastParamsToGPUs( ...@@ -163,11 +177,7 @@ void ParallelExecutor::BCastParamsToGPUs(
paddle::framework::TensorCopy(main_tensor, cpu, t); paddle::framework::TensorCopy(main_tensor, cpu, t);
} }
} }
member_->nccl_ctxs_->WaitAll();
} }
#else
PADDLE_THROW("Not compiled with CUDA");
#endif
} }
void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors, void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
...@@ -213,7 +223,7 @@ void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes( ...@@ -213,7 +223,7 @@ void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes(
} }
ParallelExecutor::~ParallelExecutor() { ParallelExecutor::~ParallelExecutor() {
if (member_->own_local_scope) { if (member_->own_local_scope_) {
for (size_t i = 1; i < member_->local_scopes_.size(); ++i) { for (size_t i = 1; i < member_->local_scopes_.size(); ++i) {
member_->global_scope_->DeleteScope(member_->local_scopes_[i]); member_->global_scope_->DeleteScope(member_->local_scopes_[i]);
} }
......
...@@ -35,14 +35,15 @@ class ReaderBase { ...@@ -35,14 +35,15 @@ class ReaderBase {
class DecoratedReader : public ReaderBase { class DecoratedReader : public ReaderBase {
public: public:
explicit DecoratedReader(ReaderBase* reader) : ReaderBase(), reader_(reader) { explicit DecoratedReader(const std::shared_ptr<ReaderBase>& reader)
: ReaderBase(), reader_(reader) {
PADDLE_ENFORCE_NOT_NULL(reader_); PADDLE_ENFORCE_NOT_NULL(reader_);
} }
void ReInit() override { reader_->ReInit(); } void ReInit() override { reader_->ReInit(); }
protected: protected:
ReaderBase* reader_; std::shared_ptr<ReaderBase> reader_;
}; };
class FileReader : public ReaderBase { class FileReader : public ReaderBase {
...@@ -64,7 +65,7 @@ class ReaderHolder { ...@@ -64,7 +65,7 @@ class ReaderHolder {
public: public:
void Reset(ReaderBase* reader) { reader_.reset(reader); } void Reset(ReaderBase* reader) { reader_.reset(reader); }
ReaderBase* Get() const { return reader_.get(); } std::shared_ptr<ReaderBase> Get() const { return reader_; }
void ReadNext(std::vector<LoDTensor>* out) { void ReadNext(std::vector<LoDTensor>* out) {
PADDLE_ENFORCE_NOT_NULL(reader_); PADDLE_ENFORCE_NOT_NULL(reader_);
...@@ -76,7 +77,7 @@ class ReaderHolder { ...@@ -76,7 +77,7 @@ class ReaderHolder {
} }
private: private:
std::unique_ptr<ReaderBase> reader_; std::shared_ptr<ReaderBase> reader_;
}; };
} // namespace framework } // namespace framework
......
...@@ -64,7 +64,8 @@ class TRTConvertValidation { ...@@ -64,7 +64,8 @@ class TRTConvertValidation {
TRTConvertValidation(int batch_size, TRTConvertValidation(int batch_size,
const std::unordered_set<std::string>& parameters, const std::unordered_set<std::string>& parameters,
framework::Scope& scope, int workspace_size = 1 << 10) framework::Scope& scope, // NOLINT
int workspace_size = 1 << 10)
: parameters_(parameters), scope_(scope) { : parameters_(parameters), scope_(scope) {
// create engine. // create engine.
engine_.reset(new TensorRTEngine(10, 1 << 10, &stream_)); engine_.reset(new TensorRTEngine(10, 1 << 10, &stream_));
......
...@@ -21,7 +21,6 @@ DEFINE_string(fp16_dirname, "", "Directory of the float16 inference model."); ...@@ -21,7 +21,6 @@ DEFINE_string(fp16_dirname, "", "Directory of the float16 inference model.");
DEFINE_int32(batch_size, 1, "Batch size of input data"); DEFINE_int32(batch_size, 1, "Batch size of input data");
DEFINE_int32(repeat, 1, "Running the inference program repeat times"); DEFINE_int32(repeat, 1, "Running the inference program repeat times");
DEFINE_bool(skip_cpu, false, "Skip the cpu test"); DEFINE_bool(skip_cpu, false, "Skip the cpu test");
DEFINE_bool(use_mkldnn, false, "Use MKLDNN to run inference");
TEST(inference, image_classification) { TEST(inference, image_classification) {
if (FLAGS_dirname.empty() || FLAGS_batch_size < 1 || FLAGS_repeat < 1) { if (FLAGS_dirname.empty() || FLAGS_batch_size < 1 || FLAGS_repeat < 1) {
...@@ -59,10 +58,8 @@ TEST(inference, image_classification) { ...@@ -59,10 +58,8 @@ TEST(inference, image_classification) {
// Run inference on CPU // Run inference on CPU
LOG(INFO) << "--- CPU Runs: ---"; LOG(INFO) << "--- CPU Runs: ---";
LOG(INFO) << "Batch size is " << FLAGS_batch_size; LOG(INFO) << "Batch size is " << FLAGS_batch_size;
LOG(INFO) << "FLAGS_use_mkldnn: " << FLAGS_use_mkldnn;
TestInference<paddle::platform::CPUPlace, false, true>( TestInference<paddle::platform::CPUPlace, false, true>(
dirname, cpu_feeds, cpu_fetchs1, FLAGS_repeat, is_combined, dirname, cpu_feeds, cpu_fetchs1, FLAGS_repeat, is_combined);
FLAGS_use_mkldnn);
LOG(INFO) << output1.dims(); LOG(INFO) << output1.dims();
} }
......
...@@ -27,7 +27,6 @@ limitations under the License. */ ...@@ -27,7 +27,6 @@ limitations under the License. */
DEFINE_string(model_path, "", "Directory of the inference model."); DEFINE_string(model_path, "", "Directory of the inference model.");
DEFINE_string(data_file, "", "File of input index data."); DEFINE_string(data_file, "", "File of input index data.");
DEFINE_int32(repeat, 100, "Running the inference program repeat times"); DEFINE_int32(repeat, 100, "Running the inference program repeat times");
DEFINE_bool(use_mkldnn, false, "Use MKLDNN to run inference");
DEFINE_bool(prepare_vars, true, "Prepare variables before executor"); DEFINE_bool(prepare_vars, true, "Prepare variables before executor");
DEFINE_int32(num_threads, 1, "Number of threads should be used"); DEFINE_int32(num_threads, 1, "Number of threads should be used");
...@@ -190,9 +189,6 @@ TEST(inference, nlp) { ...@@ -190,9 +189,6 @@ TEST(inference, nlp) {
std::unique_ptr<paddle::framework::ProgramDesc> inference_program; std::unique_ptr<paddle::framework::ProgramDesc> inference_program;
inference_program = InitProgram(&executor, scope.get(), FLAGS_model_path, inference_program = InitProgram(&executor, scope.get(), FLAGS_model_path,
/*model combined*/ false); /*model combined*/ false);
if (FLAGS_use_mkldnn) {
EnableMKLDNN(inference_program);
}
// always prepare context // always prepare context
std::unique_ptr<paddle::framework::ExecutorPrepareContext> ctx; std::unique_ptr<paddle::framework::ExecutorPrepareContext> ctx;
ctx = executor.Prepare(*inference_program, 0); ctx = executor.Prepare(*inference_program, 0);
......
...@@ -22,6 +22,8 @@ limitations under the License. */ ...@@ -22,6 +22,8 @@ limitations under the License. */
#include "paddle/fluid/inference/io.h" #include "paddle/fluid/inference/io.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
DECLARE_bool(use_mkldnn);
template <typename T> template <typename T>
void SetupTensor(paddle::framework::LoDTensor* input, void SetupTensor(paddle::framework::LoDTensor* input,
paddle::framework::DDim dims, T lower, T upper) { paddle::framework::DDim dims, T lower, T upper) {
...@@ -133,24 +135,11 @@ std::vector<std::vector<int64_t>> GetFeedTargetShapes( ...@@ -133,24 +135,11 @@ std::vector<std::vector<int64_t>> GetFeedTargetShapes(
return feed_target_shapes; return feed_target_shapes;
} }
void EnableMKLDNN(
const std::unique_ptr<paddle::framework::ProgramDesc>& program) {
for (size_t bid = 0; bid < program->Size(); ++bid) {
auto* block = program->MutableBlock(bid);
for (auto* op : block->AllOps()) {
if (op->HasAttr("use_mkldnn")) {
op->SetAttr("use_mkldnn", true);
}
}
}
}
template <typename Place, bool CreateVars = true, bool PrepareContext = false> template <typename Place, bool CreateVars = true, bool PrepareContext = false>
void TestInference(const std::string& dirname, void TestInference(const std::string& dirname,
const std::vector<paddle::framework::LoDTensor*>& cpu_feeds, const std::vector<paddle::framework::LoDTensor*>& cpu_feeds,
const std::vector<paddle::framework::LoDTensor*>& cpu_fetchs, const std::vector<paddle::framework::LoDTensor*>& cpu_fetchs,
const int repeat = 1, const bool is_combined = false, const int repeat = 1, const bool is_combined = false) {
const bool use_mkldnn = false) {
// 1. Define place, executor, scope // 1. Define place, executor, scope
auto place = Place(); auto place = Place();
auto executor = paddle::framework::Executor(place); auto executor = paddle::framework::Executor(place);
...@@ -182,9 +171,6 @@ void TestInference(const std::string& dirname, ...@@ -182,9 +171,6 @@ void TestInference(const std::string& dirname,
"init_program", "init_program",
paddle::platform::DeviceContextPool::Instance().Get(place)); paddle::platform::DeviceContextPool::Instance().Get(place));
inference_program = InitProgram(&executor, scope, dirname, is_combined); inference_program = InitProgram(&executor, scope, dirname, is_combined);
if (use_mkldnn) {
EnableMKLDNN(inference_program);
}
} }
// Disable the profiler and print the timing information // Disable the profiler and print the timing information
paddle::platform::DisableProfiler(paddle::platform::EventSortingKey::kDefault, paddle::platform::DisableProfiler(paddle::platform::EventSortingKey::kDefault,
...@@ -210,7 +196,10 @@ void TestInference(const std::string& dirname, ...@@ -210,7 +196,10 @@ void TestInference(const std::string& dirname,
fetch_targets[fetch_target_names[i]] = cpu_fetchs[i]; fetch_targets[fetch_target_names[i]] = cpu_fetchs[i];
} }
// 6. Run the inference program // 6. If export Flags_use_mkldnn=True, use mkldnn related ops.
if (FLAGS_use_mkldnn) executor.EnableMKLDNN(*inference_program);
// 7. Run the inference program
{ {
if (!CreateVars) { if (!CreateVars) {
// If users don't want to create and destroy variables every time they // If users don't want to create and destroy variables every time they
......
...@@ -187,18 +187,22 @@ endif() ...@@ -187,18 +187,22 @@ endif()
add_subdirectory(detail) add_subdirectory(detail)
if(WITH_DISTRIBUTE) if(WITH_DISTRIBUTE)
set(DISTRIBUTE_DEPS "")
if(WITH_GRPC)
set(DISTRIBUTE_DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf) set(DISTRIBUTE_DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf)
else()
set(DISTRIBUTE_DEPS sendrecvop_brpc brpc leveldb snappystream snappy protobuf ssl crypto zlib)
endif()
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
op_library(send_op DEPS ${DISTRIBUTE_DEPS})
set_source_files_properties(send_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
op_library(prefetch_op DEPS ${DISTRIBUTE_DEPS}) op_library(prefetch_op DEPS ${DISTRIBUTE_DEPS})
set_source_files_properties(prefetch_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(prefetch_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
op_library(recv_op DEPS ${DISTRIBUTE_DEPS}) op_library(recv_op DEPS ${DISTRIBUTE_DEPS})
set_source_files_properties(recv_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(recv_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
op_library(listen_and_serv_op DEPS ${DISTRIBUTE_DEPS}) op_library(listen_and_serv_op DEPS ${DISTRIBUTE_DEPS})
set_source_files_properties(listen_and_serv_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(listen_and_serv_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
op_library(send_vars_op DEPS ${DISTRIBUTE_DEPS}) op_library(send_op DEPS ${DISTRIBUTE_DEPS})
set_source_files_properties(send_vars_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(send_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
op_library(send_barrier_op DEPS ${DISTRIBUTE_DEPS}) op_library(send_barrier_op DEPS ${DISTRIBUTE_DEPS})
op_library(fetch_barrier_op DEPS ${DISTRIBUTE_DEPS}) op_library(fetch_barrier_op DEPS ${DISTRIBUTE_DEPS})
set_source_files_properties(send_barrier_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(send_barrier_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
...@@ -208,15 +212,18 @@ if(WITH_DISTRIBUTE) ...@@ -208,15 +212,18 @@ if(WITH_DISTRIBUTE)
# listen_and_serv_op sum_op executor SERIAL) # listen_and_serv_op sum_op executor SERIAL)
if(WITH_GPU) if(WITH_GPU)
set_source_files_properties(test_send_nccl_id.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(test_send_nccl_id.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(test_send_nccl_id SRCS test_send_nccl_id.cc DEPS send_op cc_test(test_send_nccl_id SRCS test_send_nccl_id.cc DEPS listen_and_serv_op executor SERIAL)
listen_and_serv_op executor SERIAL) if(WITH_GRPC)
op_library(gen_nccl_id_op DEPS nccl_common sendrecvop_grpc) op_library(gen_nccl_id_op DEPS nccl_common sendrecvop_grpc)
else()
op_library(gen_nccl_id_op DEPS nccl_common sendrecvop_brpc)
endif()
set_source_files_properties(gen_nccl_id_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(gen_nccl_id_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
else() else()
set(DEPS_OPS ${DEPS_OPS} gen_nccl_id_op) set(DEPS_OPS ${DEPS_OPS} gen_nccl_id_op)
endif() endif()
else() else()
set(DEPS_OPS ${DEPS_OPS} send_op prefetch_op recv_op listen_and_serv_op send_vars_op send_barrier_op fetch_barrier_op gen_nccl_id_op) set(DEPS_OPS ${DEPS_OPS} prefetch_op recv_op listen_and_serv_op send_op send_barrier_op fetch_barrier_op gen_nccl_id_op)
endif() endif()
op_library(cross_entropy_op DEPS cross_entropy) op_library(cross_entropy_op DEPS cross_entropy)
......
...@@ -25,7 +25,7 @@ namespace operators { ...@@ -25,7 +25,7 @@ namespace operators {
public: \ public: \
void Make() override { \ void Make() override { \
AddInput("X", "Input of " #OP_NAME " operator"); \ AddInput("X", "Input of " #OP_NAME " operator"); \
AddOutput("Out", "Output of " #OP_NAME " operator"); \ AddOutput("Out", "Output of " #OP_NAME " operator").Reuse("X"); \
AddAttr<bool>("use_mkldnn", \ AddAttr<bool>("use_mkldnn", \
"(bool, default false) Only used in mkldnn kernel") \ "(bool, default false) Only used in mkldnn kernel") \
.SetDefault(false); \ .SetDefault(false); \
......
...@@ -89,9 +89,9 @@ class AdamOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -89,9 +89,9 @@ class AdamOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("Beta1Pow", "(Tensor) Input beta1 power accumulator"); AddInput("Beta1Pow", "(Tensor) Input beta1 power accumulator");
AddInput("Beta2Pow", "(Tensor) Input beta2 power accumulator"); AddInput("Beta2Pow", "(Tensor) Input beta2 power accumulator");
AddOutput("ParamOut", "(Tensor) Output parameter"); AddOutput("ParamOut", "(Tensor) Output parameter").Reuse("Param");
AddOutput("Moment1Out", "(Tensor) Output first moment"); AddOutput("Moment1Out", "(Tensor) Output first moment").Reuse("Moment1");
AddOutput("Moment2Out", "(Tensor) Output second moment"); AddOutput("Moment2Out", "(Tensor) Output second moment").Reuse("Moment2");
AddAttr<float>("beta1", AddAttr<float>("beta1",
"(float, default 0.9) " "(float, default 0.9) "
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/arg_min_max_op_base.h"
REGISTER_OPERATOR(arg_max, paddle::operators::ArgMinMaxOp,
paddle::operators::ArgMaxOpMaker,
paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL(
arg_max,
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext, float>,
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext, double>,
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext,
int64_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext,
int32_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext,
int16_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext, size_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext,
uint8_t>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/arg_min_max_op_base.h"
REGISTER_OP_CUDA_KERNEL(
arg_max,
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext, float>,
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext,
double>,
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext,
int64_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext,
int32_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext,
int16_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext,
size_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext,
uint8_t>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include <type_traits>
#include <vector>
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/string/printf.h"
namespace paddle {
namespace operators {
enum ArgMinMaxType { kArgMin, kArgMax };
template <typename DeviceContext, typename T, typename Tout, int64_t Rank,
ArgMinMaxType argMinMaxValue>
struct ArgMinMaxFunctor {};
#define DECLARE_ARG_MIN_MAX_FUNCTOR(eigen_op_type, enum_argminmax_value) \
template <typename DeviceContext, typename T, typename Tout, int64_t Rank> \
struct ArgMinMaxFunctor<DeviceContext, T, Tout, Rank, \
enum_argminmax_value> { \
void operator()(const DeviceContext& ctx, const framework::LoDTensor& in, \
framework::LoDTensor* out, int64_t axis) { \
auto in_eigen = framework::EigenTensor<T, Rank>::From(in); \
auto out_eigen = framework::EigenTensor<Tout, Rank - 1>::From(*out); \
out_eigen.device(*(ctx.eigen_device())) = \
in_eigen.eigen_op_type(axis).template cast<Tout>(); \
} \
}
DECLARE_ARG_MIN_MAX_FUNCTOR(argmin, ArgMinMaxType::kArgMin);
DECLARE_ARG_MIN_MAX_FUNCTOR(argmax, ArgMinMaxType::kArgMax);
template <typename DeviceContext, typename T, typename Tout,
ArgMinMaxType EnumArgMinMaxValue>
class ArgMinMaxKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto& x = *(ctx.Input<framework::LoDTensor>("X"));
auto& out = *(ctx.Output<framework::LoDTensor>("Out"));
out.mutable_data<Tout>(ctx.GetPlace());
auto axis = ctx.Attr<int64_t>("axis");
auto& dev_ctx = ctx.template device_context<DeviceContext>();
#define CALL_ARG_MINMAX_FUNCTOR(rank) \
ArgMinMaxFunctor<DeviceContext, T, Tout, rank, EnumArgMinMaxValue> \
functor##rank; \
functor##rank(dev_ctx, x, &out, axis)
switch (x.dims().size()) {
case 1:
CALL_ARG_MINMAX_FUNCTOR(1);
break;
case 2:
CALL_ARG_MINMAX_FUNCTOR(2);
break;
case 3:
CALL_ARG_MINMAX_FUNCTOR(3);
break;
case 4:
CALL_ARG_MINMAX_FUNCTOR(4);
break;
case 5:
CALL_ARG_MINMAX_FUNCTOR(5);
break;
case 6:
CALL_ARG_MINMAX_FUNCTOR(6);
break;
default:
PADDLE_THROW(
"%s operator doesn't supports tensors whose ranks are greater "
"than 6.",
(EnumArgMinMaxValue == kArgMin ? "argmin" : "argmax"));
break;
#undef CALL_ARG_MINMAX_FUNCTOR
}
}
};
template <typename DeviceContext, typename T>
using ArgMinKernel =
ArgMinMaxKernel<DeviceContext, T, int64_t, ArgMinMaxType::kArgMin>;
template <typename DeviceContext, typename T>
using ArgMaxKernel =
ArgMinMaxKernel<DeviceContext, T, int64_t, ArgMinMaxType::kArgMax>;
class ArgMinMaxOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null");
const auto& x_dims = ctx->GetInputDim("X");
int64_t axis = ctx->Attrs().Get<int64_t>("axis");
PADDLE_ENFORCE(axis >= -x_dims.size() && axis < x_dims.size(),
"'axis' must be inside [-Rank(X), Rank(X))");
auto x_rank = x_dims.size();
if (axis < 0) axis += x_rank;
std::vector<int64_t> vec;
for (int64_t i = 0; i < axis; i++) vec.push_back(x_dims[i]);
for (int64_t i = axis + 1; i < x_rank; i++) vec.push_back(x_dims[i]);
ctx->SetOutputDim("Out", framework::make_ddim(vec));
}
};
class BaseArgMinMaxOpMaker : public framework::OpProtoAndCheckerMaker {
protected:
virtual const char* OpName() const = 0;
virtual const char* Name() const = 0;
public:
void Make() override {
AddInput("X", "Input tensor.");
AddOutput("Out", "Output tensor.");
AddAttr<int64_t>("axis", "The axis in which to compute the arg indics.");
AddComment(string::Sprintf(R"DOC(
%s Operator.
Computes the indices of the %s elements of the input tensor's element
along the provided axis.
)DOC",
OpName(), Name()));
}
};
class ArgMinOpMaker : public BaseArgMinMaxOpMaker {
protected:
const char* OpName() const override { return "ArgMin"; }
const char* Name() const override { return "min"; }
};
class ArgMaxOpMaker : public BaseArgMinMaxOpMaker {
protected:
const char* OpName() const override { return "ArgMax"; }
const char* Name() const override { return "max"; }
};
} // namespace operators
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/arg_min_max_op_base.h"
REGISTER_OPERATOR(arg_min, paddle::operators::ArgMinMaxOp,
paddle::operators::ArgMinOpMaker,
paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL(
arg_min,
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext, float>,
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext, double>,
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext,
int64_t>,
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext,
int32_t>,
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext,
int16_t>,
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext, size_t>,
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext,
uint8_t>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/arg_min_max_op_base.h"
REGISTER_OP_CUDA_KERNEL(
arg_min,
paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext, float>,
paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext,
double>,
paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext,
int64_t>,
paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext,
int32_t>,
paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext,
int16_t>,
paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext,
size_t>,
paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext,
uint8_t>);
...@@ -19,10 +19,17 @@ limitations under the License. */ ...@@ -19,10 +19,17 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor; using batch_norm_bwd = mkldnn::batch_normalization_backward;
using batch_norm_fwd = mkldnn::batch_normalization_forward;
using framework::DataLayout;
using framework::Tensor;
using mkldnn::memory;
using mkldnn::primitive;
using mkldnn::reorder;
using mkldnn::stream;
using paddle::platform::MKLDNNDeviceContext; using paddle::platform::MKLDNNDeviceContext;
using paddle::platform::MKLDNNMemDesc; using paddle::platform::MKLDNNMemDesc;
using mkldnn::memory; using platform::to_void_cast;
template <typename T> template <typename T>
using EigenArrayMap = using EigenArrayMap =
...@@ -64,21 +71,12 @@ void run_batch_norm_op(Args &&... args) { ...@@ -64,21 +71,12 @@ void run_batch_norm_op(Args &&... args) {
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait(); mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
} }
template <typename T>
inline void *cast_const_to_void(const T *t) {
return static_cast<void *>(const_cast<T *>(t));
}
} // namespace } // namespace
template <typename T> template <typename T>
class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> { class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
auto data_layout_str = ctx.Attr<std::string>("data_layout");
auto data_layout = framework::StringToDataLayout(data_layout_str);
PADDLE_ENFORCE(data_layout == framework::DataLayout::kNCHW,
"MKLDNN batch normalization handles only NCHW data layout");
const float epsilon = ctx.Attr<float>("epsilon"); const float epsilon = ctx.Attr<float>("epsilon");
const float momentum = ctx.Attr<float>("momentum"); const float momentum = ctx.Attr<float>("momentum");
const bool is_test = ctx.Attr<bool>("is_test"); const bool is_test = ctx.Attr<bool>("is_test");
...@@ -99,41 +97,53 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -99,41 +97,53 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const auto *scale = ctx.Input<Tensor>("Scale"); const auto *scale = ctx.Input<Tensor>("Scale");
const auto *shift = ctx.Input<Tensor>("Bias"); const auto *shift = ctx.Input<Tensor>("Bias");
y->mutable_data<T>(ctx.GetPlace()); PADDLE_ENFORCE(x->layout() == DataLayout::kMKLDNN &&
mean_out->mutable_data<T>(ctx.GetPlace()); x->format() != memory::format::format_undef,
variance_out->mutable_data<T>(ctx.GetPlace()); "Wrong layout/format set for Input x tensor");
const T *x_data = x->data<T>();
const T *mean_data = mean->data<T>();
const T *variance_data = variance->data<T>();
T *y_data = y->mutable_data<T>(ctx.GetPlace());
T *mean_out_data = mean_out->mutable_data<T>(ctx.GetPlace());
T *variance_out_data = variance_out->mutable_data<T>(ctx.GetPlace());
T *batch_mean_data = nullptr;
T *batch_variance_data = nullptr;
if (!is_test) { if (!is_test) {
batch_mean->mutable_data<T>(ctx.GetPlace()); batch_mean_data = batch_mean->mutable_data<T>(ctx.GetPlace());
batch_variance->mutable_data<T>(ctx.GetPlace()); batch_variance_data = batch_variance->mutable_data<T>(ctx.GetPlace());
} }
auto propagation = is_test == true ? mkldnn::prop_kind::forward_scoring auto propagation = is_test == true ? mkldnn::prop_kind::forward_scoring
: mkldnn::prop_kind::forward_training; : mkldnn::prop_kind::forward_training;
auto dims = paddle::framework::vectorize2int(x->dims()); auto src_tz = paddle::framework::vectorize2int(x->dims());
auto scale_tz = paddle::framework::vectorize2int(scale->dims());
auto src_md = PADDLE_ENFORCE(scale_tz.size() == 1, "Dims of scale tensor is NOT 1");
MKLDNNMemDesc(dims, memory::data_type::f32, memory::format::nchw); const unsigned int ic = scale_tz[0];
auto dst_md =
MKLDNNMemDesc(dims, memory::data_type::f32, memory::format::nchw);
auto src_pd = mkldnn::memory::primitive_desc{src_md, mkldnn_engine};
auto dst_pd = mkldnn::memory::primitive_desc{dst_md, mkldnn_engine};
auto src = mkldnn::memory{src_pd, cast_const_to_void(x->data<T>())};
auto dst = mkldnn::memory{dst_pd, y->data<T>()};
unsigned flags = mkldnn::use_scale_shift; unsigned flags = mkldnn::use_scale_shift;
if (is_test) flags |= mkldnn::use_global_stats; if (is_test) flags |= mkldnn::use_global_stats;
// create mkldnn memory from input x tensor
auto src_memory =
memory({{{src_tz}, memory::data_type::f32, x->format()}, mkldnn_engine},
to_void_cast(x_data));
// create primitive descriptor for batch norm forward
using bn_fwd_types = bn_type_traits<mkldnn::batch_normalization_forward>; using bn_fwd_types = bn_type_traits<mkldnn::batch_normalization_forward>;
auto batch_norm_fwd_desc = auto batch_norm_fwd_desc = bn_fwd_types::op_desc{
bn_fwd_types::op_desc{propagation, src_md, epsilon, flags}; propagation, src_memory.get_primitive_desc().desc(), epsilon, flags};
auto batch_norm_fwd_pd = std::shared_ptr<batch_norm_fwd::primitive_desc> batch_norm_fwd_pd =
bn_fwd_types::op_prim{batch_norm_fwd_desc, mkldnn_engine}; std::shared_ptr<batch_norm_fwd::primitive_desc>(
new batch_norm_fwd::primitive_desc(batch_norm_fwd_desc,
mkldnn_engine));
const unsigned int ic = dims[1]; // Save the pd to be used in backward pass
const std::string key = ctx.op().Output("SavedMean");
const std::string key_batch_norm_fwd_pd = key + "@bn_fwd_pd";
dev_ctx.SetBlob(key_batch_norm_fwd_pd, batch_norm_fwd_pd);
// MKLDNN requires a single piece of memory for scale and shift/bias data // MKLDNN requires a single piece of memory for scale and shift/bias data
const size_t scaleshift_size = 2 * ic; const size_t scaleshift_size = 2 * ic;
...@@ -143,73 +153,58 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -143,73 +153,58 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
copy_to_weights(scale->data<T>(), scale->data<T>() + ic, shift->data<T>(), copy_to_weights(scale->data<T>(), scale->data<T>() + ic, shift->data<T>(),
shift->data<T>() + ic, &scaleshift_data); shift->data<T>() + ic, &scaleshift_data);
auto scaleshift_memory = mkldnn::memory{ // crate mkldnn memory for weights(scale/shift)
batch_norm_fwd_pd.weights_primitive_desc(), scaleshift_data.data()}; auto scaleshift_memory = memory(batch_norm_fwd_pd->weights_primitive_desc(),
scaleshift_data.data());
if (is_test) { // create mkldnn memory for output y tensor
auto mean_memory = mkldnn::memory{batch_norm_fwd_pd.mean_primitive_desc(), auto dst_memory = memory(batch_norm_fwd_pd->dst_primitive_desc(), y_data);
cast_const_to_void(mean->data<T>())};
if (is_test) {
// create mkldnn memory for stats (as input)
auto mean_memory = memory(batch_norm_fwd_pd->mean_primitive_desc(),
to_void_cast(mean_data));
auto variance_memory = auto variance_memory =
mkldnn::memory{batch_norm_fwd_pd.variance_primitive_desc(), memory(batch_norm_fwd_pd->variance_primitive_desc(),
cast_const_to_void(variance->data<T>())}; to_void_cast(variance_data));
run_batch_norm_op<typename bn_fwd_types::op_type>( run_batch_norm_op<typename bn_fwd_types::op_type>(
batch_norm_fwd_pd, src, (const mkldnn::primitive::at &)mean_memory, *batch_norm_fwd_pd, src_memory,
(const mkldnn::primitive::at &)mean_memory,
(const mkldnn::primitive::at &)variance_memory, scaleshift_memory, (const mkldnn::primitive::at &)variance_memory, scaleshift_memory,
dst); dst_memory);
} else { } else {
// create mkldnn memory for stats (as output)
auto mean_memory = auto mean_memory =
mkldnn::memory{batch_norm_fwd_pd.mean_primitive_desc(), memory(batch_norm_fwd_pd->mean_primitive_desc(), batch_mean_data);
cast_const_to_void(batch_mean->data<T>())}; auto variance_memory = memory(
batch_norm_fwd_pd->variance_primitive_desc(), batch_variance_data);
auto variance_memory =
mkldnn::memory{batch_norm_fwd_pd.variance_primitive_desc(),
cast_const_to_void(batch_variance->data<T>())};
run_batch_norm_op<bn_fwd_types::op_type>(batch_norm_fwd_pd, src, run_batch_norm_op<bn_fwd_types::op_type>(*batch_norm_fwd_pd, src_memory,
scaleshift_memory, dst, scaleshift_memory, dst_memory,
mean_memory, variance_memory); mean_memory, variance_memory);
} }
if (!is_test) { if (!is_test) {
const unsigned int in = dims[0]; // mkldnn only compute stats for current batch
const unsigned int sample_size = x->numel() / in / ic; // so we need compute momentum stats via Eigen lib
EigenVectorArrayMap<T> batch_mean_e(batch_mean_data, ic);
// saved_xx is use just in this batch of data EigenVectorArrayMap<T> batch_variance_e(batch_variance_data, ic);
EigenVectorArrayMap<T> saved_mean_e( ConstEigenVectorArrayMap<T> mean_e(mean_data, ic);
batch_mean->mutable_data<T>(ctx.GetPlace()), ic); ConstEigenVectorArrayMap<T> variance_e{variance_data, ic};
EigenVectorArrayMap<T> saved_variance_e(
batch_variance->mutable_data<T>(ctx.GetPlace()), ic);
saved_mean_e.setZero();
saved_variance_e.setZero();
const unsigned int x_arr_size = in * ic;
ConstEigenArrayMap<T> x_arr(x->data<T>(), sample_size, x_arr_size);
for (unsigned int nc = 0; nc < x_arr_size; ++nc) {
saved_mean_e(nc % ic) += x_arr.col(nc).sum();
}
saved_mean_e /= in * sample_size;
for (unsigned int nc = 0; nc < x_arr_size; ++nc) {
saved_variance_e(nc % ic) +=
(x_arr.col(nc) - saved_mean_e(nc % ic)).matrix().squaredNorm();
}
saved_variance_e /= in * sample_size;
ConstEigenVectorArrayMap<T> mean_arr{mean->data<T>(), ic};
ConstEigenVectorArrayMap<T> variance_arr{variance->data<T>(), ic};
EigenVectorArrayMap<T> running_mean_arr( EigenVectorArrayMap<T> running_mean_e(mean_out_data, ic);
mean_out->mutable_data<T>(ctx.GetPlace()), ic); EigenVectorArrayMap<T> running_variance_e(variance_out_data, ic);
EigenVectorArrayMap<T> running_var_arr(
variance_out->mutable_data<T>(ctx.GetPlace()), ic);
auto one_minus_momentum = 1. - momentum; auto one_minus_momentum = 1. - momentum;
running_mean_arr = running_mean_e = mean_e * momentum + batch_mean_e * one_minus_momentum;
mean_arr * momentum + saved_mean_e * one_minus_momentum; running_variance_e =
running_var_arr = variance_e * momentum + batch_variance_e * one_minus_momentum;
variance_arr * momentum + saved_variance_e * one_minus_momentum;
} }
y->set_layout(DataLayout::kMKLDNN);
y->set_format(
(memory::format)dst_memory.get_primitive_desc().desc().data.format);
} }
}; };
...@@ -217,11 +212,6 @@ template <typename T> ...@@ -217,11 +212,6 @@ template <typename T>
class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
public: public:
void Compute(const paddle::framework::ExecutionContext &ctx) const override { void Compute(const paddle::framework::ExecutionContext &ctx) const override {
auto data_layout_str = ctx.Attr<std::string>("data_layout");
auto data_layout = framework::StringToDataLayout(data_layout_str);
PADDLE_ENFORCE(data_layout == framework::DataLayout::kNCHW,
"MKLDNN batch normalization handles only NCHW data layout");
auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>(); auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
auto mkldnn_engine = dev_ctx.GetEngine(); auto mkldnn_engine = dev_ctx.GetEngine();
...@@ -238,88 +228,132 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -238,88 +228,132 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
auto *diff_scale = ctx.Output<Tensor>(framework::GradVarName("Scale")); auto *diff_scale = ctx.Output<Tensor>(framework::GradVarName("Scale"));
auto *diff_shift = ctx.Output<Tensor>(framework::GradVarName("Bias")); auto *diff_shift = ctx.Output<Tensor>(framework::GradVarName("Bias"));
diff_x->mutable_data<T>(ctx.GetPlace()); PADDLE_ENFORCE(diff_y->layout() == DataLayout::kMKLDNN &&
diff_scale->mutable_data<T>(ctx.GetPlace()); diff_y->format() != memory::format::format_undef,
diff_shift->mutable_data<T>(ctx.GetPlace()); "Wrong layout/format set for Input diff_y tensor");
const T *x_data = x->data<T>();
const T *diff_y_data = diff_y->data<T>();
const T *batch_mean_data = batch_mean->data<T>();
const T *batch_variance_data = batch_variance->data<T>();
const T *scale_data = scale->data<T>();
const T *shift_data = shift->data<T>();
T *diff_x_data = diff_x->mutable_data<T>(ctx.GetPlace());
T *diff_scale_data = diff_scale->mutable_data<T>(ctx.GetPlace());
T *diff_shift_data = diff_shift->mutable_data<T>(ctx.GetPlace());
auto src_tz = paddle::framework::vectorize2int(x->dims());
auto diff_src_tz = src_tz;
auto dst_tz = src_tz;
auto diff_dst_tz = dst_tz;
auto scale_tz = paddle::framework::vectorize2int(scale->dims());
PADDLE_ENFORCE(scale_tz.size() == 1, "Dims of scale tensor is NOT 1");
const unsigned int ic = scale_tz[0];
// Retrieve bn_fwd_pd from device context
const std::string key = ctx.op().Input("SavedMean");
const std::string key_batch_norm_fwd_pd = key + "@bn_fwd_pd";
auto batch_norm_fwd_pd =
std::static_pointer_cast<batch_norm_fwd::primitive_desc>(
dev_ctx.GetBlob(key_batch_norm_fwd_pd));
PADDLE_ENFORCE(batch_norm_fwd_pd != nullptr,
"Fail to find batch_norm_fwd_pd in device context");
auto dims = paddle::framework::vectorize2int(x->dims()); using bn_bwd_types = bn_type_traits<mkldnn::batch_normalization_backward>;
unsigned flags = mkldnn::use_scale_shift | !mkldnn::use_global_stats;
auto src_md = // create mkldnn memory from input diff_y tensor
MKLDNNMemDesc(dims, memory::data_type::f32, memory::format::nchw); auto user_diff_dst_memory =
auto dst_md = memory({{{diff_dst_tz}, memory::data_type::f32, diff_y->format()},
MKLDNNMemDesc(dims, memory::data_type::f32, memory::format::nchw); mkldnn_engine},
auto diff_src_md = to_void_cast(diff_y_data));
MKLDNNMemDesc(dims, memory::data_type::f32, memory::format::nchw);
auto diff_dst_md =
MKLDNNMemDesc(dims, memory::data_type::f32, memory::format::nchw);
using bn_bwd_types = bn_type_traits<mkldnn::batch_normalization_backward>; // create mkldnn memory from input x tensor
using bn_fwd_types = bn_type_traits<mkldnn::batch_normalization_forward>; auto src_memory =
memory({{{src_tz}, memory::data_type::f32, x->format()}, mkldnn_engine},
to_void_cast(x_data));
auto batch_norm_fwd_desc = bn_fwd_types::op_desc{ // for diff_dst, try to use same format as dst in forward pass
mkldnn::prop_kind::forward_training, src_md, epsilon, flags}; auto diff_dst_pd = batch_norm_fwd_pd.get()->dst_primitive_desc();
auto batch_norm_fwd_pd = auto diff_dst_md = diff_dst_pd.desc();
bn_fwd_types::op_prim{batch_norm_fwd_desc, mkldnn_engine};
// create primitive descriptor for batch norm backward
unsigned flags = mkldnn::use_scale_shift;
auto batch_norm_bwd_desc = bn_bwd_types::op_desc{ auto batch_norm_bwd_desc = bn_bwd_types::op_desc{
mkldnn::prop_kind::backward, diff_dst_md, dst_md, epsilon, flags}; mkldnn::prop_kind::backward, diff_dst_md,
src_memory.get_primitive_desc().desc(), epsilon, flags};
auto batch_norm_bwd_pd = bn_bwd_types::op_prim{ auto batch_norm_bwd_pd = bn_bwd_types::op_prim{
batch_norm_bwd_desc, mkldnn_engine, batch_norm_fwd_pd}; batch_norm_bwd_desc, mkldnn_engine, *batch_norm_fwd_pd};
auto src = mkldnn::memory{{src_md, mkldnn_engine}, // reorder user_diff_dst if it's not in preferred format
cast_const_to_void(x->data<T>())}; auto diff_dst_memory = user_diff_dst_memory;
primitive reorder_diff_dst;
auto mean = mkldnn::memory{batch_norm_bwd_pd.mean_primitive_desc(), bool is_diff_dst_reordered = false;
cast_const_to_void(batch_mean->data<T>())}; if (diff_dst_pd != user_diff_dst_memory.get_primitive_desc()) {
diff_dst_memory = memory(diff_dst_pd);
auto variance = reorder_diff_dst = reorder(user_diff_dst_memory, diff_dst_memory);
mkldnn::memory{batch_norm_bwd_pd.variance_primitive_desc(), is_diff_dst_reordered = true;
cast_const_to_void(batch_variance->data<T>())}; }
auto diff_dst = mkldnn::memory{{diff_dst_md, mkldnn_engine},
cast_const_to_void(diff_y->data<T>())};
const unsigned int ic = dims[1]; // create mkldnn memory for input tensors (src/mean/variance)
auto mean_memory = memory(batch_norm_bwd_pd.mean_primitive_desc(),
to_void_cast(batch_mean_data));
auto variance_memory = memory(batch_norm_bwd_pd.variance_primitive_desc(),
to_void_cast(batch_variance_data));
// MKLDNN requires a single piece of memory for scale and shift/bias data
const size_t scaleshift_size = 2 * ic; const size_t scaleshift_size = 2 * ic;
std::vector<T> scaleshift_data; std::vector<T> scaleshift_data;
scaleshift_data.reserve(scaleshift_size); scaleshift_data.reserve(scaleshift_size);
copy_to_weights(scale->data<T>(), scale->data<T>() + ic, shift->data<T>(), copy_to_weights(scale_data, scale_data + ic, shift_data, shift_data + ic,
shift->data<T>() + ic, &scaleshift_data); &scaleshift_data);
auto scaleshift_memory = mkldnn::memory{ // create mkldnn memory for input tensors (scale/shift)
batch_norm_bwd_pd.weights_primitive_desc(), scaleshift_data.data()}; auto scaleshift_memory = memory(batch_norm_bwd_pd.weights_primitive_desc(),
scaleshift_data.data());
// create mkldnn memory for output diff weights (combined scale/shift)
std::vector<T> diff_scaleshift_data; std::vector<T> diff_scaleshift_data;
diff_scaleshift_data.reserve(scaleshift_size); diff_scaleshift_data.reserve(scaleshift_size);
copy_to_weights(diff_scale->data<T>(), diff_scale->data<T>() + ic,
diff_shift->data<T>(), diff_shift->data<T>() + ic,
&diff_scaleshift_data);
auto diff_scaleshift_memory = auto diff_scaleshift_memory =
mkldnn::memory{batch_norm_bwd_pd.diff_weights_primitive_desc(), memory(batch_norm_bwd_pd.diff_weights_primitive_desc(),
diff_scaleshift_data.data()}; diff_scaleshift_data.data());
auto diff_src = mkldnn::memory{{diff_src_md, mkldnn_engine}, // here assume diff_src is in the same format of src
static_cast<void *>(diff_x->data<T>())}; auto diff_src_memory = memory(src_memory.get_primitive_desc(), diff_x_data);
run_batch_norm_op<bn_bwd_types::op_type>( // finally create batch_norm backward primitive
batch_norm_bwd_pd, src, mean, variance, diff_dst, scaleshift_memory, auto batch_norm_bwd_prim =
diff_src, diff_scaleshift_memory); batch_norm_bwd(batch_norm_bwd_pd, src_memory, mean_memory,
variance_memory, diff_dst_memory, scaleshift_memory,
diff_src_memory, diff_scaleshift_memory);
// execute optional reorder and batch_norm backward primitive
std::vector<primitive> pipeline;
if (is_diff_dst_reordered) pipeline.push_back(reorder_diff_dst);
pipeline.push_back(batch_norm_bwd_prim);
stream(stream::kind::eager).submit(pipeline).wait();
// copy back diff sacle/shift to output tensors (diff scale/shift)
diff_scaleshift_data.resize(scaleshift_size);
auto it = std::begin(diff_scaleshift_data); auto it = std::begin(diff_scaleshift_data);
std::copy(it, std::next(it, ic), diff_scale->data<T>()); std::copy(it, std::next(it, ic), diff_scale_data);
std::copy(std::next(it, ic), std::end(diff_scaleshift_data), std::copy(std::next(it, ic), std::end(diff_scaleshift_data),
diff_shift->data<T>()); diff_shift_data);
// set layout/format of output tensors
diff_x->set_layout(DataLayout::kMKLDNN);
diff_x->set_format((memory::format)diff_src_memory.get_primitive_desc()
.desc()
.data.format);
} }
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_KERNEL(batch_norm, MKLDNN, paddle::platform::CPUPlace, REGISTER_OP_KERNEL(batch_norm, MKLDNN, ::paddle::platform::CPUPlace,
ops::BatchNormMKLDNNOpKernel<float>); ops::BatchNormMKLDNNOpKernel<float>);
REGISTER_OP_KERNEL(batch_norm_grad, MKLDNN, paddle::platform::CPUPlace, REGISTER_OP_KERNEL(batch_norm_grad, MKLDNN, ::paddle::platform::CPUPlace,
ops::BatchNormMKLDNNGradOpKernel<float>); ops::BatchNormMKLDNNGradOpKernel<float>);
...@@ -110,19 +110,19 @@ class BatchNormOp : public framework::OperatorWithKernel { ...@@ -110,19 +110,19 @@ class BatchNormOp : public framework::OperatorWithKernel {
ctx.Input<Tensor>("Variance")->type()), ctx.Input<Tensor>("Variance")->type()),
"Variance input should be of float type"); "Variance input should be of float type");
framework::LibraryType library_{framework::LibraryType::kPlain};
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready // TODO(pzelazko-intel): enable MKLDNN layout when it's ready
framework::LibraryType library = framework::LibraryType::kPlain;
framework::DataLayout layout = framework::DataLayout::kAnyLayout; framework::DataLayout layout = framework::DataLayout::kAnyLayout;
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain && if (library == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) { platform::CanMKLDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kMKLDNN; library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN; layout = framework::DataLayout::kMKLDNN;
} }
#endif #endif
return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout, return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout,
library_); library);
} }
}; };
...@@ -151,13 +151,15 @@ class BatchNormOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -151,13 +151,15 @@ class BatchNormOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("Variance", AddInput("Variance",
"The global variance (for training) " "The global variance (for training) "
"or estimated Variance (for testing)"); "or estimated Variance (for testing)");
AddOutput("Y", "result after normalization"); AddOutput("Y", "result after normalization").Reuse("X");
AddOutput("MeanOut", AddOutput("MeanOut",
"Share memory with Mean. " "Share memory with Mean. "
"Store the global mean when training"); "Store the global mean when training")
.Reuse("Mean");
AddOutput("VarianceOut", AddOutput("VarianceOut",
"Share memory with Variance. " "Share memory with Variance. "
"Store the global Variance when training"); "Store the global Variance when training")
.Reuse("Variance");
AddOutput("SavedMean", AddOutput("SavedMean",
"Mean of the current mini batch, " "Mean of the current mini batch, "
"will apply to output when training") "will apply to output when training")
...@@ -368,19 +370,21 @@ class BatchNormGradOp : public framework::OperatorWithKernel { ...@@ -368,19 +370,21 @@ class BatchNormGradOp : public framework::OperatorWithKernel {
PADDLE_THROW("can't find Y@GRAD"); PADDLE_THROW("can't find Y@GRAD");
} }
framework::LibraryType library_{framework::LibraryType::kPlain};
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready // TODO(pzelazko-intel): enable MKLDNN layout when it's ready
framework::DataLayout layout_ = framework::DataLayout::kAnyLayout; framework::LibraryType library = framework::LibraryType::kPlain;
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain && if (library == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) { platform::CanMKLDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kMKLDNN; library = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN; layout = framework::DataLayout::kMKLDNN;
} }
#endif #endif
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(), framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(),
layout_, library_); layout, library);
} }
}; };
......
...@@ -125,7 +125,8 @@ void Conv2DOpMaker::Make() { ...@@ -125,7 +125,8 @@ void Conv2DOpMaker::Make() {
"input image channels divided by the groups."); "input image channels divided by the groups.");
AddOutput("Output", AddOutput("Output",
"(Tensor) The output tensor of convolution operator. " "(Tensor) The output tensor of convolution operator. "
"The format of output tensor is also NCHW."); "The format of output tensor is also NCHW.")
.Reuse("Input");
AddAttr<std::vector<int>>("strides", AddAttr<std::vector<int>>("strides",
"(vector<int> default:{1, 1}), the " "(vector<int> default:{1, 1}), the "
"strides(h_stride, w_stride) of " "strides(h_stride, w_stride) of "
...@@ -220,7 +221,8 @@ void Conv3DOpMaker::Make() { ...@@ -220,7 +221,8 @@ void Conv3DOpMaker::Make() {
"input image channels divided by the groups."); "input image channels divided by the groups.");
AddOutput("Output", AddOutput("Output",
"(Tensor) The output tensor of convolution operator." "(Tensor) The output tensor of convolution operator."
"The format of output tensor is also NCDHW."); "The format of output tensor is also NCDHW.")
.Reuse("Input");
AddAttr<std::vector<int>>("strides", AddAttr<std::vector<int>>("strides",
"(vector<int>, default:{1, 1, 1}), the " "(vector<int>, default:{1, 1, 1}), the "
"strides(d_stride, h_stride, w_stride) of " "strides(d_stride, h_stride, w_stride) of "
......
...@@ -124,7 +124,8 @@ class CrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -124,7 +124,8 @@ class CrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker {
"Tensor<float/double> with shape [N x D]."); "Tensor<float/double> with shape [N x D].");
AddOutput("Y", AddOutput("Y",
"(Tensor, default Tensor<float>), a 2-D tensor with shape " "(Tensor, default Tensor<float>), a 2-D tensor with shape "
"[N x 1]. The cross entropy loss."); "[N x 1]. The cross entropy loss.")
.Reuse("X");
AddAttr<bool>("soft_label", AddAttr<bool>("soft_label",
"(bool, default false), a flag indicating whether to " "(bool, default false), a flag indicating whether to "
"interpretate the given labels as soft labels.") "interpretate the given labels as soft labels.")
......
if(WITH_DISTRIBUTE) if(NOT WITH_DISTRIBUTE)
return()
endif()
if(WITH_GRPC)
grpc_library(sendrecvop_grpc SRCS bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc grpc_library(sendrecvop_grpc SRCS bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc
request_handler_impl.cc rpc_client.cc rpc_server.cc grpc_server.cc variable_response.cc PROTO send_recv.proto DEPS lod_tensor request_handler_impl.cc rpc_client.cc rpc_server.cc grpc_server.cc variable_response.cc PROTO send_recv.proto DEPS lod_tensor
selected_rows memory) selected_rows memory)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set_source_files_properties(serde_test.cc grpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(grpc_serde_test.cc rpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(serde_test SRCS serde_test.cc variable_response.cc DEPS grpc++_unsecure grpc_unsecure gpr cc_test(serde_test SRCS grpc_serde_test.cc variable_response.cc DEPS grpc++_unsecure grpc_unsecure gpr
cares zlib protobuf sendrecvop_grpc SERIAL) cares zlib protobuf sendrecvop_grpc SERIAL)
cc_test(grpc_server_test SRCS grpc_server_test.cc DEPS sendrecvop_grpc cc_test(grpc_server_test SRCS rpc_server_test.cc DEPS sendrecvop_grpc
grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor
proto_desc lookup_table_op SERIAL) proto_desc lookup_table_op SERIAL)
return()
endif() endif()
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set_source_files_properties(brpc_server.cc brpc_client.cc rpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
brpc_library(sendrecvop_brpc SRCS brpc_client.cc brpc_server.cc rpc_server.cc rpc_client.cc request_handler_impl.cc
PROTO send_recv.proto
DEPS lod_tensor selected_rows memory)
find_library(OPENSSL_CRYPTO_LIBRARY_STATIC NAMES libcrypto.so)
ADD_LIBRARY(crypto SHARED IMPORTED GLOBAL)
SET_PROPERTY(TARGET crypto PROPERTY IMPORTED_LOCATION ${OPENSSL_CRYPTO_LIBRARY_STATIC})
find_library(OPENSSL_SSL_LIBRARY_STATIC NAMES libssl.so)
ADD_LIBRARY(ssl SHARED IMPORTED GLOBAL)
SET_PROPERTY(TARGET ssl PROPERTY IMPORTED_LOCATION ${OPENSSL_SSL_LIBRARY_STATIC})
cc_test(brpc_server_test SRCS rpc_server_test.cc DEPS sendrecvop_brpc
brpc protobuf leveldb gflags glog
protobuf executor proto_desc lookup_table_op snappystream snappy ssl crypto SERIAL)
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/detail/brpc_client.h"
#include "paddle/fluid/framework/threadpool.h"
namespace paddle {
namespace operators {
namespace detail {
DEFINE_int32(brpc_channel_num, 24,
"Number of channels to send requests connected to one server");
DEFINE_int32(timeout_ms, 30000, "RPC timeout in milliseconds");
DEFINE_int32(max_retry, 3, "Max retries(not including the first RPC)");
BRPCClient::~BRPCClient() { Wait(); }
void HandleSendResponse(brpc::Controller* cntl,
sendrecv::VoidMessage* response) {
// std::unique_ptr makes sure cntl/response will be deleted before returning.
std::unique_ptr<brpc::Controller> cntl_guard(cntl);
std::unique_ptr<sendrecv::VoidMessage> response_guard(response);
if (cntl->Failed()) {
LOG(WARNING) << "Fail to send EchoRequest, " << cntl->ErrorText();
return;
}
LOG(INFO) << "Received response from " << cntl->remote_side()
<< " latency=" << cntl->latency_us() << "us";
}
bool BRPCClient::AsyncSendVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name, int64_t time_out) {
const platform::DeviceContext* p_ctx = &ctx;
const std::string ep_val = ep;
const std::string var_name_val = var_name;
const framework::Scope* p_scope = &scope;
const auto ch_ptr = GetChannel(ep_val);
framework::AsyncIO(
[var_name_val, p_ctx, ep_val, p_scope, time_out, ch_ptr, this] {
auto ch_ctx = ch_ptr->Pop();
brpc::Controller* cntl = new brpc::Controller();
sendrecv::VoidMessage* response = new sendrecv::VoidMessage();
cntl->set_timeout_ms(time_out);
google::protobuf::Closure* done =
brpc::NewCallback(&HandleSendResponse, cntl, response);
sendrecv::VariableMessage request;
ch_ctx->stub->SendVariable(cntl, &request, response, done);
});
req_count_++;
return true;
}
void HandleGetResponse(brpc::Controller* cntl,
sendrecv::VariableMessage* response) {
// std::unique_ptr makes sure cntl/response will be deleted before returning.
std::unique_ptr<brpc::Controller> cntl_guard(cntl);
std::unique_ptr<sendrecv::VariableMessage> response_guard(response);
if (cntl->Failed()) {
LOG(WARNING) << "Fail to send EchoRequest, " << cntl->ErrorText();
return;
}
LOG(INFO) << "Received response from " << cntl->remote_side()
<< " latency=" << cntl->latency_us() << "us";
// framework::Variable* outvar = nullptr;
// DeserializeFromByteBuffer(ret_msg, *var_h.ctx, var_h.scope, &outvar);
}
bool BRPCClient::AsyncGetVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name, int64_t time_out) {
const platform::DeviceContext* p_ctx = &ctx;
const std::string ep_val = ep;
const std::string var_name_val = var_name;
const framework::Scope* p_scope = &scope;
const auto ch = GetChannel(ep_val);
framework::AsyncIO(
[var_name_val, ep_val, p_scope, p_ctx, time_out, ch, this] {});
req_count_++;
return true;
}
bool BRPCClient::AsyncPrefetchVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& in_var_name,
const std::string& out_var_name,
int64_t time_out) {
const platform::DeviceContext* p_ctx = &ctx;
const std::string ep_val = ep;
const std::string in_var_name_val = in_var_name;
const std::string out_var_name_val = out_var_name;
const framework::Scope* p_scope = &scope;
const auto ch = GetChannel(ep_val);
framework::AsyncIO([in_var_name_val, out_var_name_val, ep_val, p_scope, p_ctx,
time_out, ch, this] {});
req_count_++;
return true;
}
void BRPCClient::AsyncSendBatchBarrier(const std::string& ep,
int64_t time_out) {
req_count_++;
}
void BRPCClient::AsyncSendFetchBarrier(const std::string& ep,
int64_t time_out) {
req_count_++;
}
void BRPCClient::Wait() {
std::unique_lock<std::mutex> lk(sync_mutex_);
sync_cond_.wait(lk, [this] { return req_count_ == 0; });
}
ChannelQueuePtr BRPCClient::GetChannel(const std::string& ep) {
{
std::lock_guard<std::mutex> guard(chan_mutex_);
auto it = channels_.find(ep);
if (it != channels_.end()) {
return it->second;
}
}
ChannelQueuePtr q(new framework::BlockingQueue<ChannelContextPtr>());
brpc::ChannelOptions options;
options.protocol = "baidu_std";
options.connection_type = "pooled";
options.connect_timeout_ms = 100;
options.timeout_ms = FLAGS_timeout_ms /*milliseconds*/;
options.max_retry = FLAGS_max_retry;
for (int i = 0; i < FLAGS_brpc_channel_num; ++i) {
std::shared_ptr<ChannelContext> c(new ChannelContext());
if (c->channel.Init(ep.c_str(), &options) != 0) {
LOG(ERROR) << "Fail to initialize channel";
return nullptr;
}
c->stub.reset(new sendrecv::SendRecvService_Stub(
static_cast<google::protobuf::RpcChannel*>(&c->channel)));
q->Push(c);
}
{
std::lock_guard<std::mutex> guard(chan_mutex_);
channels_[ep] = q;
}
return q;
}
} // namespace detail
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <time.h>
#include <chrono> // NOLINT
#include <ctime>
#include <functional>
#include <iostream>
#include <map>
#include <mutex> // NOLINT
#include <string>
#include <vector>
#include "brpc/channel.h"
#include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/detail/rpc_client.h"
#include "paddle/fluid/operators/detail/send_recv.pb.h"
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
namespace paddle {
namespace operators {
namespace detail {
struct ChannelContext {
brpc::Channel channel;
std::shared_ptr<sendrecv::SendRecvService_Stub> stub;
};
typedef std::shared_ptr<ChannelContext> ChannelContextPtr;
typedef std::shared_ptr<framework::BlockingQueue<ChannelContextPtr>>
ChannelQueuePtr;
class BRPCClient : public RPCClient {
public:
BRPCClient() {}
virtual ~BRPCClient();
bool AsyncSendVar(const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& var_name,
int64_t time_out = RPCClient::rpc_time_out) override;
bool AsyncGetVar(const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& var_name,
int64_t time_out = RPCClient::rpc_time_out) override;
bool AsyncPrefetchVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& in_var_name,
const std::string& out_var_name,
int64_t time_out = RPCClient::rpc_time_out) override;
void AsyncSendBatchBarrier(
const std::string& ep,
int64_t time_out = RPCClient::rpc_time_out) override;
void AsyncSendFetchBarrier(
const std::string& ep,
int64_t time_out = RPCClient::rpc_time_out) override;
void Wait() override;
private:
void Proceed();
ChannelQueuePtr GetChannel(const std::string& ep);
private:
std::unordered_map<std::string, ChannelQueuePtr> channels_;
// mutex for Wait client sync
std::mutex sync_mutex_;
std::condition_variable sync_cond_;
std::atomic<int64_t> req_count_{0};
// mutex for GetChannel thread safety
std::mutex chan_mutex_;
DISABLE_COPY_AND_ASSIGN(BRPCClient);
};
} // namespace detail
} // namespace operators
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/detail/brpc_server.h"
#include "paddle/fluid/operators/detail/request_handler.h"
namespace sendrecv {
typedef std::unordered_map<std::string,
paddle::operators::detail::RequestHandler*>
HandlerMap;
class BRPCServiceImpl : public SendRecvService {
public:
explicit BRPCServiceImpl(const HandlerMap& rpc_call_map)
: request_send_h_(nullptr),
request_get_h_(nullptr),
request_prefetch_h_(nullptr) {
auto it = rpc_call_map.find(paddle::operators::detail::kRequestSend);
if (it != rpc_call_map.end()) {
request_send_h_ = it->second;
}
it = rpc_call_map.find(paddle::operators::detail::kRequestSend);
if (it != rpc_call_map.end()) {
request_get_h_ = it->second;
}
it = rpc_call_map.find(paddle::operators::detail::kRequestPrefetch);
if (it != rpc_call_map.end()) {
request_prefetch_h_ = it->second;
}
}
virtual ~BRPCServiceImpl() {}
void SendVariable(google::protobuf::RpcController* cntl_butil,
const VariableMessage* request, VoidMessage* response,
google::protobuf::Closure* done) override {
PADDLE_ENFORCE(request_send_h_ != nullptr,
"RequestSend handler should be registed first!");
brpc::ClosureGuard done_guard(done);
paddle::framework::Scope* local_scope = request_send_h_->scope();
paddle::framework::Variable* outvar = nullptr;
paddle::framework::Variable* invar = nullptr;
std::string varname = request->varname();
if (!request_send_h_->sync_mode()) {
local_scope = &request_send_h_->scope()->NewScope();
invar = local_scope->Var(varname);
} else {
invar = local_scope->FindVar(varname);
}
request_send_h_->Handle(varname, local_scope, invar, &outvar);
if (!request_send_h_->sync_mode()) {
request_send_h_->scope()->DeleteScope(local_scope);
}
}
void GetVariable(google::protobuf::RpcController* cntl_butil,
const VariableMessage* request, VariableMessage* response,
google::protobuf::Closure* done) override {
PADDLE_ENFORCE(request_get_h_ != nullptr,
"RequestGet handler should be registed first!");
}
void PrefetchVariable(google::protobuf::RpcController* cntl_butil,
const VariableMessage* request,
VariableMessage* response,
google::protobuf::Closure* done) override {
PADDLE_ENFORCE(request_prefetch_h_ != nullptr,
"kRequestPrefetch handler should be registed first!");
}
private:
paddle::operators::detail::RequestHandler* request_send_h_;
paddle::operators::detail::RequestHandler* request_get_h_;
paddle::operators::detail::RequestHandler* request_prefetch_h_;
};
} // namespace sendrecv
namespace paddle {
namespace operators {
namespace detail {
void AsyncBRPCServer::StartServer() {
// Instance of your service.
sendrecv::BRPCServiceImpl service_impl(rpc_call_map_);
// Add the service into server. Notice the second parameter, because the
// service is put on stack, we don't want server to delete it, otherwise
// use brpc::SERVER_OWNS_SERVICE.
if (server_.AddService(&service_impl, brpc::SERVER_DOESNT_OWN_SERVICE) != 0) {
LOG(FATAL) << "Fail to add service";
return;
}
brpc::ServerOptions options;
options.idle_timeout_sec = idle_timeout_s_;
options.max_concurrency = max_concurrency_;
if (server_.Start(bind_address_.c_str(), &options) != 0) {
LOG(FATAL) << "Fail to start EchoServer" << bind_address_;
return;
}
butil::EndPoint ep = server_.listen_address();
selected_port_ = ep.port;
{
std::lock_guard<std::mutex> lock(this->mutex_ready_);
ready_ = 1;
}
condition_ready_.notify_all();
server_.Join();
}
void AsyncBRPCServer::ShutDownImpl() { server_.Stop(1000); }
void AsyncBRPCServer::WaitServerReady() {
VLOG(3) << "AsyncGRPCServer is wait server ready";
std::unique_lock<std::mutex> lock(this->mutex_ready_);
condition_ready_.wait(lock, [=] { return this->ready_ == 1; });
VLOG(3) << "AsyncGRPCServer WaitSeverReady";
}
}; // namespace detail
}; // namespace operators
}; // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <condition_variable> // NOLINT
#include <mutex> // NOLINT
#include <string>
#include "brpc/server.h"
#include "paddle/fluid/operators/detail/rpc_server.h"
#include "paddle/fluid/operators/detail/send_recv.pb.h"
namespace paddle {
namespace operators {
namespace detail {
class AsyncBRPCServer final : public RPCServer {
public:
explicit AsyncBRPCServer(const std::string& address, int client_num)
: RPCServer(address, client_num), ready_(0) {}
virtual ~AsyncBRPCServer() {}
void StartServer() override;
void WaitServerReady() override;
private:
void ShutDownImpl() override;
brpc::Server server_;
static constexpr int idle_timeout_s_ = -1;
static constexpr int max_concurrency_ = 0;
std::mutex mutex_ready_;
std::condition_variable condition_ready_;
int ready_;
};
}; // namespace detail
}; // namespace operators
}; // namespace paddle
...@@ -19,6 +19,7 @@ limitations under the License. */ ...@@ -19,6 +19,7 @@ limitations under the License. */
#include <limits> #include <limits>
#include "paddle/fluid/framework/threadpool.h" #include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/detail/request_handler.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
namespace paddle { namespace paddle {
...@@ -33,6 +34,12 @@ void GRPCClient::InitEventLoop() { ...@@ -33,6 +34,12 @@ void GRPCClient::InitEventLoop() {
client_thread_.reset(new std::thread(std::bind(&GRPCClient::Proceed, this))); client_thread_.reset(new std::thread(std::bind(&GRPCClient::Proceed, this)));
} }
void GRPCClient::SendComplete() {
for (auto& it : channels_) {
this->AsyncSendComplete(it.first);
}
}
GRPCClient::~GRPCClient() { GRPCClient::~GRPCClient() {
Wait(); Wait();
cq_.Shutdown(); cq_.Shutdown();
...@@ -209,6 +216,19 @@ void GRPCClient::AsyncSendFetchBarrier(const std::string& ep, ...@@ -209,6 +216,19 @@ void GRPCClient::AsyncSendFetchBarrier(const std::string& ep,
req_count_++; req_count_++;
} }
void GRPCClient::AsyncSendComplete(const std::string& ep, int64_t time_out) {
const auto ch = GetChannel(ep);
BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
s->Prepare(time_out);
sendrecv::VariableMessage req;
req.set_varname(COMPLETE_MESSAGE);
auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
req_count_++;
}
void GRPCClient::Wait() { void GRPCClient::Wait() {
std::unique_lock<std::mutex> lk(sync_mutex_); std::unique_lock<std::mutex> lk(sync_mutex_);
sync_cond_.wait(lk, [this] { return req_count_ == 0; }); sync_cond_.wait(lk, [this] { return req_count_ == 0; });
......
...@@ -195,6 +195,8 @@ class GRPCClient : public RPCClient { ...@@ -195,6 +195,8 @@ class GRPCClient : public RPCClient {
void Wait() override; void Wait() override;
void SendComplete() override;
protected: protected:
void InitImpl() override; void InitImpl() override;
...@@ -204,6 +206,9 @@ class GRPCClient : public RPCClient { ...@@ -204,6 +206,9 @@ class GRPCClient : public RPCClient {
void Proceed(); void Proceed();
void AsyncSendComplete(const std::string& ep,
int64_t time_out = RPCClient::rpc_time_out);
std::shared_ptr<grpc::Channel> GetChannel(const std::string& ep); std::shared_ptr<grpc::Channel> GetChannel(const std::string& ep);
private: private:
......
...@@ -41,11 +41,22 @@ class RequestBase { ...@@ -41,11 +41,22 @@ class RequestBase {
virtual ~RequestBase() {} virtual ~RequestBase() {}
virtual void Process() = 0; virtual void Process() = 0;
CallStatus Status() { return status_; } CallStatus Status() const {
void SetStatus(CallStatus status) { status_ = status; } std::lock_guard<std::mutex> l(status_mu_);
return status_;
}
template <typename T>
void Finish(const T& reply, ServerAsyncResponseWriter<T>* responder) {
std::lock_guard<std::mutex> l(status_mu_);
status_ = FINISH;
responder->Finish(reply, ::grpc::Status::OK,
reinterpret_cast<void*>(static_cast<intptr_t>(req_id_)));
}
virtual std::string GetReqName() = 0; virtual std::string GetReqName() = 0;
protected: protected:
mutable std::mutex status_mu_;
::grpc::ServerContext ctx_; ::grpc::ServerContext ctx_;
GrpcService::AsyncService* service_; GrpcService::AsyncService* service_;
::grpc::ServerCompletionQueue* cq_; ::grpc::ServerCompletionQueue* cq_;
...@@ -80,9 +91,7 @@ class RequestSend final : public RequestBase { ...@@ -80,9 +91,7 @@ class RequestSend final : public RequestBase {
framework::Variable* outvar = nullptr; framework::Variable* outvar = nullptr;
request_handler_->Handle(varname, scope, invar, &outvar); request_handler_->Handle(varname, scope, invar, &outvar);
status_ = FINISH; Finish(reply_, &responder_);
responder_.Finish(reply_, ::grpc::Status::OK,
reinterpret_cast<void*>(static_cast<intptr_t>(req_id_)));
} }
protected: protected:
...@@ -122,9 +131,7 @@ class RequestGet final : public RequestBase { ...@@ -122,9 +131,7 @@ class RequestGet final : public RequestBase {
SerializeToByteBuffer(varname, outvar, *request_handler_->dev_ctx(), SerializeToByteBuffer(varname, outvar, *request_handler_->dev_ctx(),
&reply_); &reply_);
} }
status_ = FINISH; Finish(reply_, &responder_);
responder_.Finish(reply_, ::grpc::Status::OK,
reinterpret_cast<void*>(static_cast<intptr_t>(req_id_)));
} }
protected: protected:
...@@ -155,20 +162,20 @@ class RequestPrefetch final : public RequestBase { ...@@ -155,20 +162,20 @@ class RequestPrefetch final : public RequestBase {
void Process() override { void Process() override {
// prefetch process... // prefetch process...
std::string varname = request_->OutVarname(); std::string in_var_name = request_->Varname();
VLOG(3) << "RequestPrefetch " << varname; std::string out_var_name = request_->OutVarname();
VLOG(3) << "RequestPrefetch, in_var_name: " << in_var_name
<< " out_var_name: " << out_var_name;
auto scope = request_->GetMutableLocalScope(); auto scope = request_->GetMutableLocalScope();
auto invar = scope->FindVar(varname); auto invar = scope->FindVar(in_var_name);
framework::Variable* outvar = nullptr; framework::Variable* outvar = scope->FindVar(out_var_name);
request_handler_->Handle(varname, scope, invar, &outvar); request_handler_->Handle(in_var_name, scope, invar, &outvar, out_var_name);
SerializeToByteBuffer(varname, outvar, *request_handler_->dev_ctx(), SerializeToByteBuffer(out_var_name, outvar, *request_handler_->dev_ctx(),
&reply_); &reply_);
responder_.Finish(reply_, ::grpc::Status::OK, Finish(reply_, &responder_);
reinterpret_cast<void*>(static_cast<intptr_t>(req_id_)));
status_ = FINISH;
} }
protected: protected:
...@@ -282,7 +289,7 @@ void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name, ...@@ -282,7 +289,7 @@ void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name,
} else if (rpc_name == kRequestPrefetch) { } else if (rpc_name == kRequestPrefetch) {
b = new RequestPrefetch(&service_, cq.get(), handler, req_id); b = new RequestPrefetch(&service_, cq.get(), handler, req_id);
} else { } else {
PADDLE_ENFORCE(false, "not surpported rpc"); PADDLE_ENFORCE(false, "not supported rpc");
} }
reqs[req_id] = b; reqs[req_id] = b;
......
...@@ -53,6 +53,7 @@ class AsyncGRPCServer final : public RPCServer { ...@@ -53,6 +53,7 @@ class AsyncGRPCServer final : public RPCServer {
void StartServer() override; void StartServer() override;
private: private:
// HandleRequest needs to be thread-safe.
void HandleRequest( void HandleRequest(
::grpc::ServerCompletionQueue* cq, const std::string& rpc_name, ::grpc::ServerCompletionQueue* cq, const std::string& rpc_name,
std::function<void(const std::string&, int)> TryToRegisterNewOne); std::function<void(const std::string&, int)> TryToRegisterNewOne);
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#ifdef PADDLE_WITH_GRPC
#include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/operators/detail/grpc_server.h"
#define RPCSERVER_T detail::AsyncGRPCServer
#define RPCCLIENT_T detail::GRPCClient
#else
#include "paddle/fluid/operators/detail/brpc_client.h"
#include "paddle/fluid/operators/detail/brpc_server.h"
#define RPCSERVER_T detail::AsyncBRPCServer
#define RPCCLIENT_T detail::BRPCClient
#endif
...@@ -28,7 +28,6 @@ ...@@ -28,7 +28,6 @@
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/var_type.h" #include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -38,6 +37,11 @@ constexpr char kRequestSend[] = "RequestSend"; ...@@ -38,6 +37,11 @@ constexpr char kRequestSend[] = "RequestSend";
constexpr char kRequestGet[] = "RequestGet"; constexpr char kRequestGet[] = "RequestGet";
constexpr char kRequestPrefetch[] = "RequestPrefetch"; constexpr char kRequestPrefetch[] = "RequestPrefetch";
#define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV"
#define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV"
#define FETCH_BARRIER_MESSAGE "FETCH_BARRIER@RECV"
#define COMPLETE_MESSAGE "COMPLETE@RECV"
class RPCServer; class RPCServer;
class RequestHandler { class RequestHandler {
...@@ -57,9 +61,12 @@ class RequestHandler { ...@@ -57,9 +61,12 @@ class RequestHandler {
void SetDevCtx(const platform::DeviceContext* dev_ctx) { dev_ctx_ = dev_ctx; } void SetDevCtx(const platform::DeviceContext* dev_ctx) { dev_ctx_ = dev_ctx; }
void SetProgram(framework::ProgramDesc* program) { program_ = program; } void SetProgram(framework::ProgramDesc* program) { program_ = program; }
void SetExecutor(framework::Executor* executor) { executor_ = executor; } void SetExecutor(framework::Executor* executor) { executor_ = executor; }
// Used for dist lookup table prefetch
void SetPrefetchPreparedCtx( void SetPrefetchPreparedCtx(
std::unique_ptr<framework::ExecutorPrepareContext> prepared) { std::unordered_map<
prefetch_ctx_.reset(prepared.release()); std::string, std::shared_ptr<framework::ExecutorPrepareContext>>* g) {
prefetch_var_name_to_prepared_ctx_ = g;
} }
// Used for async. // Used for async.
...@@ -75,9 +82,6 @@ class RequestHandler { ...@@ -75,9 +82,6 @@ class RequestHandler {
bool sync_mode() { return sync_mode_; } bool sync_mode() { return sync_mode_; }
framework::Scope* scope() { return scope_; } framework::Scope* scope() { return scope_; }
const platform::DeviceContext* dev_ctx() { return dev_ctx_; } const platform::DeviceContext* dev_ctx() { return dev_ctx_; }
framework::ExecutorPrepareContext* prefetch_ctx() {
return prefetch_ctx_.get();
}
framework::ProgramDesc* program() { return program_; } framework::ProgramDesc* program() { return program_; }
framework::Executor* executor() { return executor_; } framework::Executor* executor() { return executor_; }
...@@ -96,8 +100,8 @@ class RequestHandler { ...@@ -96,8 +100,8 @@ class RequestHandler {
// *request_handler_->dev_ctx(), &reply_); // *request_handler_->dev_ctx(), &reply_);
// } // }
virtual bool Handle(const std::string& varname, framework::Scope* scope, virtual bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable* var, framework::Variable** outvar,
framework::Variable** outvar) = 0; const std::string& out_var_name = "") = 0;
protected: protected:
const bool sync_mode_; const bool sync_mode_;
...@@ -106,12 +110,17 @@ class RequestHandler { ...@@ -106,12 +110,17 @@ class RequestHandler {
framework::Executor* executor_; framework::Executor* executor_;
framework::Scope* scope_; framework::Scope* scope_;
framework::ProgramDesc* program_; framework::ProgramDesc* program_;
std::unique_ptr<framework::ExecutorPrepareContext> prefetch_ctx_;
// used for distribute lookup table prefetch
std::unordered_map<std::string,
std::shared_ptr<framework::ExecutorPrepareContext>>*
prefetch_var_name_to_prepared_ctx_;
// Used for async. // Used for async.
std::unordered_map<std::string, std::unordered_map<std::string,
std::shared_ptr<framework::ExecutorPrepareContext>>* std::shared_ptr<framework::ExecutorPrepareContext>>*
grad_to_prepared_ctx_; grad_to_prepared_ctx_;
RPCServer* rpc_server_; RPCServer* rpc_server_;
}; };
......
...@@ -16,15 +16,12 @@ ...@@ -16,15 +16,12 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/detail/request_handler_impl.h" #include "paddle/fluid/operators/detail/request_handler_impl.h"
#include "paddle/fluid/operators/detail/rpc_server.h" #include "paddle/fluid/operators/detail/rpc_server.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
#include "paddle/fluid/operators/detail/variable_response.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -33,7 +30,8 @@ namespace detail { ...@@ -33,7 +30,8 @@ namespace detail {
bool RequestSendHandler::Handle(const std::string& varname, bool RequestSendHandler::Handle(const std::string& varname,
framework::Scope* scope, framework::Scope* scope,
framework::Variable* invar, framework::Variable* invar,
framework::Variable** outvar) { framework::Variable** outvar,
const std::string& out_var_name) {
VLOG(4) << "RequestSendHandler:" << varname; VLOG(4) << "RequestSendHandler:" << varname;
// Async // Async
...@@ -52,6 +50,9 @@ bool RequestSendHandler::Handle(const std::string& varname, ...@@ -52,6 +50,9 @@ bool RequestSendHandler::Handle(const std::string& varname,
if (varname == BATCH_BARRIER_MESSAGE) { if (varname == BATCH_BARRIER_MESSAGE) {
VLOG(3) << "sync: recv batch barrier message"; VLOG(3) << "sync: recv batch barrier message";
rpc_server_->IncreaseBatchBarrier(kRequestSend); rpc_server_->IncreaseBatchBarrier(kRequestSend);
} else if (varname == COMPLETE_MESSAGE) {
VLOG(3) << "sync: recv complete message";
rpc_server_->DecreaseClientNum();
} else { } else {
VLOG(3) << "sync: received var_name: " << varname; VLOG(3) << "sync: received var_name: " << varname;
if (sync_mode_) { if (sync_mode_) {
...@@ -82,7 +83,8 @@ void RequestSendHandler::ResetSparseVarRecorder() { ...@@ -82,7 +83,8 @@ void RequestSendHandler::ResetSparseVarRecorder() {
bool RequestGetHandler::Handle(const std::string& varname, bool RequestGetHandler::Handle(const std::string& varname,
framework::Scope* scope, framework::Scope* scope,
framework::Variable* invar, framework::Variable* invar,
framework::Variable** outvar) { framework::Variable** outvar,
const std::string& out_var_name) {
VLOG(4) << "RequestGetHandler:" << varname; VLOG(4) << "RequestGetHandler:" << varname;
if (varname != FETCH_BARRIER_MESSAGE) { if (varname != FETCH_BARRIER_MESSAGE) {
...@@ -105,13 +107,14 @@ bool RequestGetHandler::Handle(const std::string& varname, ...@@ -105,13 +107,14 @@ bool RequestGetHandler::Handle(const std::string& varname,
bool RequestPrefetchHandler::Handle(const std::string& varname, bool RequestPrefetchHandler::Handle(const std::string& varname,
framework::Scope* scope, framework::Scope* scope,
framework::Variable* invar, framework::Variable* invar,
framework::Variable** outvar) { framework::Variable** outvar,
const std::string& out_var_name) {
VLOG(4) << "RequestPrefetchHandler " << varname; VLOG(4) << "RequestPrefetchHandler " << varname;
auto var_desc = program_->Block(0).FindVar(varname); auto var_desc = program_->Block(0).FindVar(out_var_name);
*outvar = scope->FindVar(varname);
InitializeVariable(*outvar, var_desc->GetType()); InitializeVariable(*outvar, var_desc->GetType());
executor_->RunPreparedContext(prefetch_ctx_.get(), scope); executor_->RunPreparedContext(
(*prefetch_var_name_to_prepared_ctx_)[varname].get(), scope);
return true; return true;
} }
......
...@@ -29,7 +29,6 @@ ...@@ -29,7 +29,6 @@
#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/var_type.h" #include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/detail/request_handler.h" #include "paddle/fluid/operators/detail/request_handler.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -40,7 +39,8 @@ class RequestSendHandler final : public RequestHandler { ...@@ -40,7 +39,8 @@ class RequestSendHandler final : public RequestHandler {
explicit RequestSendHandler(bool sync_mode) : RequestHandler(sync_mode) {} explicit RequestSendHandler(bool sync_mode) : RequestHandler(sync_mode) {}
virtual ~RequestSendHandler() {} virtual ~RequestSendHandler() {}
bool Handle(const std::string& varname, framework::Scope* scope, bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar) override; framework::Variable* var, framework::Variable** outvar,
const std::string& out_var_name = "") override;
void ResetSparseVarRecorder(); void ResetSparseVarRecorder();
private: private:
...@@ -53,7 +53,8 @@ class RequestGetHandler final : public RequestHandler { ...@@ -53,7 +53,8 @@ class RequestGetHandler final : public RequestHandler {
explicit RequestGetHandler(bool sync_mode) : RequestHandler(sync_mode) {} explicit RequestGetHandler(bool sync_mode) : RequestHandler(sync_mode) {}
virtual ~RequestGetHandler() {} virtual ~RequestGetHandler() {}
bool Handle(const std::string& varname, framework::Scope* scope, bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar) override; framework::Variable* var, framework::Variable** outvar,
const std::string& out_var_name = "") override;
}; };
class RequestPrefetchHandler final : public RequestHandler { class RequestPrefetchHandler final : public RequestHandler {
...@@ -61,7 +62,8 @@ class RequestPrefetchHandler final : public RequestHandler { ...@@ -61,7 +62,8 @@ class RequestPrefetchHandler final : public RequestHandler {
explicit RequestPrefetchHandler(bool sync_mode) : RequestHandler(sync_mode) {} explicit RequestPrefetchHandler(bool sync_mode) : RequestHandler(sync_mode) {}
virtual ~RequestPrefetchHandler() {} virtual ~RequestPrefetchHandler() {}
bool Handle(const std::string& varname, framework::Scope* scope, bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar) override; framework::Variable* var, framework::Variable** outvar,
const std::string& out_var_name = "") override;
}; };
} // namespace detail } // namespace detail
......
...@@ -26,6 +26,8 @@ namespace detail { ...@@ -26,6 +26,8 @@ namespace detail {
class RPCClient { class RPCClient {
public: public:
RPCClient() {}
virtual ~RPCClient() {}
virtual bool AsyncSendVar(const std::string& ep, virtual bool AsyncSendVar(const std::string& ep,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
const framework::Scope& scope, const framework::Scope& scope,
...@@ -51,6 +53,11 @@ class RPCClient { ...@@ -51,6 +53,11 @@ class RPCClient {
virtual void AsyncSendFetchBarrier(const std::string& ep, virtual void AsyncSendFetchBarrier(const std::string& ep,
int64_t time_out = rpc_time_out) = 0; int64_t time_out = rpc_time_out) = 0;
// SendComplete tells all the server that current trainer have no more data
// to train, so that the pserver can reduce it's barrier count, and continue
// to train with other trainers.
virtual void SendComplete() = 0;
virtual void Wait() = 0; virtual void Wait() = 0;
static constexpr int64_t rpc_time_out = 120 * 1000; static constexpr int64_t rpc_time_out = 120 * 1000;
......
...@@ -43,7 +43,7 @@ void RPCServer::SavePort() const { ...@@ -43,7 +43,7 @@ void RPCServer::SavePort() const {
void RPCServer::WaitBarrier(const std::string& rpc_name) { void RPCServer::WaitBarrier(const std::string& rpc_name) {
std::unique_lock<std::mutex> lock(this->mutex_); std::unique_lock<std::mutex> lock(this->mutex_);
barrier_cond_.wait(lock, [=] { barrier_cond_.wait(lock, [this, &rpc_name] {
return (barrier_counter_[rpc_name] >= client_num_ || exit_flag_.load()); return (barrier_counter_[rpc_name] >= client_num_ || exit_flag_.load());
}); });
...@@ -53,17 +53,21 @@ void RPCServer::WaitBarrier(const std::string& rpc_name) { ...@@ -53,17 +53,21 @@ void RPCServer::WaitBarrier(const std::string& rpc_name) {
void RPCServer::IncreaseBatchBarrier(const std::string rpc_name) { void RPCServer::IncreaseBatchBarrier(const std::string rpc_name) {
VLOG(3) << "RPCServer begin IncreaseBatchBarrier " << rpc_name; VLOG(3) << "RPCServer begin IncreaseBatchBarrier " << rpc_name;
int b = 0; int b = 0;
{
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
b = ++barrier_counter_[rpc_name]; b = ++barrier_counter_[rpc_name];
}
VLOG(3) << "RPCServer IncreaseBatchBarrier " << rpc_name
<< ", barrier_count:" << b << ", fan_in" << client_num_;
if (b >= client_num_) { if (b >= client_num_) {
lock.unlock();
barrier_cond_.notify_all(); barrier_cond_.notify_all();
lock.lock();
}
}
void RPCServer::DecreaseClientNum() {
{
std::unique_lock<std::mutex> lock(mutex_);
client_num_--;
} }
barrier_cond_.notify_all();
} }
void RPCServer::ResetBarrierCounter() { void RPCServer::ResetBarrierCounter() {
......
...@@ -60,7 +60,7 @@ class RPCServer { ...@@ -60,7 +60,7 @@ class RPCServer {
void SetCond(const std::string& rpc_name); void SetCond(const std::string& rpc_name);
void WaitCond(const std::string& rpc_name); void WaitCond(const std::string& rpc_name);
void IncreaseBatchBarrier(const std::string rpc_name); void IncreaseBatchBarrier(const std::string rpc_name);
void DecreaseClientNum();
void ResetBarrierCounter(); void ResetBarrierCounter();
protected: protected:
...@@ -79,8 +79,7 @@ class RPCServer { ...@@ -79,8 +79,7 @@ class RPCServer {
std::string bind_address_; std::string bind_address_;
std::atomic<int> exit_flag_; std::atomic<int> exit_flag_;
int selected_port_; int selected_port_;
int client_num_;
const int client_num_;
std::unordered_map<std::string, RequestHandler*> rpc_call_map_; std::unordered_map<std::string, RequestHandler*> rpc_call_map_;
std::unordered_map<std::string, int> rpc_thread_num_; std::unordered_map<std::string, int> rpc_thread_num_;
......
...@@ -17,15 +17,14 @@ limitations under the License. */ ...@@ -17,15 +17,14 @@ limitations under the License. */
#include <thread> // NOLINT #include <thread> // NOLINT
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/operators/detail/grpc_server.h"
#include "paddle/fluid/operators/detail/rpc_client.h"
#include "paddle/fluid/framework/block_desc.h" #include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/operators/detail/request_handler_impl.h" #include "paddle/fluid/operators/detail/request_handler_impl.h"
#include "paddle/fluid/operators/detail/rpc_client.h"
#include "paddle/fluid/operators/detail/rpc_server.h"
namespace framework = paddle::framework; namespace framework = paddle::framework;
namespace platform = paddle::platform; namespace platform = paddle::platform;
...@@ -33,7 +32,7 @@ namespace detail = paddle::operators::detail; ...@@ -33,7 +32,7 @@ namespace detail = paddle::operators::detail;
USE_OP(lookup_table); USE_OP(lookup_table);
std::unique_ptr<detail::AsyncGRPCServer> g_rpc_service; std::unique_ptr<detail::RPCServer> g_rpc_service;
std::unique_ptr<detail::RequestHandler> g_req_handler; std::unique_ptr<detail::RequestHandler> g_req_handler;
framework::BlockDesc* AppendPrefetchBlcok(framework::ProgramDesc* program) { framework::BlockDesc* AppendPrefetchBlcok(framework::ProgramDesc* program) {
...@@ -99,11 +98,17 @@ void StartServer() { ...@@ -99,11 +98,17 @@ void StartServer() {
framework::Executor exe(place); framework::Executor exe(place);
platform::CPUDeviceContext ctx(place); platform::CPUDeviceContext ctx(place);
auto* block = AppendPrefetchBlcok(&program); auto* block = AppendPrefetchBlcok(&program);
auto prepared = exe.Prepare(program, block->ID()); std::string in_var_name("ids");
std::vector<int> prefetch_block_ids{block->ID()};
auto prepared = exe.Prepare(program, prefetch_block_ids);
InitTensorsOnServer(&scope, &place, 10); InitTensorsOnServer(&scope, &place, 10);
std::unordered_map<std::string,
std::shared_ptr<framework::ExecutorPrepareContext>>
prefetch_var_name_to_prepared;
prefetch_var_name_to_prepared[in_var_name] = prepared[0];
g_req_handler->SetProgram(&program); g_req_handler->SetProgram(&program);
g_req_handler->SetPrefetchPreparedCtx(std::move(prepared)); g_req_handler->SetPrefetchPreparedCtx(&prefetch_var_name_to_prepared);
g_req_handler->SetDevCtx(&ctx); g_req_handler->SetDevCtx(&ctx);
g_req_handler->SetScope(&scope); g_req_handler->SetScope(&scope);
g_req_handler->SetExecutor(&exe); g_req_handler->SetExecutor(&exe);
...@@ -112,20 +117,19 @@ void StartServer() { ...@@ -112,20 +117,19 @@ void StartServer() {
g_req_handler->SetRPCServer(g_rpc_service.get()); g_req_handler->SetRPCServer(g_rpc_service.get());
std::thread server_thread( std::thread server_thread(
std::bind(&detail::AsyncGRPCServer::StartServer, g_rpc_service.get())); std::bind(&detail::RPCServer::StartServer, g_rpc_service.get()));
server_thread.join(); server_thread.join();
} }
TEST(PREFETCH, CPU) { TEST(PREFETCH, CPU) {
g_req_handler.reset(new detail::RequestPrefetchHandler(true)); g_req_handler.reset(new detail::RequestPrefetchHandler(true));
g_rpc_service.reset(new detail::AsyncGRPCServer("127.0.0.1:0", 1)); g_rpc_service.reset(new RPCSERVER_T("127.0.0.1:0", 1));
detail::RPCClient* client = detail::RPCClient::GetInstance<RPCCLIENT_T>();
std::thread server_thread(StartServer); std::thread server_thread(StartServer);
g_rpc_service->WaitServerReady(); g_rpc_service->WaitServerReady();
detail::RPCClient* client =
detail::RPCClient::GetInstance<detail::GRPCClient>();
int port = g_rpc_service->GetSelectedPort(); int port = g_rpc_service->GetSelectedPort();
std::string ep = paddle::string::Sprintf("127.0.0.1:%d", port); std::string ep = paddle::string::Sprintf("127.0.0.1:%d", port);
......
...@@ -14,6 +14,8 @@ limitations under the License. */ ...@@ -14,6 +14,8 @@ limitations under the License. */
syntax = "proto3"; syntax = "proto3";
package sendrecv; package sendrecv;
// option cc_generic_services = true;
service SendRecvService { service SendRecvService {
// For parameter server round-robin like hashing, do not split tensors. // For parameter server round-robin like hashing, do not split tensors.
// Send and recv only one tensor // Send and recv only one tensor
......
...@@ -32,16 +32,6 @@ namespace paddle { ...@@ -32,16 +32,6 @@ namespace paddle {
namespace operators { namespace operators {
namespace detail { namespace detail {
#define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV"
#define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV"
#define FETCH_BARRIER_MESSAGE "FETCH_BARRIER@RECV"
static int64_t GetTimestamp() {
struct timeval tp;
gettimeofday(&tp, NULL);
return tp.tv_sec * 1000 + tp.tv_usec / 1000;
}
typedef void (*DestroyCallback)(void*); typedef void (*DestroyCallback)(void*);
void SerializeToByteBuffer(const std::string& name, framework::Variable* var, void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
......
...@@ -59,47 +59,48 @@ class ElementwiseOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -59,47 +59,48 @@ class ElementwiseOpMaker : public framework::OpProtoAndCheckerMaker {
void Make() final { void Make() final {
AddInput("X", "(Tensor), The first input tensor of elementwise op."); AddInput("X", "(Tensor), The first input tensor of elementwise op.");
AddInput("Y", "(Tensor), The second input tensor of elementwise op."); AddInput("Y", "(Tensor), The second input tensor of elementwise op.");
AddOutput("Out", "The output of elementwise op."); AddOutput("Out", "The output of elementwise op.").Reuse("X");
AddAttr<int>("axis", AddAttr<int>("axis",
"(int, default -1). The start dimension index " "(int, default -1). The start dimension index "
"for broadcasting Y onto X.") "for broadcasting Y onto X.")
.SetDefault(-1) .SetDefault(-1)
.EqualGreaterThan(-1); .EqualGreaterThan(-1);
AddComment(string::Sprintf(R"DOC( AddComment(string::Sprintf(R"DOC(
Limited Elementwise %s Operator. Limited Elementwise %s Operator
The equation is: The equation is:
$$%s$$ $$%s$$
$X$ is a tensor of any dimension and the dimensions of tensor $Y$ must be - $X$: a tensor of any dimension.
smaller than or equal to the dimensions of $X$. - $Y$: a tensor whose dimensions must be less than or equal to the dimensions of $X$.
There are two cases for this operator: There are two cases for this operator:
1. The shape of $Y$ is same with $X$;
2. The shape of $Y$ is a congiguous subsequencet of $X$. The trailing dimensions
of size 1 for $Y$ will be ignored for the consideration of subsequence.
1. The shape of $Y$ is the same with $X$.
2. The shape of $Y$ is a continuous subsequence of $X$.
For case 2: For case 2:
$Y$ will be broadcasted to match the shape of $X$ and axis should be 1. Broadcast $Y$ to match the shape of $X$, where $axis$ is the start dimension index
set to index of the start dimension to broadcast $Y$ onto $X$. for broadcasting $Y$ onto $X$.
2. If $axis$ is -1 (default), $axis = rank(X) - rank(Y)$.
3. The trailing dimensions of size 1 for $Y$ will be ignored for the consideration of
subsequence, such as shape(Y) = (2, 1) => (2).
If axis is -1, it is treated as axis=rank(X)-rank(Y). For example:
For example
.. code-block:: python .. code-block:: python
shape(X) = (2, 3, 4, 5), shape(Y) = (,) shape(X) = (2, 3, 4, 5), shape(Y) = (,)
shape(X) = (2, 3, 4, 5), shape(Y) = (5,) shape(X) = (2, 3, 4, 5), shape(Y) = (5,)
shape(X) = (2, 3, 4, 5), shape(Y) = (4, 5) shape(X) = (2, 3, 4, 5), shape(Y) = (4, 5), with axis=-1(default) or axis=2
shape(X) = (2, 3, 4, 5), shape(Y) = (3, 4), with axis=1 shape(X) = (2, 3, 4, 5), shape(Y) = (3, 4), with axis=1
shape(X) = (2, 3, 4, 5), shape(Y) = (2), with axis=0 shape(X) = (2, 3, 4, 5), shape(Y) = (2), with axis=0
shape(X) = (2, 3, 4, 5), shape(Y) = (2, 1), with axis=0 shape(X) = (2, 3, 4, 5), shape(Y) = (2, 1), with axis=0
Either of the inputs $X$ and $Y$ or none can carry the LoD (Level of Details) The inputs $X$ and $Y$ can carry the different LoD information.
information. However, the output only shares the LoD information with input $X$. But the output only shares the LoD information with the input $X$.
)DOC", )DOC",
GetName(), GetEquation())); GetName(), GetEquation()));
......
...@@ -19,9 +19,7 @@ limitations under the License. */ ...@@ -19,9 +19,7 @@ limitations under the License. */
#include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/operators/detail/rpc_client.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
namespace paddle { namespace paddle {
...@@ -45,7 +43,7 @@ class FetchBarrierOp : public framework::OperatorBase { ...@@ -45,7 +43,7 @@ class FetchBarrierOp : public framework::OperatorBase {
platform::RecordEvent record_event(Type(), &ctx); platform::RecordEvent record_event(Type(), &ctx);
detail::RPCClient* rpc_client = detail::RPCClient* rpc_client =
detail::RPCClient::GetInstance<detail::GRPCClient>(); detail::RPCClient::GetInstance<RPCCLIENT_T>();
rpc_client->Wait(); rpc_client->Wait();
......
...@@ -21,8 +21,7 @@ limitations under the License. */ ...@@ -21,8 +21,7 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/threadpool.h" #include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/detail/grpc_client.h" #include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/operators/detail/grpc_server.h"
#include "paddle/fluid/operators/detail/request_handler_impl.h" #include "paddle/fluid/operators/detail/request_handler_impl.h"
#include "paddle/fluid/platform/nccl_helper.h" #include "paddle/fluid/platform/nccl_helper.h"
...@@ -61,8 +60,8 @@ class GenNCCLIdOp : public framework::OperatorBase { ...@@ -61,8 +60,8 @@ class GenNCCLIdOp : public framework::OperatorBase {
std::vector<std::string> endpoint_list = std::vector<std::string> endpoint_list =
Attr<std::vector<std::string>>("endpoint_list"); Attr<std::vector<std::string>>("endpoint_list");
detail::RPCClient* client = detail::RPCClient* client = detail::RPCClient::GetInstance<RPCCLIENT_T>();
detail::RPCClient::GetInstance<detail::GRPCClient>();
for (auto& ep : endpoint_list) { for (auto& ep : endpoint_list) {
VLOG(3) << "sending nccl id to " << ep; VLOG(3) << "sending nccl id to " << ep;
client->AsyncSendVar(ep, dev_ctx, *scope, NCCL_ID_VARNAME); client->AsyncSendVar(ep, dev_ctx, *scope, NCCL_ID_VARNAME);
...@@ -78,9 +77,11 @@ class GenNCCLIdOp : public framework::OperatorBase { ...@@ -78,9 +77,11 @@ class GenNCCLIdOp : public framework::OperatorBase {
// deleter will call GRPC Server's base class's dtor and // deleter will call GRPC Server's base class's dtor and
// that will cause a wired crash. // that will cause a wired crash.
detail::RequestSendHandler rpc_h(true); detail::RequestSendHandler rpc_h(true);
detail::AsyncGRPCServer rpc_service(endpoint, 1); std::unique_ptr<detail::RPCServer> rpc_service(
rpc_service.RegisterRPC(detail::kRequestSend, &rpc_h); new RPCSERVER_T(endpoint, 1));
rpc_h.SetRPCServer(&rpc_service);
rpc_service->RegisterRPC(detail::kRequestSend, &rpc_h);
rpc_h.SetRPCServer(rpc_service.get());
framework::ProgramDesc empty_program; framework::ProgramDesc empty_program;
framework::Executor executor(dev_ctx.GetPlace()); framework::Executor executor(dev_ctx.GetPlace());
...@@ -90,12 +91,13 @@ class GenNCCLIdOp : public framework::OperatorBase { ...@@ -90,12 +91,13 @@ class GenNCCLIdOp : public framework::OperatorBase {
rpc_h.SetExecutor(&executor); rpc_h.SetExecutor(&executor);
std::thread server_thread( std::thread server_thread(
std::bind(&detail::AsyncGRPCServer::StartServer, &rpc_service)); std::bind(&detail::RPCServer::StartServer, rpc_service.get()));
rpc_service.SetCond(detail::kRequestSend);
rpc_service->SetCond(detail::kRequestSend);
VLOG(3) << "start getting nccl id from trainer 0..."; VLOG(3) << "start getting nccl id from trainer 0...";
rpc_service.WaitBarrier(detail::kRequestSend); rpc_service->WaitBarrier(detail::kRequestSend);
VLOG(3) << "got nccl id and stop server..."; VLOG(3) << "got nccl id and stop server...";
rpc_service.ShutDown(); rpc_service->ShutDown();
VLOG(3) << "rpc server stopped"; VLOG(3) << "rpc server stopped";
server_thread.join(); server_thread.join();
} }
......
...@@ -19,7 +19,8 @@ limitations under the License. */ ...@@ -19,7 +19,8 @@ limitations under the License. */
#include <thread> // NOLINT #include <thread> // NOLINT
#include <vector> #include <vector>
#include "paddle/fluid/operators/detail/grpc_server.h" #include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/operators/detail/request_handler_impl.h" #include "paddle/fluid/operators/detail/request_handler_impl.h"
#include "paddle/fluid/operators/listen_and_serv_op.h" #include "paddle/fluid/operators/listen_and_serv_op.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
...@@ -89,19 +90,28 @@ void ListenAndServOp::SavePort() const { ...@@ -89,19 +90,28 @@ void ListenAndServOp::SavePort() const {
rpc_service_->SavePort(); rpc_service_->SavePort();
} }
void ListenAndServOp::RunSyncLoop(framework::Executor *executor, static int64_t GetTimestamp() {
framework::ProgramDesc *program, struct timeval tp;
gettimeofday(&tp, NULL);
return tp.tv_sec * 1000 + tp.tv_usec / 1000;
}
void ListenAndServOp::RunSyncLoop(
framework::Executor *executor, framework::ProgramDesc *program,
framework::Scope *recv_scope, framework::Scope *recv_scope,
framework::BlockDesc *prefetch_block) const { const std::vector<int> &prefetch_block_id_list) const {
size_t num_blocks = program->Size(); size_t num_blocks = program->Size();
PADDLE_ENFORCE_GE(num_blocks, 2, PADDLE_ENFORCE_GE(num_blocks, 2,
"server program should have at least 2 blocks"); "server program should have at least 2 blocks");
std::vector<int> block_list; std::vector<int> optimize_block_id_list;
for (size_t blkid = 1; blkid < num_blocks; ++blkid) { for (int blkid = 1; blkid < num_blocks; ++blkid) {
block_list.push_back(blkid); if (std::find(prefetch_block_id_list.begin(), prefetch_block_id_list.end(),
blkid) == prefetch_block_id_list.end()) {
optimize_block_id_list.push_back(blkid);
} }
auto optimize_prepared = executor->Prepare(*program, block_list); }
auto optimize_prepared = executor->Prepare(*program, optimize_block_id_list);
// Insert placeholder for block0 which holds current op itself. // Insert placeholder for block0 which holds current op itself.
optimize_prepared.insert( optimize_prepared.insert(
optimize_prepared.begin(), optimize_prepared.begin(),
...@@ -127,9 +137,11 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor, ...@@ -127,9 +137,11 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor,
int32_t last_parent_blkid = program->Block(1).Parent(); int32_t last_parent_blkid = program->Block(1).Parent();
std::vector<size_t> parallel_blkids; std::vector<size_t> parallel_blkids;
parallel_blkids.push_back(1); parallel_blkids.push_back(1);
double ts = detail::GetTimestamp(); double ts = GetTimestamp();
for (size_t blkid = 2; blkid < num_blocks; ++blkid) { for (size_t i = 1; i < optimize_block_id_list.size(); ++i) {
if (blkid != static_cast<size_t>(prefetch_block->ID())) { // skip the first optimize block because it is already in the
// parallel_blkids.
int blkid = optimize_block_id_list[i];
if (program->Block(blkid).Parent() != last_parent_blkid) { if (program->Block(blkid).Parent() != last_parent_blkid) {
ParallelExecuteBlocks(parallel_blkids, executor, optimize_prepared, ParallelExecuteBlocks(parallel_blkids, executor, optimize_prepared,
program, recv_scope); program, recv_scope);
...@@ -138,10 +150,9 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor, ...@@ -138,10 +150,9 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor,
} }
parallel_blkids.push_back(blkid); parallel_blkids.push_back(blkid);
} }
}
ParallelExecuteBlocks(parallel_blkids, executor, optimize_prepared, program, ParallelExecuteBlocks(parallel_blkids, executor, optimize_prepared, program,
recv_scope); recv_scope);
VLOG(2) << "run all blocks spent " << detail::GetTimestamp() - ts << "(ms)"; VLOG(2) << "run all blocks spent " << GetTimestamp() - ts << "(ms)";
rpc_service_->SetCond(detail::kRequestGet); rpc_service_->SetCond(detail::kRequestGet);
rpc_service_->WaitBarrier(detail::kRequestGet); rpc_service_->WaitBarrier(detail::kRequestGet);
...@@ -203,18 +214,19 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor, ...@@ -203,18 +214,19 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
} // while(true) } // while(true)
} }
static void FillRequestCtx(detail::RequestHandler *h, framework::Scope *scope, static void FillRequestCtx(
platform::DeviceContext *dev_ctx, detail::RequestHandler *h, framework::Scope *scope,
framework::Executor *executor, platform::DeviceContext *dev_ctx, framework::Executor *executor,
framework::ProgramDesc *program, framework::ProgramDesc *program,
framework::ExecutorPrepareContext *prefetch_ctx, std::unordered_map<std::string,
std::shared_ptr<framework::ExecutorPrepareContext>>
*prefetch_ctx,
detail::RPCServer *rpc_server) { detail::RPCServer *rpc_server) {
h->SetScope(scope); h->SetScope(scope);
h->SetDevCtx(dev_ctx); h->SetDevCtx(dev_ctx);
h->SetExecutor(executor); h->SetExecutor(executor);
h->SetProgram(program); h->SetProgram(program);
h->SetPrefetchPreparedCtx( h->SetPrefetchPreparedCtx(prefetch_ctx);
std::unique_ptr<framework::ExecutorPrepareContext>(prefetch_ctx));
h->SetRPCServer(rpc_server); h->SetRPCServer(rpc_server);
} }
...@@ -235,8 +247,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -235,8 +247,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
LOG(INFO) << "sync_mode:" << sync_mode << ", fan_in:" << fan_in LOG(INFO) << "sync_mode:" << sync_mode << ", fan_in:" << fan_in
<< ", end_point:" << endpoint; << ", end_point:" << endpoint;
// request_handler_.reset(new detail::GRPCRequestSendHandler(sync_mode)); rpc_service_.reset(new RPCSERVER_T(endpoint, fan_in));
rpc_service_.reset(new detail::AsyncGRPCServer(endpoint, fan_in));
request_send_handler_.reset(new detail::RequestSendHandler(sync_mode)); request_send_handler_.reset(new detail::RequestSendHandler(sync_mode));
request_get_handler_.reset(new detail::RequestGetHandler(sync_mode)); request_get_handler_.reset(new detail::RequestGetHandler(sync_mode));
request_prefetch_handler_.reset( request_prefetch_handler_.reset(
...@@ -248,17 +260,42 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -248,17 +260,42 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
request_prefetch_handler_.get()); request_prefetch_handler_.get());
auto *optimize_block = Attr<framework::BlockDesc *>(kOptimizeBlock); auto *optimize_block = Attr<framework::BlockDesc *>(kOptimizeBlock);
auto *prefetch_block = Attr<framework::BlockDesc *>(kPrefetchBlock);
auto *program = optimize_block->Program(); auto *program = optimize_block->Program();
framework::Executor executor(dev_place); framework::Executor executor(dev_place);
// prepare for prefetch // prepare for prefetch
VLOG(3) << "prefetch block id is " << prefetch_block->ID(); std::vector<int> prefetch_block_id_list;
auto prefetch_prepared = executor.Prepare(*program, prefetch_block->ID()); std::unordered_map<int, std::string> block_id_to_prefetch_var_name;
auto prefetch_var_name_to_block_id_str =
Attr<std::vector<std::string>>(kPrefetchVarNameToBlockId);
for (const auto &prefetch_var_name_and_id :
prefetch_var_name_to_block_id_str) {
std::vector<std::string> pieces;
split(prefetch_var_name_and_id, ':', &pieces);
VLOG(3) << "after split, prefetch_var = " << pieces[0]
<< ", id=" << pieces[1];
PADDLE_ENFORCE_EQ(pieces.size(), 2);
int block_id = std::stoi(pieces[1]);
prefetch_block_id_list.push_back(block_id);
block_id_to_prefetch_var_name[block_id] = pieces[0];
}
auto prefetch_prepared = executor.Prepare(*program, prefetch_block_id_list);
std::unordered_map<std::string,
std::shared_ptr<framework::ExecutorPrepareContext>>
prefetch_var_name_to_prepared_ctx;
for (size_t i = 0; i < prefetch_block_id_list.size(); ++i) {
auto block_id = prefetch_block_id_list[i];
auto prefetch_var_name = block_id_to_prefetch_var_name[block_id];
prefetch_var_name_to_prepared_ctx[prefetch_var_name] = prefetch_prepared[i];
}
auto f = std::bind(FillRequestCtx, std::placeholders::_1, &recv_scope, auto f = std::bind(FillRequestCtx, std::placeholders::_1, &recv_scope,
&dev_ctx, &executor, program, prefetch_prepared.release(), &dev_ctx, &executor, program,
rpc_service_.get()); &prefetch_var_name_to_prepared_ctx, rpc_service_.get());
f(request_send_handler_.get()); f(request_send_handler_.get());
f(request_get_handler_.get()); f(request_get_handler_.get());
...@@ -276,7 +313,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -276,7 +313,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
// Write to a file of server selected port for python use. // Write to a file of server selected port for python use.
SavePort(); SavePort();
if (sync_mode) { if (sync_mode) {
RunSyncLoop(&executor, program, &recv_scope, prefetch_block); RunSyncLoop(&executor, program, &recv_scope, prefetch_block_id_list);
} else { } else {
RunAsyncLoop(&executor, program); RunAsyncLoop(&executor, program);
} }
...@@ -302,8 +339,9 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -302,8 +339,9 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<bool>("sync_mode", "if works at sync_mode or not").SetDefault(true); AddAttr<bool>("sync_mode", "if works at sync_mode or not").SetDefault(true);
AddAttr<framework::BlockDesc *>(kOptimizeBlock, AddAttr<framework::BlockDesc *>(kOptimizeBlock,
"BlockID to run on server side."); "BlockID to run on server side.");
AddAttr<framework::BlockDesc *>(kPrefetchBlock, AddAttr<std::vector<std::string>>(kPrefetchVarNameToBlockId,
"prefetch block to run on server side."); "prefetch blocks to run on server side.")
.SetDefault({});
AddAttr<int>("Fanin", "How many clients send to this server.") AddAttr<int>("Fanin", "How many clients send to this server.")
.SetDefault(1); .SetDefault(1);
} }
......
...@@ -18,6 +18,7 @@ limitations under the License. */ ...@@ -18,6 +18,7 @@ limitations under the License. */
#include <atomic> #include <atomic>
#include <set> #include <set>
#include <string> #include <string>
#include <vector>
#include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
...@@ -30,7 +31,7 @@ namespace paddle { ...@@ -30,7 +31,7 @@ namespace paddle {
namespace operators { namespace operators {
constexpr char kOptimizeBlock[] = "OptimizeBlock"; constexpr char kOptimizeBlock[] = "OptimizeBlock";
constexpr char kPrefetchBlock[] = "PrefetchBlock"; constexpr char kPrefetchVarNameToBlockId[] = "prefetch_var_name_to_block_id";
void RunServer(std::shared_ptr<detail::RPCServer> service); void RunServer(std::shared_ptr<detail::RPCServer> service);
...@@ -46,7 +47,7 @@ class ListenAndServOp : public framework::OperatorBase { ...@@ -46,7 +47,7 @@ class ListenAndServOp : public framework::OperatorBase {
void RunSyncLoop(framework::Executor* executor, void RunSyncLoop(framework::Executor* executor,
framework::ProgramDesc* program, framework::ProgramDesc* program,
framework::Scope* recv_scope, framework::Scope* recv_scope,
framework::BlockDesc* prefetch_block) const; const std::vector<int>& prefetch_block_id_list) const;
void RunAsyncLoop(framework::Executor* executor, void RunAsyncLoop(framework::Executor* executor,
framework::ProgramDesc* program) const; framework::ProgramDesc* program) const;
......
...@@ -34,7 +34,7 @@ class MeanOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -34,7 +34,7 @@ class MeanOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
AddInput("X", "The input of mean op"); AddInput("X", "The input of mean op");
AddOutput("Out", "The output of mean op"); AddOutput("Out", "The output of mean op").Reuse("X");
AddComment(R"DOC( AddComment(R"DOC(
Mean Operator. Mean Operator.
......
...@@ -16,40 +16,34 @@ limitations under the License. */ ...@@ -16,40 +16,34 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename AttrType>
class NormOpMaker : public framework::OpProtoAndCheckerMaker { class NormOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
AddInput( AddInput("X", "(Tensor) A tensor of rank >= axis.");
"X", AddAttr<int>("axis",
"(Tensor) The input tensor of norm operator. " "The axis on which to apply normalization. If axis < 0, "
"The format of input tensor is NCHW. Where N is batch size, C is the " "the dimension to normalization is rank(X) + axis. -1 is "
"number of channels, H and W is the height and width of feature."); "the last dimension.");
AddInput("Scale", AddAttr<float>("epsilon",
"(Tensor) The input tensor of norm operator. " "(float, default 1e-10) The epsilon value is used "
"The format of input tensor is C * 1."); "to avoid division by zero.")
AddAttr<AttrType>("epsilon",
"(float, default 1e-10) Constant "
"for numerical stability.")
.SetDefault(1.0e-10f); .SetDefault(1.0e-10f);
AddOutput("Out", AddOutput("Norm",
"(Tensor) The output tensor of norm operator." "(Tensor) A tensor saved the `sqrt(sum(x) + epsion)` will "
"N * M." "be used in backward kernel.")
"M = C * H * W"); .AsIntermediate();
AddOutput("Out", "(Tensor) A tensor of the same shape as X.");
AddComment(R"DOC( AddComment(R"DOC(
"Input shape: $(N, C, H, W)$
Scale shape: $(C, 1)$ Given a tensor, apply 2-normalization along the provided axis.
Output shape: $(N, C, H, W)$
Where $$
forward y = \frac{x}{ \sqrt{\sum {x^2} + epsion }}
$$ $$
[\frac {x_{1}}{\sqrt{\sum{x_{i}^{2}}}} \frac {x_{2}}{\sqrt{\sum{x_{i}^{2}}}} \frac {x_{3}}{\sqrt{\sum{x_{i}^{2}}}} \cdot \cdot \cdot \frac {x_{n}}{\sqrt{\sum{x_{i}^{2}}}}]
$$ where, $\sum {x^2}$ is calculated along the `axis` dimension.
backward
$$ )DOC");
\frac{\frac{\mathrm{d}L }{\mathrm{d}y_{1}} - \frac {x_{1}\sum {\frac{\mathrm{d} L}{\mathrm{d} y_{j}}}x_{j}}{\sum x_{j}^{2}} }{\sqrt{\sum{x_{j}^{2}}}}
$$
)DOC");
} }
}; };
...@@ -58,15 +52,15 @@ class NormOp : public framework::OperatorWithKernel { ...@@ -58,15 +52,15 @@ class NormOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of NormOp" "Input(X) of NormOp should not be null.");
"should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Scale"),
"Input(Scale) of NormOp"
"should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of NormOp should not be null."); "Output(Out) of NormOp should not be null.");
auto in_x_dims = ctx->GetInputDim("X"); auto xdim = ctx->GetInputDim("X");
ctx->SetOutputDim("Out", in_x_dims); ctx->SetOutputDim("Out", xdim);
int axis = ctx->Attrs().Get<int>("axis");
if (axis < 0) axis = xdim.size() + axis;
xdim[axis] = 1;
ctx->SetOutputDim("Norm", xdim);
} }
}; };
...@@ -84,12 +78,12 @@ class NormOpGrad : public framework::OperatorWithKernel { ...@@ -84,12 +78,12 @@ class NormOpGrad : public framework::OperatorWithKernel {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(norm, ops::NormOp, ops::NormOpMaker<float>, using CPU = paddle::platform::CPUDeviceContext;
REGISTER_OPERATOR(norm, ops::NormOp, ops::NormOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>); paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(norm_grad, ops::NormOpGrad); REGISTER_OPERATOR(norm_grad, ops::NormOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(norm, ops::NormKernel<CPU, float>,
norm, ops::NormKernel<paddle::platform::CPUDeviceContext, float>, ops::NormKernel<CPU, double>);
ops::NormKernel<paddle::platform::CPUDeviceContext, double, float>); REGISTER_OP_CPU_KERNEL(norm_grad, ops::NormGradKernel<CPU, float>,
REGISTER_OP_CPU_KERNEL( ops::NormGradKernel<CPU, double>);
norm_grad, ops::NormGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::NormGradKernel<paddle::platform::CPUDeviceContext, double, float>);
...@@ -16,9 +16,9 @@ limitations under the License. */ ...@@ -16,9 +16,9 @@ limitations under the License. */
#include "paddle/fluid/operators/norm_op.h" #include "paddle/fluid/operators/norm_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL( using CUDA = paddle::platform::CUDADeviceContext;
norm, ops::NormKernel<paddle::platform::CUDADeviceContext, float>,
ops::NormKernel<paddle::platform::CUDADeviceContext, double, float>); REGISTER_OP_CUDA_KERNEL(norm, ops::NormKernel<CUDA, float>,
REGISTER_OP_CUDA_KERNEL( ops::NormKernel<CUDA, double>);
norm_grad, ops::NormGradKernel<paddle::platform::CUDADeviceContext, float>, REGISTER_OP_CUDA_KERNEL(norm_grad, ops::NormGradKernel<CUDA, float>,
ops::NormGradKernel<paddle::platform::CUDADeviceContext, double, float>); ops::NormGradKernel<CUDA, double>);
...@@ -19,156 +19,110 @@ limitations under the License. */ ...@@ -19,156 +19,110 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename DeviceContext, typename T, typename AttrType = T> inline void GetDims(const framework::DDim& dim, int axis, int* pre, int* n,
int* post) {
*pre = 1;
*post = 1;
*n = dim[axis];
for (int i = 0; i < axis; ++i) {
(*pre) *= dim[i];
}
for (int i = axis + 1; i < dim.size(); ++i) {
(*post) *= dim[i];
}
}
template <typename DeviceContext, typename T>
class NormKernel : public framework::OpKernel<T> { class NormKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& ctx) const override {
const framework::Tensor* in_x = context.Input<framework::Tensor>("X"); auto* in_x = ctx.Input<framework::Tensor>("X");
const framework::Tensor* scale = context.Input<framework::Tensor>("Scale"); auto* out_y = ctx.Output<framework::Tensor>("Out");
auto* out = context.Output<framework::Tensor>("Out"); auto* out_norm = ctx.Output<framework::Tensor>("Norm");
auto epsilon = static_cast<T>(context.Attr<AttrType>("epsilon")); out_y->mutable_data<T>(ctx.GetPlace());
out->mutable_data<T>(context.GetPlace()); out_norm->mutable_data<T>(ctx.GetPlace());
int batch_size = in_x->dims()[0];
int channels = in_x->dims()[1]; auto xdim = in_x->dims();
int height = in_x->dims()[2]; auto ndim = out_norm->dims();
int width = in_x->dims()[3]; T eps = static_cast<T>(ctx.Attr<float>("epsilon"));
int fea_len = height * width; int axis = ctx.Attr<int>("axis");
auto* place = if (axis < 0) axis = xdim.size() + axis;
context.template device_context<DeviceContext>().eigen_device(); int pre, n, post;
auto x = GetDims(xdim, axis, &pre, &n, &post);
framework::EigenMatrix<T, Eigen::RowMajor, Eigen::DenseIndex>::From(
*in_x, framework::make_ddim({batch_size, fea_len * channels})); auto* place = ctx.template device_context<DeviceContext>().eigen_device();
// get square
framework::Tensor x_square; Eigen::DSizes<int, 3> shape(pre, n, post);
x_square.mutable_data<T>(in_x->dims(), context.GetPlace()); Eigen::DSizes<int, 2> norm_shape(pre, post);
auto x_square_eigen =
framework::EigenMatrix<T, Eigen::RowMajor, Eigen::DenseIndex>::From( auto x_e = framework::EigenVector<T>::Flatten(*in_x);
x_square, framework::make_ddim({batch_size, fea_len * channels})); auto y_e = framework::EigenVector<T>::Flatten(*out_y);
x_square_eigen.device(*place) = x.square(); auto norm_e = framework::EigenVector<T>::Flatten(*out_norm);
auto scale_eigen = auto x = x_e.reshape(shape);
framework::EigenVector<T, Eigen::RowMajor, Eigen::DenseIndex>::Flatten( auto y = y_e.reshape(shape);
*scale); auto norm = norm_e.reshape(norm_shape);
for (int n = 0; n < batch_size; ++n) {
framework::Tensor in_x_batch = in_x->Slice(n, n + 1); Eigen::DSizes<int, 1> rdim(1);
auto in_x_batch_eigen = // y = x / sqrt((sum(x * x) + epsilon))
framework::EigenMatrix<T, Eigen::RowMajor, Eigen::DenseIndex>::From( // norm = sqrt(sum(x * x) + epsilon)
in_x_batch, framework::make_ddim({channels, fea_len})); auto sum = x.pow(2).sum(rdim) + eps;
framework::Tensor x_square_batch = x_square.Slice(n, n + 1); norm.device(*place) = sum.sqrt();
auto x_square_batch_eigen = // y = x / norm
framework::EigenMatrix<T, Eigen::RowMajor, Eigen::DenseIndex>::From( Eigen::DSizes<int, 3> rshape(pre, 1, post);
x_square_batch, framework::make_ddim({channels, fea_len})); Eigen::DSizes<int, 3> bcast(1, n, 1);
framework::Tensor out_batch = out->Slice(n, n + 1); y.device(*place) = x / norm.reshape(rshape).broadcast(bcast);
auto out_batch_eigen =
framework::EigenMatrix<T, Eigen::RowMajor, Eigen::DenseIndex>::From(
out_batch, framework::make_ddim({channels, fea_len}));
framework::Tensor tmp_tensor;
tmp_tensor.mutable_data<T>(framework::make_ddim({1, fea_len}),
context.GetPlace());
auto tmp = framework::EigenVector<T, Eigen::RowMajor,
Eigen::DenseIndex>::Flatten(tmp_tensor);
// get colsum and sqrt , inverse
auto dim = Eigen::array<int, 1>({{0}});
tmp.device(*place) = x_square_batch_eigen.sum(dim);
tmp.device(*place) = (tmp + epsilon).sqrt().inverse();
Eigen::array<int, 2> broadcast_dim_col;
broadcast_dim_col[1] = 1;
broadcast_dim_col[0] = channels;
out_batch_eigen.device(*place) =
in_x_batch_eigen * (tmp.broadcast(broadcast_dim_col));
Eigen::array<int, 2> broadcast_dim_row;
broadcast_dim_row[1] = fea_len;
broadcast_dim_row[0] = 1;
out_batch_eigen.device(*place) =
out_batch_eigen * (scale_eigen.broadcast(broadcast_dim_row));
}
} }
}; };
template <typename DeviceContext, typename T, typename AttrType = T> template <typename DeviceContext, typename T, typename AttrType = T>
class NormGradKernel : public framework::OpKernel<T> { class NormGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& ctx) const override {
const framework::Tensor* in_x = context.Input<framework::Tensor>("X"); auto* in_x = ctx.Input<framework::Tensor>("X");
const framework::Tensor* scale = context.Input<framework::Tensor>("Scale"); auto* in_norm = ctx.Input<framework::Tensor>("Norm");
const framework::Tensor* out_grad = auto* in_dy = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
context.Input<framework::Tensor>(framework::GradVarName("Out")); auto* out_dx = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto epsilon = static_cast<T>(context.Attr<AttrType>("epsilon")); out_dx->mutable_data<T>(ctx.GetPlace());
framework::Tensor* in_x_grad =
context.Output<framework::Tensor>(framework::GradVarName("X")); auto xdim = in_x->dims();
in_x_grad->mutable_data<T>(context.GetPlace()); int axis = ctx.Attr<int>("axis");
int batch_size = in_x->dims()[0]; if (axis < 0) axis = xdim.size() + axis;
int channels = in_x->dims()[1]; int pre, n, post;
int height = in_x->dims()[2]; GetDims(xdim, axis, &pre, &n, &post);
int width = in_x->dims()[3];
int fea_len = height * width; auto* place = ctx.template device_context<DeviceContext>().eigen_device();
auto* place =
context.template device_context<DeviceContext>().eigen_device(); auto x_e = framework::EigenVector<T>::Flatten(*in_x);
auto dy_e = framework::EigenVector<T>::Flatten(*in_dy);
auto scale_eigen = auto norm_e = framework::EigenVector<T>::Flatten(*in_norm);
framework::EigenVector<T, Eigen::RowMajor, Eigen::DenseIndex>::Flatten( auto dx_e = framework::EigenVector<T>::Flatten(*out_dx);
*scale);
auto x = Eigen::DSizes<int, 3> shape(pre, n, post);
framework::EigenMatrix<T, Eigen::RowMajor, Eigen::DenseIndex>::From( Eigen::DSizes<int, 2> norm_shape(pre, post);
*in_x, framework::make_ddim({batch_size, fea_len * channels})); auto x = x_e.reshape(shape);
// get square auto dy = dy_e.reshape(shape);
framework::Tensor x_square; auto norm = norm_e.reshape(norm_shape);
x_square.mutable_data<T>(in_x->dims(), context.GetPlace()); auto dx = dx_e.reshape(shape);
auto x_square_eigen =
framework::EigenMatrix<T, Eigen::RowMajor, Eigen::DenseIndex>::From( framework::Tensor rsum;
x_square, framework::make_ddim({batch_size, fea_len * channels})); rsum.mutable_data<T>({pre, post}, ctx.GetPlace());
x_square_eigen.device(*place) = x.square(); auto sum = framework::EigenTensor<T, 2>::From(rsum);
for (int n = 0; n < batch_size; ++n) { Eigen::DSizes<int, 1> rdim(1);
framework::Tensor in_x_batch = in_x->Slice(n, n + 1); Eigen::DSizes<int, 3> bcast(1, n, 1);
auto in_x_batch_eigen = Eigen::DSizes<int, 3> rshape(pre, 1, post);
framework::EigenMatrix<T, Eigen::RowMajor, Eigen::DenseIndex>::From(
in_x_batch, framework::make_ddim({channels, fea_len})); // dx = ( dy/sqrt(sum(x*x)) ) * [1 - x*sum(x) / (sum(x*x) + e)]
framework::Tensor in_g_batch = in_x_grad->Slice(n, n + 1); // = [dy - dy * x * sum(x) / (sum(x*x) + e)] / sqrt(sum(x*x))
auto in_g_batch_eigen = // = [dy - x * sum(x*dy) / (sum(x*x) + e)] / sqrt(sum(x*x))
framework::EigenMatrix<T, Eigen::RowMajor, Eigen::DenseIndex>::From( // 1. sum = sum(x*dy)
in_g_batch, framework::make_ddim({channels, fea_len})); sum.device(*place) = (x * dy).sum(rdim);
framework::Tensor x_square_batch = x_square.Slice(n, n + 1); // 2. dx = x * sum
auto x_square_batch_eigen = dx.device(*place) = sum.reshape(rshape).broadcast(bcast) * x;
framework::EigenMatrix<T, Eigen::RowMajor, Eigen::DenseIndex>::From( // 3. dx / (sum(x*x) + e)
x_square_batch, framework::make_ddim({channels, fea_len})); // where, norm.pow(2) = sum(x*x) + e, which is calculated in forward.
framework::Tensor outg_batch = out_grad->Slice(n, n + 1); dx.device(*place) = dx / norm.pow(2).broadcast(bcast);
auto outg_batch_eigen = // 4. [dy - dx] / sqrt(sum(x*x))
framework::EigenMatrix<T, Eigen::RowMajor, Eigen::DenseIndex>::From( dx.device(*place) = (dy - dx) / norm.broadcast(bcast);
outg_batch, framework::make_ddim({channels, fea_len}));
framework::Tensor tmp_tensor;
tmp_tensor.mutable_data<T>(framework::make_ddim({1, fea_len}),
context.GetPlace());
auto tmp_eigen =
framework::EigenVector<T, Eigen::RowMajor,
Eigen::DenseIndex>::Flatten(tmp_tensor);
auto dim = Eigen::array<int, 1>({{0}});
tmp_eigen.device(*place) = (in_x_batch_eigen * outg_batch_eigen).sum(dim);
framework::Tensor norm_tmp_tensor;
norm_tmp_tensor.mutable_data<T>(framework::make_ddim({1, fea_len}),
context.GetPlace());
auto norm_tmp_eigen =
framework::EigenVector<T, Eigen::RowMajor,
Eigen::DenseIndex>::Flatten(norm_tmp_tensor);
norm_tmp_eigen.device(*place) =
(x_square_batch_eigen.sum(dim) + epsilon).sqrt();
Eigen::array<int, 2> broadcast_dim_col;
broadcast_dim_col[1] = 1;
broadcast_dim_col[0] = channels;
in_g_batch_eigen.device(*place) =
in_x_batch_eigen * tmp_eigen.broadcast(broadcast_dim_col);
in_g_batch_eigen.device(*place) =
in_g_batch_eigen /
(norm_tmp_eigen * norm_tmp_eigen).broadcast(broadcast_dim_col);
in_g_batch_eigen.device(*place) = outg_batch_eigen - in_g_batch_eigen;
// outg_batch_eigen + (in_g_batch_eigen * -1);
in_g_batch_eigen.device(*place) =
in_g_batch_eigen / norm_tmp_eigen.broadcast(broadcast_dim_col);
Eigen::array<int, 2> broadcast_dim_row;
broadcast_dim_row[1] = fea_len;
broadcast_dim_row[0] = 1;
in_g_batch_eigen.device(*place) =
in_g_batch_eigen * (scale_eigen.broadcast(broadcast_dim_row));
}
} }
}; };
} // namespace operators } // namespace operators
......
...@@ -18,9 +18,14 @@ limitations under the License. */ ...@@ -18,9 +18,14 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using mkldnn::memory; // Note: paddle has also "memory" namespace using framework::DataLayout;
using mkldnn::pooling_forward; using mkldnn::memory;
using mkldnn::pooling_backward; using mkldnn::pooling_backward;
using mkldnn::pooling_forward;
using mkldnn::primitive;
using mkldnn::reorder;
using mkldnn::stream;
using platform::to_void_cast;
// Generate keys for storing/retriving primitives for this operator // Generate keys for storing/retriving primitives for this operator
// TODO(jczaja): Make hashing function more optimial // TODO(jczaja): Make hashing function more optimial
...@@ -55,8 +60,9 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -55,8 +60,9 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const Tensor* input = ctx.Input<Tensor>("X"); const Tensor* input = ctx.Input<Tensor>("X");
Tensor* output = ctx.Output<Tensor>("Out"); Tensor* output = ctx.Output<Tensor>("Out");
// Get an unique name from "argument" name of "Out" variable PADDLE_ENFORCE(input->layout() == DataLayout::kMKLDNN &&
// This name will be used as key when saving info into device context input->format() != memory::format::format_undef,
"Wrong layout/format set for Input tensor");
std::string pooling_type = ctx.Attr<std::string>("pooling_type"); std::string pooling_type = ctx.Attr<std::string>("pooling_type");
std::vector<int> ksize = ctx.Attr<std::vector<int>>("ksize"); std::vector<int> ksize = ctx.Attr<std::vector<int>>("ksize");
...@@ -82,6 +88,9 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -82,6 +88,9 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std::vector<int> src_tz = paddle::framework::vectorize2int(input->dims()); std::vector<int> src_tz = paddle::framework::vectorize2int(input->dims());
std::vector<int> dst_tz = paddle::framework::vectorize2int(output->dims()); std::vector<int> dst_tz = paddle::framework::vectorize2int(output->dims());
auto input_format = input->format();
memory::format output_format{memory::format::format_undef};
const std::string key = gethash(src_tz, pooling_type, ksize, strides, const std::string key = gethash(src_tz, pooling_type, ksize, strides,
paddings, ctx.op().Output("Out")); paddings, ctx.op().Output("Out"));
const std::string key_pool_p = key + "@pool_p"; const std::string key_pool_p = key + "@pool_p";
...@@ -94,16 +103,17 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -94,16 +103,17 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto pool_p = auto pool_p =
std::static_pointer_cast<pooling_forward>(dev_ctx.GetBlob(key_pool_p)); std::static_pointer_cast<pooling_forward>(dev_ctx.GetBlob(key_pool_p));
if (pool_p == nullptr) { if (pool_p == nullptr) {
// TODO(pzelazko-intel): support more formats auto src_md = platform::MKLDNNMemDesc(
src_tz, platform::MKLDNNGetDataType<T>(), input_format);
auto src_md = /* create memory descriptor for pooling without specified format
platform::MKLDNNMemDesc(src_tz, platform::MKLDNNGetDataType<T>(), * ('any') which lets a primitive (pooling in this case) choose
mkldnn::memory::format::nchw); * the memory format preferred for best performance
auto dst_md = */
platform::MKLDNNMemDesc(dst_tz, platform::MKLDNNGetDataType<T>(), auto dst_md = platform::MKLDNNMemDesc(dst_tz, mkldnn::memory::f32,
mkldnn::memory::format::nchw); mkldnn::memory::format::any);
std::shared_ptr<pooling_forward::primitive_desc> pool_pd = std::shared_ptr<mkldnn::pooling_forward::primitive_desc> pool_pd =
CreatePrimitiveDesc(src_md, dst_md, strides, paddings, ksize, CreatePrimitiveDesc(src_md, dst_md, strides, paddings, ksize,
pooling_type, mkldnn_engine); pooling_type, mkldnn_engine);
...@@ -116,20 +126,22 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -116,20 +126,22 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
// save pool_workspace_memory to be referred in backward path // save pool_workspace_memory to be referred in backward path
dev_ctx.SetBlob(key_pool_workspace_memory, workspace_memory); dev_ctx.SetBlob(key_pool_workspace_memory, workspace_memory);
auto pool_src_memory_p = std::make_shared<memory>( auto src_memory = std::make_shared<memory>(pool_pd->src_primitive_desc(),
memory::primitive_desc{src_md, mkldnn_engine}, to_void_cast<T>(input_data));
static_cast<void*>(const_cast<T*>(input_data))); auto dst_memory =
dev_ctx.SetBlob(key_pool_src_mem_p, pool_src_memory_p); std::make_shared<memory>(pool_pd->dst_primitive_desc(), output_data);
auto pool_dst_memory_p = std::make_shared<memory>( dev_ctx.SetBlob(key_pool_src_mem_p, src_memory);
memory::primitive_desc{dst_md, mkldnn_engine}, dev_ctx.SetBlob(key_pool_dst_mem_p, dst_memory);
static_cast<void*>(output_data));
dev_ctx.SetBlob(key_pool_dst_mem_p, pool_dst_memory_p);
pool_p = std::make_shared<pooling_forward>( pool_p = std::make_shared<pooling_forward>(*pool_pd, *(src_memory.get()),
*pool_pd, *(pool_src_memory_p.get()), *(pool_dst_memory_p.get()), *(dst_memory.get()),
*workspace_memory); *workspace_memory);
dev_ctx.SetBlob(key_pool_p, pool_p); dev_ctx.SetBlob(key_pool_p, pool_p);
output_format =
(memory::format)dst_memory->get_primitive_desc().desc().data.format;
} else { } else {
// Primitives already exist // Primitives already exist
auto pool_src_memory_p = auto pool_src_memory_p =
...@@ -140,14 +152,20 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -140,14 +152,20 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std::static_pointer_cast<memory>(dev_ctx.GetBlob(key_pool_dst_mem_p)); std::static_pointer_cast<memory>(dev_ctx.GetBlob(key_pool_dst_mem_p));
PADDLE_ENFORCE(pool_dst_memory_p != nullptr, PADDLE_ENFORCE(pool_dst_memory_p != nullptr,
"Fail to find pooling dst mem_p in device context"); "Fail to find pooling dst mem_p in device context");
pool_src_memory_p->set_data_handle( pool_src_memory_p->set_data_handle(to_void_cast<T>(input_data));
reinterpret_cast<void*>(const_cast<T*>(input_data)));
pool_dst_memory_p->set_data_handle(output_data); pool_dst_memory_p->set_data_handle(output_data);
output_format = (memory::format)pool_dst_memory_p->get_primitive_desc()
.desc()
.data.format;
} }
// push primitive to stream and wait until it's executed // push primitive to stream and wait until it's executed
std::vector<mkldnn::primitive> pipeline{*(pool_p.get())}; std::vector<mkldnn::primitive> pipeline{*(pool_p.get())};
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait(); stream(stream::kind::eager).submit(pipeline).wait();
output->set_layout(DataLayout::kMKLDNN);
output->set_format(output_format);
} }
private: private:
...@@ -194,6 +212,13 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -194,6 +212,13 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
const Tensor* out_grad = ctx.Input<Tensor>(framework::GradVarName("Out")); const Tensor* out_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
Tensor* in_x_grad = ctx.Output<Tensor>(framework::GradVarName("X")); Tensor* in_x_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
PADDLE_ENFORCE(in_x->layout() == DataLayout::kMKLDNN &&
in_x->format() != memory::format::format_undef,
"Wrong layout/format set for Input X tensor");
PADDLE_ENFORCE(out_grad->layout() == DataLayout::kMKLDNN &&
out_grad->format() != memory::format::format_undef,
"Wrong layout/format set for Input output_grad tensor");
std::string pooling_type = ctx.Attr<std::string>("pooling_type"); std::string pooling_type = ctx.Attr<std::string>("pooling_type");
std::vector<int> ksize = ctx.Attr<std::vector<int>>("ksize"); std::vector<int> ksize = ctx.Attr<std::vector<int>>("ksize");
std::vector<int> strides = ctx.Attr<std::vector<int>>("strides"); std::vector<int> strides = ctx.Attr<std::vector<int>>("strides");
...@@ -212,6 +237,7 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -212,6 +237,7 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
const T* out_grad_data = out_grad->data<T>(); const T* out_grad_data = out_grad->data<T>();
T* in_x_grad_data = in_x_grad->mutable_data<T>(ctx.GetPlace()); T* in_x_grad_data = in_x_grad->mutable_data<T>(ctx.GetPlace());
memory::format in_x_grad_format{memory::format::format_undef};
std::vector<int> diff_src_tz = std::vector<int> diff_src_tz =
paddle::framework::vectorize2int(in_x_grad->dims()); paddle::framework::vectorize2int(in_x_grad->dims());
...@@ -225,39 +251,48 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -225,39 +251,48 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
const std::string key_pool_bwd_p = key + "@pool_bwd_p"; const std::string key_pool_bwd_p = key + "@pool_bwd_p";
const std::string key_pool_diff_src_mem_p = key + "@pool_diff_src_mem_p"; const std::string key_pool_diff_src_mem_p = key + "@pool_diff_src_mem_p";
const std::string key_pool_diff_dst_mem_p = key + "@pool_diff_dst_mem_p"; const std::string key_pool_diff_dst_mem_p = key + "@pool_diff_dst_mem_p";
const std::string key_pool_src_mem_p = key + "@pool_src_mem_p";
const std::string key_pool_dst_mem_p = key + "@pool_dst_mem_p";
const std::string key_pool_pd = key + "@pool_pd"; const std::string key_pool_pd = key + "@pool_pd";
const std::string key_pool_workspace_memory = const std::string key_pool_workspace_memory =
key + "@pool_workspace_memory"; key + "@pool_workspace_memory";
auto user_diff_dst_memory =
memory({{{diff_dst_tz}, memory::data_type::f32, out_grad->format()},
mkldnn_engine},
to_void_cast<T>(out_grad_data));
std::shared_ptr<memory> diff_src_memory;
std::shared_ptr<memory> diff_dst_memory;
auto dst_memory =
std::static_pointer_cast<memory>(dev_ctx.GetBlob(key_pool_dst_mem_p));
PADDLE_ENFORCE(dst_memory != nullptr,
"Fail to find dst_memory in device context");
primitive reorder_diff_dst;
bool is_diff_dst_reordered = false;
auto pool_bwd_p = std::static_pointer_cast<pooling_backward>( auto pool_bwd_p = std::static_pointer_cast<pooling_backward>(
dev_ctx.GetBlob(key_pool_bwd_p)); dev_ctx.GetBlob(key_pool_bwd_p));
if (pool_bwd_p == nullptr) { if (pool_bwd_p == nullptr) {
auto diff_src_md = // Retrieve src_memory/dst_memory saved in forward pass
platform::MKLDNNMemDesc(diff_src_tz, platform::MKLDNNGetDataType<T>(), auto src_memory =
mkldnn::memory::format::nchw); std::static_pointer_cast<memory>(dev_ctx.GetBlob(key_pool_src_mem_p));
auto diff_dst_md = PADDLE_ENFORCE(src_memory != nullptr,
platform::MKLDNNMemDesc(diff_dst_tz, platform::MKLDNNGetDataType<T>(), "Fail to find src_memory in device context");
mkldnn::memory::format::nchw);
// Retrieve pool_pd/pool_workspace_memory from device context // Retrieve pool_pd/pool_workspace_memory from device context
auto pool_pd = auto pool_pd =
std::static_pointer_cast<mkldnn::pooling_forward::primitive_desc>( std::static_pointer_cast<mkldnn::pooling_forward::primitive_desc>(
dev_ctx.GetBlob(key_pool_pd)); dev_ctx.GetBlob(key_pool_pd));
PADDLE_ENFORCE(pool_pd != nullptr, PADDLE_ENFORCE(pool_pd != nullptr,
"Fail to find pool_pd in device context"); "Fail to find pool_pd in device context");
auto workspace_memory = std::static_pointer_cast<memory>(
auto workspace_memory = std::static_pointer_cast<mkldnn::memory>(
dev_ctx.GetBlob(key_pool_workspace_memory)); dev_ctx.GetBlob(key_pool_workspace_memory));
PADDLE_ENFORCE(workspace_memory != nullptr, PADDLE_ENFORCE(workspace_memory != nullptr,
"Fail to find workspace_memory in device context"); "Fail to find workspace_memory in device context");
auto pool_diff_src_memory_p = std::make_shared<memory>(memory( // create memory descriptors for pooling
{diff_src_md, mkldnn_engine}, static_cast<void*>(in_x_grad_data))); auto diff_src_md = src_memory.get()->get_primitive_desc().desc();
dev_ctx.SetBlob(key_pool_diff_src_mem_p, pool_diff_src_memory_p); auto diff_dst_md = dst_memory.get()->get_primitive_desc().desc();
auto pool_diff_dst_memory_p = std::make_shared<memory>(
memory({diff_dst_md, mkldnn_engine},
static_cast<void*>(const_cast<T*>(out_grad_data))));
dev_ctx.SetBlob(key_pool_diff_dst_mem_p, pool_diff_dst_memory_p);
auto pool_bwd_desc = mkldnn::pooling_backward::desc( auto pool_bwd_desc = mkldnn::pooling_backward::desc(
pooling_type == "max" ? mkldnn::algorithm::pooling_max pooling_type == "max" ? mkldnn::algorithm::pooling_max
...@@ -267,35 +302,74 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -267,35 +302,74 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
auto pool_bwd_pd = mkldnn::pooling_backward::primitive_desc( auto pool_bwd_pd = mkldnn::pooling_backward::primitive_desc(
pool_bwd_desc, mkldnn_engine, *pool_pd); pool_bwd_desc, mkldnn_engine, *pool_pd);
// reorder between user_diff_dst and pool diff_dst if needed
diff_dst_memory = std::make_shared<memory>(user_diff_dst_memory);
if (memory::primitive_desc(dst_memory->get_primitive_desc()) !=
user_diff_dst_memory.get_primitive_desc()) {
diff_dst_memory =
std::make_shared<memory>(dst_memory.get()->get_primitive_desc());
reorder_diff_dst = reorder(user_diff_dst_memory, *diff_dst_memory);
is_diff_dst_reordered = true;
}
diff_src_memory = std::make_shared<memory>(
pool_bwd_pd.diff_src_primitive_desc(), in_x_grad_data);
dev_ctx.SetBlob(key_pool_diff_src_mem_p, diff_src_memory);
dev_ctx.SetBlob(key_pool_diff_dst_mem_p, diff_dst_memory);
pool_bwd_p = std::make_shared<pooling_backward>( pool_bwd_p = std::make_shared<pooling_backward>(
pool_bwd_pd, *(pool_diff_dst_memory_p.get()), *workspace_memory, pool_bwd_pd, *(diff_dst_memory.get()), *workspace_memory,
*(pool_diff_src_memory_p)); *(diff_src_memory));
dev_ctx.SetBlob(key_pool_bwd_p, pool_bwd_p); dev_ctx.SetBlob(key_pool_bwd_p, pool_bwd_p);
} else { } else {
// Primitives already exist // Primitives already exist
auto pool_diff_src_memory_p = std::static_pointer_cast<memory>( diff_src_memory = std::static_pointer_cast<memory>(
dev_ctx.GetBlob(key_pool_diff_src_mem_p)); dev_ctx.GetBlob(key_pool_diff_src_mem_p));
PADDLE_ENFORCE(pool_diff_src_memory_p != nullptr, PADDLE_ENFORCE(diff_src_memory != nullptr,
"Fail to find pooling src mem_p in device context"); "Fail to find pooling src mem_p in device context");
auto pool_diff_dst_memory_p = std::static_pointer_cast<memory>( diff_dst_memory = std::static_pointer_cast<memory>(
dev_ctx.GetBlob(key_pool_diff_dst_mem_p)); dev_ctx.GetBlob(key_pool_diff_dst_mem_p));
PADDLE_ENFORCE(pool_diff_dst_memory_p != nullptr, PADDLE_ENFORCE(diff_dst_memory != nullptr,
"Fail to find pooling dst mem_p in device context"); "Fail to find pooling dst mem_p in device context");
pool_diff_src_memory_p->set_data_handle(
reinterpret_cast<void*>(in_x_grad_data)); diff_src_memory->set_data_handle(reinterpret_cast<void*>(in_x_grad_data));
pool_diff_dst_memory_p->set_data_handle(const_cast<T*>(out_grad_data)); diff_dst_memory->set_data_handle(const_cast<T*>(out_grad_data));
// reorder between user_diff_dst and pool diff_dst if needed
if (memory::primitive_desc(dst_memory->get_primitive_desc()) !=
user_diff_dst_memory.get_primitive_desc()) {
diff_dst_memory =
std::make_shared<memory>(dst_memory.get()->get_primitive_desc());
reorder_diff_dst = reorder(user_diff_dst_memory, *diff_dst_memory);
is_diff_dst_reordered = true;
} }
}
in_x_grad_format = (memory::format)diff_src_memory->get_primitive_desc()
.desc()
.data.format;
// push primitive to stream and wait until it's executed // push primitive to stream and wait until it's executed
std::vector<mkldnn::primitive> pipeline{*(pool_bwd_p.get())}; std::vector<mkldnn::primitive> pipeline;
if (is_diff_dst_reordered) {
pipeline.push_back(reorder_diff_dst);
}
pipeline.push_back(*(pool_bwd_p.get()));
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait(); mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
in_x_grad->set_layout(DataLayout::kMKLDNN);
in_x_grad->set_format(in_x_grad_format);
} // Compute() } // Compute()
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_KERNEL(pool2d, MKLDNN, ::paddle::platform::CPUPlace, REGISTER_OP_KERNEL(pool2d, MKLDNN, ::paddle::platform::CPUPlace,
paddle::operators::PoolMKLDNNOpKernel<float>); ops::PoolMKLDNNOpKernel<float>);
REGISTER_OP_KERNEL(pool2d_grad, MKLDNN, ::paddle::platform::CPUPlace, REGISTER_OP_KERNEL(pool2d_grad, MKLDNN, ::paddle::platform::CPUPlace,
paddle::operators::PoolMKLDNNGradOpKernel<float>); ops::PoolMKLDNNGradOpKernel<float>);
...@@ -151,7 +151,8 @@ void Pool2dOpMaker::Make() { ...@@ -151,7 +151,8 @@ void Pool2dOpMaker::Make() {
"The format of output tensor is also NCHW, " "The format of output tensor is also NCHW, "
"where N is batch size, C is the number of channels, " "where N is batch size, C is the number of channels, "
"H is the height of the feature, " "H is the height of the feature, "
"and W is the width of the feature."); "and W is the width of the feature.")
.Reuse("X");
AddAttr<std::string>("pooling_type", AddAttr<std::string>("pooling_type",
"(string), pooling type, can be \"max\" for max-pooling " "(string), pooling type, can be \"max\" for max-pooling "
...@@ -244,7 +245,8 @@ void Pool3dOpMaker::Make() { ...@@ -244,7 +245,8 @@ void Pool3dOpMaker::Make() {
"The format of output tensor is also NCDHW, " "The format of output tensor is also NCDHW, "
"where N is batch size, C is " "where N is batch size, C is "
"the number of channels, and D, H and W is the depth, height and " "the number of channels, and D, H and W is the depth, height and "
"width of the feature, respectively."); "width of the feature, respectively.")
.Reuse("X");
AddAttr<std::string>("pooling_type", AddAttr<std::string>("pooling_type",
"(string) Pooling type, can be \"max\" for max-pooling " "(string) Pooling type, can be \"max\" for max-pooling "
......
...@@ -18,7 +18,7 @@ limitations under the License. */ ...@@ -18,7 +18,7 @@ limitations under the License. */
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/grpc_client.h" #include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/operators/send_recv_util.h" #include "paddle/fluid/operators/send_recv_util.h"
namespace paddle { namespace paddle {
...@@ -42,7 +42,7 @@ class PrefetchOp : public framework::OperatorBase { ...@@ -42,7 +42,7 @@ class PrefetchOp : public framework::OperatorBase {
auto& ctx = *pool.Get(place); auto& ctx = *pool.Get(place);
detail::RPCClient* rpc_client = detail::RPCClient* rpc_client =
detail::RPCClient::GetInstance<detail::GRPCClient>(); detail::RPCClient::GetInstance<RPCCLIENT_T>();
for (size_t i = 0; i < ins.size(); i++) { for (size_t i = 0; i < ins.size(); i++) {
if (NeedSend(scope, ins[i])) { if (NeedSend(scope, ins[i])) {
......
...@@ -20,7 +20,7 @@ namespace reader { ...@@ -20,7 +20,7 @@ namespace reader {
class BatchReader : public framework::DecoratedReader { class BatchReader : public framework::DecoratedReader {
public: public:
BatchReader(ReaderBase* reader, int batch_size) BatchReader(const std::shared_ptr<ReaderBase>& reader, int batch_size)
: DecoratedReader(reader), batch_size_(batch_size) { : DecoratedReader(reader), batch_size_(batch_size) {
buffer_.reserve(batch_size_); buffer_.reserve(batch_size_);
} }
......
...@@ -22,7 +22,8 @@ namespace reader { ...@@ -22,7 +22,8 @@ namespace reader {
class CustomReader : public framework::DecoratedReader { class CustomReader : public framework::DecoratedReader {
public: public:
CustomReader(ReaderBase* reader, const framework::BlockDesc& sub_block, CustomReader(const std::shared_ptr<ReaderBase>& reader,
const framework::BlockDesc& sub_block,
const std::vector<std::string>& source_var_names, const std::vector<std::string>& source_var_names,
const std::vector<std::string>& sink_var_names) const std::vector<std::string>& sink_var_names)
: DecoratedReader(reader), : DecoratedReader(reader),
......
...@@ -34,7 +34,8 @@ static constexpr size_t kChannelSize = 1; // kCacheSize - 2 ...@@ -34,7 +34,8 @@ static constexpr size_t kChannelSize = 1; // kCacheSize - 2
class DoubleBufferReader : public framework::DecoratedReader { class DoubleBufferReader : public framework::DecoratedReader {
public: public:
explicit DoubleBufferReader( explicit DoubleBufferReader(
ReaderBase* reader, platform::Place target_place = platform::CPUPlace()) const std::shared_ptr<ReaderBase>& reader,
platform::Place target_place = platform::CPUPlace())
: DecoratedReader(reader), place_(target_place) { : DecoratedReader(reader), place_(target_place) {
cpu_tensor_cache_.resize(kCacheSize); cpu_tensor_cache_.resize(kCacheSize);
gpu_tensor_cache_.resize(kCacheSize); gpu_tensor_cache_.resize(kCacheSize);
......
...@@ -21,7 +21,7 @@ namespace reader { ...@@ -21,7 +21,7 @@ namespace reader {
class MultiPassReader : public framework::DecoratedReader { class MultiPassReader : public framework::DecoratedReader {
public: public:
MultiPassReader(ReaderBase* reader, int pass_num) MultiPassReader(const std::shared_ptr<ReaderBase>& reader, int pass_num)
: DecoratedReader(reader), pass_num_(pass_num), pass_count_(0) {} : DecoratedReader(reader), pass_num_(pass_num), pass_count_(0) {}
void ReadNext(std::vector<framework::LoDTensor>* out) override { void ReadNext(std::vector<framework::LoDTensor>* out) override {
......
...@@ -23,7 +23,8 @@ namespace reader { ...@@ -23,7 +23,8 @@ namespace reader {
class ShuffleReader : public framework::DecoratedReader { class ShuffleReader : public framework::DecoratedReader {
public: public:
ShuffleReader(ReaderBase* reader, size_t buffer_size, size_t seed = 0) ShuffleReader(const std::shared_ptr<ReaderBase>& reader, size_t buffer_size,
size_t seed = 0)
: DecoratedReader(reader), buffer_size_(buffer_size), seed_(seed) { : DecoratedReader(reader), buffer_size_(buffer_size), seed_(seed) {
VLOG(10) << "Create shuffle reader of " << reader_; VLOG(10) << "Create shuffle reader of " << reader_;
if (seed_ == 0) { if (seed_ == 0) {
......
...@@ -21,7 +21,8 @@ namespace reader { ...@@ -21,7 +21,8 @@ namespace reader {
class ThreadedReader : public framework::DecoratedReader { class ThreadedReader : public framework::DecoratedReader {
public: public:
explicit ThreadedReader(ReaderBase* reader) : DecoratedReader(reader) {} explicit ThreadedReader(const std::shared_ptr<ReaderBase>& reader)
: DecoratedReader(reader) {}
void ReadNext(std::vector<framework::LoDTensor>* out) override { void ReadNext(std::vector<framework::LoDTensor>* out) override {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
......
...@@ -19,8 +19,7 @@ limitations under the License. */ ...@@ -19,8 +19,7 @@ limitations under the License. */
#include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
namespace paddle { namespace paddle {
...@@ -45,7 +44,7 @@ class RecvOp : public framework::OperatorBase { ...@@ -45,7 +44,7 @@ class RecvOp : public framework::OperatorBase {
platform::RecordEvent record_event(Type(), &ctx); platform::RecordEvent record_event(Type(), &ctx);
detail::RPCClient* rpc_client = detail::RPCClient* rpc_client =
detail::RPCClient::GetInstance<detail::GRPCClient>(); detail::RPCClient::GetInstance<RPCCLIENT_T>();
for (size_t i = 0; i < outs.size(); i++) { for (size_t i = 0; i < outs.size(); i++) {
VLOG(3) << "getting " << outs[i] << " from " << epmap[i]; VLOG(3) << "getting " << outs[i] << " from " << epmap[i];
...@@ -78,9 +77,15 @@ This operator can get variables from server side. ...@@ -78,9 +77,15 @@ This operator can get variables from server side.
} }
}; };
class RecvOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext* ctx) const override {}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(recv, ops::RecvOp, ops::RecvOpMaker); REGISTER_OPERATOR(recv, ops::RecvOp, paddle::framework::EmptyGradOpMaker,
ops::RecvOpMaker, ops::RecvOpShapeInference);
...@@ -19,8 +19,8 @@ limitations under the License. */ ...@@ -19,8 +19,8 @@ limitations under the License. */
#include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
namespace paddle { namespace paddle {
...@@ -45,7 +45,7 @@ class SendBarrierOp : public framework::OperatorBase { ...@@ -45,7 +45,7 @@ class SendBarrierOp : public framework::OperatorBase {
platform::RecordEvent record_event(Type(), &ctx); platform::RecordEvent record_event(Type(), &ctx);
detail::RPCClient* rpc_client = detail::RPCClient* rpc_client =
detail::RPCClient::GetInstance<detail::GRPCClient>(); detail::RPCClient::GetInstance<RPCCLIENT_T>();
VLOG(3) << "SendBarrierOp sync_mode:" << sync_mode; VLOG(3) << "SendBarrierOp sync_mode:" << sync_mode;
......
...@@ -16,10 +16,9 @@ limitations under the License. */ ...@@ -16,10 +16,9 @@ limitations under the License. */
#include <ostream> #include <ostream>
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/grpc_client.h" #include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/operators/send_recv_util.h" #include "paddle/fluid/operators/send_recv_util.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
...@@ -36,12 +35,9 @@ class SendOp : public framework::OperatorBase { ...@@ -36,12 +35,9 @@ class SendOp : public framework::OperatorBase {
void RunImpl(const framework::Scope& scope, void RunImpl(const framework::Scope& scope,
const platform::Place& place) const override { const platform::Place& place) const override {
auto ins = Inputs("X"); auto ins = Inputs("X");
auto outs = Outputs("Out");
std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap");
std::vector<std::string> endpoints =
Attr<std::vector<std::string>>("endpoints");
bool sync_mode = Attr<bool>("sync_mode"); std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap");
int sync_send = Attr<int>("sync_mode");
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place); auto& ctx = *pool.Get(place);
...@@ -50,37 +46,19 @@ class SendOp : public framework::OperatorBase { ...@@ -50,37 +46,19 @@ class SendOp : public framework::OperatorBase {
platform::RecordEvent record_event(Type(), &ctx); platform::RecordEvent record_event(Type(), &ctx);
detail::RPCClient* rpc_client = detail::RPCClient* rpc_client =
detail::RPCClient::GetInstance<detail::GRPCClient>(); detail::RPCClient::GetInstance<RPCCLIENT_T>();
for (size_t i = 0; i < ins.size(); i++) { for (size_t i = 0; i < ins.size(); i++) {
if (NeedSend(scope, ins[i])) { if (NeedSend(scope, ins[i])) {
VLOG(3) << "sending " << ins[i] << " to " << epmap[i]; VLOG(3) << "sending " << ins[i] << " to " << epmap[i];
// TODO(Yancey1989): we need to use an IO threadpool which has
// a larger number of threads than the computing threadpool.
rpc_client->AsyncSendVar(epmap[i], ctx, scope, ins[i]); rpc_client->AsyncSendVar(epmap[i], ctx, scope, ins[i]);
} else { } else {
VLOG(3) << "don't send no-initialied variable: " << ins[i]; VLOG(3) << "don't send no-initialied variable: " << ins[i];
} }
} }
rpc_client->Wait(); if (sync_send) {
if (sync_mode) {
for (auto& ep : endpoints) {
VLOG(3) << "batch barrier, ep: " << ep;
rpc_client->AsyncSendBatchBarrier(ep);
}
rpc_client->Wait();
}
if (outs.size() > 0) {
for (size_t i = 0; i < outs.size(); i++) {
VLOG(2) << "getting " << outs[i] << " from " << epmap[i];
rpc_client->AsyncGetVar(epmap[i], ctx, scope, outs[i]);
}
rpc_client->Wait();
// tell pservers that current trainer have called fetch
for (auto& ep : endpoints) {
VLOG(2) << "send fetch barrier, ep: " << ep;
rpc_client->AsyncSendFetchBarrier(ep);
}
rpc_client->Wait(); rpc_client->Wait();
} }
} }
...@@ -89,26 +67,22 @@ class SendOp : public framework::OperatorBase { ...@@ -89,26 +67,22 @@ class SendOp : public framework::OperatorBase {
class SendOpMaker : public framework::OpProtoAndCheckerMaker { class SendOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() { void Make() {
AddInput("X", "(Tensor) Input tensor to be sent").AsDuplicable(); AddInput("X", "(Tensor, SelectedRows) Input variables to be sent")
AddOutput("Out", "(Tensor) Output tensor to be received from server")
.AsDuplicable(); .AsDuplicable();
AddComment(R"DOC( AddComment(R"DOC(
Send operator Send operator
This operator will send tensor to recv_op at the parameter server. This operator will send variables to listen_and_serve op at the parameter server.
)DOC"); )DOC");
// TODO(typhoonzero): remove this attr generate de-duplicated vector from AddAttr<int>("sync_mode",
// epmap when initializing. "(int, default 0)"
AddAttr<std::vector<std::string>>("endpoints", "sync send or async send.")
"(string vector, default 127.0.0.1:6164)" .SetDefault(0);
"Server endpoints to send variables to.")
.SetDefault({});
AddAttr<std::vector<std::string>>("epmap", AddAttr<std::vector<std::string>>("epmap",
"(string vector, default 127.0.0.1:6164)" "(string vector, default 127.0.0.1:6164)"
"Server endpoints in the order of input " "Server endpoints in the order of input "
"variables for mapping") "variables for mapping")
.SetDefault({}); .SetDefault({"127.0.0.1:6164"});
AddAttr<bool>("sync_mode", "work in sync_mode or not").SetDefault(true);
} }
}; };
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <future> // NOLINT
#include <ostream>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/operators/send_recv_util.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
namespace operators {
class SendVarsOp : public framework::OperatorBase {
public:
SendVarsOp(const std::string& type, const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope& scope,
const platform::Place& place) const override {
auto ins = Inputs("X");
std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap");
int sync_send = Attr<int>("sync_send");
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place);
// For profiling
platform::RecordEvent record_event(Type(), &ctx);
detail::RPCClient* rpc_client =
detail::RPCClient::GetInstance<detail::GRPCClient>();
for (size_t i = 0; i < ins.size(); i++) {
if (NeedSend(scope, ins[i])) {
VLOG(3) << "sending " << ins[i] << " to " << epmap[i];
// TODO(Yancey1989): we need to use an IO threadpool which has
// a larger number of threads than the computing threadpool.
rpc_client->AsyncSendVar(epmap[i], ctx, scope, ins[i]);
} else {
VLOG(3) << "don't send no-initialied variable: " << ins[i];
}
}
if (sync_send) {
rpc_client->Wait();
}
}
};
class SendVarsOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "(Tensor, SelectedRows) Input variables to be sent")
.AsDuplicable();
AddComment(R"DOC(
Send operator
This operator will send variables to listen_and_serve op at the parameter server.
)DOC");
AddAttr<int>("sync_send",
"(int, default 0)"
"sync send or async send.")
.SetDefault(0);
AddAttr<std::vector<std::string>>("epmap",
"(string vector, default 127.0.0.1:6164)"
"Server endpoints in the order of input "
"variables for mapping")
.SetDefault({"127.0.0.1:6164"});
}
};
class SendVarsOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext* ctx) const override {}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(send_vars, ops::SendVarsOp,
paddle::framework::EmptyGradOpMaker, ops::SendVarsOpMaker,
ops::SendVarsOpShapeInference);
...@@ -74,7 +74,8 @@ class SGDOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -74,7 +74,8 @@ class SGDOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("Grad", "(Tensor or SelectedRows) Input gradient"); AddInput("Grad", "(Tensor or SelectedRows) Input gradient");
AddOutput("ParamOut", AddOutput("ParamOut",
"(Tensor or SelectedRows, same with Param) " "(Tensor or SelectedRows, same with Param) "
"Output parameter, should share the same memory with Param"); "Output parameter, should share the same memory with Param")
.Reuse("Param");
AddComment(R"DOC( AddComment(R"DOC(
SGD operator SGD operator
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/slice_op.h"
#include <algorithm>
#include <vector>
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
class SliceOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Input"),
"Input (Input) of slice op should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output (Out) of slice op should not be null.");
auto in_dims = ctx->GetInputDim("Input");
PADDLE_ENFORCE(in_dims.size() < 7,
"The rank of input should be less than 7.");
framework::DDim out_dims(in_dims);
auto axes = ctx->Attrs().Get<std::vector<int>>("axes");
auto starts = ctx->Attrs().Get<std::vector<int>>("starts");
auto ends = ctx->Attrs().Get<std::vector<int>>("ends");
PADDLE_ENFORCE_EQ(starts.size(), ends.size());
PADDLE_ENFORCE_EQ(starts.size(), axes.size());
int dim_value, start, end;
for (size_t i = 0; i < axes.size(); ++i) {
dim_value = out_dims[axes[i]];
start = starts[i] < 0 ? (starts[i] + dim_value) : starts[i];
end = ends[i] < 0 ? (ends[i] + dim_value) : ends[i];
start = std::max(start, 0);
end = std::max(end, 0);
start = std::min(start, dim_value);
end = std::min(end, dim_value);
start = std::min(start, end);
out_dims[axes[i]] = end - start;
}
ctx->SetOutputDim("Out", out_dims);
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Input")->type()),
ctx.GetPlace());
}
};
class SliceOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Input", "Tensor of data to extract slices from.");
AddOutput("Out", "Sliced data tensor.");
AddAttr<std::vector<int>>(
"axes",
"(list<int>) Axes that `starts` and `ends` apply to. It's optional."
"If not present, will be treated as [0, 1, ..., len(`starts`) - 1].");
AddAttr<std::vector<int>>(
"starts",
"(list<int>) Starting indices of corresponding axis in `axes`");
AddAttr<std::vector<int>>(
"ends",
"(list<int>) Starting indices of corresponding axis in `axes`.");
AddComment(R"DOC(
Slice Operator.
Produces a slice of the input tensor along multiple axes. Similar to numpy:
https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html
Slice uses `axes`, `starts` and `ends` attributes to specify the start and
end dimension for each axis in the list of axes, it uses this information
to slice the input data tensor. If a negative value is passed for any of
the start or end indices, it represents number of elements before the end
of that dimension. If the value passed to start or end is larger than
the n (the number of elements in this dimension), it represents n.
For slicing to the end of a dimension with unknown size, it is recommended
to pass in INT_MAX. If axes are omitted, they are set to [0, ..., ndim-1].
Example 1:
Given:
data = [ [1, 2, 3, 4], [5, 6, 7, 8], ]
axes = [0, 1]
starts = [1, 0]
ends = [2, 3]
Then:
result = [ [5, 6, 7], ]
Example 2:
Given:
data = [ [1, 2, 3, 4], [5, 6, 7, 8], ]
starts = [0, 1]
ends = [-1, 1000]
Then:
result = [ [2, 3, 4], ]
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(slice, ops::SliceOp, ops::SliceOpMaker,
paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL(
slice, ops::SliceKernel<paddle::platform::CPUDeviceContext, int>,
ops::SliceKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::SliceKernel<paddle::platform::CPUDeviceContext, float>,
ops::SliceKernel<paddle::platform::CPUDeviceContext, double>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/slice_op.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
slice, ops::SliceKernel<paddle::platform::CUDADeviceContext, float>,
ops::SliceKernel<paddle::platform::CUDADeviceContext, double>,
ops::SliceKernel<paddle::platform::CUDADeviceContext, int>,
ops::SliceKernel<paddle::platform::CUDADeviceContext, int64_t>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class SliceKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
int rank = ctx.Input<framework::Tensor>("Input")->dims().size();
switch (rank) {
case 1:
SliceCompute<1>(ctx);
break;
case 2:
SliceCompute<2>(ctx);
break;
case 3:
SliceCompute<3>(ctx);
break;
case 4:
SliceCompute<4>(ctx);
break;
case 5:
SliceCompute<5>(ctx);
break;
case 6:
SliceCompute<6>(ctx);
break;
}
}
private:
template <size_t D>
void SliceCompute(const framework::ExecutionContext& context) const {
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
auto in = context.Input<framework::Tensor>("Input");
auto out = context.Output<framework::Tensor>("Out");
out->mutable_data<T>(context.GetPlace());
auto out_dims = out->dims();
auto in_dims = in->dims();
auto axes = context.Attr<std::vector<int>>("axes");
auto starts = context.Attr<std::vector<int>>("starts");
auto offsets = Eigen::array<int, D>();
auto extents = Eigen::array<int, D>();
for (size_t i = 0; i < D; ++i) {
offsets[i] = 0;
extents[i] = out_dims[i];
}
int start;
for (size_t i = 0; i < axes.size(); ++i) {
start = starts[i];
if (start < 0) {
start = (start + in_dims[axes[i]]);
}
start = std::max(start, 0);
offsets[axes[i]] = start;
}
auto in_t =
framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
*in);
auto out_t =
framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
*out);
out_t.device(place) = in_t.slice(offsets, extents);
}
};
} // namespace operators
} // namespace paddle
...@@ -83,7 +83,8 @@ class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -83,7 +83,8 @@ class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("X", AddInput("X",
"The input tensor of softmax. " "The input tensor of softmax. "
"2-D with shape [batch_size, input_feature_dimensions]."); "2-D with shape [batch_size, input_feature_dimensions].");
AddOutput("Out", "The normalized values with the same shape as X."); AddOutput("Out", "The normalized values with the same shape as X.")
.Reuse("X");
AddAttr<bool>( AddAttr<bool>(
"use_cudnn", "use_cudnn",
"(bool, default false) Only used in cudnn kernel, need install cudnn") "(bool, default false) Only used in cudnn kernel, need install cudnn")
......
...@@ -115,7 +115,7 @@ class SumOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -115,7 +115,7 @@ class SumOpMaker : public framework::OpProtoAndCheckerMaker {
void Make() override { void Make() override {
AddInput("X", "(vector<Tensor>) The input tensors of sum operator.") AddInput("X", "(vector<Tensor>) The input tensors of sum operator.")
.AsDuplicable(); .AsDuplicable();
AddOutput("Out", "(Tensor) The output tensor of sum operator."); AddOutput("Out", "(Tensor) The output tensor of sum operator.").Reuse("X");
AddComment(R"DOC( AddComment(R"DOC(
Sum operator. Sum operator.
......
...@@ -20,8 +20,7 @@ limitations under the License. */ ...@@ -20,8 +20,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/operators/detail/grpc_client.h" #include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/operators/detail/grpc_server.h"
#include "paddle/fluid/operators/detail/request_handler_impl.h" #include "paddle/fluid/operators/detail/request_handler_impl.h"
#include "paddle/fluid/operators/listen_and_serv_op.h" #include "paddle/fluid/operators/listen_and_serv_op.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
...@@ -29,6 +28,10 @@ limitations under the License. */ ...@@ -29,6 +28,10 @@ limitations under the License. */
#include "paddle/fluid/platform/nccl_helper.h" #include "paddle/fluid/platform/nccl_helper.h"
#include "paddle/fluid/string/printf.h" #include "paddle/fluid/string/printf.h"
#ifdef PADDLE_WITH_GRPC
#include "paddle/fluid/operators/send_recv_util.h"
#endif
USE_NO_KERNEL_OP(listen_and_serv); USE_NO_KERNEL_OP(listen_and_serv);
namespace f = paddle::framework; namespace f = paddle::framework;
...@@ -37,7 +40,7 @@ namespace m = paddle::operators::math; ...@@ -37,7 +40,7 @@ namespace m = paddle::operators::math;
namespace detail = paddle::operators::detail; namespace detail = paddle::operators::detail;
namespace string = paddle::string; namespace string = paddle::string;
std::unique_ptr<detail::AsyncGRPCServer> g_rpc_service; std::unique_ptr<detail::RPCServer> g_rpc_service;
std::unique_ptr<detail::RequestHandler> g_req_handler; std::unique_ptr<detail::RequestHandler> g_req_handler;
void StartServer() { void StartServer() {
...@@ -58,7 +61,7 @@ void StartServer() { ...@@ -58,7 +61,7 @@ void StartServer() {
g_req_handler->SetRPCServer(g_rpc_service.get()); g_req_handler->SetRPCServer(g_rpc_service.get());
std::thread server_thread( std::thread server_thread(
std::bind(&detail::AsyncGRPCServer::StartServer, g_rpc_service.get())); std::bind(&detail::RPCServer::StartServer, g_rpc_service.get()));
g_rpc_service->SetCond(detail::kRequestSend); g_rpc_service->SetCond(detail::kRequestSend);
g_rpc_service->WaitBarrier(detail::kRequestSend); g_rpc_service->WaitBarrier(detail::kRequestSend);
...@@ -68,9 +71,9 @@ void StartServer() { ...@@ -68,9 +71,9 @@ void StartServer() {
server_thread.join(); server_thread.join();
} }
TEST(SendNcclId, GrpcServer) { TEST(SendNcclId, RPCServer) {
g_req_handler.reset(new detail::RequestSendHandler(true)); g_req_handler.reset(new detail::RequestSendHandler(true));
g_rpc_service.reset(new detail::AsyncGRPCServer("127.0.0.1:0", 1)); g_rpc_service.reset(new RPCSERVER_T("127.0.0.1:0", 1));
std::thread server_thread(StartServer); std::thread server_thread(StartServer);
g_rpc_service->WaitServerReady(); g_rpc_service->WaitServerReady();
...@@ -87,8 +90,9 @@ TEST(SendNcclId, GrpcServer) { ...@@ -87,8 +90,9 @@ TEST(SendNcclId, GrpcServer) {
int port = g_rpc_service->GetSelectedPort(); int port = g_rpc_service->GetSelectedPort();
std::string ep = string::Sprintf("127.0.0.1:%d", port); std::string ep = string::Sprintf("127.0.0.1:%d", port);
detail::RPCClient* client =
detail::RPCClient::GetInstance<detail::GRPCClient>(); detail::RPCClient* client = detail::RPCClient::GetInstance<RPCCLIENT_T>();
LOG(INFO) << "connect to server" << ep; LOG(INFO) << "connect to server" << ep;
client->AsyncSendVar(ep, dev_ctx, scope, NCCL_ID_VARNAME); client->AsyncSendVar(ep, dev_ctx, scope, NCCL_ID_VARNAME);
client->Wait(); client->Wait();
......
...@@ -50,7 +50,7 @@ class TopkOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -50,7 +50,7 @@ class TopkOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
AddInput("X", "(Tensor) The input of Topk op"); AddInput("X", "(Tensor) The input of Topk op");
AddOutput("Out", "(Tensor) The output tensor of Topk op"); AddOutput("Out", "(Tensor) The output tensor of Topk op").Reuse("X");
AddOutput("Indices", "(Tensor) The indices of Topk elements of input"); AddOutput("Indices", "(Tensor) The indices of Topk elements of input");
AddComment(R"DOC( AddComment(R"DOC(
Top K operator Top K operator
......
...@@ -21,12 +21,17 @@ limitations under the License. */ ...@@ -21,12 +21,17 @@ limitations under the License. */
#include <unistd.h> #include <unistd.h>
#endif #endif
#include <algorithm>
#include "gflags/gflags.h" #include "gflags/gflags.h"
DEFINE_double(fraction_of_cpu_memory_to_use, 1, DEFINE_double(fraction_of_cpu_memory_to_use, 1,
"Default use 100% of CPU memory for PaddlePaddle," "Default use 100% of CPU memory for PaddlePaddle,"
"reserve the rest for page tables, etc"); "reserve the rest for page tables, etc");
DEFINE_uint64(
initial_cpu_memory_in_mb, 500,
"Default initial 500MB of CPU memory for PaddlePaddle, in MD unit.");
DEFINE_double( DEFINE_double(
fraction_of_cuda_pinned_memory_to_use, 0.5, fraction_of_cuda_pinned_memory_to_use, 0.5,
"Default use 50% of CPU memory as the pinned_memory for PaddlePaddle," "Default use 50% of CPU memory as the pinned_memory for PaddlePaddle,"
...@@ -54,7 +59,10 @@ inline size_t CpuTotalPhysicalMemory() { ...@@ -54,7 +59,10 @@ inline size_t CpuTotalPhysicalMemory() {
size_t CpuMaxAllocSize() { size_t CpuMaxAllocSize() {
// For distributed systems, it requires configuring and limiting // For distributed systems, it requires configuring and limiting
// the fraction of memory to use. // the fraction of memory to use.
return FLAGS_fraction_of_cpu_memory_to_use * CpuTotalPhysicalMemory(); return std::min(
static_cast<size_t>(FLAGS_fraction_of_cpu_memory_to_use *
CpuTotalPhysicalMemory()),
static_cast<size_t>(FLAGS_initial_cpu_memory_in_mb * 1 << 20));
} }
size_t CpuMinChunkSize() { size_t CpuMinChunkSize() {
......
...@@ -11,6 +11,7 @@ limitations under the License. */ ...@@ -11,6 +11,7 @@ limitations under the License. */
#pragma once #pragma once
#include <memory> #include <memory>
#include <mutex> // NOLINT
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
...@@ -100,6 +101,7 @@ class CUDADeviceContext : public DeviceContext { ...@@ -100,6 +101,7 @@ class CUDADeviceContext : public DeviceContext {
template <typename Callback> template <typename Callback>
void RecordEvent(cudaEvent_t ev, Callback callback) { void RecordEvent(cudaEvent_t ev, Callback callback) {
std::lock_guard<std::mutex> guard(mtx_);
callback(); callback();
PADDLE_ENFORCE(cudaEventRecord(ev, stream_)); PADDLE_ENFORCE(cudaEventRecord(ev, stream_));
} }
...@@ -116,6 +118,8 @@ class CUDADeviceContext : public DeviceContext { ...@@ -116,6 +118,8 @@ class CUDADeviceContext : public DeviceContext {
int compute_capability; int compute_capability;
int multi_process; int multi_process;
int max_threads_per_mp; int max_threads_per_mp;
std::mutex mtx_;
}; };
template <> template <>
......
...@@ -413,6 +413,9 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -413,6 +413,9 @@ All parameter, weight, gradient are variables in Paddle.
py::class_<framework::Executor>(m, "Executor") py::class_<framework::Executor>(m, "Executor")
.def(py::init<const platform::Place &>()) .def(py::init<const platform::Place &>())
#ifdef PADDLE_WITH_DISTRIBUTE
.def("complete", &Executor::Complete)
#endif
.def("run", .def("run",
(void (Executor::*)(const ProgramDesc &, Scope *, int, bool, bool)) & (void (Executor::*)(const ProgramDesc &, Scope *, int, bool, bool)) &
Executor::Run); Executor::Run);
...@@ -509,10 +512,10 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -509,10 +512,10 @@ All parameter, weight, gradient are variables in Paddle.
self.num_threads_ = num_threads; self.num_threads_ = num_threads;
}) })
.def_property( .def_property(
"use_event", "use_cuda",
[](const ExecutionStrategy &self) { return self.use_event_; }, [](const ExecutionStrategy &self) { return self.use_cuda_; },
[](ExecutionStrategy &self, bool use_event) { [](ExecutionStrategy &self, bool use_cuda) {
self.use_event_ = use_event; self.use_cuda_ = use_cuda;
}) })
.def_property( .def_property(
"allow_op_delay", "allow_op_delay",
......
...@@ -181,6 +181,7 @@ function build() { ...@@ -181,6 +181,7 @@ function build() {
============================================ ============================================
EOF EOF
make clean make clean
make -j `nproc`
make install -j `nproc` make install -j `nproc`
} }
......
...@@ -30,7 +30,7 @@ int main(int argc, char** argv) { ...@@ -30,7 +30,7 @@ int main(int argc, char** argv) {
new_argv.push_back( new_argv.push_back(
strdup("--tryfromenv=fraction_of_gpu_memory_to_use,use_pinned_memory")); strdup("--tryfromenv=fraction_of_gpu_memory_to_use,use_pinned_memory"));
#else #else
new_argv.push_back(strdup("--tryfromenv=use_pinned_memory")); new_argv.push_back(strdup("--tryfromenv=use_pinned_memory,use_mkldnn"));
#endif #endif
int new_argc = static_cast<int>(new_argv.size()); int new_argc = static_cast<int>(new_argv.size());
char** new_argv_address = new_argv.data(); char** new_argv_address = new_argv.data();
......
...@@ -119,7 +119,8 @@ def reader_creator(data_file, ...@@ -119,7 +119,8 @@ def reader_creator(data_file,
yield sample, int(label) - 1 yield sample, int(label) - 1
if use_xmap: if use_xmap:
return xmap_readers(mapper, reader, cpu_count(), buffered_size) cpu_num = int(os.environ.get('CPU_NUM', cpu_count()))
return xmap_readers(mapper, reader, cpu_num, buffered_size)
else: else:
return map_readers(mapper, reader) return map_readers(mapper, reader)
......
...@@ -117,7 +117,7 @@ def __bootstrap__(): ...@@ -117,7 +117,7 @@ def __bootstrap__():
read_env_flags = [ read_env_flags = [
'use_pinned_memory', 'check_nan_inf', 'benchmark', 'warpctc_dir', 'use_pinned_memory', 'check_nan_inf', 'benchmark', 'warpctc_dir',
'eager_delete_scope' 'eager_delete_scope', 'use_mkldnn'
] ]
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
read_env_flags += [ read_env_flags += [
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
from __future__ import print_function from __future__ import print_function
import core import core
import numpy import numpy
import os
import six.moves as six import six.moves as six
import multiprocessing import multiprocessing
...@@ -150,7 +151,9 @@ class DataFeeder(object): ...@@ -150,7 +151,9 @@ class DataFeeder(object):
elif isinstance(self.place, core.CUDAPlace): elif isinstance(self.place, core.CUDAPlace):
return core.get_cuda_device_count() return core.get_cuda_device_count()
else: else:
return multiprocessing.cpu_count() cpu_num = int(
os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
return cpu_num
def decorate_reader(self, def decorate_reader(self,
reader, reader,
......
...@@ -382,7 +382,7 @@ class Operator(object): ...@@ -382,7 +382,7 @@ class Operator(object):
'rnn_memory_helper_grad', 'conditional_block', 'while', 'send', 'recv', 'rnn_memory_helper_grad', 'conditional_block', 'while', 'send', 'recv',
'listen_and_serv', 'parallel_do', 'save_combine', 'load_combine', 'listen_and_serv', 'parallel_do', 'save_combine', 'load_combine',
'ncclInit', 'channel_create', 'channel_close', 'channel_send', 'ncclInit', 'channel_create', 'channel_close', 'channel_send',
'channel_recv', 'select' 'channel_recv', 'select', 'gen_nccl_id'
} }
def __init__(self, def __init__(self,
......
...@@ -95,7 +95,6 @@ def fc(input, ...@@ -95,7 +95,6 @@ def fc(input,
num_flatten_dims=1, num_flatten_dims=1,
param_attr=None, param_attr=None,
bias_attr=None, bias_attr=None,
use_cudnn=False,
use_mkldnn=False, use_mkldnn=False,
act=None, act=None,
is_test=False, is_test=False,
...@@ -222,6 +221,7 @@ def embedding(input, ...@@ -222,6 +221,7 @@ def embedding(input,
have two elements which indicate the size of the dictionary of have two elements which indicate the size of the dictionary of
embeddings and the size of each embedding vector respectively. embeddings and the size of each embedding vector respectively.
is_sparse(bool): The flag indicating whether to use sparse update. is_sparse(bool): The flag indicating whether to use sparse update.
is_distributed (bool): Whether to run lookup table from remote parameter server.
padding_idx(int|long|None): If :attr:`None`, it makes no effect to lookup. padding_idx(int|long|None): If :attr:`None`, it makes no effect to lookup.
Otherwise the given :attr:`padding_idx` indicates padding the output Otherwise the given :attr:`padding_idx` indicates padding the output
with zeros whenever lookup encounters it in :attr:`input`. If with zeros whenever lookup encounters it in :attr:`input`. If
...@@ -654,8 +654,9 @@ def dynamic_gru(input, ...@@ -654,8 +654,9 @@ def dynamic_gru(input,
:attr:`False`. :attr:`False`.
gate_activation(str): The activation for update gate and reset gate. gate_activation(str): The activation for update gate and reset gate.
Choices = ["sigmoid", "tanh", "relu", "identity"], default "sigmoid". Choices = ["sigmoid", "tanh", "relu", "identity"], default "sigmoid".
activation(str): The activation for candidate hidden state. candidate_activation(str): The activation for candidate hidden state.
Choices = ["sigmoid", "tanh", "relu", "identity"], default "tanh". Choices = ["sigmoid", "tanh", "relu", "identity"], default "tanh".
h_0 (Variable): The hidden output of the first time step.
Returns: Returns:
Variable: The hidden state of GRU. The shape is :math:`(T \\times D)`, \ Variable: The hidden state of GRU. The shape is :math:`(T \\times D)`, \
...@@ -873,6 +874,13 @@ def cos_sim(X, Y): ...@@ -873,6 +874,13 @@ def cos_sim(X, Y):
""" """
This function performs the cosine similarity between two tensors This function performs the cosine similarity between two tensors
X and Y and returns that as the output. X and Y and returns that as the output.
Args:
X (Variable): The input X.
Y (Variable): The input Y.
Returns:
Variable: the output of cosine(X, Y).
""" """
helper = LayerHelper('cos_sim', **locals()) helper = LayerHelper('cos_sim', **locals())
out = helper.create_tmp_variable(dtype=X.dtype) out = helper.create_tmp_variable(dtype=X.dtype)
...@@ -899,14 +907,14 @@ def dropout(x, dropout_prob, is_test=False, seed=None, name=None): ...@@ -899,14 +907,14 @@ def dropout(x, dropout_prob, is_test=False, seed=None, name=None):
unchanged. unchanged.
Args: Args:
x(variable): The input tensor. x (Variable): The input tensor.
dropout_prob(float): Probability of setting units to zero. dropout_prob (float): Probability of setting units to zero.
is_test(bool): A flag indicating whether it is in test phrase or not. is_test (bool): A flag indicating whether it is in test phrase or not.
seed(int): A Python integer used to create random seeds. If this seed (int): A Python integer used to create random seeds. If this
parameter is set to None, a random seed is used. parameter is set to None, a random seed is used.
NOTE: If an integer seed is given, always the same output NOTE: If an integer seed is given, always the same output
units will be dropped. DO NOT use a fixed seed in training. units will be dropped. DO NOT use a fixed seed in training.
name(str|None): A name for this layer(optional). If set None, the layer name (str|None): A name for this layer(optional). If set None, the layer
will be named automatically. will be named automatically.
Returns: Returns:
...@@ -1029,8 +1037,8 @@ def square_error_cost(input, label): ...@@ -1029,8 +1037,8 @@ def square_error_cost(input, label):
* :math:`Out`: Output value, same shape with :math:`X`. * :math:`Out`: Output value, same shape with :math:`X`.
Args: Args:
input(Variable): Input tensor, has predictions. input (Variable): Input tensor, has predictions.
label(Variable): Label tensor, has target labels. label (Variable): Label tensor, has target labels.
Returns: Returns:
Variable: The tensor variable storing the element-wise squared error \ Variable: The tensor variable storing the element-wise squared error \
...@@ -1059,6 +1067,7 @@ def square_error_cost(input, label): ...@@ -1059,6 +1067,7 @@ def square_error_cost(input, label):
return square_out return square_out
@templatedoc()
def chunk_eval(input, def chunk_eval(input,
label, label,
chunk_scheme, chunk_scheme,
...@@ -1067,6 +1076,18 @@ def chunk_eval(input, ...@@ -1067,6 +1076,18 @@ def chunk_eval(input,
""" """
This function computes and outputs the precision, recall and This function computes and outputs the precision, recall and
F1-score of chunk detection. F1-score of chunk detection.
Args:
input (Variable): prediction output of the network.
label (Variable): label of the test data set.
chunk_scheme (str): ${chunk_scheme_comment}
num_chunk_types (int): ${num_chunk_types_comment}
excluded_chunk_types (list): ${excluded_chunk_types_comment}
Returns:
tuple: tuple containing: (precision, recall, f1_score,
num_infer_chunks, num_label_chunks,
num_correct_chunks)
""" """
helper = LayerHelper("chunk_eval", **locals()) helper = LayerHelper("chunk_eval", **locals())
...@@ -1099,6 +1120,7 @@ def chunk_eval(input, ...@@ -1099,6 +1120,7 @@ def chunk_eval(input,
num_correct_chunks) num_correct_chunks)
@templatedoc()
def sequence_conv(input, def sequence_conv(input,
num_filters, num_filters,
filter_size=3, filter_size=3,
...@@ -1111,6 +1133,19 @@ def sequence_conv(input, ...@@ -1111,6 +1133,19 @@ def sequence_conv(input,
This function creates the op for sequence_conv, using the inputs and This function creates the op for sequence_conv, using the inputs and
other convolutional configurations for the filters and stride as given other convolutional configurations for the filters and stride as given
in the input parameters to the function. in the input parameters to the function.
Args:
input (Variable): ${x_comment}
num_filters (int): number of filters.
filter_size (int): the filter size (H and W).
filter_stride (int): stride of the filter.
padding (bool): if True, add paddings.
bias_attr (ParamAttr|None): attributes for bias
param_attr (ParamAttr|None): attributes for parameter
act (str): the activation type
Returns:
Variable: output of sequence_conv
""" """
# FIXME(dzh) : want to unify the argument of python layer # FIXME(dzh) : want to unify the argument of python layer
...@@ -1225,32 +1260,33 @@ def conv2d(input, ...@@ -1225,32 +1260,33 @@ def conv2d(input,
W_{out}&= \\frac{(W_{in} + 2 * paddings[1] - (dilations[1] * (W_f - 1) + 1))}{strides[1]} + 1 W_{out}&= \\frac{(W_{in} + 2 * paddings[1] - (dilations[1] * (W_f - 1) + 1))}{strides[1]} + 1
Args: Args:
input(Variable): The input image with [N, C, H, W] format. input (Variable): The input image with [N, C, H, W] format.
num_filters(int): The number of filter. It is as same as the output num_filters(int): The number of filter. It is as same as the output
image channel. image channel.
filter_size(int|tuple|None): The filter size. If filter_size is a tuple, filter_size (int|tuple|None): The filter size. If filter_size is a tuple,
it must contain two integers, (filter_size_H, filter_size_W). it must contain two integers, (filter_size_H, filter_size_W).
Otherwise, the filter will be a square. Otherwise, the filter will be a square.
stride(int|tuple): The stride size. If stride is a tuple, it must stride (int|tuple): The stride size. If stride is a tuple, it must
contain two integers, (stride_H, stride_W). Otherwise, the contain two integers, (stride_H, stride_W). Otherwise, the
stride_H = stride_W = stride. Default: stride = 1. stride_H = stride_W = stride. Default: stride = 1.
padding(int|tuple): The padding size. If padding is a tuple, it must padding (int|tuple): The padding size. If padding is a tuple, it must
contain two integers, (padding_H, padding_W). Otherwise, the contain two integers, (padding_H, padding_W). Otherwise, the
padding_H = padding_W = padding. Default: padding = 0. padding_H = padding_W = padding. Default: padding = 0.
dilation(int|tuple): The dilation size. If dilation is a tuple, it must dilation (int|tuple): The dilation size. If dilation is a tuple, it must
contain two integers, (dilation_H, dilation_W). Otherwise, the contain two integers, (dilation_H, dilation_W). Otherwise, the
dilation_H = dilation_W = dilation. Default: dilation = 1. dilation_H = dilation_W = dilation. Default: dilation = 1.
groups(int): The groups number of the Conv2d Layer. According to grouped groups (int): The groups number of the Conv2d Layer. According to grouped
convolution in Alex Krizhevsky's Deep CNN paper: when group=2, convolution in Alex Krizhevsky's Deep CNN paper: when group=2,
the first half of the filters is only connected to the first half the first half of the filters is only connected to the first half
of the input channels, while the second half of the filters is only of the input channels, while the second half of the filters is only
connected to the second half of the input channels. Default: groups=1 connected to the second half of the input channels. Default: groups=1
param_attr(ParamAttr): The parameters to the Conv2d Layer. Default: None param_attr (ParamAttr): The parameters to the Conv2d Layer. Default: None
bias_attr(ParamAttr): Bias parameter for the Conv2d layer. Default: None bias_attr (ParamAttr): Bias parameter for the Conv2d layer. Default: None
use_cudnn(bool): Use cudnn kernel or not, it is valid only when the cudnn use_cudnn (bool): Use cudnn kernel or not, it is valid only when the cudnn
library is installed. Default: True library is installed. Default: True
act(str): Activation type. Default: None use_mkldnn (bool): Use mkldnn kernels or not.
name(str|None): A name for this layer(optional). If set None, the layer act (str): Activation type. Default: None
name (str|None): A name for this layer(optional). If set None, the layer
will be named automatically. will be named automatically.
Returns: Returns:
...@@ -1409,7 +1445,7 @@ def sequence_pool(input, pool_type): ...@@ -1409,7 +1445,7 @@ def sequence_pool(input, pool_type):
def sequence_first_step(input): def sequence_first_step(input):
""" """
This funciton get the first step of sequence. This function gets the first step of sequence.
.. code-block:: text .. code-block:: text
...@@ -1442,7 +1478,7 @@ def sequence_first_step(input): ...@@ -1442,7 +1478,7 @@ def sequence_first_step(input):
def sequence_last_step(input): def sequence_last_step(input):
""" """
This funciton get the last step of sequence. This function gets the last step of sequence.
.. code-block:: text .. code-block:: text
...@@ -1486,6 +1522,22 @@ def pool2d(input, ...@@ -1486,6 +1522,22 @@ def pool2d(input,
""" """
This function adds the operator for pooling in 2 dimensions, using the This function adds the operator for pooling in 2 dimensions, using the
pooling configurations mentioned in input parameters. pooling configurations mentioned in input parameters.
Args:
input (Variable): ${input_comment}
pool_size (int): ${ksize_comment}
pool_type (str): ${pooling_type_comment}
pool_stride (int): stride of the pooling layer.
pool_padding (int): padding size.
global_pooling (bool): ${global_pooling_comment}
use_cudnn (bool): ${use_cudnn_comment}
ceil_mode (bool): ${ceil_mode_comment}
use_mkldnn (bool): ${use_mkldnn_comment}
name (str): A name for this layer(optional). If set None, the layer
will be named automatically.
Returns:
Variable: output of pool2d layer.
""" """
if pool_type not in ["max", "avg"]: if pool_type not in ["max", "avg"]:
raise ValueError( raise ValueError(
...@@ -1543,6 +1595,25 @@ def batch_norm(input, ...@@ -1543,6 +1595,25 @@ def batch_norm(input,
""" """
This function helps create an operator to implement This function helps create an operator to implement
the BatchNorm layer using the configurations from the input parameters. the BatchNorm layer using the configurations from the input parameters.
Args:
input (Variable): the input variable.
act (str): activation type
is_test (bool): whether to run batch_norm as test mode.
momentum (float): momentum
epsilon (float): epsilon, default 1e-05
param_attr (ParamAttr|None): attributes for parameter
bias_attr (ParamAttr|None): attributes for bias
data_layout (str): data layout, default NCHW
in_place (bool): if True, do not create tmp variable
use_mkldnn (bool): ${use_mkldnn_comment}
name (str): The name of this layer. It is optional.
moving_mean_name (str): The name of moving mean variable name, optional.
moving_variance_name (str): The name of moving variance name, optional.
do_model_average_for_mean_and_var (bool):
Returns:
Variable: output of batch_norm layer.
""" """
helper = LayerHelper('batch_norm', **locals()) helper = LayerHelper('batch_norm', **locals())
dtype = helper.input_dtype() dtype = helper.input_dtype()
...@@ -1667,6 +1738,7 @@ def layer_norm(input, ...@@ -1667,6 +1738,7 @@ def layer_norm(input,
bias_attr(ParamAttr|None): The parameter attribute for the learnable bias_attr(ParamAttr|None): The parameter attribute for the learnable
bias :math:`b`. bias :math:`b`.
act(str): Activation to be applied to the output of layer normalizaiton. act(str): Activation to be applied to the output of layer normalizaiton.
name (str): The name of this layer. It is optional.
Returns: Returns:
${y_comment} ${y_comment}
...@@ -1711,6 +1783,17 @@ def layer_norm(input, ...@@ -1711,6 +1783,17 @@ def layer_norm(input,
def beam_search_decode(ids, scores, name=None): def beam_search_decode(ids, scores, name=None):
"""
${beam_search_decode}
Args:
ids (Variable): ${ids_comment}
scores (Variable): ${scores_comment}
name (str): The name of this layer. It is optional.
Returns:
tuple: a tuple of two output variable: sentence_ids, sentence_scores
"""
helper = LayerHelper('beam_search_decode', **locals()) helper = LayerHelper('beam_search_decode', **locals())
sentence_ids = helper.create_tmp_variable(dtype=ids.dtype) sentence_ids = helper.create_tmp_variable(dtype=ids.dtype)
sentence_scores = helper.create_tmp_variable(dtype=ids.dtype) sentence_scores = helper.create_tmp_variable(dtype=ids.dtype)
...@@ -1962,6 +2045,17 @@ def sequence_expand(x, y, ref_level=-1, name=None): ...@@ -1962,6 +2045,17 @@ def sequence_expand(x, y, ref_level=-1, name=None):
def beam_search(pre_ids, ids, scores, beam_size, end_id, level=0): def beam_search(pre_ids, ids, scores, beam_size, end_id, level=0):
''' '''
This function implements the beam search algorithm. This function implements the beam search algorithm.
Args:
pre_ids (Variable): ${pre_ids_comment}
ids (Variable): ${ids_comment}
scores (Variable): ${scores_comment}
beam_size (int): ${beam_size_comment}
end_id (int): ${end_id_comment}
level (int): ${level_comment}
Returns:
tuple: a tuple of beam_search output variables: selected_ids, selected_scores
''' '''
helper = LayerHelper('beam_search', **locals()) helper = LayerHelper('beam_search', **locals())
score_type = scores.dtype score_type = scores.dtype
...@@ -2457,17 +2551,19 @@ def l2_normalize(x, axis, epsilon=1e-12, name=None): ...@@ -2457,17 +2551,19 @@ def l2_normalize(x, axis, epsilon=1e-12, name=None):
The l2 normalize layer normalizes `x` along dimension `axis` using an L2 The l2 normalize layer normalizes `x` along dimension `axis` using an L2
norm. For a 1-D tensor (`dim` is fixed to 0), this layer computes norm. For a 1-D tensor (`dim` is fixed to 0), this layer computes
output = x / sqrt(max(sum(x**2), epsilon)) .. math::
y = \frac{x}{ \sqrt{\sum {x^2} + epsion }}
For `x` with more dimensions, this layer independently normalizes each 1-D For `x` with more dimensions, this layer independently normalizes each 1-D
slice along dimension `axis`. slice along dimension `axis`.
Args: Args:
x(Variable|list): The input tensor to l2_normalize layer. x(Variable|list): The input tensor to l2_normalize layer.
axis(int): Dimension along which to normalize the input. axis(int): The axis on which to apply normalization. If `axis < 0`,
epsilon(float): A lower bound value for `x`'s l2 norm. sqrt(epsilon) will the dimension to normalization is rank(X) + axis. -1 is the
be used as the divisor if the l2 norm of `x` is less than last dimension.
sqrt(epsilon). epsilon(float): The epsilon value is used to avoid division by zero,
the defalut value is 1e-10.
name(str|None): A name for this layer(optional). If set None, the layer name(str|None): A name for this layer(optional). If set None, the layer
will be named automatically. will be named automatically.
...@@ -2488,46 +2584,17 @@ def l2_normalize(x, axis, epsilon=1e-12, name=None): ...@@ -2488,46 +2584,17 @@ def l2_normalize(x, axis, epsilon=1e-12, name=None):
axis = 0 axis = 0
helper = LayerHelper("l2_normalize", **locals()) helper = LayerHelper("l2_normalize", **locals())
square = helper.create_tmp_variable(dtype=x.dtype) out = helper.create_tmp_variable(dtype=x.dtype)
helper.append_op(type="square", inputs={"X": x}, outputs={"Out": square}) norm = helper.create_tmp_variable(dtype=x.dtype)
reduced_sum = helper.create_tmp_variable(dtype=x.dtype)
helper.append_op( helper.append_op(
type="reduce_sum", type="norm",
inputs={"X": square}, inputs={"X": x},
outputs={"Out": reduced_sum}, outputs={"Out": out,
"Norm": norm},
attrs={ attrs={
"dim": [1] if axis is None else [axis], "axis": 1 if axis is None else axis,
"keep_dim": True, "epsilon": epsilon,
"reduce_all": False
}) })
# TODO(caoying) A lower bound value epsilon for the norm is needed to
# imporve the numeric stability of reciprocal. This requires a maximum_op.
rsquare = helper.create_tmp_variable(dtype=x.dtype)
helper.append_op(
type="reciprocal", inputs={"X": reduced_sum}, outputs={"Out": rsquare})
# TODO(caoying) the current elementwise_mul operator does not support a
# general broadcast rule which broadcasts input(Y) to have the same
# dimension with Input(X) starting from a specified dimension. So this
# exanpsion is requred. Once a general broadcast rule is spported, this
# expanding canbe removed.
rsquare_expanded = helper.create_tmp_variable(dtype=x.dtype)
expand_times = [1] * len(x.shape)
expand_times[axis] = int(x.shape[axis])
helper.append_op(
type="expand",
inputs={"X": rsquare},
outputs={"Out": rsquare_expanded},
attrs={"expand_times": expand_times})
out = helper.create_tmp_variable(dtype=x.dtype)
helper.append_op(
type="elementwise_mul",
inputs={"X": x,
"Y": rsquare_expanded},
outputs={"Out": out})
return out return out
...@@ -2711,16 +2778,13 @@ def edit_distance(input, label, normalized=True, ignored_tokens=None, ...@@ -2711,16 +2778,13 @@ def edit_distance(input, label, normalized=True, ignored_tokens=None,
the edit distance will be divided by the length of reference string. the edit distance will be divided by the length of reference string.
Args: Args:
input(Variable): The indices for hypothesis strings. input(Variable): The indices for hypothesis strings.
label(Variable): The indices for reference strings. label(Variable): The indices for reference strings.
normalized(bool): Indicated whether to normalize the edit distance by normalized(bool): Indicated whether to normalize the edit distance by
the length of reference string. the length of reference string.
ignored_tokens(list of int): Tokens that should be removed before ignored_tokens(list of int): Tokens that should be removed before
calculating edit distance. calculating edit distance.
name (str): The name of this layer. It is optional.
Returns: Returns:
Variable: sequence-to-sequence edit distance in shape [batch_size, 1]. Variable: sequence-to-sequence edit distance in shape [batch_size, 1].
...@@ -2810,10 +2874,10 @@ def ctc_greedy_decoder(input, blank, name=None): ...@@ -2810,10 +2874,10 @@ def ctc_greedy_decoder(input, blank, name=None):
where Lp is the sum of all input sequences' length and where Lp is the sum of all input sequences' length and
num_classes is the true number of classes. (not num_classes is the true number of classes. (not
including the blank label). including the blank label).
blank(int): the blank label index of Connectionist Temporal blank(int): the blank label index of Connectionist Temporal
Classification (CTC) loss, which is in thehalf-opened Classification (CTC) loss, which is in thehalf-opened
interval [0, num_classes + 1). interval [0, num_classes + 1).
name (str): The name of this layer. It is optional.
Returns: Returns:
Variable: CTC greedy decode result. If all the sequences in result were Variable: CTC greedy decode result. If all the sequences in result were
...@@ -2860,10 +2924,10 @@ def warpctc(input, label, blank=0, norm_by_times=False): ...@@ -2860,10 +2924,10 @@ def warpctc(input, label, blank=0, norm_by_times=False):
of variable-length sequence, which is a 2-D Tensor with LoD of variable-length sequence, which is a 2-D Tensor with LoD
information. It is of the shape [Lg, 1], where Lg is th sum of information. It is of the shape [Lg, 1], where Lg is th sum of
all labels' length. all labels' length.
blank: (int, default: 0), the blank label index of Connectionist blank (int): default 0, the blank label index of Connectionist
Temporal Classification (CTC) loss, which is in the Temporal Classification (CTC) loss, which is in the
half-opened interval [0, num_classes + 1). half-opened interval [0, num_classes + 1).
norm_by_times: (bool, default: false), whether to normalize norm_by_times (bool): default false, whether to normalize
the gradients by the number of time-step, which is also the the gradients by the number of time-step, which is also the
sequence's length. There is no need to normalize the gradients sequence's length. There is no need to normalize the gradients
if warpctc layer was follewed by a mean_op. if warpctc layer was follewed by a mean_op.
...@@ -2949,7 +3013,10 @@ def sequence_reshape(input, new_dim): ...@@ -2949,7 +3013,10 @@ def sequence_reshape(input, new_dim):
return out return out
@autodoc() # FIXME(wuyi): let docstring_checker.py understand @autodoc.
# For now, the comments in c++ use types like Tensor, but in python side
# the type is often "Variable", and arguments may vary.
@templatedoc(op_type="nce")
def nce(input, def nce(input,
label, label,
num_total_classes, num_total_classes,
...@@ -2957,6 +3024,21 @@ def nce(input, ...@@ -2957,6 +3024,21 @@ def nce(input,
param_attr=None, param_attr=None,
bias_attr=None, bias_attr=None,
num_neg_samples=None): num_neg_samples=None):
"""
${comment}
Args:
input (Variable): input variable.
label (Variable): label.
num_total_classes (int):${num_total_classes_comment}
sample_weight (int): ${sample_weight_comment}
param_attr (ParamAttr|None): attributes for parameter
bias_attr (ParamAttr|None): attributes for bias
num_neg_samples (int): ${num_neg_samples_comment}
Returns:
Variable: output of nce layer.
"""
helper = LayerHelper('nce', **locals()) helper = LayerHelper('nce', **locals())
assert isinstance(input, Variable) assert isinstance(input, Variable)
dim = input.shape[1] dim = input.shape[1]
...@@ -3014,8 +3096,9 @@ def transpose(x, perm, name=None): ...@@ -3014,8 +3096,9 @@ def transpose(x, perm, name=None):
perm[i]-th dimension of `input`. perm[i]-th dimension of `input`.
Args: Args:
input (Variable): (Tensor), A Tensor. x (Variable): The input Tensor.
perm (list): A permutation of the dimensions of `input`. perm (list): A permutation of the dimensions of `input`.
name (str): The name of this layer. It is optional.
Returns: Returns:
Variable: A transposed Tensor. Variable: A transposed Tensor.
...@@ -3413,7 +3496,8 @@ def autoincreased_step_counter(counter_name=None, begin=1, step=1): ...@@ -3413,7 +3496,8 @@ def autoincreased_step_counter(counter_name=None, begin=1, step=1):
begin(int): The first value of this counter. begin(int): The first value of this counter.
step(int): The increment step between each execution. step(int): The increment step between each execution.
Returns(Variable): The global run counter. Returns:
Variable: The global run counter.
""" """
helper = LayerHelper('global_step_counter') helper = LayerHelper('global_step_counter')
if counter_name is None: if counter_name is None:
...@@ -3474,7 +3558,7 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=True, name=None): ...@@ -3474,7 +3558,7 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=True, name=None):
the corresponding dimension of x. the corresponding dimension of x.
Args: Args:
input(variable): The input tensor. x(variable): The input tensor.
shape(list): The new shape. At most one dimension of the new shape can shape(list): The new shape. At most one dimension of the new shape can
be -1. be -1.
actual_shape(variable): An optional input. If provided, reshape actual_shape(variable): An optional input. If provided, reshape
...@@ -3486,8 +3570,10 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=True, name=None): ...@@ -3486,8 +3570,10 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=True, name=None):
inplace(bool): If this flag is set true, a new output tensor is created inplace(bool): If this flag is set true, a new output tensor is created
whose data is copied from input x, otherwise the output whose data is copied from input x, otherwise the output
shares data with input without copying. shares data with input without copying.
name (str): The name of this layer. It is optional.
Returns(variable): The output tensor. Returns:
Variable: The output tensor.
Examples: Examples:
.. code-block:: python .. code-block:: python
...@@ -4008,7 +4094,6 @@ def resize_bilinear(input, out_shape=None, scale=None, name=None): ...@@ -4008,7 +4094,6 @@ def resize_bilinear(input, out_shape=None, scale=None, name=None):
name(str|None): The output variable name. name(str|None): The output variable name.
Returns: Returns:
${out_comment}. ${out_comment}.
""" """
...@@ -4027,6 +4112,7 @@ def image_resize_short(input, out_short_len, resample='BILINEAR'): ...@@ -4027,6 +4112,7 @@ def image_resize_short(input, out_short_len, resample='BILINEAR'):
This is a 4-D tensor of the shape This is a 4-D tensor of the shape
(num_batches, channels, in_h, in_w). (num_batches, channels, in_h, in_w).
out_short_len(int): The length of output images' short edge. out_short_len(int): The length of output images' short edge.
resample (str): resample method, default: BILINEAR.
Returns: Returns:
out (Variable): The output is a 4-D tensor of the shape out (Variable): The output is a 4-D tensor of the shape
......
...@@ -71,6 +71,7 @@ __all__ = [ ...@@ -71,6 +71,7 @@ __all__ = [
'cumsum', 'cumsum',
'scatter', 'scatter',
'sum', 'sum',
'slice',
'polygon_box_transform', 'polygon_box_transform',
'shape', 'shape',
'maxout', 'maxout',
......
...@@ -31,6 +31,8 @@ __all__ = [ ...@@ -31,6 +31,8 @@ __all__ = [
'assign', 'assign',
'fill_constant_batch_size_like', 'fill_constant_batch_size_like',
'fill_constant', 'fill_constant',
'argmin',
'argmax',
'ones', 'ones',
'zeros', 'zeros',
] ]
...@@ -326,6 +328,68 @@ def fill_constant_batch_size_like(input, ...@@ -326,6 +328,68 @@ def fill_constant_batch_size_like(input,
return out return out
def argmin(x, axis=0):
"""
**argmin**
This function computes the indices of the min elements
of the input tensor's element along the provided axis.
Args:
x(Variable): The input to compute the indices of
the min elements.
axis(int): Axis to compute indices along.
Returns:
Variable: The tensor variable storing the output
Examples:
.. code-block:: python
out = fluid.layers.argmin(x=in, axis=0)
out = fluid.layers.argmin(x=in, axis=-1)
"""
helper = LayerHelper("arg_min", **locals())
out = helper.create_tmp_variable(VarDesc.VarType.INT64)
helper.append_op(
type='arg_min',
inputs={'X': x},
outputs={'Out': [out]},
attrs={'axis': axis})
return out
def argmax(x, axis=0):
"""
**argmax**
This function computes the indices of the max elements
of the input tensor's element along the provided axis.
Args:
x(Variable): The input to compute the indices of
the max elements.
axis(int): Axis to compute indices along.
Returns:
Variable: The tensor variable storing the output
Examples:
.. code-block:: python
out = fluid.layers.argmax(x=in, axis=0)
out = fluid.layers.argmax(x=in, axis=-1)
"""
helper = LayerHelper("arg_max", **locals())
out = helper.create_tmp_variable(VarDesc.VarType.INT64)
helper.append_op(
type='arg_max',
inputs={'X': x},
outputs={'Out': [out]},
attrs={'axis': axis})
return out
def ones(shape, dtype, force_cpu=False): def ones(shape, dtype, force_cpu=False):
""" """
**ones** **ones**
......
...@@ -18,6 +18,7 @@ import framework ...@@ -18,6 +18,7 @@ import framework
import executor import executor
import warnings import warnings
import sys import sys
import os
__all__ = ['ParallelExecutor', 'ExecutionStrategy', 'BuildStrategy'] __all__ = ['ParallelExecutor', 'ExecutionStrategy', 'BuildStrategy']
...@@ -101,7 +102,9 @@ class ParallelExecutor(object): ...@@ -101,7 +102,9 @@ class ParallelExecutor(object):
p.set_place(self._act_places[-1]) p.set_place(self._act_places[-1])
self._places.append(p) self._places.append(p)
else: else:
for i in xrange(multiprocessing.cpu_count()): cpu_num = int(
os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
for i in xrange(cpu_num):
p = core.Place() p = core.Place()
self._act_places.append(core.CPUPlace()) self._act_places.append(core.CPUPlace())
p.set_place(self._act_places[-1]) p.set_place(self._act_places[-1])
...@@ -110,19 +113,17 @@ class ParallelExecutor(object): ...@@ -110,19 +113,17 @@ class ParallelExecutor(object):
if exec_strategy is None: if exec_strategy is None:
exec_strategy = ExecutionStrategy() exec_strategy = ExecutionStrategy()
if use_cuda: exec_strategy.use_cuda = use_cuda
exec_strategy.use_event = True
else:
exec_strategy.use_event = False
if exec_strategy.num_threads == 0: if exec_strategy.num_threads == 0:
if use_cuda: if use_cuda:
# Experiments on se-resnext shows that too many threads hurt # Experiments on se-resnext shows that too many threads hurt
# performance. Worth tunning for other models in the future. # performance. Worth tunning for other models in the future.
exec_strategy.num_threads = len(self._places) * 2 exec_strategy.num_threads = len(self._places) * 4
else: else:
exec_strategy.num_threads = min( cpu_num = int(
len(self._places) * 2, multiprocessing.cpu_count()) os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
exec_strategy.num_threads = cpu_num
if build_strategy is None: if build_strategy is None:
build_strategy = BuildStrategy() build_strategy = BuildStrategy()
......
...@@ -41,8 +41,8 @@ function(py_test_modules TARGET_NAME) ...@@ -41,8 +41,8 @@ function(py_test_modules TARGET_NAME)
endfunction() endfunction()
list(REMOVE_ITEM TEST_OPS test_warpctc_op) list(REMOVE_ITEM TEST_OPS test_warpctc_op)
list(REMOVE_ITEM TEST_OPS test_dist_train) list(REMOVE_ITEM TEST_OPS test_dist_train)
list(REMOVE_ITEM TEST_OPS test_parallel_executor_crf) #list(REMOVE_ITEM TEST_OPS test_parallel_executor_crf)
list(REMOVE_ITEM TEST_OPS test_parallel_executor_fetch_feed) #list(REMOVE_ITEM TEST_OPS test_parallel_executor_fetch_feed)
# TODO(wuyi): this test hungs on CI, will add it back later # TODO(wuyi): this test hungs on CI, will add it back later
list(REMOVE_ITEM TEST_OPS test_listen_and_serv_op) list(REMOVE_ITEM TEST_OPS test_listen_and_serv_op)
foreach(TEST_OP ${TEST_OPS}) foreach(TEST_OP ${TEST_OPS})
......
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import multiprocessing
import os
import unittest import unittest
import paddle.fluid as fluid import paddle.fluid as fluid
import time import time
...@@ -23,6 +25,7 @@ __all__ = ['TestParallelExecutorBase'] ...@@ -23,6 +25,7 @@ __all__ = ['TestParallelExecutorBase']
class TestParallelExecutorBase(unittest.TestCase): class TestParallelExecutorBase(unittest.TestCase):
def check_network_convergence(self, def check_network_convergence(self,
method, method,
use_cuda=True,
memory_opt=True, memory_opt=True,
iter=50, iter=50,
batch_size=None, batch_size=None,
...@@ -53,7 +56,7 @@ class TestParallelExecutorBase(unittest.TestCase): ...@@ -53,7 +56,7 @@ class TestParallelExecutorBase(unittest.TestCase):
adam.minimize(loss) adam.minimize(loss)
if memory_opt: if memory_opt:
fluid.memory_optimize(main) fluid.memory_optimize(main)
place = fluid.CUDAPlace(0) place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
startup_exe = fluid.Executor(place) startup_exe = fluid.Executor(place)
startup_exe.run(startup) startup_exe.run(startup)
exec_strategy = fluid.ExecutionStrategy() exec_strategy = fluid.ExecutionStrategy()
...@@ -64,7 +67,7 @@ class TestParallelExecutorBase(unittest.TestCase): ...@@ -64,7 +67,7 @@ class TestParallelExecutorBase(unittest.TestCase):
if use_parallel_executor: if use_parallel_executor:
exe = fluid.ParallelExecutor( exe = fluid.ParallelExecutor(
True, use_cuda,
loss_name=loss.name, loss_name=loss.name,
exec_strategy=exec_strategy, exec_strategy=exec_strategy,
build_strategy=build_strategy) build_strategy=build_strategy)
...@@ -72,7 +75,9 @@ class TestParallelExecutorBase(unittest.TestCase): ...@@ -72,7 +75,9 @@ class TestParallelExecutorBase(unittest.TestCase):
exe = fluid.Executor(place=place) exe = fluid.Executor(place=place)
if batch_size is not None: if batch_size is not None:
batch_size *= fluid.core.get_cuda_device_count() batch_size *= fluid.core.get_cuda_device_count(
) if use_cuda else int(
os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
begin = time.time() begin = time.time()
first_loss, = run_executor( first_loss, = run_executor(
exe=exe, feed=feed_dict, fetch_list=[loss.name]) exe=exe, feed=feed_dict, fetch_list=[loss.name])
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from op_test import OpTest
class BaseTestCase(OpTest):
def initTestCase(self):
self.op_type = 'arg_min'
self.dims = (3, 4, 5)
self.dtype = 'float32'
self.axis = 0
def setUp(self):
self.initTestCase()
self.x = (1000 * np.random.random(self.dims)).astype(self.dtype)
self.inputs = {'X': self.x}
self.attrs = {'axis': self.axis}
if self.op_type == "arg_min":
self.outputs = {'Out': np.argmin(self.x, axis=self.axis)}
else:
self.outputs = {'Out': np.argmax(self.x, axis=self.axis)}
def test_check_output(self):
self.check_output()
class TestCase0(BaseTestCase):
def initTestCase(self):
self.op_type = 'arg_max'
self.dims = (3, 4, 5)
self.dtype = 'float32'
self.axis = 0
class TestCase1(BaseTestCase):
def initTestCase(self):
self.op_type = 'arg_min'
self.dims = (3, 4)
self.dtype = 'float64'
self.axis = 1
class TestCase2(BaseTestCase):
def initTestCase(self):
self.op_type = 'arg_max'
self.dims = (3, 4)
self.dtype = 'int64'
self.axis = 0
class TestCase3(BaseTestCase):
def initTestCase(self):
self.op_type = 'arg_max'
self.dims = (3, )
self.dtype = 'int64'
self.axis = 0
class TestCase4(BaseTestCase):
def initTestCase(self):
self.op_type = 'arg_min'
self.dims = (1, )
self.dtype = 'int32'
self.axis = 0
if __name__ == '__main__':
unittest.main()
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import unittest
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.transpiler.distribute_transpiler import delete_ops from paddle.fluid.transpiler.distribute_transpiler import delete_ops
...@@ -54,10 +55,10 @@ class TestDistTranspiler(TranspilerTest): ...@@ -54,10 +55,10 @@ class TestDistTranspiler(TranspilerTest):
delete_ops(trainer.global_block(), optimize_ops) delete_ops(trainer.global_block(), optimize_ops)
ops = [op.type for op in trainer.global_block().ops] + [ ops = [op.type for op in trainer.global_block().ops] + [
"split_byref", "send_vars", "send_barrier", "recv", "recv", "split_byref", "send", "send_barrier", "recv", "recv",
"fetch_barrier", "concat" "fetch_barrier", "concat"
] ]
ops.insert(ops.index("elementwise_add_grad") + 1, "send_vars") ops.insert(ops.index("elementwise_add_grad") + 1, "send")
return ops return ops
......
...@@ -387,6 +387,12 @@ class TestBook(unittest.TestCase): ...@@ -387,6 +387,12 @@ class TestBook(unittest.TestCase):
self.assertIsNotNone(output) self.assertIsNotNone(output)
print(str(program)) print(str(program))
def test_l2_normalize(self):
program = Program()
with program_guard(program):
x = layers.data(name='x', shape=[8, 7, 10], dtype="float32")
output = layers.l2_normalize(x, axis=1)
def test_maxout(self): def test_maxout(self):
program = Program() program = Program()
with program_guard(program): with program_guard(program):
......
...@@ -70,17 +70,18 @@ class TestListenAndServOp(OpTest): ...@@ -70,17 +70,18 @@ class TestListenAndServOp(OpTest):
return p.pid return p.pid
def _wait_ps_ready(self, pid): def _wait_ps_ready(self, pid):
retry_times = self.ps_timeout start_left_time = self.ps_timeout
sleep_time = 0.5
while True: while True:
assert retry_times >= 0, "wait ps ready failed" assert start_left_time >= 0, "wait ps ready failed"
time.sleep(0.5) time.sleep(sleep_time)
try: try:
# the listen_and_serv_op would touch a file which contains the listen port # the listen_and_serv_op would touch a file which contains the listen port
# on the /tmp directory until it was ready to process all the RPC call. # on the /tmp directory until it was ready to process all the RPC call.
os.stat("/tmp/paddle.%d.port" % pid) os.stat("/tmp/paddle.%d.port" % pid)
return return
except os.error: except os.error:
retry_times -= 1 start_left_time -= sleep_time
def test_rpc_interfaces(self): def test_rpc_interfaces(self):
# TODO(Yancey1989): need to make sure the rpc interface correctly. # TODO(Yancey1989): need to make sure the rpc interface correctly.
......
...@@ -17,44 +17,23 @@ import numpy as np ...@@ -17,44 +17,23 @@ import numpy as np
from op_test import OpTest from op_test import OpTest
def norm(input, scale, epsilon): def l2_norm(x, axis, epsilon):
s0, s1, s2, s3 = input.shape x2 = x**2
x_square = input * input s = np.sum(x2, axis=axis, keepdims=True)
for i in xrange(s0): r = np.sqrt(s + epsilon)
input_batch = input[i:i + 1, :, :, :] y = x / np.broadcast_to(r, x.shape)
input_batch = input_batch.reshape(s1, s2 * s3) return y, r
x_square_batch = x_square[i:i + 1, :, :, :]
x_square_batch = x_square_batch.reshape(s1, s2 * s3)
square_colsum = x_square_batch.sum(axis=0) + epsilon
tmp = pow(square_colsum, 0.5)
tmp = np.reciprocal(tmp)
tmp_tile = np.tile(tmp, s1)
tmp_tile = tmp_tile.reshape(s1, s2 * s3)
scale_tile = np.tile(scale, (1, s2 * s3))
scale_tile = scale_tile.reshape(s1, s2 * s3)
out_batch = input_batch * tmp_tile * scale_tile
out_batch = out_batch.reshape(1, s1, s2, s3)
if i == 0:
out = out_batch
else:
out = np.concatenate((out, out_batch), 0)
out.reshape(s0, s1, s2, s3)
return out
class TestNormOp(OpTest): class TestNormOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "norm" self.op_type = "norm"
self.init_test_case() self.init_test_case()
input = np.random.random(self.shape).astype("float32") x = np.random.random(self.shape).astype("float64")
scale = np.array([10, 10, 10]) y, norm = l2_norm(x, self.axis, self.epsilon)
self.inputs = { self.inputs = {'X': x}
'X': input.astype('float32'), self.attrs = {'epsilon': self.epsilon, 'axis': self.axis}
'Scale': scale.astype('float32') self.outputs = {'Out': y, 'Norm': norm}
}
self.attrs = {'epsilon': self.epsilon}
output = norm(input, scale, self.epsilon)
self.outputs = {'Out': output.astype('float32')}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
...@@ -63,8 +42,23 @@ class TestNormOp(OpTest): ...@@ -63,8 +42,23 @@ class TestNormOp(OpTest):
self.check_grad(['X'], 'Out') self.check_grad(['X'], 'Out')
def init_test_case(self): def init_test_case(self):
self.shape = [2, 3, 2, 2] self.shape = [2, 3, 4, 4]
self.epsilon = 1e-6 self.axis = 1
self.epsilon = 1e-8
class TestNormOp2(TestNormOp):
def init_test_case(self):
self.shape = [5, 3, 9, 7]
self.axis = 0
self.epsilon = 1e-8
class TestNormOp3(TestNormOp):
def init_test_case(self):
self.shape = [5, 3, 2, 7]
self.axis = -1
self.epsilon = 1e-8
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -70,8 +70,9 @@ class TestNormalization(unittest.TestCase): ...@@ -70,8 +70,9 @@ class TestNormalization(unittest.TestCase):
def l2_normalize(self, data, axis, epsilon): def l2_normalize(self, data, axis, epsilon):
""" Compute the groundtruth. """ Compute the groundtruth.
""" """
output = data * np.reciprocal( output = data / np.broadcast_to(
np.sum(np.square(data), axis=axis, keepdims=True)) np.sqrt(np.sum(np.square(data), axis=axis, keepdims=True)),
data.shape)
return output return output
def test_l2_normalize(self): def test_l2_normalize(self):
......
...@@ -17,6 +17,7 @@ import paddle.fluid as fluid ...@@ -17,6 +17,7 @@ import paddle.fluid as fluid
import unittest import unittest
import paddle import paddle
import numpy as np import numpy as np
import os
word_dict, verb_dict, label_dict = conll05.get_dict() word_dict, verb_dict, label_dict = conll05.get_dict()
word_dict_len = len(word_dict) word_dict_len = len(word_dict)
...@@ -101,7 +102,11 @@ def db_lstm(word, predicate, ctx_n2, ctx_n1, ctx_0, ctx_p1, ctx_p2, mark, ...@@ -101,7 +102,11 @@ def db_lstm(word, predicate, ctx_n2, ctx_n1, ctx_0, ctx_p1, ctx_p2, mark,
class TestCRFModel(unittest.TestCase): class TestCRFModel(unittest.TestCase):
def check_network_convergence(self, is_sparse, build_strategy=None): def check_network_convergence(self,
is_sparse,
build_strategy=None,
use_cuda=True):
os.environ['CPU_NUM'] = str(4)
main = fluid.Program() main = fluid.Program()
startup = fluid.Program() startup = fluid.Program()
with fluid.program_guard(main, startup): with fluid.program_guard(main, startup):
...@@ -145,12 +150,12 @@ class TestCRFModel(unittest.TestCase): ...@@ -145,12 +150,12 @@ class TestCRFModel(unittest.TestCase):
paddle.dataset.conll05.test(), buf_size=8192), paddle.dataset.conll05.test(), buf_size=8192),
batch_size=16) batch_size=16)
place = fluid.CUDAPlace(0) place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(startup) exe.run(startup)
pe = fluid.ParallelExecutor( pe = fluid.ParallelExecutor(
use_cuda=True, use_cuda=use_cuda,
loss_name=avg_cost.name, loss_name=avg_cost.name,
build_strategy=build_strategy) build_strategy=build_strategy)
...@@ -172,25 +177,33 @@ class TestCRFModel(unittest.TestCase): ...@@ -172,25 +177,33 @@ class TestCRFModel(unittest.TestCase):
build_strategy = fluid.BuildStrategy() build_strategy = fluid.BuildStrategy()
build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.AllReduce build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.AllReduce
self.check_network_convergence( self.check_network_convergence(
is_sparse=True, build_strategy=build_strategy) is_sparse=True, build_strategy=build_strategy, use_cuda=True)
self.check_network_convergence(
is_sparse=True, build_strategy=build_strategy, use_cuda=False)
def test_update_dense_parameter_all_reduce(self): def test_update_dense_parameter_all_reduce(self):
build_strategy = fluid.BuildStrategy() build_strategy = fluid.BuildStrategy()
build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.AllReduce build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.AllReduce
self.check_network_convergence( self.check_network_convergence(
is_sparse=False, build_strategy=build_strategy) is_sparse=False, build_strategy=build_strategy, use_cuda=True)
self.check_network_convergence(
is_sparse=False, build_strategy=build_strategy, use_cuda=False)
def test_update_sparse_parameter_reduce(self): def test_update_sparse_parameter_reduce(self):
build_strategy = fluid.BuildStrategy() build_strategy = fluid.BuildStrategy()
build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce
self.check_network_convergence( self.check_network_convergence(
is_sparse=True, build_strategy=build_strategy) is_sparse=True, build_strategy=build_strategy, use_cuda=True)
self.check_network_convergence(
is_sparse=True, build_strategy=build_strategy, use_cuda=False)
def test_update_dense_parameter_reduce(self): def test_update_dense_parameter_reduce(self):
build_strategy = fluid.BuildStrategy() build_strategy = fluid.BuildStrategy()
build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce
self.check_network_convergence( self.check_network_convergence(
is_sparse=False, build_strategy=build_strategy) is_sparse=False, build_strategy=build_strategy, use_cuda=True)
self.check_network_convergence(
is_sparse=False, build_strategy=build_strategy, use_cuda=False)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -18,6 +18,7 @@ import paddle.fluid as fluid ...@@ -18,6 +18,7 @@ import paddle.fluid as fluid
import unittest import unittest
import numpy as np import numpy as np
import paddle import paddle
import os
def Lenet(data, class_dim): def Lenet(data, class_dim):
...@@ -35,7 +36,7 @@ def Lenet(data, class_dim): ...@@ -35,7 +36,7 @@ def Lenet(data, class_dim):
class TestFetchOp(unittest.TestCase): class TestFetchOp(unittest.TestCase):
def parallel_exe(self, train_inputs, seed): def parallel_exe(self, train_inputs, seed, use_cuda):
main = fluid.Program() main = fluid.Program()
startup = fluid.Program() startup = fluid.Program()
startup.random_seed = seed startup.random_seed = seed
...@@ -59,13 +60,13 @@ class TestFetchOp(unittest.TestCase): ...@@ -59,13 +60,13 @@ class TestFetchOp(unittest.TestCase):
# conv2d_1.b_0@GRAD. Those variables should not be pruned. # conv2d_1.b_0@GRAD. Those variables should not be pruned.
# fluid.memory_optimize(main) # fluid.memory_optimize(main)
place = fluid.CUDAPlace(0) place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(startup) exe.run(startup)
feeder = fluid.DataFeeder(place=place, feed_list=[data, label]) feeder = fluid.DataFeeder(place=place, feed_list=[data, label])
pe = fluid.ParallelExecutor( pe = fluid.ParallelExecutor(
use_cuda=True, loss_name=loss.name, main_program=main) use_cuda=use_cuda, loss_name=loss.name, main_program=main)
fetch_list = [] fetch_list = []
all_vars = main.global_block().vars all_vars = main.global_block().vars
...@@ -88,14 +89,16 @@ class TestFetchOp(unittest.TestCase): ...@@ -88,14 +89,16 @@ class TestFetchOp(unittest.TestCase):
for i in range(iters): for i in range(iters):
train_inputs.append(tst_reader_iter.next()) train_inputs.append(tst_reader_iter.next())
self.parallel_exe(train_inputs, seed=1) os.environ['CPU_NUM'] = str(4)
self.parallel_exe(train_inputs, seed=1, use_cuda=True)
self.parallel_exe(train_inputs, seed=1, use_cuda=False)
class TestFeedParallel(unittest.TestCase): class TestFeedParallel(unittest.TestCase):
def test_main(self): def parallel_exe(self, use_cuda, seed):
main = fluid.Program() main = fluid.Program()
startup = fluid.Program() startup = fluid.Program()
startup.random_seed = 1 startup.random_seed = seed
with fluid.scope_guard(fluid.core.Scope()): with fluid.scope_guard(fluid.core.Scope()):
with fluid.program_guard(main, startup): with fluid.program_guard(main, startup):
data = fluid.layers.data( data = fluid.layers.data(
...@@ -111,15 +114,18 @@ class TestFeedParallel(unittest.TestCase): ...@@ -111,15 +114,18 @@ class TestFeedParallel(unittest.TestCase):
regularization=fluid.regularizer.L2Decay(1e-4)) regularization=fluid.regularizer.L2Decay(1e-4))
opt.minimize(loss) opt.minimize(loss)
place = fluid.CUDAPlace(0)
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
feeder = fluid.DataFeeder(place=place, feed_list=[data, label]) feeder = fluid.DataFeeder(place=place, feed_list=[data, label])
reader = feeder.decorate_reader( reader = feeder.decorate_reader(
paddle.batch( paddle.batch(
flowers.train(), batch_size=16), multi_devices=True) flowers.train(), batch_size=16), multi_devices=True)
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(startup) exe.run(startup)
pe = fluid.ParallelExecutor( pe = fluid.ParallelExecutor(
use_cuda=True, loss_name=loss.name, main_program=main) use_cuda=use_cuda, loss_name=loss.name, main_program=main)
for batch_id, data in enumerate(reader()): for batch_id, data in enumerate(reader()):
loss_np = np.array(pe.run(feed=data, fetch_list=[loss.name])[0]) loss_np = np.array(pe.run(feed=data, fetch_list=[loss.name])[0])
...@@ -127,6 +133,11 @@ class TestFeedParallel(unittest.TestCase): ...@@ -127,6 +133,11 @@ class TestFeedParallel(unittest.TestCase):
if batch_id == 2: if batch_id == 2:
break break
def test_feed_op(self):
os.environ['CPU_NUM'] = str(4)
self.parallel_exe(use_cuda=True, seed=1)
self.parallel_exe(use_cuda=False, seed=1)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -18,6 +18,7 @@ import numpy as np ...@@ -18,6 +18,7 @@ import numpy as np
import paddle import paddle
import paddle.dataset.mnist as mnist import paddle.dataset.mnist as mnist
import unittest import unittest
import os
MNIST_RECORDIO_FILE = "./mnist_test_pe.recordio" MNIST_RECORDIO_FILE = "./mnist_test_pe.recordio"
...@@ -85,6 +86,7 @@ def fc_with_batchnorm(use_feed): ...@@ -85,6 +86,7 @@ def fc_with_batchnorm(use_feed):
class TestMNIST(TestParallelExecutorBase): class TestMNIST(TestParallelExecutorBase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
os.environ['CPU_NUM'] = str(4)
# Convert mnist to recordio file # Convert mnist to recordio file
with fluid.program_guard(fluid.Program(), fluid.Program()): with fluid.program_guard(fluid.Program(), fluid.Program()):
reader = paddle.batch(mnist.train(), batch_size=4) reader = paddle.batch(mnist.train(), batch_size=4)
...@@ -99,9 +101,12 @@ class TestMNIST(TestParallelExecutorBase): ...@@ -99,9 +101,12 @@ class TestMNIST(TestParallelExecutorBase):
fluid.recordio_writer.convert_reader_to_recordio_file( fluid.recordio_writer.convert_reader_to_recordio_file(
MNIST_RECORDIO_FILE, reader, feeder) MNIST_RECORDIO_FILE, reader, feeder)
def check_simple_fc_convergence(self, balance_parameter_opt_between_cards): def check_simple_fc_convergence(self,
self.check_network_convergence(simple_fc_net) balance_parameter_opt_between_cards,
self.check_network_convergence(simple_fc_net, allow_op_delay=True) use_cuda=True):
self.check_network_convergence(simple_fc_net, use_cuda=use_cuda)
self.check_network_convergence(
simple_fc_net, use_cuda=use_cuda, allow_op_delay=True)
img = np.zeros(shape=[32, 784], dtype='float32') img = np.zeros(shape=[32, 784], dtype='float32')
label = np.ones(shape=[32, 1], dtype='int64') label = np.ones(shape=[32, 1], dtype='int64')
...@@ -109,17 +114,21 @@ class TestMNIST(TestParallelExecutorBase): ...@@ -109,17 +114,21 @@ class TestMNIST(TestParallelExecutorBase):
simple_fc_net, simple_fc_net,
feed_dict={"image": img, feed_dict={"image": img,
"label": label}, "label": label},
use_cuda=use_cuda,
balance_parameter_opt_between_cards=balance_parameter_opt_between_cards balance_parameter_opt_between_cards=balance_parameter_opt_between_cards
) )
def test_simple_fc(self): def test_simple_fc(self):
self.check_simple_fc_convergence(False) self.check_simple_fc_convergence(False, use_cuda=True)
self.check_simple_fc_convergence(False, use_cuda=False)
def test_simple_fc_with_new_strategy(self): def test_simple_fc_with_new_strategy(self):
self.check_simple_fc_convergence(True) self.check_simple_fc_convergence(True, use_cuda=True)
self.check_simple_fc_convergence(True, use_cuda=False)
def check_simple_fc_parallel_accuracy(self, def check_simple_fc_parallel_accuracy(self,
balance_parameter_opt_between_cards): balance_parameter_opt_between_cards,
use_cuda=True):
img = np.zeros(shape=[32, 784], dtype='float32') img = np.zeros(shape=[32, 784], dtype='float32')
label = np.ones(shape=[32, 1], dtype='int64') label = np.ones(shape=[32, 1], dtype='int64')
single_first_loss, single_last_loss = self.check_network_convergence( single_first_loss, single_last_loss = self.check_network_convergence(
...@@ -127,12 +136,14 @@ class TestMNIST(TestParallelExecutorBase): ...@@ -127,12 +136,14 @@ class TestMNIST(TestParallelExecutorBase):
seed=1000, seed=1000,
feed_dict={"image": img, feed_dict={"image": img,
"label": label}, "label": label},
use_cuda=use_cuda,
use_parallel_executor=False) use_parallel_executor=False)
parallel_first_loss, parallel_last_loss = self.check_network_convergence( parallel_first_loss, parallel_last_loss = self.check_network_convergence(
method=simple_fc_net, method=simple_fc_net,
seed=1000, seed=1000,
feed_dict={"image": img, feed_dict={"image": img,
"label": label}, "label": label},
use_cuda=use_cuda,
use_parallel_executor=True, use_parallel_executor=True,
balance_parameter_opt_between_cards=balance_parameter_opt_between_cards balance_parameter_opt_between_cards=balance_parameter_opt_between_cards
) )
...@@ -143,28 +154,33 @@ class TestMNIST(TestParallelExecutorBase): ...@@ -143,28 +154,33 @@ class TestMNIST(TestParallelExecutorBase):
self.assertAlmostEquals(p_l, single_last_loss[0], delta=1e-6) self.assertAlmostEquals(p_l, single_last_loss[0], delta=1e-6)
def test_simple_fc_parallel_accuracy(self): def test_simple_fc_parallel_accuracy(self):
self.check_simple_fc_parallel_accuracy(False) self.check_simple_fc_parallel_accuracy(False, use_cuda=True)
self.check_simple_fc_parallel_accuracy(False, use_cuda=False)
def test_simple_fc_parallel_accuracy_with_new_strategy(self): def test_simple_fc_parallel_accuracy_with_new_strategy(self):
self.check_simple_fc_parallel_accuracy(True) self.check_simple_fc_parallel_accuracy(True, use_cuda=True)
self.check_simple_fc_parallel_accuracy(True, use_cuda=False)
def check_batchnorm_fc_convergence(self, def check_batchnorm_fc_convergence(
balance_parameter_opt_between_cards): self, balance_parameter_opt_between_cards, use_cuda):
self.check_network_convergence(fc_with_batchnorm) self.check_network_convergence(fc_with_batchnorm, use_cuda=use_cuda)
img = np.zeros(shape=[32, 784], dtype='float32') img = np.zeros(shape=[32, 784], dtype='float32')
label = np.ones(shape=[32, 1], dtype='int64') label = np.ones(shape=[32, 1], dtype='int64')
self.check_network_convergence( self.check_network_convergence(
fc_with_batchnorm, fc_with_batchnorm,
feed_dict={"image": img, feed_dict={"image": img,
"label": label}, "label": label},
use_cuda=use_cuda,
balance_parameter_opt_between_cards=balance_parameter_opt_between_cards balance_parameter_opt_between_cards=balance_parameter_opt_between_cards
) )
def test_batchnorm_fc(self): def test_batchnorm_fc(self):
self.check_batchnorm_fc_convergence(False) self.check_batchnorm_fc_convergence(False, use_cuda=True)
self.check_batchnorm_fc_convergence(False, use_cuda=False)
def test_batchnorm_fc_with_new_strategy(self): def test_batchnorm_fc_with_new_strategy(self):
self.check_batchnorm_fc_convergence(True) self.check_batchnorm_fc_convergence(True, use_cuda=True)
self.check_batchnorm_fc_convergence(True, use_cuda=False)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import paddle.fluid as fluid import paddle.fluid as fluid
from parallel_executor_test_base import TestParallelExecutorBase from parallel_executor_test_base import TestParallelExecutorBase
import unittest import unittest
import os
def squeeze_excitation(input, num_channels, reduction_ratio): def squeeze_excitation(input, num_channels, reduction_ratio):
...@@ -130,22 +131,30 @@ def SE_ResNeXt50Small(batch_size=2, use_feed=False): ...@@ -130,22 +131,30 @@ def SE_ResNeXt50Small(batch_size=2, use_feed=False):
class TestResnet(TestParallelExecutorBase): class TestResnet(TestParallelExecutorBase):
def check_resnet_convergence(self, balance_parameter_opt_between_cards): def check_resnet_convergence(self,
balance_parameter_opt_between_cards,
use_cuda=True,
iter=20):
os.environ['CPU_NUM'] = str(4)
import functools import functools
batch_size = 2 batch_size = 2
self.check_network_convergence( self.check_network_convergence(
functools.partial( functools.partial(
SE_ResNeXt50Small, batch_size=batch_size), SE_ResNeXt50Small, batch_size=batch_size),
iter=20, iter=iter,
batch_size=batch_size, batch_size=batch_size,
use_cuda=use_cuda,
balance_parameter_opt_between_cards=balance_parameter_opt_between_cards balance_parameter_opt_between_cards=balance_parameter_opt_between_cards
) )
def test_resnet(self): def test_resnet(self):
self.check_resnet_convergence(False) self.check_resnet_convergence(False, use_cuda=True)
self.check_resnet_convergence(False, use_cuda=False, iter=5)
def test_resnet_with_new_strategy(self): def test_resnet_with_new_strategy(self):
self.check_resnet_convergence(True) self.check_resnet_convergence(True, use_cuda=True)
self.check_resnet_convergence(True, use_cuda=False, iter=5)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import paddle.fluid as fluid import paddle.fluid as fluid
import numpy as np import numpy as np
import unittest import unittest
import os
def simple_fc_net(): def simple_fc_net():
...@@ -35,7 +36,8 @@ def simple_fc_net(): ...@@ -35,7 +36,8 @@ def simple_fc_net():
class ParallelExecutorTestingDuringTraining(unittest.TestCase): class ParallelExecutorTestingDuringTraining(unittest.TestCase):
def check_network_convergence(self, build_strategy=None): def check_network_convergence(self, use_cuda, build_strategy=None):
os.environ['CPU_NUM'] = str(4)
main = fluid.Program() main = fluid.Program()
startup = fluid.Program() startup = fluid.Program()
with fluid.program_guard(main, startup): with fluid.program_guard(main, startup):
...@@ -49,19 +51,19 @@ class ParallelExecutorTestingDuringTraining(unittest.TestCase): ...@@ -49,19 +51,19 @@ class ParallelExecutorTestingDuringTraining(unittest.TestCase):
image = np.random.normal(size=(batch_size, 784)).astype('float32') image = np.random.normal(size=(batch_size, 784)).astype('float32')
label = np.random.randint(0, 10, (batch_size, 1), dtype="int64") label = np.random.randint(0, 10, (batch_size, 1), dtype="int64")
place = fluid.CUDAPlace(0) place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(startup) exe.run(startup)
feed_dict = {'image': image, 'label': label} feed_dict = {'image': image, 'label': label}
train_exe = fluid.ParallelExecutor( train_exe = fluid.ParallelExecutor(
use_cuda=True, use_cuda=use_cuda,
loss_name=loss.name, loss_name=loss.name,
main_program=main, main_program=main,
build_strategy=build_strategy) build_strategy=build_strategy)
test_exe = fluid.ParallelExecutor( test_exe = fluid.ParallelExecutor(
use_cuda=True, use_cuda=use_cuda,
main_program=test_program, main_program=test_program,
share_vars_from=train_exe, share_vars_from=train_exe,
build_strategy=build_strategy) build_strategy=build_strategy)
...@@ -81,12 +83,18 @@ class ParallelExecutorTestingDuringTraining(unittest.TestCase): ...@@ -81,12 +83,18 @@ class ParallelExecutorTestingDuringTraining(unittest.TestCase):
def test_parallel_testing(self): def test_parallel_testing(self):
build_strategy = fluid.BuildStrategy() build_strategy = fluid.BuildStrategy()
build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.AllReduce build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.AllReduce
self.check_network_convergence(build_strategy) self.check_network_convergence(
use_cuda=True, build_strategy=build_strategy)
self.check_network_convergence(
use_cuda=False, build_strategy=build_strategy)
def test_parallel_testing_with_new_strategy(self): def test_parallel_testing_with_new_strategy(self):
build_strategy = fluid.BuildStrategy() build_strategy = fluid.BuildStrategy()
build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce
self.check_network_convergence(build_strategy) self.check_network_convergence(
use_cuda=True, build_strategy=build_strategy)
self.check_network_convergence(
use_cuda=False, build_strategy=build_strategy)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -19,6 +19,7 @@ from parallel_executor_test_base import TestParallelExecutorBase ...@@ -19,6 +19,7 @@ from parallel_executor_test_base import TestParallelExecutorBase
import unittest import unittest
import paddle import paddle
import paddle.dataset.wmt16 as wmt16 import paddle.dataset.wmt16 as wmt16
import os
WMT16_RECORDIO_FILE = "./wmt16_test_pe.recordio" WMT16_RECORDIO_FILE = "./wmt16_test_pe.recordio"
...@@ -149,6 +150,7 @@ def transformer(use_feed): ...@@ -149,6 +150,7 @@ def transformer(use_feed):
class TestTransformer(TestParallelExecutorBase): class TestTransformer(TestParallelExecutorBase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
os.environ['CPU_NUM'] = str(4)
reader = paddle.batch( reader = paddle.batch(
wmt16.train(ModelHyperParams.src_vocab_size, wmt16.train(ModelHyperParams.src_vocab_size,
ModelHyperParams.trg_vocab_size), ModelHyperParams.trg_vocab_size),
...@@ -167,7 +169,8 @@ class TestTransformer(TestParallelExecutorBase): ...@@ -167,7 +169,8 @@ class TestTransformer(TestParallelExecutorBase):
@unittest.skip("transformer is buggy in multi gpu") @unittest.skip("transformer is buggy in multi gpu")
def test_main(self): def test_main(self):
self.check_network_convergence(transformer) self.check_network_convergence(transformer, use_cuda=True)
self.check_network_convergence(transformer, use_cuda=False)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -59,9 +59,9 @@ class TestSimpleDistTranspiler(TranspilerTest): ...@@ -59,9 +59,9 @@ class TestSimpleDistTranspiler(TranspilerTest):
delete_ops(trainer.global_block(), optimize_ops) delete_ops(trainer.global_block(), optimize_ops)
ops = [op.type for op in trainer.global_block().ops] + [ ops = [op.type for op in trainer.global_block().ops] + [
"send_vars", "send_barrier", "recv", "recv", "fetch_barrier" "send", "send_barrier", "recv", "recv", "fetch_barrier"
] ]
ops.insert(ops.index("elementwise_add_grad") + 1, "send_vars") ops.insert(ops.index("elementwise_add_grad") + 1, "send")
return ops return ops
def _transpiler_instance(self): def _transpiler_instance(self):
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from op_test import OpTest
class TestSliceOp(OpTest):
def setUp(self):
self.op_type = "slice"
self.config()
self.inputs = {'Input': self.input}
self.outputs = {'Out': self.out}
self.attrs = {
'axes': self.axes,
'starts': self.starts,
'ends': self.ends
}
def config(self):
self.input = np.random.random([3, 4, 5, 6]).astype("float32")
self.starts = [1, 0, 2]
self.ends = [3, 3, 4]
self.axes = [0, 1, 2]
self.out = self.input[1:3, 0:3, 2:4, :]
def test_check_output(self):
self.check_output()
class TestCase1(TestSliceOp):
def config(self):
self.input = np.random.random([3, 4, 5, 6]).astype("float32")
self.starts = [-3, 0, 2]
self.ends = [3, 100, -1]
self.axes = [0, 1, 2]
self.out = self.input[-3:3, 0:100, 2:-1, :]
class TestCase2(TestSliceOp):
def config(self):
self.input = np.random.random([3, 4, 5, 6]).astype("float32")
self.starts = [-3, 0, 2]
self.ends = [3, 100, -1]
self.axes = [0, 1, 3]
self.out = self.input[-3:3, 0:100, :, 2:-1]
if __name__ == '__main__':
unittest.main()
...@@ -24,9 +24,9 @@ Steps to transpile trainer: ...@@ -24,9 +24,9 @@ Steps to transpile trainer:
1. split variable to multiple blocks, aligned by product(dim[1:]) (width). 1. split variable to multiple blocks, aligned by product(dim[1:]) (width).
2. rename splited grad variables to add trainer_id suffix ".trainer_%d". 2. rename splited grad variables to add trainer_id suffix ".trainer_%d".
3. modify trainer program add split_op to each grad variable. 3. modify trainer program add split_op to each grad variable.
4. append send_op to send splited variables to server and fetch 4. append send_op to send splited variables to server and
params(splited blocks or origin param) from server. 5. add recv_op to fetch params(splited blocks or origin param) from server.
5. append concat_op to merge splited blocks to update local weights. 6. append concat_op to merge splited blocks to update local weights.
Steps to transpile pserver: Steps to transpile pserver:
1. create new program for parameter server. 1. create new program for parameter server.
...@@ -317,7 +317,7 @@ class DistributeTranspiler: ...@@ -317,7 +317,7 @@ class DistributeTranspiler:
program.global_block().insert_op( program.global_block().insert_op(
index=index + 1, index=index + 1,
type="send_vars", type="send",
inputs={"X": splited_vars}, inputs={"X": splited_vars},
outputs={}, outputs={},
attrs={ attrs={
...@@ -515,35 +515,38 @@ class DistributeTranspiler: ...@@ -515,35 +515,38 @@ class DistributeTranspiler:
grad_to_block_id, None) grad_to_block_id, None)
# process distributed lookup_table # process distributed lookup_table
prefetch_block = None prefetch_var_name_to_block_id = []
if self.has_distributed_lookup_table: if self.has_distributed_lookup_table:
pserver_index = self.pserver_endpoints.index(endpoint) pserver_index = self.pserver_endpoints.index(endpoint)
table_opt_block = self._create_table_optimize_block( table_opt_block = self._create_table_optimize_block(
pserver_index, pserver_program, pre_block_idx, grad_to_block_id) pserver_index, pserver_program, pre_block_idx, grad_to_block_id)
prefetch_block = self._create_prefetch_block( prefetch_var_name_to_block_id = self._create_prefetch_block(
pserver_index, pserver_program, table_opt_block) pserver_index, pserver_program, table_opt_block)
# NOTE: if has_distributed_lookup_table is False, then prefetch_block will # NOTE: if has_distributed_lookup_table is False, then prefetch_block will
# not be executed, so it's safe to use optimize_block to hold the place # not be executed, so it's safe to use optimize_block to hold the place
if self.has_distributed_lookup_table: if self.has_distributed_lookup_table:
assert prefetch_block is not None assert len(prefetch_var_name_to_block_id) > 0
else: else:
assert prefetch_block is None assert len(prefetch_var_name_to_block_id) == 0
prefetch_block = pserver_program.global_block()
# step5 append the listen_and_serv op attrs = {
pserver_program.global_block().append_op(
type="listen_and_serv",
inputs={'X': recv_inputs},
outputs={},
attrs={
"OptimizeBlock": pserver_program.block(1), "OptimizeBlock": pserver_program.block(1),
"endpoint": endpoint, "endpoint": endpoint,
"Fanin": self.trainer_num, "Fanin": self.trainer_num,
"PrefetchBlock": prefetch_block,
"sync_mode": self.sync_mode, "sync_mode": self.sync_mode,
"grad_to_block_id": grad_to_block_id "grad_to_block_id": grad_to_block_id
}) }
if len(prefetch_var_name_to_block_id) > 0:
attrs['prefetch_var_name_to_block_id'] \
= prefetch_var_name_to_block_id
# step5 append the listen_and_serv op
pserver_program.global_block().append_op(
type="listen_and_serv",
inputs={'X': recv_inputs},
outputs={},
attrs=attrs)
pserver_program.sync_with_cpp() pserver_program.sync_with_cpp()
return pserver_program return pserver_program
...@@ -608,8 +611,15 @@ class DistributeTranspiler: ...@@ -608,8 +611,15 @@ class DistributeTranspiler:
def _replace_lookup_table_op_with_prefetch(self, program, def _replace_lookup_table_op_with_prefetch(self, program,
pserver_endpoints): pserver_endpoints):
# 1. replace lookup_table_op with split_ids_op -> prefetch_op -> sum_op # 1. replace lookup_table_op with split_ids_op -> prefetch_op -> sum_op
self.prefetch_input_vars = None # self.all_prefetch_input_vars =
self.prefetch_output_vars = None # [[var0_prefetch_in_pserver0, var0_prefetch_in_pserver1]
# [var1_prefetch_in_pserver0, var1_prefetch_in_pserver1]]
self.all_prefetch_input_vars = []
# self.all_prefetch_input_vars =
# [[var0_prefetch_in_pserver0, var0_prefetch_in_pserver1]
# [var1_prefetch_in_pserver0, var1_prefetch_in_pserver1]]
self.all_prefetch_output_vars = []
continue_search_lookup_table_op = True continue_search_lookup_table_op = True
while continue_search_lookup_table_op: while continue_search_lookup_table_op:
...@@ -623,18 +633,19 @@ class DistributeTranspiler: ...@@ -623,18 +633,19 @@ class DistributeTranspiler:
ids_name = op.input("Ids") ids_name = op.input("Ids")
out_name = op.output("Out") out_name = op.output("Out")
if self.prefetch_input_vars is None:
ids_var = program.global_block().vars[ids_name[0]] ids_var = program.global_block().vars[ids_name[0]]
self.prefetch_input_vars = self.create_splited_vars( prefetch_input_vars = self.create_splited_vars(
source_var=ids_var, source_var=ids_var,
block=program.global_block(), block=program.global_block(),
tag="_prefetch_in_") tag="_prefetch_in_")
if self.prefetch_output_vars is None: self.all_prefetch_input_vars.append(prefetch_input_vars)
out_var = program.global_block().vars[out_name[0]] out_var = program.global_block().vars[out_name[0]]
self.prefetch_output_vars = self.create_splited_vars( prefetch_output_vars = self.create_splited_vars(
source_var=out_var, source_var=out_var,
block=program.global_block(), block=program.global_block(),
tag="_prefetch_out_") tag="_prefetch_out_")
self.all_prefetch_output_vars.append(prefetch_output_vars)
# insert split_ids_op # insert split_ids_op
program.global_block().insert_op( program.global_block().insert_op(
...@@ -646,14 +657,14 @@ class DistributeTranspiler: ...@@ -646,14 +657,14 @@ class DistributeTranspiler:
for varname in ids_name for varname in ids_name
] ]
}, },
outputs={"Out": self.prefetch_input_vars}) outputs={"Out": prefetch_input_vars})
# insert prefetch_op # insert prefetch_op
program.global_block().insert_op( program.global_block().insert_op(
index=op_index + 1, index=op_index + 1,
type="prefetch", type="prefetch",
inputs={'X': self.prefetch_input_vars}, inputs={'X': prefetch_input_vars},
outputs={"Out": self.prefetch_output_vars}, outputs={"Out": prefetch_output_vars},
attrs={ attrs={
"epmap": pserver_endpoints, "epmap": pserver_endpoints,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
...@@ -663,7 +674,7 @@ class DistributeTranspiler: ...@@ -663,7 +674,7 @@ class DistributeTranspiler:
program.global_block().insert_op( program.global_block().insert_op(
index=op_index + 2, index=op_index + 2,
type="concat", type="concat",
inputs={'X': self.prefetch_output_vars}, inputs={'X': prefetch_output_vars},
outputs={ outputs={
"Out": [ "Out": [
program.global_block().vars[varname] program.global_block().vars[varname]
...@@ -678,7 +689,7 @@ class DistributeTranspiler: ...@@ -678,7 +689,7 @@ class DistributeTranspiler:
break break
def _split_table_grad_and_add_send_vars(self, program, pserver_endpoints): def _split_table_grad_and_add_send_vars(self, program, pserver_endpoints):
# 2. add split_ids_op and send_vars_op to send gradient to pservers # 2. add split_ids_op and send_op to send gradient to pservers
# there should only be one table_name # there should only be one table_name
all_ops = program.global_block().ops all_ops = program.global_block().ops
table_grad_name = grad_var_name(self.table_name) table_grad_name = grad_var_name(self.table_name)
...@@ -695,11 +706,11 @@ class DistributeTranspiler: ...@@ -695,11 +706,11 @@ class DistributeTranspiler:
outputs={"Out": self.trainer_side_table_grad_list}) outputs={"Out": self.trainer_side_table_grad_list})
program.global_block().insert_op( program.global_block().insert_op(
index=op_index + 2, index=op_index + 2,
type="send_vars", type="send",
inputs={'X': self.trainer_side_table_grad_list}, inputs={'X': self.trainer_side_table_grad_list},
outputs={}, outputs={},
attrs={ attrs={
"sync_send": True, "sync_mode": True,
"epmap": pserver_endpoints, "epmap": pserver_endpoints,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
}) })
...@@ -709,14 +720,16 @@ class DistributeTranspiler: ...@@ -709,14 +720,16 @@ class DistributeTranspiler:
optimize_block): optimize_block):
# STEP: create prefetch block # STEP: create prefetch block
table_var = pserver_program.global_block().vars[self.table_name] table_var = pserver_program.global_block().vars[self.table_name]
prefetch_var_name_to_block_id = []
for index in range(len(self.all_prefetch_input_vars)):
prefetch_block = pserver_program.create_block(optimize_block.idx) prefetch_block = pserver_program.create_block(optimize_block.idx)
trainer_ids = self.prefetch_input_vars[pserver_index] trainer_ids = self.all_prefetch_input_vars[index][pserver_index]
pserver_ids = pserver_program.global_block().create_var( pserver_ids = pserver_program.global_block().create_var(
name=trainer_ids.name, name=trainer_ids.name,
type=trainer_ids.type, type=trainer_ids.type,
shape=trainer_ids.shape, shape=trainer_ids.shape,
dtype=trainer_ids.dtype) dtype=trainer_ids.dtype)
trainer_out = self.prefetch_output_vars[pserver_index] trainer_out = self.all_prefetch_output_vars[index][pserver_index]
pserver_out = pserver_program.global_block().create_var( pserver_out = pserver_program.global_block().create_var(
name=trainer_out.name, name=trainer_out.name,
type=trainer_out.type, type=trainer_out.type,
...@@ -732,7 +745,9 @@ class DistributeTranspiler: ...@@ -732,7 +745,9 @@ class DistributeTranspiler:
"is_distributed": True, "is_distributed": True,
"padding_idx": -1 "padding_idx": -1
}) })
return prefetch_block prefetch_var_name_to_block_id.append(trainer_ids.name + ":" + str(
prefetch_block.idx))
return prefetch_var_name_to_block_id
def _create_table_optimize_block(self, pserver_index, pserver_program, def _create_table_optimize_block(self, pserver_index, pserver_program,
pre_block_idx, grad_to_block_id): pre_block_idx, grad_to_block_id):
......
...@@ -119,7 +119,8 @@ def reader_creator(data_file, ...@@ -119,7 +119,8 @@ def reader_creator(data_file,
yield sample, int(label) - 1 yield sample, int(label) - 1
if use_xmap: if use_xmap:
return xmap_readers(mapper, reader, cpu_count(), buffered_size) cpu_num = int(os.environ.get('CPU_NUM', cpu_count()))
return xmap_readers(mapper, reader, cpu_num, buffered_size)
else: else:
return map_readers(mapper, reader) return map_readers(mapper, reader)
......
...@@ -126,7 +126,8 @@ class DocstringChecker(BaseChecker): ...@@ -126,7 +126,8 @@ class DocstringChecker(BaseChecker):
'W9002': 'W9002':
('Doc string does not end with "." period', symbol + "-end-with", ('Doc string does not end with "." period', symbol + "-end-with",
'Used when a doc string does not end with a period'), 'Used when a doc string does not end with a period'),
'W9003': ('All args with their types must be mentioned in doc string', 'W9003':
('All args with their types must be mentioned in doc string %s',
symbol + "-with-all-args", symbol + "-with-all-args",
'Used when not all arguments are in the doc string '), 'Used when not all arguments are in the doc string '),
'W9005': ('Missing docstring or docstring is too short', 'W9005': ('Missing docstring or docstring is too short',
...@@ -178,6 +179,8 @@ class DocstringChecker(BaseChecker): ...@@ -178,6 +179,8 @@ class DocstringChecker(BaseChecker):
self.indent_style(node) self.indent_style(node)
def missing_doc_string(self, node): def missing_doc_string(self, node):
if node.name.startswith("__") or node.name.startswith("_"):
return True
if node.tolineno - node.fromlineno <= 10: if node.tolineno - node.fromlineno <= 10:
return True return True
...@@ -199,12 +202,16 @@ class DocstringChecker(BaseChecker): ...@@ -199,12 +202,16 @@ class DocstringChecker(BaseChecker):
doc = node.doc doc = node.doc
lines = doc.splitlines() lines = doc.splitlines()
line_num = 0
for l in lines: for l in lines:
if line_num == 0:
continue
cur_indent = len(l) - len(l.lstrip()) cur_indent = len(l) - len(l.lstrip())
if cur_indent % indent != 0: if cur_indent % indent != 0:
self.add_message('W9006', node=node, line=node.fromlineno) self.add_message('W9006', node=node, line=node.fromlineno)
return False return False
line_num += 1
return True return True
...@@ -320,15 +327,19 @@ class DocstringChecker(BaseChecker): ...@@ -320,15 +327,19 @@ class DocstringChecker(BaseChecker):
return True return True
parsed_args = doc.args parsed_args = doc.args
args_not_documented = set(args) - set(parsed_args)
if len(args) > 0 and len(parsed_args) <= 0: if len(args) > 0 and len(parsed_args) <= 0:
print "debug:parsed args: ", parsed_args self.add_message(
self.add_message('W9003', node=node, line=node.fromlineno) 'W9003',
node=node,
line=node.fromlineno,
args=list(args_not_documented))
return False return False
for t in args: for t in args:
if t not in parsed_args: if t not in parsed_args:
print t, " with (type) not in ", parsed_args self.add_message(
self.add_message('W9003', node=node, line=node.fromlineno) 'W9003', node=node, line=node.fromlineno, args=[t, ])
return False return False
return True return True
...@@ -7,13 +7,13 @@ DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" ...@@ -7,13 +7,13 @@ DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
export PYTHONPATH=$DIR:$PYTHONPATH export PYTHONPATH=$DIR:$PYTHONPATH
# The trick to remove deleted files: https://stackoverflow.com/a/2413151 # The trick to remove deleted files: https://stackoverflow.com/a/2413151
for file in $(git diff --cached --name-status | awk '$1 != "D" {print $2}'); do for file in $(git diff --name-status | awk '$1 != "D" {print $2}'); do
pylint --disable=all --load-plugins=docstring_checker \ pylint --disable=all --load-plugins=docstring_checker \
--enable=doc-string-one-line,doc-string-end-with,doc-string-with-all-args,doc-string-triple-quotes,doc-string-missing,doc-string-indent-error,doc-string-with-returns,doc-string-with-raises $file; --enable=doc-string-one-line,doc-string-end-with,doc-string-with-all-args,doc-string-triple-quotes,doc-string-missing,doc-string-indent-error,doc-string-with-returns,doc-string-with-raises $file;
TOTAL_ERRORS=$(expr $TOTAL_ERRORS + $?); TOTAL_ERRORS=$(expr $TOTAL_ERRORS + $?);
done done
#exit $TOTAL_ERRORS exit $TOTAL_ERRORS
#For now, just warning: #For now, just warning:
exit 0 #exit 0
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册