提交 55d908c9 编写于 作者: T tangwei12

Merge branch 'develop' of github.com:PaddlePaddle/Paddle into new_api_about_cpkt

......@@ -58,6 +58,8 @@ PaddlePaddle uses this [Git branching model](http://nvie.com/posts/a-successful-
create mode 100644 233
```
NOTE: The `yapf` installed by `pip install pre-commit` and `conda install -c conda-forge pre-commit` is slightly different. Paddle developers use `pip install pre-commit`.
1. Build and test
Users can build PaddlePaddle natively on Linux and Mac OS X. But to unify the building environment and to make it easy for debugging, the recommended way is [using Docker](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/howto/dev/build_en.md).
......
......@@ -29,7 +29,7 @@ RUN apt-get update && \
wget unzip unrar tar xz-utils bzip2 gzip coreutils ntp \
curl sed grep graphviz libjpeg-dev zlib1g-dev \
python-matplotlib gcc-4.8 g++-4.8 \
automake locales clang-format swig doxygen cmake \
automake locales clang-format swig cmake \
liblapack-dev liblapacke-dev \
clang-3.8 llvm-3.8 libclang-3.8-dev \
net-tools libtool ccache && \
......
......@@ -98,6 +98,8 @@ def parse_args():
'--use_fake_data',
action='store_true',
help='If set ommit the actual read data operators.')
parser.add_argument(
'--profile', action='store_true', help='If set, profile a few steps.')
parser.add_argument(
'--update_method',
type=str,
......@@ -108,8 +110,8 @@ def parse_args():
return args
def append_nccl2_prepare():
if os.getenv("PADDLE_TRAINER_ID", None) != None:
def append_nccl2_prepare(trainer_id):
if trainer_id >= 0:
# append gen_nccl_id at the end of startup program
trainer_id = int(os.getenv("PADDLE_TRAINER_ID"))
port = os.getenv("PADDLE_PSERVER_PORT")
......@@ -136,12 +138,12 @@ def append_nccl2_prepare():
})
return nccl_id_var, num_trainers, trainer_id
else:
raise Exception(
"must set PADDLE_TRAINER_ID env variables for dist train.")
raise Exception("must set positive PADDLE_TRAINER_ID env variables for "
"nccl-based dist train.")
def dist_transpile():
if "PADDLE_TRAINING_ROLE" not in os.environ:
def dist_transpile(trainer_id):
if trainer_id < 0:
return None, None
# the port of all pservers, needed by both trainer and pserver
......@@ -158,9 +160,6 @@ def dist_transpile():
trainers = int(os.getenv("PADDLE_TRAINERS"))
# the IP of the local machine, needed by pserver only
current_endpoint = os.getenv("PADDLE_CURRENT_IP", "") + ":" + port
# the unique trainer id, starting from 0, needed by trainer
# only
trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
# the role, should be either PSERVER or TRAINER
training_role = os.getenv("PADDLE_TRAINING_ROLE")
......@@ -295,6 +294,11 @@ def train_parallel(avg_loss, infer_prog, optimizer, train_reader, test_reader,
iters = 0
start_time = time.time()
for batch_id, data in enumerate(train_reader()):
if args.profile and pass_id == 0 and batch_id == 5:
profiler.start_profiler("All")
elif args.profile and pass_id == 0 and batch_id == 10:
profiler.stop_profiler("total", "/tmp/profile_%d" % trainer_id)
if iters == args.skip_batch_num:
start_time = time.time()
num_samples = 0
......@@ -334,7 +338,11 @@ def print_arguments(args):
def main():
args = parse_args()
print_arguments(args)
nccl_id_var, num_trainers, trainer_id = None, 1, 0
# the unique trainer id, starting from 0, needed by trainer
# only
nccl_id_var, num_trainers, trainer_id = (
None, 1, int(os.getenv("PADDLE_TRAINER_ID", "-1")))
if args.use_cprof:
pr = cProfile.Profile()
......@@ -348,7 +356,7 @@ def main():
fluid.memory_optimize(fluid.default_main_program())
if args.update_method == "pserver":
train_prog, startup_prog = dist_transpile()
train_prog, startup_prog = dist_transpile(trainer_id)
if not train_prog:
raise Exception(
"Must configure correct environments to run dist train.")
......@@ -364,7 +372,7 @@ def main():
train_args.append(fluid.default_startup_program())
if args.update_method == "nccl2":
nccl_id_var, num_trainers, trainer_id = append_nccl2_prepare()
nccl_id_var, num_trainers, trainer_id = append_nccl2_prepare(trainer_id)
if args.gpus == 1:
# NOTE: parallel executor use profiler interanlly
if args.use_nvprof and args.device == 'GPU':
......
......@@ -49,7 +49,7 @@ def parse_args():
parser.add_argument(
'--fluid', default=1, type=int, help='whether is fluid job')
parser.add_argument(
'--rdma', action='store_ture', help='whether mount rdma libs')
'--rdma', action='store_true', help='whether mount rdma libs')
parser.add_argument(
'--disttype',
default="pserver",
......
......@@ -86,7 +86,7 @@
<br>
<p align="center">
<img src="https://raw.githubusercontent.com/PaddlePaddle/Paddle/develop/doc/fluid/images/fluid_compiler.png" width=100%>
<img src="https://raw.githubusercontent.com/PaddlePaddle/Paddle/develop/doc/fluid/images/fluid-compiler.png" width=100%>
</p>
---
......@@ -123,12 +123,12 @@
<font size=5>
- 在科学计算领域,计算图是一种描述计算的经典方式。下图展示了从前向计算图(蓝色)开始,通过添加反向(红色)和优化算法相关(绿色)操作,构建出整个计算图的过程:
-
-
<p align="center">
<img src="https://raw.githubusercontent.com/PaddlePaddle/Paddle/develop/doc/fluid/images/graph_construction_example_all.png" width=60%>
</p>
- Fluid ==使用`Program`而不是计算图==来描述模型和优化过程。`Program``Block``Operator``Variable`构成,相关概念会在后文详细展开。
- 编译时 Fluid 接受前向计算(这里可以先简单的理解为是一段有序的计算流)`Program`,为这段前向计算按照:前向 -> 反向 -> 梯度 clip -> 正则 -> 优化 的顺序,添加相关 `Operator``Variable``Program`到完整的计算。
......@@ -328,7 +328,7 @@
</font>
---
---
### 编译时概念 :==**[Transpiler](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/fluid/design/motivation/fluid_compiler.md)**==
<font size=5>
......@@ -402,7 +402,7 @@
- `Scope`
- 计算相关
- `Block`
- `Block`
- `Kernel``OpWithKernel``OpWithoutKernel`
<table>
......@@ -439,7 +439,7 @@
</tbody>
</table>
- 执行相关 :`Executor`
- 执行相关 :`Executor`
</font>
......@@ -798,7 +798,7 @@ class GPUAllocator : public SystemAllocator {
- step 1:添加Place类型,<span style="background-color:#DAB1D5;">由用户实现添加到框架</span>
- 可以将Place类型理解为一个整数加上一个枚举型,包括:设备号 + 设备类型
<p align="center">
<img src="https://raw.githubusercontent.com/PaddlePaddle/Paddle/develop/doc/fluid/images/place.png" width=40%>
</p>
......@@ -824,7 +824,7 @@ class GPUAllocator : public SystemAllocator {
1. DataType 执行数据类型 FP32/FP64/INT32/INT64
1. Memory layout: 运行时 Tensor 在内存中的排布格式 NCHW、 NHWC
1. 使用的库
来区分Kernel,为同一个operator注册多个 Kernel。
```cpp
......@@ -876,7 +876,7 @@ step 3: 运行时的 KernelType 推断和Kernel切换,<span style="background-
namespace framework {
using LoDTensorArray = std::vector<LoDTensor>;
}
}
}
```
- 每一次循环,从原始输入中“切出”一个片段
- LoDTensorArray 在Python端暴露,是Fluid支持的基础数据结构之一,用户可以直接创建并使用
......@@ -910,7 +910,7 @@ void Run(const framework::Scope &scope,
false /*create_local_scope*/);
}
}
```
</font>
......@@ -951,7 +951,7 @@ void Run(const framework::Scope &scope,
---
#### dynamicRNN 中的 Memory
#### dynamicRNN 中的 Memory
<font size=5>
......@@ -961,7 +961,7 @@ void Run(const framework::Scope &scope,
- `memory` 在 operator A 前向计算之后,进行前向计算
-`memory` 的前向计算会 "指向" A 的输出 LoDTensor
- `memory` 的输出可以是另一个 operator 的输入,于是形成了“循环”连接
</font>
---
......@@ -1107,7 +1107,7 @@ void Run(const framework::Scope &scope,
<td>
<p align="center">
<img src="https://raw.githubusercontent.com/PaddlePaddle/Paddle/develop/doc/fluid/images/fluid_module_1.png" width=60%>
</p>
</p>
</td>
<td>
<p align="center">
......@@ -1127,13 +1127,13 @@ void Run(const framework::Scope &scope,
<font size=5>
- 设计概览
- 重构概览 [->](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/refactorization.md)
- fluid [->](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/fluid.md)
- 重构概览 [->](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/refactorization.md)
- fluid [->](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/fluid.md)
- fluid_compiler [->](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/fluid/design/motivation/fluid_compiler.md)
- 核心概念
- variable 描述 [->](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/var_desc.md)
- Tensor [->](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/tensor.md)
- LoDTensor [->](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/lod_tensor.md)
- LoDTensor [->](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/lod_tensor.md)
- TensorArray [->](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/tensor_array.md)
- Program [->](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/program.md)
- Block [->](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/block.md)
......@@ -1152,7 +1152,7 @@ void Run(const framework::Scope &scope,
- 支持新设硬件设备库 [->](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/support_new_device.md)
- 添加新的Operator [->](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/howto/dev/new_op_cn.md)
- 添加新的Kernel [->](
https://github.com/PaddlePaddle/Paddle/blob/develop/doc/howto/dev/new_op_kernel_en.md)
https://github.com/PaddlePaddle/Paddle/blob/develop/doc/howto/dev/new_op_kernel_en.md)
</font>
......@@ -1167,10 +1167,10 @@ https://github.com/PaddlePaddle/Paddle/blob/develop/doc/howto/dev/new_op_kernel_
<font size=5>
Docker编译PaddlePaddle源码: [->](http://www.paddlepaddle.org/docs/develop/documentation/fluid/zh/build_and_install/docker_install_cn.html)
PaddlePaddle 在 Dockerhub 地址:[->](
https://hub.docker.com/r/paddlepaddle/paddle/tags/)
1. 获取PaddlePaddle的Docker镜像
```bash
docker pull paddlepaddle/paddle:latest-dev
......@@ -1183,7 +1183,7 @@ PaddlePaddle 在 Dockerhub 地址:[->](
```
1. 进入docker container后,从源码编译,请参考文档 [->]( http://www.paddlepaddle.org/docs/develop/documentation/fluid/zh/build_and_install/build_from_source_cn.html)
</font>
---
......@@ -1196,7 +1196,7 @@ PaddlePaddle 在 Dockerhub 地址:[->](
1. 开发推荐使用tag为`latest-dev`的镜像,其中打包了所有编译依赖。`latest``lastest-gpu`是production镜像,主要用于运行PaddlePaddle程序。
2. 在Docker中运行GPU程序,推荐使用nvidia-docker,[否则需要将CUDA库和设备挂载到Docker容器内](http://www.paddlepaddle.org/docs/develop/documentation/fluid/zh/build_and_install/docker_install_cn.html)
<font size=4>
```bash
nvidia-docker run -it -v $PWD/Paddle:/paddle paddlepaddle/paddle:latest-dev /bin/bash
```
......@@ -1353,9 +1353,9 @@ Op注册实现在`.cc`文件;Kernel注册CPU实现在`.cc`文件中,CUDA实
}
};
```
</font>
---
###### 实现带Kernel的Operator <span style="background-color:#c4e1e1;">step2</span>: 定义Operator类
......@@ -1420,11 +1420,11 @@ class ClipOp : public framework::OperatorWithKernel {
2. override InferShape函数(参考 [clip_op](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/fluid/operators/clip_op.cc#L24)
1. 什么是`functor` ?
- 类或结构体仅重载了`()`,一般是可被多个kernel复用的计算函数。
<font size=4>
```cpp
template <typename T>
class CrossEntropyFunctor<platform::CPUDeviceContext, T> {
......@@ -1438,9 +1438,9 @@ class ClipOp : public framework::OperatorWithKernel {
};
```
</font>
- 在 clip_op 内也会看到将一段计算函数抽象为functor的使用法: [->](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/fluid/operators/clip_op.h#L27)。
</font>
---
......@@ -1504,7 +1504,7 @@ class ClipKernel : public framework::OpKernel<T> {
- 需要注意,<span style="background-color:#e1c4c4;">Fluid中,不区分Cost Op和中间层Op,所有Op都必须正确处理接收到的梯度</span>
2. 反向Op的输出
- 对可学习参数的求导结果
- 对所有输入的求导结果
- 对所有输入的求导结果
</font>
......@@ -1520,7 +1520,7 @@ class ClipKernel : public framework::OpKernel<T> {
1.`.cc`文件中注册前向、反向Op类,注册CPU Kernel。
<font size=4>
```cpp
namespace ops = paddle::operators;
REGISTER_OP(clip, ops::ClipOp, ops::ClipOpMaker<float>, clip_grad,
......@@ -1530,13 +1530,13 @@ class ClipKernel : public framework::OpKernel<T> {
REGISTER_OP_CPU_KERNEL(
clip_grad, ops::ClipGradKernel<paddle::platform::CPUDeviceContext, float>);
```
- 在上面的代码片段中:
1. `REGISTER_OP` : 注册`ops::ClipOp`类,类型名为`clip`,该类的`ProtoMaker`为`ops::ClipOpMaker`,注册`ops::ClipOpGrad`,类型名为`clip_grad`
1. `REGISTER_OP_WITHOUT_GRADIENT` : 用于注册没有反向的Op,例如:优化算法相关的Op
1. `REGISTER_OP_CPU_KERNEL` :注册`ops::ClipKernel`类,并特化模板参数为`paddle::platform::CPUPlace`和`float`类型,同理,注册`ops::ClipGradKernel`类
</font>
1. 按照同样方法,在`.cu`文件中注册GPU Kernel
- <span style="background-color:#e1c4c4;">如果CUDA Kernel的实现基于Eigen,需在 `.cu`的开始加上宏定义 `#define EIGEN_USE_GPU` </span>
......@@ -1593,7 +1593,7 @@ class ClipKernel : public framework::OpKernel<T> {
```bash
make test ARGS="-R test_mul_op -V"
```
或者:
```
......@@ -1613,7 +1613,7 @@ class ClipKernel : public framework::OpKernel<T> {
- 如果多个Op依赖一些共用的函数,可以创建非`*_op.*`格式的文件来存放,如`gather.h`文件。
</font>
---
### ==10.== 使用相关问题
......@@ -1735,7 +1735,7 @@ class ClipKernel : public framework::OpKernel<T> {
y_data = np.random.randint(0, 8, [1]).astype("int32")
y_tensor = core.Tensor()
y_tensor.set(y_data, place)
x_data = np.random.uniform(0.1, 1, [11, 8]).astype("float32")
x_tensor = core.Tensor()
x_tensor.set(x_data, place)
......
......@@ -17,3 +17,4 @@
:maxdepth: 1
concepts/use_concepts_cn.rst
developer's_guide_to_paddle_fluid.md
......@@ -16,3 +16,4 @@ Here is an example of linear regression. It introduces workflow of PaddlePaddle,
:maxdepth: 1
concepts/index_en.rst
developer's_guide_to_paddle_fluid.md
......@@ -11,7 +11,7 @@ PaddlePaddle支持使用pip快速安装,目前支持CentOS 6以上, Ubuntu 14.
pip install paddlepaddle
如果需要安装支持GPU的版本(cuda7.5_cudnn5_avx_openblas),需要执行:
如果需要安装支持GPU的版本(cuda8.0_cudnn5_avx_openblas),需要执行:
.. code-block:: bash
......@@ -28,18 +28,18 @@ PaddlePaddle支持使用pip快速安装,目前支持CentOS 6以上, Ubuntu 14.
import paddle.dataset.uci_housing as uci_housing
import paddle.fluid as fluid
with fluid.scope_guard(fluid.core.Scope()):
# initialize executor with cpu
exe = fluid.Executor(place=fluid.CPUPlace())
# load inference model
# load inference model
[inference_program, feed_target_names,fetch_targets] = \
fluid.io.load_inference_model(uci_housing.fluid_model(), exe)
# run inference
result = exe.run(inference_program,
feed={feed_target_names[0]: uci_housing.predict_reader()},
result = exe.run(inference_program,
feed={feed_target_names[0]: uci_housing.predict_reader()},
fetch_list=fetch_targets)
# print predicted price is $12,273.97
# print predicted price is $12,273.97
print 'Predicted price: ${:,.2f}'.format(result[0][0][0] * 1000)
执行 :code:`python housing.py` 瞧! 它应该打印出预测住房数据的清单。
......@@ -12,7 +12,7 @@ Simply run the following command to install, the version is cpu_avx_openblas:
pip install paddlepaddle
If you need to install GPU version (cuda7.5_cudnn5_avx_openblas), run:
If you need to install GPU version (cuda8.0_cudnn5_avx_openblas), run:
.. code-block:: bash
......@@ -31,18 +31,18 @@ code:
import paddle.dataset.uci_housing as uci_housing
import paddle.fluid as fluid
with fluid.scope_guard(fluid.core.Scope()):
# initialize executor with cpu
exe = fluid.Executor(place=fluid.CPUPlace())
# load inference model
# load inference model
[inference_program, feed_target_names,fetch_targets] = \
fluid.io.load_inference_model(uci_housing.fluid_model(), exe)
# run inference
result = exe.run(inference_program,
feed={feed_target_names[0]: uci_housing.predict_reader()},
result = exe.run(inference_program,
feed={feed_target_names[0]: uci_housing.predict_reader()},
fetch_list=fetch_targets)
# print predicted price is $12,273.97
# print predicted price is $12,273.97
print 'Predicted price: ${:,.2f}'.format(result[0][0][0] * 1000)
Run :code:`python housing.py` and voila! It should print out a list of predictions
......
......@@ -51,6 +51,8 @@ Paddle 开发人员使用 [pre-commit](http://pre-commit.com/) 工具来管理 G
Paddle 使用 `clang-format` 来调整 C/C++ 源代码格式,请确保 `clang-format` 版本在 3.8 以上。
注:通过`pip install pre-commit``conda install -c conda-forge pre-commit`安装的`yapf`稍有不同的,Paddle 开发人员使用的是`pip install pre-commit`
## 开始开发
在本例中,我删除了 README.md 中的一行,并创建了一个新文件。
......
......@@ -13,7 +13,11 @@
# limitations under the License.
#
function(inference_api_test TARGET_NAME TEST_SRC DEP_TEST)
if(APPLE)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-error=pessimizing-move")
endif(APPLE)
function(inference_api_test TARGET_NAME TEST_SRC)
set(options "")
set(oneValueArgs "")
set(multiValueArgs ARGS)
......@@ -32,8 +36,10 @@ function(inference_api_test TARGET_NAME TEST_SRC DEP_TEST)
string(REGEX REPLACE "^_$" "" arg "${arg}")
cc_test(${TARGET_NAME}
SRCS ${TEST_SRC}
DEPS paddle_fluid_api paddle_inference_api paddle_inference_api_impl
DEPS paddle_fluid_api paddle_inference_api
ARGS --dirname=${PYTHON_TESTS_DIR}/book/)
# TODO(panyx0178): Figure out how to add word2vec and image_classification
# as deps.
# set_tests_properties(${TARGET_NAME}
# PROPERTIES DEPENDS ${DEP_TEST})
endforeach()
......@@ -41,17 +47,12 @@ endfunction(inference_api_test)
cc_library(paddle_inference_api
SRCS paddle_inference_api.cc
SRCS paddle_inference_api.cc paddle_inference_api_impl.cc
DEPS ${FLUID_CORE_MODULES} ${GLOB_OP_LIB})
cc_library(paddle_inference_api_impl
SRCS paddle_inference_api_impl.cc
DEPS paddle_inference_api paddle_fluid_api)
cc_test(test_paddle_inference_api
SRCS test_paddle_inference_api.cc
DEPS paddle_inference_api)
inference_api_test(test_paddle_inference_api_impl
test_paddle_inference_api_impl.cc
test_word2vec)
test_paddle_inference_api_impl.cc)
......@@ -45,10 +45,10 @@ struct PaddleTensor {
};
/*
* A simple Inference API for Paddle. Currently this API might just be used by
* non-sequence scenerios.
* TODO(Superjomn) Prepare another API for NLP-related usages.
*/
* A simple Inference API for Paddle. Currently this API can be used by
* non-sequence scenerios.
* TODO(Superjomn) Support another API for NLP-related usages.
*/
class PaddlePredictor {
public:
struct Config;
......@@ -66,34 +66,38 @@ class PaddlePredictor {
// be thread-safe.
virtual std::unique_ptr<PaddlePredictor> Clone() = 0;
virtual bool InitShared() { return false; }
// Destroy the Predictor.
virtual ~PaddlePredictor() {}
friend std::unique_ptr<PaddlePredictor> CreatePaddlePredictor(
const PaddlePredictor::Config& config);
enum class EngineKind {
kNative = -1, // Use the native Fluid facility.
// TODO(Superjomn) support latter.
// kAnakin, // Use Anakin for inference.
// kTensorRT, // Use TensorRT for inference.
// kAutoMixedAnakin, // Automatically mix Fluid with Anakin.
// kAutoMixedTensorRT, // Automatically mix Fluid with TensorRT.
};
// The common configs for all the predictors.
struct Config {
enum class EngineKind;
std::string model_dir; // path to the model directory.
bool enable_engine{false}; // Enable to execute (part of) the model on
// third-party engines.
EngineKind engine_kind{Config::EngineKind::kNone};
enum class EngineKind {
kNone = -1, // Use the native Fluid facility.
kAnakin, // Use Anakin for inference.
kTensorRT, // Use TensorRT for inference.
kAutoMixedAnakin, // Automatically mix Fluid with Anakin.
kAutoMixedTensorRT, // Automatically mix Fluid with TensorRT.
};
};
};
struct NativeConfig : public PaddlePredictor::Config {
bool use_gpu{false};
int device;
float fraction_of_gpu_memory;
std::string prog_file;
std::string param_file;
bool share_variables;
};
// A factory to help create difference predictor.
template <typename ConfigT>
template <
typename ConfigT,
PaddlePredictor::EngineKind engine = PaddlePredictor::EngineKind::kNative>
std::unique_ptr<PaddlePredictor> CreatePaddlePredictor(const ConfigT& config);
} // namespace paddle
......@@ -54,7 +54,7 @@ std::string num2str(T a) {
}
} // namespace
bool PaddlePredictorImpl::Init() {
bool NativePaddlePredictor::Init() {
VLOG(3) << "Predictor::init()";
// TODO(panyx0718): Should CPU vs GPU device be decided by id?
......@@ -96,14 +96,14 @@ bool PaddlePredictorImpl::Init() {
return true;
}
bool PaddlePredictorImpl::Run(const std::vector<PaddleTensor> &inputs,
std::vector<PaddleTensor> *output_data) {
bool NativePaddlePredictor::Run(const std::vector<PaddleTensor> &inputs,
std::vector<PaddleTensor> *output_data) {
VLOG(3) << "Predictor::predict";
Timer timer;
timer.tic();
// set feed variable
std::map<std::string, const paddle::framework::LoDTensor *> feed_targets;
std::vector<paddle::framework::LoDTensor> feeds;
std::map<std::string, const framework::LoDTensor *> feed_targets;
std::vector<framework::LoDTensor> feeds;
if (!SetFeed(inputs, &feeds)) {
LOG(ERROR) << "fail to set feed";
return false;
......@@ -112,8 +112,8 @@ bool PaddlePredictorImpl::Run(const std::vector<PaddleTensor> &inputs,
feed_targets[feed_target_names_[i]] = &feeds[i];
}
// get fetch variable
std::map<std::string, paddle::framework::LoDTensor *> fetch_targets;
std::vector<paddle::framework::LoDTensor> fetchs;
std::map<std::string, framework::LoDTensor *> fetch_targets;
std::vector<framework::LoDTensor> fetchs;
fetchs.resize(fetch_target_names_.size());
for (size_t i = 0; i < fetch_target_names_.size(); ++i) {
fetch_targets[fetch_target_names_[i]] = &fetchs[i];
......@@ -133,76 +133,33 @@ bool PaddlePredictorImpl::Run(const std::vector<PaddleTensor> &inputs,
return true;
}
std::unique_ptr<PaddlePredictor> PaddlePredictorImpl::Clone() {
std::unique_ptr<PaddlePredictor> NativePaddlePredictor::Clone() {
VLOG(3) << "Predictor::clone";
std::unique_ptr<PaddlePredictor> cls(new PaddlePredictorImpl(config_));
if (!cls->InitShared()) {
LOG(ERROR) << "fail to call InitShared";
std::unique_ptr<PaddlePredictor> cls(new NativePaddlePredictor(config_));
if (!dynamic_cast<NativePaddlePredictor *>(cls.get())->Init()) {
LOG(ERROR) << "fail to call Init";
return nullptr;
}
// fix manylinux compile error.
return std::move(cls);
}
// TODO(panyx0718): Consider merge with Init()?
bool PaddlePredictorImpl::InitShared() {
VLOG(3) << "Predictor::init_shared";
// 1. Define place, executor, scope
if (this->config_.device >= 0) {
place_ = paddle::platform::CUDAPlace();
} else {
place_ = paddle::platform::CPUPlace();
}
this->executor_.reset(new paddle::framework::Executor(this->place_));
this->scope_.reset(new paddle::framework::Scope());
// Initialize the inference program
if (!this->config_.model_dir.empty()) {
// Parameters are saved in separate files sited in
// the specified `dirname`.
this->inference_program_ = paddle::inference::Load(
this->executor_.get(), this->scope_.get(), this->config_.model_dir);
} else if (!this->config_.prog_file.empty() &&
!this->config_.param_file.empty()) {
// All parameters are saved in a single file.
// The file names should be consistent with that used
// in Python API `fluid.io.save_inference_model`.
this->inference_program_ =
paddle::inference::Load(this->executor_.get(),
this->scope_.get(),
this->config_.prog_file,
this->config_.param_file);
}
this->ctx_ = this->executor_->Prepare(*this->inference_program_, 0);
// 3. create variables
// TODO(panyx0718): why test share_variables.
if (config_.share_variables) {
this->executor_->CreateVariables(
*this->inference_program_, this->scope_.get(), 0);
}
// 4. Get the feed_target_names and fetch_target_names
this->feed_target_names_ = this->inference_program_->GetFeedTargetNames();
this->fetch_target_names_ = this->inference_program_->GetFetchTargetNames();
return true;
}
bool PaddlePredictorImpl::SetFeed(
const std::vector<PaddleTensor> &inputs,
std::vector<paddle::framework::LoDTensor> *feeds) {
bool NativePaddlePredictor::SetFeed(const std::vector<PaddleTensor> &inputs,
std::vector<framework::LoDTensor> *feeds) {
VLOG(3) << "Predictor::set_feed";
if (inputs.size() != feed_target_names_.size()) {
LOG(ERROR) << "wrong feed input size.";
return false;
}
for (size_t i = 0; i < feed_target_names_.size(); ++i) {
paddle::framework::LoDTensor input;
paddle::framework::DDim ddim =
paddle::framework::make_ddim(inputs[i].shape);
framework::LoDTensor input;
framework::DDim ddim = framework::make_ddim(inputs[i].shape);
void *input_ptr;
if (inputs[i].dtype == PaddleDType::INT64) {
input_ptr =
input.mutable_data<int64_t>(ddim, paddle::platform::CPUPlace());
input_ptr = input.mutable_data<int64_t>(ddim, platform::CPUPlace());
} else if (inputs[i].dtype == PaddleDType::FLOAT32) {
input_ptr = input.mutable_data<float>(ddim, paddle::platform::CPUPlace());
input_ptr = input.mutable_data<float>(ddim, platform::CPUPlace());
} else {
LOG(ERROR) << "unsupported feed type " << inputs[i].dtype;
return false;
......@@ -213,13 +170,12 @@ bool PaddlePredictorImpl::SetFeed(
inputs[i].data.data,
inputs[i].data.length);
feeds->push_back(input);
LOG(ERROR) << "Actual feed type " << feeds->back().type().name();
}
return true;
}
bool PaddlePredictorImpl::GetFetch(
const std::vector<paddle::framework::LoDTensor> &fetchs,
bool NativePaddlePredictor::GetFetch(
const std::vector<framework::LoDTensor> &fetchs,
std::vector<PaddleTensor> *outputs) {
VLOG(3) << "Predictor::get_fetch";
outputs->resize(fetchs.size());
......@@ -284,27 +240,30 @@ bool PaddlePredictorImpl::GetFetch(
return true;
}
std::unique_ptr<PaddlePredictorImpl> CreatePaddlePredictorImpl(
const VisConfig &config) {
VLOG(3) << "create PaddlePredictorImpl";
// 1. GPU memeroy
std::vector<std::string> flags;
if (config.fraction_of_gpu_memory >= 0.0f ||
config.fraction_of_gpu_memory <= 0.95f) {
flags.push_back("dummpy");
std::string flag = "--fraction_of_gpu_memory_to_use=" +
num2str<float>(config.fraction_of_gpu_memory);
flags.push_back(flag);
VLOG(3) << "set flag: " << flag;
framework::InitGflags(flags);
template <>
std::unique_ptr<PaddlePredictor>
CreatePaddlePredictor<NativeConfig, PaddlePredictor::EngineKind::kNative>(
const NativeConfig &config) {
VLOG(3) << "create NativePaddlePredictor";
if (config.use_gpu) {
// 1. GPU memeroy
std::vector<std::string> flags;
if (config.fraction_of_gpu_memory >= 0.0f ||
config.fraction_of_gpu_memory <= 0.95f) {
flags.push_back("dummpy");
std::string flag = "--fraction_of_gpu_memory_to_use=" +
num2str<float>(config.fraction_of_gpu_memory);
flags.push_back(flag);
VLOG(3) << "set flag: " << flag;
framework::InitGflags(flags);
}
}
std::unique_ptr<PaddlePredictorImpl> predictor(
new PaddlePredictorImpl(config));
if (!predictor->Init()) {
std::unique_ptr<PaddlePredictor> predictor(new NativePaddlePredictor(config));
if (!dynamic_cast<NativePaddlePredictor *>(predictor.get())->Init()) {
return nullptr;
}
return predictor;
return std::move(predictor);
}
} // namespace paddle
......@@ -29,20 +29,10 @@
namespace paddle {
struct VisConfig : public PaddlePredictor::Config {
int device;
float fraction_of_gpu_memory;
std::string prog_file;
std::string param_file;
bool share_variables;
};
/*
* Do not use this, just a demo indicating how to customize a Predictor.
*/
class PaddlePredictorImpl : public PaddlePredictor {
class NativePaddlePredictor : public PaddlePredictor {
public:
explicit PaddlePredictorImpl(const VisConfig &config) : config_(config) {}
explicit NativePaddlePredictor(const NativeConfig &config)
: config_(config) {}
bool Init();
......@@ -51,26 +41,22 @@ class PaddlePredictorImpl : public PaddlePredictor {
std::unique_ptr<PaddlePredictor> Clone() override;
~PaddlePredictorImpl() override{};
~NativePaddlePredictor() override{};
private:
bool InitShared() override;
bool SetFeed(const std::vector<PaddleTensor> &input_datas,
std::vector<paddle::framework::LoDTensor> *feeds);
bool GetFetch(const std::vector<paddle::framework::LoDTensor> &fetchs,
std::vector<framework::LoDTensor> *feeds);
bool GetFetch(const std::vector<framework::LoDTensor> &fetchs,
std::vector<PaddleTensor> *output_data);
VisConfig config_;
paddle::platform::Place place_;
std::unique_ptr<paddle::framework::Executor> executor_;
std::unique_ptr<paddle::framework::Scope> scope_;
std::unique_ptr<paddle::framework::ExecutorPrepareContext> ctx_;
std::unique_ptr<paddle::framework::ProgramDesc> inference_program_;
NativeConfig config_;
platform::Place place_;
std::unique_ptr<framework::Executor> executor_;
std::unique_ptr<framework::Scope> scope_;
std::unique_ptr<framework::ExecutorPrepareContext> ctx_;
std::unique_ptr<framework::ProgramDesc> inference_program_;
std::vector<std::string> feed_target_names_;
std::vector<std::string> fetch_target_names_;
};
std::unique_ptr<PaddlePredictorImpl> CreatePaddlePredictorImpl(
const VisConfig &config);
} // namespace paddle
......@@ -40,16 +40,20 @@ PaddleTensor LodTensorToPaddleTensor(framework::LoDTensor* t) {
return pt;
}
TEST(paddle_inference_api_impl, word2vec) {
VisConfig config;
NativeConfig GetConfig() {
NativeConfig config;
config.model_dir = FLAGS_dirname + "word2vec.inference.model";
LOG(INFO) << "dirname " << config.model_dir;
config.fraction_of_gpu_memory = 0.15;
config.use_gpu = true;
config.device = 0;
config.share_variables = true;
return config;
}
std::unique_ptr<PaddlePredictorImpl> predictor =
CreatePaddlePredictorImpl(config);
TEST(paddle_inference_api_impl, word2vec) {
NativeConfig config = GetConfig();
auto predictor = CreatePaddlePredictor<NativeConfig>(config);
framework::LoDTensor first_word, second_word, third_word, fourth_word;
framework::LoD lod{{0, 1}};
......@@ -60,24 +64,90 @@ TEST(paddle_inference_api_impl, word2vec) {
SetupLoDTensor(&third_word, lod, static_cast<int64_t>(0), dict_size - 1);
SetupLoDTensor(&fourth_word, lod, static_cast<int64_t>(0), dict_size - 1);
std::vector<PaddleTensor> cpu_feeds;
cpu_feeds.push_back(LodTensorToPaddleTensor(&first_word));
cpu_feeds.push_back(LodTensorToPaddleTensor(&second_word));
cpu_feeds.push_back(LodTensorToPaddleTensor(&third_word));
cpu_feeds.push_back(LodTensorToPaddleTensor(&fourth_word));
std::vector<PaddleTensor> paddle_tensor_feeds;
paddle_tensor_feeds.push_back(LodTensorToPaddleTensor(&first_word));
paddle_tensor_feeds.push_back(LodTensorToPaddleTensor(&second_word));
paddle_tensor_feeds.push_back(LodTensorToPaddleTensor(&third_word));
paddle_tensor_feeds.push_back(LodTensorToPaddleTensor(&fourth_word));
std::vector<PaddleTensor> outputs;
ASSERT_TRUE(predictor->Run(paddle_tensor_feeds, &outputs));
ASSERT_EQ(outputs.size(), 1UL);
size_t len = outputs[0].data.length;
float* data = static_cast<float*>(outputs[0].data.data);
for (int j = 0; j < len / sizeof(float); ++j) {
ASSERT_LT(data[j], 1.0);
ASSERT_GT(data[j], -1.0);
}
std::vector<paddle::framework::LoDTensor*> cpu_feeds;
cpu_feeds.push_back(&first_word);
cpu_feeds.push_back(&second_word);
cpu_feeds.push_back(&third_word);
cpu_feeds.push_back(&fourth_word);
framework::LoDTensor output1;
std::vector<paddle::framework::LoDTensor*> cpu_fetchs1;
cpu_fetchs1.push_back(&output1);
TestInference<platform::CPUPlace>(config.model_dir, cpu_feeds, cpu_fetchs1);
float* lod_data = output1.data<float>();
for (size_t i = 0; i < output1.numel(); ++i) {
EXPECT_LT(lod_data[i] - data[i], 1e-3);
EXPECT_GT(lod_data[i] - data[i], -1e-3);
}
free(outputs[0].data.data);
}
TEST(paddle_inference_api_impl, image_classification) {
int batch_size = 2;
bool use_mkldnn = false;
bool repeat = false;
NativeConfig config = GetConfig();
config.model_dir =
FLAGS_dirname + "image_classification_resnet.inference.model";
const bool is_combined = false;
std::vector<std::vector<int64_t>> feed_target_shapes =
GetFeedTargetShapes(config.model_dir, is_combined);
framework::LoDTensor input;
// Use normilized image pixels as input data,
// which should be in the range [0.0, 1.0].
feed_target_shapes[0][0] = batch_size;
framework::DDim input_dims = framework::make_ddim(feed_target_shapes[0]);
SetupTensor<float>(
&input, input_dims, static_cast<float>(0), static_cast<float>(1));
std::vector<framework::LoDTensor*> cpu_feeds;
cpu_feeds.push_back(&input);
framework::LoDTensor output1;
std::vector<framework::LoDTensor*> cpu_fetchs1;
cpu_fetchs1.push_back(&output1);
TestInference<platform::CPUPlace, false, true>(config.model_dir,
cpu_feeds,
cpu_fetchs1,
repeat,
is_combined,
use_mkldnn);
auto predictor = CreatePaddlePredictor(config);
std::vector<PaddleTensor> paddle_tensor_feeds;
paddle_tensor_feeds.push_back(LodTensorToPaddleTensor(&input));
std::vector<PaddleTensor> outputs;
ASSERT_TRUE(predictor->Run(cpu_feeds, &outputs));
ASSERT_TRUE(predictor->Run(paddle_tensor_feeds, &outputs));
ASSERT_EQ(outputs.size(), 1UL);
for (size_t i = 0; i < outputs.size(); ++i) {
size_t len = outputs[i].data.length;
float* data = static_cast<float*>(outputs[i].data.data);
for (size_t j = 0; j < len / sizeof(float); ++j) {
ASSERT_LT(data[j], 1.0);
ASSERT_GT(data[j], -1.0);
}
free(outputs[i].data.data);
size_t len = outputs[0].data.length;
float* data = static_cast<float*>(outputs[0].data.data);
float* lod_data = output1.data<float>();
for (size_t j = 0; j < len / sizeof(float); ++j) {
EXPECT_NEAR(lod_data[j], data[j], 1e-3);
}
free(data);
}
} // namespace paddle
......@@ -469,6 +469,7 @@ class RuntimeInferShapeContext : public InferShapeContext {
protected:
DDim GetDim(const std::string& name) const override {
Variable* var = scope_.FindVar(name);
PADDLE_ENFORCE_NOT_NULL(var);
if (var->IsType<LoDTensor>()) {
return var->Get<LoDTensor>().dims();
} else if (var->IsType<SelectedRows>()) {
......
......@@ -25,8 +25,10 @@ void FileReader::ReadNext(std::vector<LoDTensor> *out) {
if (out->empty()) {
return;
}
PADDLE_ENFORCE_EQ(out->size(), dims_.size());
for (size_t i = 0; i < dims_.size(); ++i) {
auto &actual = out->at(i).dims();
auto &actual = (*out)[i].dims();
auto &expect = dims_[i];
PADDLE_ENFORCE_EQ(actual.size(), expect.size());
......
......@@ -18,8 +18,8 @@ namespace paddle {
namespace framework {
struct ReAllocateVisitor {
ReAllocateVisitor(framework::Tensor* tensor, const framework::DDim& dims)
: tensor_(tensor), dims_(dims) {}
ReAllocateVisitor(const framework::DDim& dims, framework::Tensor* tensor)
: dims_(dims), tensor_(tensor) {}
template <typename T>
void operator()() const {
......@@ -34,8 +34,8 @@ struct ReAllocateVisitor {
tensor_->ShareDataWith(cpu_tensor);
}
framework::Tensor* tensor_;
framework::DDim dims_;
framework::Tensor* tensor_;
};
struct TensorCopyVisitor {
......@@ -158,6 +158,7 @@ bool SelectedRows::Set(int64_t key, const framework::Tensor& value) {
}
PADDLE_ENFORCE_EQ(value.dims()[0], static_cast<size_t>(1),
"The first dim of value should be 1.");
std::lock_guard<std::mutex> lock(*auto_grown_mutex_.get());
auto index = Index(key);
bool is_new_key = false;
if (index == -1) {
......@@ -169,7 +170,7 @@ bool SelectedRows::Set(int64_t key, const framework::Tensor& value) {
auto dims = value_->dims();
dims[0] = (dims[0] + 1) << 1;
framework::VisitDataType(framework::ToDataType(value.type()),
ReAllocateVisitor(value_.get(), dims));
ReAllocateVisitor(dims, value_.get()));
}
}
......
......@@ -15,6 +15,8 @@ limitations under the License. */
#pragma once
#include <algorithm>
#include <memory>
#include <mutex> // NOLINT
#include <utility>
#include <vector>
......@@ -46,11 +48,13 @@ class SelectedRows {
SelectedRows(const std::vector<int64_t>& rows, const int64_t& height)
: rows_(rows), height_(height) {
value_.reset(new Tensor());
auto_grown_mutex_.reset(new std::mutex);
}
SelectedRows() {
height_ = 0;
value_.reset(new Tensor());
auto_grown_mutex_.reset(new std::mutex);
}
platform::Place place() const { return value_->place(); }
......@@ -125,6 +129,7 @@ class SelectedRows {
Vector<int64_t> rows_;
std::unique_ptr<Tensor> value_{nullptr};
int64_t height_;
std::unique_ptr<std::mutex> auto_grown_mutex_{nullptr};
};
/*
......
......@@ -39,7 +39,7 @@ template <typename T>
inline const T* Tensor::data() const {
check_memory_size();
PADDLE_ENFORCE(std::is_same<T, void>::value ||
holder_->type().hash_code() == typeid(T).hash_code(),
holder_->type() == std::type_index(typeid(T)),
"Tensor holds the wrong type, it holds %s",
this->holder_->type().name());
......@@ -53,7 +53,7 @@ template <typename T>
inline T* Tensor::data() {
check_memory_size();
PADDLE_ENFORCE(std::is_same<T, void>::value ||
holder_->type().hash_code() == typeid(T).hash_code(),
holder_->type() == std::type_index(typeid(T)),
"Tensor holds the wrong type, it holds %s",
this->holder_->type().name());
return reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(holder_->ptr()) +
......
......@@ -5,14 +5,19 @@ cc_library(paddle_fluid_api
SRCS io.cc
DEPS ${FLUID_CORE_MODULES} ${GLOB_OP_LIB})
# Create static library
get_property(fluid_modules GLOBAL PROPERTY FLUID_MODULES)
cc_library(paddle_fluid DEPS ${fluid_modules})
if(WITH_CONTRIB)
set(fluid_modules "${fluid_modules}" paddle_inference_api)
endif()
# Create static library
cc_library(paddle_fluid DEPS ${fluid_modules} paddle_fluid_api)
# Create shared library
cc_library(paddle_fluid_shared SHARED
SRCS io.cc
DEPS ${fluid_modules})
DEPS ${fluid_modules} paddle_fluid_api)
set_target_properties(paddle_fluid_shared PROPERTIES OUTPUT_NAME paddle_fluid)
if(NOT APPLE)
# TODO(liuyiqun): Temporarily disable the link flag because it is not support on Mac.
......
......@@ -21,7 +21,10 @@ limitations under the License. */
#include <deque>
#include <stack>
#include <string>
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/fluid/inference/analysis/graph_traits.h"
#include "paddle/fluid/inference/analysis/node.h"
......
......@@ -44,6 +44,6 @@ TEST_F(DFG_Tester, Test) {
LOG(INFO) << graph.nodes.size();
}
} // analysis
} // inference
} // paddle
}; // namespace analysis
}; // namespace inference
}; // namespace paddle
......@@ -12,9 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h"
#include <string>
#include <vector>
#include "paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h"
namespace paddle {
namespace inference {
namespace analysis {
......
......@@ -19,6 +19,8 @@
#pragma once
#include <string>
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/inference/analysis/data_flow_graph.h"
#include "paddle/fluid/inference/analysis/pass.h"
......
......@@ -32,6 +32,6 @@ TEST_F(DFG_Tester, Init) {
LOG(INFO) << '\n' << graph.DotString();
}
} // analysis
} // inference
} // paddle
} // namespace analysis
} // namespace inference
} // namespace paddle
......@@ -50,7 +50,7 @@ struct DataTypeNamer {
return dic_.at(x);
}
const std::string &repr(size_t &hash) const {
const std::string &repr(size_t &hash) const { // NOLINT
PADDLE_ENFORCE(dic_.count(hash), "unknown type for representation");
return dic_.at(hash);
}
......@@ -62,7 +62,9 @@ struct DataTypeNamer {
SET_TYPE(float);
}
std::unordered_map<decltype(typeid(int).hash_code()), std::string> dic_;
std::unordered_map<decltype(typeid(int).hash_code()), // NOLINT
std::string>
dic_;
};
#undef SET_TYPE
......
......@@ -16,6 +16,7 @@ limitations under the License. */
#include <glog/logging.h>
#include <iosfwd>
#include <string>
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/inference/analysis/data_flow_graph.h"
......
......@@ -18,6 +18,8 @@ limitations under the License. */
#pragma once
#include <vector>
#include "paddle/fluid/inference/analysis/data_flow_graph.h"
#include "paddle/fluid/inference/analysis/node.h"
......
......@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include <string>
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/inference/analysis/data_flow_graph.h"
#include "paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h"
......
......@@ -19,6 +19,9 @@ limitations under the License. */
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/inference/analysis/helper.h"
......@@ -58,7 +61,7 @@ class TRTConvertValidation {
public:
TRTConvertValidation() = delete;
TRTConvertValidation(int batch_size, int workspace_size = 1 << 10) {
explicit TRTConvertValidation(int batch_size, int workspace_size = 1024) {
// create engine.
engine_.reset(new TensorRTEngine(10, 1 << 10, &stream_));
engine_->InitNetwork();
......
......@@ -131,6 +131,20 @@ void* TensorRTEngine::GetOutputInGPU(const std::string& name) {
return buffer(name).buffer;
}
void TensorRTEngine::GetOutputInGPU(const std::string& name, void* dst,
size_t max_size) {
// determine data size
auto it = buffer_sizes_.find(name);
PADDLE_ENFORCE(it != buffer_sizes_.end());
PADDLE_ENFORCE_GT(it->second, 0);
PADDLE_ENFORCE_GE(max_size, it->second);
auto& buf = buffer(name);
PADDLE_ENFORCE_NOT_NULL(buf.buffer, "buffer should be allocated before");
PADDLE_ENFORCE_EQ(cudaMemcpyAsync(dst, buf.buffer, it->second,
cudaMemcpyDeviceToDevice, *stream_),
0);
}
void TensorRTEngine::GetOutputInCPU(const std::string& name, void* dst,
size_t max_size) {
// determine data size
......@@ -152,7 +166,7 @@ Buffer& TensorRTEngine::buffer(const std::string& name) {
return buffers_[slot_offset];
}
void TensorRTEngine::SetInputFromCPU(const std::string& name, void* data,
void TensorRTEngine::SetInputFromCPU(const std::string& name, const void* data,
size_t size) {
auto& buf = buffer(name);
PADDLE_ENFORCE_NOT_NULL(buf.buffer);
......@@ -162,6 +176,16 @@ void TensorRTEngine::SetInputFromCPU(const std::string& name, void* data,
cudaMemcpyHostToDevice, *stream_));
}
void TensorRTEngine::SetInputFromGPU(const std::string& name, const void* data,
size_t size) {
auto& buf = buffer(name);
PADDLE_ENFORCE_NOT_NULL(buf.buffer);
PADDLE_ENFORCE_LE(size, buf.max_size, "buffer is too small");
PADDLE_ENFORCE(buf.device == DeviceType::GPU);
PADDLE_ENFORCE_EQ(0, cudaMemcpyAsync(buf.buffer, data, size,
cudaMemcpyDeviceToDevice, *stream_));
}
void TensorRTEngine::SetITensor(const std::string& name,
nvinfer1::ITensor* tensor) {
PADDLE_ENFORCE(tensor != nullptr);
......
......@@ -92,13 +92,15 @@ class TensorRTEngine : public EngineBase {
cudaStream_t* stream() { return stream_; }
// Fill an input from CPU memory with name and size.
void SetInputFromCPU(const std::string& name, void* data, size_t size);
void SetInputFromCPU(const std::string& name, const void* data, size_t size);
// TODO(Superjomn) is this method necessary given that buffer(xxx) can be
// accessed directly. Fill an input from GPU memory with name and size.
void SetInputFromGPU(const std::string& name, void* data, size_t size);
void SetInputFromGPU(const std::string& name, const void* data, size_t size);
// Get an output called name, the output of tensorrt is in GPU, so this method
// will just return the output's GPU memory address.
// Return the output's GPU memory address without copy.
void* GetOutputInGPU(const std::string& name);
// Copy data into dst inside the GPU device.
void GetOutputInGPU(const std::string& name, void* dst, size_t max_size);
// LOW EFFICENCY! Get output to CPU, this will trigger a memory copy from GPU
// to CPU.
void GetOutputInCPU(const std::string& name, void* dst, size_t max_size);
......
......@@ -168,6 +168,8 @@ function(op_library TARGET)
file(APPEND ${pybind_file} "USE_OP(relu);\n")
elseif(${TARGET} STREQUAL "reduce")
file(APPEND ${pybind_file} "USE_OP(reduce_sum);\n")
elseif(${TARGET} STREQUAL "fake_dequantize")
file(APPEND ${pybind_file} "USE_OP(fake_dequantize_max_abs);\n")
else()
file(APPEND ${pybind_file} "USE_OP(${TARGET});\n")
endif()
......@@ -223,6 +225,11 @@ op_library(cross_entropy_op DEPS cross_entropy)
op_library(softmax_with_cross_entropy_op DEPS cross_entropy softmax)
op_library(softmax_op DEPS softmax)
op_library(sequence_softmax_op DEPS softmax)
if (WITH_GPU AND TENSORRT_FOUND)
op_library(tensorrt_engine_op DEPS tensorrt_engine)
else()
set(DEPS_OPS ${DEPS_OPS} tensorrt_engine_op)
endif()
op_library(sum_op DEPS selected_rows_functor)
op_library(sgd_op DEPS selected_rows_functor)
op_library(print_op DEPS lod_tensor)
......
......@@ -89,4 +89,5 @@ REGISTER_OP_CPU_KERNEL(cast, ops::CastOpKernel<CPU, float>,
ops::CastOpKernel<CPU, int>,
ops::CastOpKernel<CPU, int64_t>,
ops::CastOpKernel<CPU, bool>,
ops::CastOpKernel<CPU, uint8_t>,
ops::CastOpKernel<CPU, paddle::platform::float16>);
......@@ -21,5 +21,5 @@ using CastOpKernel =
REGISTER_OP_CUDA_KERNEL(cast, CastOpKernel<float>, CastOpKernel<double>,
CastOpKernel<int>, CastOpKernel<int64_t>,
CastOpKernel<bool>,
CastOpKernel<bool>, CastOpKernel<uint8_t>,
CastOpKernel<paddle::platform::float16>);
if(WITH_DISTRIBUTE)
grpc_library(sendrecvop_grpc SRCS bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc
grpc_server.cc variable_response.cc PROTO send_recv.proto DEPS lod_tensor selected_rows)
request_handler_impl.cc rpc_server.cc grpc_server.cc variable_response.cc PROTO send_recv.proto DEPS lod_tensor
selected_rows memory)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set_source_files_properties(serde_test.cc grpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(serde_test SRCS serde_test.cc variable_response.cc DEPS grpc++_unsecure grpc_unsecure gpr
......
......@@ -205,6 +205,8 @@ void RPCClient::AsyncSendFetchBarrier(const std::string& ep, int64_t time_out) {
}
bool RPCClient::Wait() {
VLOG(3) << "RPCClient begin Wait()"
<< " req_count_:" << req_count_;
if (req_count_ <= 0) {
return true;
}
......
......@@ -14,6 +14,8 @@ limitations under the License. */
#pragma once
#include <map>
#include <set>
#include <string>
#include <thread> // NOLINT
#include <utility>
......@@ -28,6 +30,8 @@ limitations under the License. */
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/detail/grpc_service.h"
#include "paddle/fluid/operators/detail/request_handler.h"
#include "paddle/fluid/operators/detail/rpc_server.h"
#include "paddle/fluid/operators/detail/send_recv.grpc.pb.h"
#include "paddle/fluid/operators/detail/send_recv.pb.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
......@@ -37,106 +41,48 @@ namespace paddle {
namespace operators {
namespace detail {
typedef std::pair<std::string, std::shared_ptr<VariableResponse>>
ReceivedMessage;
typedef framework::BlockingQueue<ReceivedMessage> ReceivedQueue;
typedef std::pair<std::string, sendrecv::VariableMessage> MessageWithName;
class RequestBase;
class AsyncGRPCServer final {
class AsyncGRPCServer final : public RPCServer {
public:
explicit AsyncGRPCServer(const std::string &address, bool sync_mode)
: address_(address), sync_mode_(sync_mode), ready_(0) {}
~AsyncGRPCServer() {}
void WaitServerReady();
void RunSyncUpdate();
// functions to sync server barrier status.
void WaitCond(int cond);
void SetCond(int cond);
void WaitClientGet(int count);
void SetScope(framework::Scope *scope) { scope_ = scope; }
void SetDevCtx(const platform::DeviceContext *dev_ctx) { dev_ctx_ = dev_ctx; }
void SetProgram(framework::ProgramDesc *program) { program_ = program; }
void SetExecutor(framework::Executor *executor) { executor_ = executor; }
void SetPrefetchPreparedCtx(
std::unique_ptr<framework::ExecutorPrepareContext> prepared) {
prefetch_ctx_.reset(prepared.release());
}
int GetSelectedPort() const { return selected_port_; }
const ReceivedMessage Get() { return this->var_recv_queue_.Pop(); }
explicit AsyncGRPCServer(const std::string& address, int client_num)
: RPCServer(address, client_num), ready_(0) {}
void Push(const std::string &msg_name) {
this->var_recv_queue_.Push(std::make_pair(msg_name, nullptr));
}
virtual ~AsyncGRPCServer() {}
void WaitServerReady() override;
void StartServer() override;
void ShutDown();
private:
void HandleRequest(
::grpc::ServerCompletionQueue* cq, const std::string& rpc_name,
std::function<void(const std::string&, int)> TryToRegisterNewOne);
protected:
void HandleRequest(::grpc::ServerCompletionQueue *cq,
const std::string &cq_name,
std::function<void(int)> TryToRegisterNewOne);
void TryToRegisterNewSendOne(int req_id);
void TryToRegisterNewGetOne(int req_id);
void TryToRegisterNewPrefetchOne(int req_id);
void TryToRegisterNewOne(const std::string& rpc_name, int req_id);
void ShutdownQueue();
void ShutDownImpl() override;
private:
static const int kSendReqsBufSize = 100;
static const int kGetReqsBufSize = 100;
static const int kPrefetchReqsBufSize = 10;
static const int kRequestBufSize = 100;
std::mutex cq_mutex_;
volatile bool is_shut_down_ = false;
std::unique_ptr<::grpc::ServerCompletionQueue> cq_send_;
std::unique_ptr<::grpc::ServerCompletionQueue> cq_get_;
std::unique_ptr<::grpc::ServerCompletionQueue> cq_prefetch_;
RequestBase *send_reqs_[kSendReqsBufSize];
RequestBase *get_reqs_[kGetReqsBufSize];
RequestBase *prefetch_reqs_[kPrefetchReqsBufSize];
GrpcService::AsyncService service_;
std::unique_ptr<::grpc::Server> server_;
std::string address_;
const bool sync_mode_;
framework::Scope *scope_;
const platform::DeviceContext *dev_ctx_;
// received variable from RPC, operators fetch variable from this queue.
framework::BlockingQueue<MessageWithName> var_get_queue_;
// client send variable to this queue.
ReceivedQueue var_recv_queue_;
// condition of the sub program
std::mutex barrier_mutex_;
mutable int barrier_cond_step_;
std::condition_variable barrier_condition_;
std::vector<std::unique_ptr<std::thread>> t_sends_;
std::vector<std::unique_ptr<std::thread>> t_gets_;
std::vector<std::unique_ptr<std::thread>> t_prefetchs_;
std::unique_ptr<std::thread> t_prefetch_;
std::unique_ptr<framework::ExecutorPrepareContext> prefetch_ctx_;
framework::ProgramDesc *program_;
framework::Executor *executor_;
int selected_port_;
std::mutex mutex_ready_;
std::condition_variable condition_ready_;
int ready_;
std::map<std::string, std::unique_ptr<::grpc::ServerCompletionQueue>> rpc_cq_;
std::map<std::string, std::vector<std::unique_ptr<std::thread>>> rpc_threads_;
std::map<std::string, std::vector<RequestBase*>> rpc_reqs_;
};
}; // namespace detail
......
......@@ -24,13 +24,16 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/detail/request_handler_impl.h"
namespace framework = paddle::framework;
namespace platform = paddle::platform;
namespace detail = paddle::operators::detail;
USE_OP(lookup_table);
std::unique_ptr<detail::AsyncGRPCServer> rpc_service_;
std::unique_ptr<detail::AsyncGRPCServer> g_rpc_service;
std::unique_ptr<detail::RequestHandler> g_req_handler;
framework::BlockDesc* AppendPrefetchBlcok(framework::ProgramDesc* program) {
auto root_block = program->MutableBlock(0);
......@@ -88,8 +91,7 @@ void InitTensorsOnServer(framework::Scope* scope, platform::CPUPlace* place,
}
}
void StartServer(const std::string& endpoint) {
rpc_service_.reset(new detail::AsyncGRPCServer(endpoint, true));
void StartServer() {
framework::ProgramDesc program;
framework::Scope scope;
platform::CPUPlace place;
......@@ -99,42 +101,59 @@ void StartServer(const std::string& endpoint) {
auto prepared = exe.Prepare(program, block->ID());
InitTensorsOnServer(&scope, &place, 10);
rpc_service_->SetProgram(&program);
rpc_service_->SetPrefetchPreparedCtx(std::move(prepared));
rpc_service_->SetDevCtx(&ctx);
rpc_service_->SetScope(&scope);
rpc_service_->SetExecutor(&exe);
g_req_handler->SetProgram(&program);
g_req_handler->SetPrefetchPreparedCtx(std::move(prepared));
g_req_handler->SetDevCtx(&ctx);
g_req_handler->SetScope(&scope);
g_req_handler->SetExecutor(&exe);
g_rpc_service->RegisterRPC(detail::kRequestPrefetch, g_req_handler.get());
g_req_handler->SetRPCServer(g_rpc_service.get());
std::thread server_thread(
std::bind(&detail::AsyncGRPCServer::StartServer, g_rpc_service.get()));
rpc_service_->RunSyncUpdate();
// FIXME(gongwb): don't use hard time.
sleep(10);
LOG(INFO) << "got nccl id and stop server...";
g_rpc_service->ShutDown();
server_thread.join();
}
TEST(PREFETCH, DISABLED_CPU) {
// start up a server instance backend
std::thread server_thread(StartServer, "127.0.0.1:8889");
sleep(2);
TEST(PREFETCH, CPU) {
g_req_handler.reset(new detail::RequestPrefetchHandler(true));
g_rpc_service.reset(new detail::AsyncGRPCServer("127.0.0.1:0", 1));
std::thread server_thread(StartServer);
g_rpc_service->WaitServerReady();
detail::RPCClient client;
int port = g_rpc_service->GetSelectedPort();
std::string ep = paddle::string::Sprintf("127.0.0.1:%d", port);
framework::Scope scope;
platform::CPUPlace place;
platform::CPUDeviceContext ctx(place);
// create var on local scope
int64_t rows_numel = 5;
InitTensorsOnClient(&scope, &place, rows_numel);
std::string in_var_name("ids");
std::string out_var_name("out");
auto client = detail::RPCClient::GetInstance();
client->AsyncPrefetchVariable("127.0.0.1:8889", ctx, scope, in_var_name,
out_var_name);
client->Wait();
auto var = scope.Var(out_var_name);
auto value = var->GetMutable<framework::SelectedRows>()->value();
auto ptr = value.mutable_data<float>(place);
rpc_service_->ShutDown();
server_thread.join();
rpc_service_.reset(nullptr);
for (int64_t i = 0; i < rows_numel; ++i) {
EXPECT_EQ(ptr[0 + i * value.dims()[1]], static_cast<float>(i * 2));
{
// create var on local scope
int64_t rows_numel = 5;
InitTensorsOnClient(&scope, &place, rows_numel);
std::string in_var_name("ids");
std::string out_var_name("out");
client.AsyncPrefetchVariable(ep, ctx, scope, in_var_name, out_var_name);
client.Wait();
auto var = scope.Var(out_var_name);
auto value = var->GetMutable<framework::SelectedRows>()->value();
auto ptr = value.mutable_data<float>(place);
for (int64_t i = 0; i < rows_numel; ++i) {
EXPECT_EQ(ptr[0 + i * value.dims()[1]], static_cast<float>(i * 2));
}
}
server_thread.join();
LOG(INFO) << "begin reset";
g_rpc_service.reset(nullptr);
g_req_handler.reset(nullptr);
}
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <time.h>
#include <functional>
#include <string>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
namespace paddle {
namespace operators {
namespace detail {
constexpr char kRequestSend[] = "RequestSend";
constexpr char kRequestGet[] = "RequestGet";
constexpr char kRequestPrefetch[] = "RequestPrefetch";
class RPCServer;
class RequestHandler {
public:
explicit RequestHandler(bool sync_mode)
: sync_mode_(sync_mode),
dev_ctx_(nullptr),
executor_(nullptr),
scope_(nullptr),
program_(nullptr),
rpc_server_(nullptr) {}
virtual ~RequestHandler() {}
// Set attributes.
void SetScope(framework::Scope* scope) { scope_ = scope; }
void SetDevCtx(const platform::DeviceContext* dev_ctx) { dev_ctx_ = dev_ctx; }
void SetProgram(framework::ProgramDesc* program) { program_ = program; }
void SetExecutor(framework::Executor* executor) { executor_ = executor; }
void SetPrefetchPreparedCtx(
std::unique_ptr<framework::ExecutorPrepareContext> prepared) {
prefetch_ctx_.reset(prepared.release());
}
// Used for async.
void SetGradToPreparedCtx(
std::unordered_map<
std::string, std::shared_ptr<framework::ExecutorPrepareContext>>* g) {
grad_to_prepared_ctx_ = g;
}
void SetRPCServer(RPCServer* rpc_server) { rpc_server_ = rpc_server; }
// Get attributes.
bool sync_mode() { return sync_mode_; }
framework::Scope* scope() { return scope_; }
const platform::DeviceContext* dev_ctx() { return dev_ctx_; }
framework::ExecutorPrepareContext* prefetch_ctx() {
return prefetch_ctx_.get();
}
framework::ProgramDesc* program() { return program_; }
framework::Executor* executor() { return executor_; }
std::vector<framework::Variable*>& sparse_vars() { return sparse_vars_; }
// This function processes user's rpc request.
// The implemention is in request_handler_impl.
// example:
// std::string varname = request_.varname();
//
// auto scope = request_handler_->scope();
// auto invar = scope->FindVar(varname);
// framework::Variable* outvar = nullptr;
//
// request_handler_->Handle(varname, scope, invar, &outvar);
// if (outvar) {
// SerializeToByteBuffer(varname, outvar,
// *request_handler_->dev_ctx(), &reply_);
// }
virtual bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var,
framework::Variable** outvar) = 0;
protected:
const bool sync_mode_;
const platform::DeviceContext* dev_ctx_;
framework::Executor* executor_;
framework::Scope* scope_;
framework::ProgramDesc* program_;
std::unique_ptr<framework::ExecutorPrepareContext> prefetch_ctx_;
// Used for async.
std::unordered_map<std::string,
std::shared_ptr<framework::ExecutorPrepareContext>>*
grad_to_prepared_ctx_;
// Record received sparse variables, so that
// we could reset those after execute optimize program
std::vector<framework::Variable*> sparse_vars_;
RPCServer* rpc_server_;
std::mutex sparse_var_mutex_;
};
} // namespace detail
} // namespace operators
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include <string>
#include <vector>
#include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/detail/request_handler_impl.h"
#include "paddle/fluid/operators/detail/rpc_server.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
#include "paddle/fluid/operators/detail/variable_response.h"
namespace paddle {
namespace operators {
namespace detail {
bool RequestSendHandler::Handle(const std::string& varname,
framework::Scope* scope,
framework::Variable* invar,
framework::Variable** outvar) {
VLOG(4) << "RequestSendHandler:" << varname;
// Async
if (!sync_mode_) {
try {
executor_->RunPreparedContext((*grad_to_prepared_ctx_)[varname].get(),
scope);
} catch (std::exception& e) {
LOG(ERROR) << "async: run sub program error " << e.what();
return false;
}
return true;
}
// Sync
if (varname == BATCH_BARRIER_MESSAGE) {
VLOG(3) << "sync: recv batch barrier message";
rpc_server_->IncreaseBatchBarrier(kRequestSend);
} else {
VLOG(3) << "sync: received var_name: " << varname;
if (sync_mode_) {
rpc_server_->WaitCond(kRequestSend);
}
if (invar == nullptr) {
LOG(ERROR) << "sync: Can not find server side var: " << varname;
PADDLE_THROW("sync: Can not find server side var");
return false;
}
if (invar->IsType<framework::SelectedRows>()) {
std::unique_lock<std::mutex> lock(sparse_var_mutex_);
sparse_vars_.push_back(invar);
}
}
return true;
}
bool RequestGetHandler::Handle(const std::string& varname,
framework::Scope* scope,
framework::Variable* invar,
framework::Variable** outvar) {
VLOG(4) << "RequestGetHandler:" << varname;
if (varname != FETCH_BARRIER_MESSAGE) {
if (sync_mode_) {
rpc_server_->WaitCond(kRequestGet);
}
*outvar = scope_->FindVar(varname);
return true;
}
// FETCH_BARRIER_MESSAGE
if (sync_mode_) {
VLOG(3) << "sync: recv fetch barrier message";
rpc_server_->IncreaseBatchBarrier(kRequestGet);
}
return true;
}
bool RequestPrefetchHandler::Handle(const std::string& varname,
framework::Scope* scope,
framework::Variable* invar,
framework::Variable** outvar) {
VLOG(4) << "RequestPrefetchHandler " << varname;
auto var_desc = program_->Block(0).FindVar(varname);
*outvar = scope->FindVar(varname);
InitializeVariable(*outvar, var_desc->GetType());
executor_->RunPreparedContext(prefetch_ctx_.get(), scope);
return true;
}
} // namespace detail
} // namespace operators
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <time.h>
#include <functional>
#include <string>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/detail/request_handler.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
namespace paddle {
namespace operators {
namespace detail {
class RequestSendHandler final : public RequestHandler {
public:
explicit RequestSendHandler(bool sync_mode) : RequestHandler(sync_mode) {}
virtual ~RequestSendHandler() {}
bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar) override;
};
class RequestGetHandler final : public RequestHandler {
public:
explicit RequestGetHandler(bool sync_mode) : RequestHandler(sync_mode) {}
virtual ~RequestGetHandler() {}
bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar) override;
};
class RequestPrefetchHandler final : public RequestHandler {
public:
explicit RequestPrefetchHandler(bool sync_mode) : RequestHandler(sync_mode) {}
virtual ~RequestPrefetchHandler() {}
bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar) override;
};
} // namespace detail
} // namespace operators
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <fstream>
#include <iostream>
#include <limits>
#include <string>
#include "paddle/fluid/operators/detail/rpc_server.h"
namespace paddle {
namespace operators {
namespace detail {
void RPCServer::ShutDown() {
LOG(INFO) << "RPCServer ShutDown ";
ShutDownImpl();
exit_flag_ = true;
barrier_cond_.notify_all();
rpc_cond_.notify_all();
}
void RPCServer::SavePort() const {
auto file_path = string::Sprintf("/tmp/paddle.%d.port", ::getpid());
std::ofstream port_file;
port_file.open(file_path);
port_file << selected_port_;
port_file.close();
VLOG(4) << "selected port written to " << file_path;
}
void RPCServer::WaitBarrier(const std::string& rpc_name) {
std::unique_lock<std::mutex> lock(this->mutex_);
barrier_cond_.wait(lock, [=] {
return (barrier_counter_[rpc_name] >= client_num_ || exit_flag_.load());
});
VLOG(3) << "batch_barrier_:" << barrier_counter_[rpc_name];
}
void RPCServer::IncreaseBatchBarrier(const std::string rpc_name) {
VLOG(3) << "RPCServer begin IncreaseBatchBarrier " << rpc_name;
int b = 0;
{
std::unique_lock<std::mutex> lock(mutex_);
b = ++barrier_counter_[rpc_name];
}
VLOG(3) << "RPCServer IncreaseBatchBarrier " << rpc_name
<< ", barrier_count:" << b << ", fan_in" << client_num_;
if (b >= client_num_) {
barrier_cond_.notify_all();
}
}
void RPCServer::ResetBarrierCounter() {
VLOG(3) << "RPCServer ResetBarrierCounter ";
std::unique_lock<std::mutex> lock(mutex_);
for (auto& t : barrier_counter_) {
t.second = 0;
}
}
void RPCServer::RegisterRPC(const std::string& rpc_name,
RequestHandler* handler, int thread_num) {
rpc_call_map_[rpc_name] = handler;
rpc_thread_num_[rpc_name] = thread_num;
static int cond = -1;
rpc_cond_map_[rpc_name] = ++cond;
VLOG(4) << "RegisterRPC rpc_name:" << rpc_name << ", handler:" << handler
<< ", cond:" << rpc_cond_map_[rpc_name];
}
void RPCServer::SetCond(const std::string& rpc_name) {
VLOG(3) << "RPCServer SetCond " << rpc_name;
{
std::unique_lock<std::mutex> lock(mutex_);
cur_cond_ = rpc_cond_map_[rpc_name];
}
rpc_cond_.notify_all();
}
void RPCServer::WaitCond(const std::string& rpc_name) {
VLOG(3) << "RPCServer WaitCond " << rpc_name;
int cond = 0;
{
std::unique_lock<std::mutex> lock(mutex_);
cond = rpc_cond_map_[rpc_name];
}
std::unique_lock<std::mutex> lock(mutex_);
rpc_cond_.wait(
lock, [=] { return (cur_cond_.load() == cond || exit_flag_.load()); });
}
} // namespace detail
} // namespace operators
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <set>
#include <string>
#include <thread> // NOLINT
#include <utility>
#include <vector>
#include "paddle/fluid/operators/detail/request_handler.h"
namespace paddle {
namespace operators {
namespace detail {
class RPCServer {
public:
explicit RPCServer(const std::string& address, int client_num)
: cur_cond_(0),
bind_address_(address),
exit_flag_(false),
selected_port_(0),
client_num_(client_num) {}
virtual ~RPCServer() {}
virtual void StartServer() = 0;
virtual void WaitServerReady() = 0;
void ShutDown();
bool IsExit() { return exit_flag_.load(); }
int GetSelectedPort() const { return selected_port_; }
void SavePort() const;
// RegisterRPC, register the rpc method name to a handler
// class, and auto generate a condition id for this call
// to be used for the barrier.
void RegisterRPC(const std::string& rpc_name, RequestHandler* handler,
int thread_num = 5);
// Wait util all the clients have reached the barrier for one
// rpc method. This function should be called in the
// RequestHandler if you want to run the server/client in a
// synchronous mode.
void WaitBarrier(const std::string& rpc_name);
void SetCond(const std::string& rpc_name);
void WaitCond(const std::string& rpc_name);
void IncreaseBatchBarrier(const std::string rpc_name);
void ResetBarrierCounter();
protected:
virtual void ShutDownImpl() = 0;
private:
std::mutex mutex_;
std::unordered_map<std::string, int> barrier_counter_;
std::condition_variable barrier_cond_;
std::unordered_map<std::string, int> rpc_cond_map_;
std::atomic<int> cur_cond_;
std::condition_variable rpc_cond_;
protected:
std::string bind_address_;
std::atomic<int> exit_flag_;
int selected_port_;
const int client_num_;
std::unordered_map<std::string, RequestHandler*> rpc_call_map_;
std::unordered_map<std::string, int> rpc_thread_num_;
friend class RequestHandler;
};
}; // namespace detail
}; // namespace operators
}; // namespace paddle
......@@ -67,8 +67,8 @@ class VariableResponse {
framework::Scope* GetMutableLocalScope() const { return local_scope_; }
inline std::string Varname() { return meta_.varname(); }
inline std::string OutVarname() { return meta_.out_varname(); }
inline std::string Varname() const { return meta_.varname(); }
inline std::string OutVarname() const { return meta_.out_varname(); }
// should call parse first.
framework::Variable* GetVar() {
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/fake_dequantize_op.h"
#include <string>
namespace paddle {
namespace operators {
class FakeDequantizeMaxAbsOp : public framework::OperatorWithKernel {
public:
FakeDequantizeMaxAbsOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorWithKernel(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of FakeDequantizeMaxAbsOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of FakeDequantizeMaxAbsOp should not be null.");
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
ctx->ShareLoD("X", /*->*/ "Out");
}
};
class FakeDequantizeMaxAbsOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(Tensor) The input with float-32/64 type is the "
"low precision tensor.");
AddOutput("Out",
"(Tensor) The output is the dequantized high "
"precision tensor.");
AddAttr<int>("num_bits",
"(int) `num_bits` is the quantization level bits, "
"such as 2, 5, 8.");
AddAttr<float>("scale",
"(float) The maximum absolute value of low precision tensor."
"It is usually calculated by the fake_quantize_max_abs_op.");
AddComment(R"DOC(
FakeDequantizeMaxAbsOp operator.
This calculation is an opposite operation of FakeQuantizeMaxAbsOp:
$$Out = \frac{scale*X}{2^{num_bits} - 1}$$
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
using CPU = paddle::platform::CPUDeviceContext;
REGISTER_OPERATOR(fake_dequantize_max_abs, ops::FakeDequantizeMaxAbsOp,
ops::FakeDequantizeMaxAbsOpMaker,
paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL(fake_dequantize_max_abs,
ops::FakeDequantizeMaxAbsKernel<CPU, float>,
ops::FakeDequantizeMaxAbsKernel<CPU, double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/fake_dequantize_op.h"
namespace ops = paddle::operators;
using CUDA = paddle::platform::CUDADeviceContext;
REGISTER_OP_CUDA_KERNEL(fake_dequantize_max_abs,
ops::FakeDequantizeMaxAbsKernel<CUDA, float>,
ops::FakeDequantizeMaxAbsKernel<CUDA, double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class FakeDequantizeMaxAbsKernel : public framework::OpKernel<T> {
public:
virtual void Compute(const framework::ExecutionContext& ctx) const {
auto* in = ctx.Input<framework::Tensor>("X");
auto* out = ctx.Output<framework::Tensor>("Out");
out->mutable_data<T>(in->place());
int num_bits = ctx.Attr<int>("num_bits");
T scale = static_cast<T>(ctx.Attr<float>("scale"));
int range = std::pow(2, num_bits) - 1;
auto eigen_out = framework::EigenVector<T>::Flatten(*out);
auto eigen_in = framework::EigenVector<T>::Flatten(*in);
auto& dev = *ctx.template device_context<DeviceContext>().eigen_device();
eigen_out.device(dev) = (scale / range) * eigen_in;
}
};
} // namespace operators
} // namespace paddle
......@@ -23,6 +23,7 @@ limitations under the License. */
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/operators/detail/grpc_server.h"
#include "paddle/fluid/operators/detail/request_handler_impl.h"
#include "paddle/fluid/platform/nccl_helper.h"
namespace paddle {
......@@ -75,19 +76,23 @@ class GenNCCLIdOp : public framework::OperatorBase {
// NOTE: Can not use unique_ptr here because the default
// deleter will call GRPC Server's base class's dtor and
// that will cause a wired crash.
detail::AsyncGRPCServer rpc_service(endpoint, true);
detail::RequestSendHandler rpc_h(true);
detail::AsyncGRPCServer rpc_service(endpoint, 1);
rpc_service.RegisterRPC(detail::kRequestSend, &rpc_h);
rpc_h.SetRPCServer(&rpc_service);
framework::ProgramDesc empty_program;
framework::Executor executor(dev_ctx.GetPlace());
rpc_service.SetScope(scope);
rpc_service.SetDevCtx(&dev_ctx);
rpc_service.SetProgram(&empty_program);
rpc_service.SetExecutor(&executor);
rpc_h.SetScope(scope);
rpc_h.SetDevCtx(&dev_ctx);
rpc_h.SetProgram(&empty_program);
rpc_h.SetExecutor(&executor);
std::thread server_thread(
std::bind(&detail::AsyncGRPCServer::RunSyncUpdate, &rpc_service));
rpc_service.SetCond(0);
std::bind(&detail::AsyncGRPCServer::StartServer, &rpc_service));
rpc_service.SetCond(detail::kRequestSend);
VLOG(3) << "start getting nccl id from trainer 0...";
auto recv = rpc_service.Get();
rpc_service.WaitBarrier(detail::kRequestSend);
VLOG(3) << "got nccl id and stop server...";
rpc_service.ShutDown();
VLOG(3) << "rpc server stopped";
......
......@@ -19,14 +19,16 @@ limitations under the License. */
#include <thread> // NOLINT
#include <vector>
#include "paddle/fluid/operators/detail/grpc_server.h"
#include "paddle/fluid/operators/detail/request_handler_impl.h"
#include "paddle/fluid/operators/listen_and_serv_op.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
namespace operators {
void RunServer(std::shared_ptr<detail::AsyncGRPCServer> service) {
service->RunSyncUpdate();
void RunServer(std::shared_ptr<detail::RPCServer> service) {
service->StartServer();
VLOG(4) << "RunServer thread end";
}
static void split(const std::string &str, char sep,
......@@ -67,8 +69,6 @@ static void ParallelExecuteBlocks(
for (size_t i = 0; i < fs.size(); ++i) fs[i].wait();
}
std::atomic_int ListenAndServOp::selected_port_{0};
ListenAndServOp::ListenAndServOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
......@@ -78,7 +78,6 @@ ListenAndServOp::ListenAndServOp(const std::string &type,
ListenAndServOp::~ListenAndServOp() { Stop(); }
void ListenAndServOp::Stop() {
rpc_service_->Push(LISTEN_TERMINATE_MESSAGE);
rpc_service_->ShutDown();
server_thread_->join();
auto file_path = string::Sprintf("/tmp/paddle.%d.port", ::getpid());
......@@ -87,26 +86,13 @@ void ListenAndServOp::Stop() {
void ListenAndServOp::SavePort() const {
// NOTE: default write file to /tmp/paddle.selected_port
selected_port_ = rpc_service_->GetSelectedPort();
auto file_path = string::Sprintf("/tmp/paddle.%d.port", ::getpid());
std::ofstream port_file;
port_file.open(file_path);
port_file << selected_port_.load();
port_file.close();
VLOG(4) << "selected port written to " << file_path;
}
void ListenAndServOp::WaitServerReady() {
while (selected_port_.load() == 0) {
}
rpc_service_->SavePort();
}
void ListenAndServOp::RunSyncLoop(framework::Executor *executor,
framework::ProgramDesc *program,
framework::Scope *recv_scope,
framework::BlockDesc *prefetch_block) const {
auto fan_in = Attr<int>("Fanin");
size_t num_blocks = program->Size();
PADDLE_ENFORCE_GE(num_blocks, 2,
"server program should have at least 2 blocks");
......@@ -121,49 +107,24 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor,
optimize_prepared.begin(),
std::shared_ptr<framework::ExecutorPrepareContext>(nullptr));
bool exit_flag = false;
rpc_service_->ResetBarrierCounter();
// Record received sparse variables, so that
// we could reset those after execute optimize program
std::vector<framework::Variable *> sparse_vars;
while (!exit_flag && !SignalHandler::IsProgramExit()) {
while (true) {
// Get from multiple trainers, we don't care about the order in which
// the gradients arrives, just add suffix 0~n and merge the gradient.
rpc_service_->SetCond(0);
size_t recv_var_cnt = 0;
int batch_barrier = 0;
while (batch_barrier != fan_in) {
const detail::ReceivedMessage v = rpc_service_->Get();
auto recv_var_name = v.first;
if (recv_var_name == LISTEN_TERMINATE_MESSAGE) {
LOG(INFO) << "received terminate message and exit";
exit_flag = true;
break;
} else if (recv_var_name == BATCH_BARRIER_MESSAGE) {
VLOG(3) << "recv batch barrier message";
batch_barrier++;
continue;
} else {
VLOG(3) << "received grad: " << recv_var_name;
recv_var_cnt++;
auto var = v.second->GetVar();
if (var == nullptr) {
LOG(ERROR) << "Can not find server side var: " << recv_var_name;
PADDLE_THROW("Can not find server side var");
}
if (var->IsType<framework::SelectedRows>()) {
sparse_vars.push_back(var);
}
}
}
if (exit_flag) {
rpc_service_->SetCond(1);
rpc_service_->ShutDown();
rpc_service_->SetCond(detail::kRequestSend);
rpc_service_->WaitBarrier(detail::kRequestSend);
if (rpc_service_->IsExit()) {
LOG(WARNING) << "get exit!rpc_processor break!";
rpc_service_->SetCond(detail::kRequestGet);
break;
}
// NOTE: if is_gpu_place, CUDA kernels are launched by multiple threads
// and this will still work.
// The optimize blocks which have the same parent ID would run parallel
// TODO(Yancey1989): need to use ParallelExecutor for future
int32_t last_parent_blkid = program->Block(1).Parent();
......@@ -194,52 +155,18 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor,
var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear();
}
rpc_service_->SetCond(1);
// FIXME(typhoonzero): use another condition to sync wait clients get.
rpc_service_->WaitClientGet(fan_in);
sparse_vars.clear();
rpc_service_->SetCond(detail::kRequestGet);
rpc_service_->WaitBarrier(detail::kRequestGet);
rpc_service_->ResetBarrierCounter();
} // while(true)
}
static void AsyncUpdateThread(
const std::string &var_name, const bool &exit_flag,
const std::shared_ptr<detail::ReceivedQueue> &queue,
framework::Executor *executor,
framework::ExecutorPrepareContext *prepared) {
VLOG(3) << "update thread for " << var_name << " started";
while (!exit_flag && !SignalHandler::IsProgramExit()) {
const detail::ReceivedMessage v = queue->Pop();
if (SignalHandler::IsProgramExit()) {
VLOG(3) << "update thread for " << var_name << " exit";
break;
}
auto recv_var_name = v.first;
VLOG(4) << "async update " << recv_var_name;
auto var = v.second->GetVar();
if (var == nullptr) {
LOG(ERROR) << "Can not find server side var: " << recv_var_name;
PADDLE_THROW("Can not find server side var");
}
auto fs = framework::Async([var_name, &executor, &v, prepared] {
try {
executor->RunPreparedContext(prepared,
v.second->GetMutableLocalScope());
} catch (const std::exception &e) {
LOG(ERROR) << "run sub program error " << e.what();
}
});
fs.wait();
}
}
void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
framework::ProgramDesc *program) const {
VLOG(3) << "RunAsyncLoop in";
// grad name to block id
std::unordered_map<std::string, int32_t> grad_to_block_id;
std::unordered_map<int32_t, std::string> id_to_grad;
std::unordered_map<std::string, std::shared_ptr<detail::ReceivedQueue>>
grad_to_queue;
auto grad_to_block_id_str =
Attr<std::vector<std::string>>("grad_to_block_id");
......@@ -249,13 +176,9 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
VLOG(3) << "after split, grad = " << pieces[0] << ", id=" << pieces[1];
PADDLE_ENFORCE_EQ(pieces.size(), 2);
PADDLE_ENFORCE_EQ(grad_to_block_id.count(pieces[0]), 0);
int block_id = std::stoi(pieces[1]);
grad_to_block_id[pieces[0]] = block_id;
std::shared_ptr<detail::ReceivedQueue> queue =
std::make_shared<detail::ReceivedQueue>();
grad_to_queue[pieces[0]] = queue;
// record blocking queue in SignalHandler
SignalHandler::RegisterBlockingQueue(queue);
id_to_grad[block_id] = pieces[0];
}
size_t num_blocks = program->Size();
......@@ -274,39 +197,36 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
grad_to_prepared_ctx[id_to_grad[block_list[i]]] = optimize_prepared[i];
}
bool exit_flag = false;
request_send_handler_->SetGradToPreparedCtx(&grad_to_prepared_ctx);
request_get_handler_->SetGradToPreparedCtx(&grad_to_prepared_ctx);
request_prefetch_handler_->SetGradToPreparedCtx(&grad_to_prepared_ctx);
VLOG(3) << "start async optimize threads";
std::vector<std::future<void>> fs;
for (auto iter = grad_to_queue.begin(); iter != grad_to_queue.end(); iter++) {
std::string grad_name = iter->first;
VLOG(3) << "create async update thread for " << grad_name;
fs.push_back(framework::AsyncIO([grad_name, &exit_flag, &executor,
&grad_to_queue, &grad_to_prepared_ctx]() {
AsyncUpdateThread(grad_name, exit_flag, grad_to_queue[grad_name],
executor, grad_to_prepared_ctx[grad_name].get());
}));
}
VLOG(3) << "RunAsyncLoop into while";
while (!exit_flag && !SignalHandler::IsProgramExit()) {
const detail::ReceivedMessage v = rpc_service_->Get();
auto recv_var_name = v.first;
if (recv_var_name == LISTEN_TERMINATE_MESSAGE) {
LOG(INFO) << "received terminate message and exit";
exit_flag = true;
while (true) {
if (rpc_service_->IsExit()) {
LOG(INFO) << "get exit!rpc_processor break!";
break;
} else {
VLOG(3) << "received grad: " << recv_var_name;
grad_to_queue[recv_var_name]->Push(v);
}
if (exit_flag) {
rpc_service_->ShutDown();
break;
}
sleep(1);
} // while(true)
}
static void FillRequestCtx(detail::RequestHandler *h, framework::Scope *scope,
platform::DeviceContext *dev_ctx,
framework::Executor *executor,
framework::ProgramDesc *program,
framework::ExecutorPrepareContext *prefetch_ctx,
detail::RPCServer *rpc_server) {
h->SetScope(scope);
h->SetDevCtx(dev_ctx);
h->SetExecutor(executor);
h->SetProgram(program);
h->SetPrefetchPreparedCtx(std::move(
std::unique_ptr<framework::ExecutorPrepareContext>(prefetch_ctx)));
h->SetRPCServer(rpc_server);
}
void ListenAndServOp::RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const {
// Mark this as PS that it should decide profiling by listening from trainer.
......@@ -316,27 +236,42 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
framework::Scope &recv_scope = scope.NewScope();
bool sync_mode = Attr<bool>("sync_mode");
auto fan_in = Attr<int>("Fanin");
PADDLE_ENFORCE(!rpc_service_);
std::string endpoint = Attr<std::string>("endpoint");
rpc_service_.reset(new detail::AsyncGRPCServer(endpoint, sync_mode));
LOG(INFO) << "sync_mode:" << sync_mode << ", fan_in:" << fan_in
<< ", end_point:" << endpoint;
// request_handler_.reset(new detail::GRPCRequestSendHandler(sync_mode));
rpc_service_.reset(new detail::AsyncGRPCServer(endpoint, fan_in));
request_send_handler_.reset(new detail::RequestSendHandler(sync_mode));
request_get_handler_.reset(new detail::RequestGetHandler(sync_mode));
request_prefetch_handler_.reset(
new detail::RequestPrefetchHandler(sync_mode));
rpc_service_->RegisterRPC(detail::kRequestSend, request_send_handler_.get());
rpc_service_->RegisterRPC(detail::kRequestGet, request_get_handler_.get());
rpc_service_->RegisterRPC(detail::kRequestPrefetch,
request_prefetch_handler_.get());
auto *optimize_block = Attr<framework::BlockDesc *>(kOptimizeBlock);
auto *prefetch_block = Attr<framework::BlockDesc *>(kPrefetchBlock);
auto *program = optimize_block->Program();
framework::Executor executor(dev_place);
// prepare rpc_service
rpc_service_->SetScope(&recv_scope);
rpc_service_->SetDevCtx(&dev_ctx);
rpc_service_->SetProgram(program);
rpc_service_->SetExecutor(&executor);
// prepare for prefetch
VLOG(3) << "prefetch block id is " << prefetch_block->ID();
auto prefetch_prepared = executor.Prepare(*program, prefetch_block->ID());
rpc_service_->SetPrefetchPreparedCtx(std::move(prefetch_prepared));
auto f = std::bind(FillRequestCtx, std::placeholders::_1, &recv_scope,
&dev_ctx, &executor, program, prefetch_prepared.release(),
rpc_service_.get());
f(request_send_handler_.get());
f(request_get_handler_.get());
f(request_prefetch_handler_.get());
// start the server listening after all member initialized.
server_thread_.reset(new std::thread(RunServer, rpc_service_));
......@@ -348,8 +283,6 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
signal(SIGTERM, SignalHandler::StopAndExit);
// Write to a file of server selected port for python use.
std::string file_path = string::Sprintf("/tmp/paddle.%d.selected_port",
static_cast<int>(::getpid()));
SavePort();
if (sync_mode) {
RunSyncLoop(&executor, program, &recv_scope, prefetch_block);
......@@ -385,27 +318,9 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
}
};
bool SignalHandler::program_exit_flag_ = false;
SignalHandler::BlockingQueueSet SignalHandler::blocking_queue_set_{};
void SignalHandler::StopAndExit(int signal_num) {
VLOG(3) << "Catch interrupt signal: " << signal_num << ", program will exit";
program_exit_flag_ = true;
// awake all blocking queues
for (BlockingQueueSet::iterator iter = blocking_queue_set_.begin();
iter != blocking_queue_set_.end(); iter++) {
iter->get()->Push(
std::make_pair(std::string(LISTEN_TERMINATE_MESSAGE), nullptr));
}
exit(EXIT_SUCCESS);
}
void SignalHandler::RegisterBlockingQueue(BlockingQueue &queue) {
blocking_queue_set_.insert(queue);
exit(0);
}
} // namespace operators
......
......@@ -23,7 +23,8 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/detail/grpc_server.h"
#include "paddle/fluid/operators/detail/request_handler.h"
#include "paddle/fluid/operators/detail/rpc_server.h"
namespace paddle {
namespace operators {
......@@ -31,7 +32,7 @@ namespace operators {
constexpr char kOptimizeBlock[] = "OptimizeBlock";
constexpr char kPrefetchBlock[] = "PrefetchBlock";
void RunServer(std::shared_ptr<detail::AsyncGRPCServer> service);
void RunServer(std::shared_ptr<detail::RPCServer> service);
class ListenAndServOp : public framework::OperatorBase {
public:
......@@ -52,41 +53,27 @@ class ListenAndServOp : public framework::OperatorBase {
void SavePort() const;
void WaitServerReady();
int GetSelectedPort() { return selected_port_; }
int GetSelectedPort() { return rpc_service_->GetSelectedPort(); }
void Stop() override;
void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override;
static void ResetPort() { selected_port_ = 0; }
protected:
mutable std::shared_ptr<detail::AsyncGRPCServer> rpc_service_;
mutable std::shared_ptr<detail::RPCServer> rpc_service_;
mutable std::shared_ptr<detail::RequestHandler> request_send_handler_;
mutable std::shared_ptr<detail::RequestHandler> request_get_handler_;
mutable std::shared_ptr<detail::RequestHandler> request_prefetch_handler_;
mutable std::shared_ptr<std::thread> server_thread_;
// FIXME(wuyi): it's static so that the operator can be cloned.
static std::atomic_int selected_port_;
};
class SignalHandler {
public:
typedef std::shared_ptr<detail::ReceivedQueue> BlockingQueue;
typedef std::unordered_set<BlockingQueue> BlockingQueueSet;
public:
static void StopAndExit(int signal_num);
static void RegisterBlockingQueue(BlockingQueue&);
static inline bool IsProgramExit() { return program_exit_flag_; }
private:
static bool program_exit_flag_;
static BlockingQueueSet blocking_queue_set_;
DISABLE_COPY_AND_ASSIGN(SignalHandler);
};
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "mkldnn.hpp"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/mul_op.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/mkldnn_helper.h"
namespace paddle {
namespace operators {
using paddle::framework::Tensor;
using paddle::platform::MKLDNNDeviceContext;
template <typename Format = mkldnn::memory::format>
mkldnn::memory::desc type(const std::vector<int>& dims, Format&& f) {
return platform::MKLDNNMemDesc(dims, mkldnn::memory::data_type::f32, f);
}
template <typename T>
class MulMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
void Compute(const paddle::framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
"It must use CPUPlace.");
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
auto mkldnn_engine = dev_ctx.GetEngine();
auto input = ctx.Input<Tensor>("X");
auto weight = ctx.Input<Tensor>("Y");
PADDLE_ENFORCE(input->dims().size() & (2 | 4),
"Input must be with 2 or 4 dimensions, i.e. NC or NCHW");
PADDLE_ENFORCE(weight->dims().size() & (2 | 4),
"Weights must be with 2 or 4 dimensions, i.e. OI or OIHW");
std::vector<int> w_tz = paddle::framework::vectorize2int(weight->dims());
std::vector<int> src_tz = paddle::framework::vectorize2int(input->dims());
auto src_md =
src_tz.size() != 2
? type(src_tz, mkldnn::memory::format::nchw)
: type({src_tz[0], src_tz[1]}, mkldnn::memory::format::nc);
auto dst_md = type({src_tz[0], w_tz[1]}, mkldnn::memory::format::nc);
auto weights_md =
src_tz.size() != 2
? type({w_tz[1], src_tz[1], src_tz[2], src_tz[3]},
mkldnn::memory::format::oihw)
: type({w_tz[1], src_tz[1]}, mkldnn::memory::format::oi);
auto output = ctx.Output<Tensor>("Out");
T* output_data = output->mutable_data<T>(ctx.GetPlace());
const std::string key = ctx.op().Output("Out");
const std::string key_fc_pd = key + "@mul_pd";
const T* input_data = input->data<T>();
const T* w_data = weight->data<T>();
auto dst_memory = mkldnn::memory({dst_md, mkldnn_engine}, output_data);
auto src_memory = mkldnn::memory({src_md, mkldnn_engine},
platform::to_void_cast(input_data));
auto weights_memory = mkldnn::memory({weights_md, mkldnn_engine},
platform::to_void_cast(w_data));
auto pd = platform::MKLDNNFwdPrimitiveDesc<mkldnn::inner_product_forward>(
mkldnn_engine, src_md, weights_md, dst_md);
dev_ctx.SetBlob(key_fc_pd, pd);
auto forward = mkldnn::inner_product_forward(*pd, src_memory,
weights_memory, dst_memory);
std::vector<mkldnn::primitive> pipeline = {forward};
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
}
};
template <typename T>
class MulMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
"It must use CPUPlace.");
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
auto mkldnn_engine = dev_ctx.GetEngine();
const Tensor* input = ctx.Input<Tensor>("X");
const Tensor* w = ctx.Input<Tensor>("Y");
const Tensor* out_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
Tensor* input_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
Tensor* w_grad = ctx.Output<Tensor>(framework::GradVarName("Y"));
const std::string key = ctx.op().Input("Out");
const std::string key_fc_pd = key + "@mul_pd";
const T* input_data = input->data<T>();
const T* w_data = w->data<T>();
const T* out_grad_data = out_grad->data<T>();
T* input_grad_data = nullptr;
T* w_grad_data = nullptr;
if (input_grad) {
input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
}
if (w_grad) {
w_grad_data = w_grad->mutable_data<T>(ctx.GetPlace());
}
std::vector<int> src_tz = paddle::framework::vectorize2int(input->dims());
std::vector<int> w_tz = paddle::framework::vectorize2int(w->dims());
auto src_md =
src_tz.size() != 2
? type(src_tz, mkldnn::memory::format::nchw)
: type({src_tz[0], src_tz[1]}, mkldnn::memory::format::nc);
auto dst_md = type({src_tz[0], w_tz[1]}, mkldnn::memory::format::nc);
auto weights_md =
src_tz.size() != 2
? type({w_tz[1], src_tz[1], src_tz[2], src_tz[3]},
mkldnn::memory::format::oihw)
: type({w_tz[1], src_tz[1]}, mkldnn::memory::format::oi);
auto src_memory = mkldnn::memory({src_md, mkldnn_engine},
platform::to_void_cast(input_data));
auto dst_memory = mkldnn::memory({dst_md, mkldnn_engine},
platform::to_void_cast(out_grad_data));
auto weight_memory = mkldnn::memory({weights_md, mkldnn_engine},
platform::to_void_cast(w_data));
auto pd =
std::static_pointer_cast<mkldnn::inner_product_forward::primitive_desc>(
dev_ctx.GetBlob(key_fc_pd));
PADDLE_ENFORCE(pd != nullptr, "Fail to find pd in device context");
if (w_grad) {
auto weights_grad_memory = mkldnn::memory(
{weights_md, mkldnn_engine}, platform::to_void_cast(w_grad_data));
auto bwd_weight_pd = platform::MKLDNNBwdPrimitiveDesc<
mkldnn::inner_product_backward_weights>(mkldnn_engine, *pd, src_md,
weights_md, dst_md);
auto bwd_weights_prim = mkldnn::inner_product_backward_weights(
bwd_weight_pd, src_memory, dst_memory, weights_grad_memory);
std::vector<mkldnn::primitive> pipeline{bwd_weights_prim};
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
}
if (input_grad) {
auto src_grad_memory = mkldnn::memory(
{src_md, mkldnn_engine}, platform::to_void_cast(input_grad_data));
auto bwd_data_pd =
platform::MKLDNNBwdPrimitiveDesc<mkldnn::inner_product_backward_data>(
mkldnn_engine, *pd, src_md, weights_md, dst_md);
auto bwd_data_prim = mkldnn::inner_product_backward_data(
bwd_data_pd, dst_memory, weight_memory, src_grad_memory);
std::vector<mkldnn::primitive> pipeline{bwd_data_prim};
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
}
}
};
} // namespace operators
} // namespace paddle
REGISTER_OP_KERNEL(mul, MKLDNN, ::paddle::platform::CPUPlace,
paddle::operators::MulMKLDNNOpKernel<float>);
REGISTER_OP_KERNEL(mul_grad, MKLDNN, ::paddle::platform::CPUPlace,
paddle::operators::MulMKLDNNGradOpKernel<float>);
......@@ -16,10 +16,6 @@ limitations under the License. */
#include <string>
#include <vector>
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
namespace paddle {
namespace operators {
......@@ -76,22 +72,6 @@ class MulOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("Out", framework::make_ddim(output_dims));
ctx->ShareLoD("X", /*->*/ "Out");
}
private:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
framework::LibraryType library{framework::LibraryType::kPlain};
#ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
library = framework::LibraryType::kMKLDNN;
}
#endif
framework::DataLayout layout{framework::DataLayout::kAnyLayout};
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(),
layout, library);
}
};
class MulOpMaker : public framework::OpProtoAndCheckerMaker {
......@@ -120,9 +100,6 @@ class MulOpMaker : public framework::OpProtoAndCheckerMaker {
)DOC")
.SetDefault(1)
.EqualGreaterThan(1);
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddAttr<int>(
"y_num_col_dims",
R"DOC((int, default 1), The mul_op can take tensors with more than two,
......@@ -177,22 +154,6 @@ class MulGradOp : public framework::OperatorWithKernel {
ctx->SetOutputDim(y_grad_name, y_dims);
}
}
private:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
framework::LibraryType library{framework::LibraryType::kPlain};
#ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
library = framework::LibraryType::kMKLDNN;
}
#endif
framework::DataLayout layout{framework::DataLayout::kAnyLayout};
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(),
layout, library);
}
};
} // namespace operators
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/random_crop_op.h"
namespace paddle {
namespace operators {
class RandomCropOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
ctx.device_context());
}
};
class RandomCropOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "A batch of instances to random crop.");
AddInput("Seed", "The random seed.");
AddOutput("Out", "The cropped instance batch.");
AddOutput("SeedOut", "The random seed after random cropping.")
.AsDispensable();
AddAttr<std::vector<int>>("shape", "The shape of a cropped instance.");
AddComment(R"DOC(
This operator takes a batch of instance, and do random cropping on each instance.
It means that cropping positions differs on each instance, which is determined
by an uniform random generator. All cropped instances have the same shape, which
is determined by the operator's attribute 'shape'.
)DOC");
}
};
class RandomCropOpInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext* ctx) const override {
auto seed_dim = ctx->GetInputDim("Seed");
PADDLE_ENFORCE(seed_dim.size() == 1 && seed_dim[0] == 1);
auto shape = ctx->Attrs().Get<std::vector<int>>("shape");
auto x_dim = ctx->GetInputDim("X");
PADDLE_ENFORCE_GT(x_dim.size(), static_cast<int64_t>(shape.size()));
auto out_dim = framework::vectorize2int(x_dim);
for (size_t i = 1; i <= shape.size(); ++i) {
size_t x_i = x_dim.size() - i;
size_t shape_i = shape.size() - i;
PADDLE_ENFORCE_GE(x_dim[x_i], shape[shape_i]);
out_dim[x_i] = shape[shape_i];
}
ctx->SetOutputDim("Out", framework::make_ddim(out_dim));
ctx->SetOutputDim("SeedOut", framework::make_ddim({1}));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace f = paddle::framework;
REGISTER_OPERATOR(random_crop, ops::RandomCropOp, ops::RandomCropOpMaker,
ops::RandomCropOpInferShape, f::EmptyGradOpMaker);
template <typename T>
using Kernel = ops::RandomCropKernel<paddle::platform::CPUDeviceContext, T>;
REGISTER_OP_CPU_KERNEL(random_crop, Kernel<float>, Kernel<int>, Kernel<double>,
Kernel<uint8_t>, Kernel<int16_t>);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/random_crop_op.h"
namespace ops = paddle::operators;
template <typename T>
using Kernel = ops::RandomCropKernel<paddle::platform::CUDADeviceContext, T>;
REGISTER_OP_CUDA_KERNEL(random_crop, Kernel<float>, Kernel<int>, Kernel<double>,
Kernel<uint8_t>, Kernel<int16_t>);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/for_range.h"
#ifdef PADDLE_WITH_CUDA
#include <thrust/random.h>
#endif
namespace paddle {
namespace operators {
template <typename DeviceContext>
struct Random;
template <>
struct Random<platform::CPUDeviceContext> {
using Engine = std::minstd_rand;
template <typename T>
using UniformIntDist = std::uniform_int_distribution<T>;
};
#ifdef PADDLE_WITH_CUDA
template <>
struct Random<platform::CUDADeviceContext> {
using Engine = thrust::minstd_rand;
template <typename T>
using UniformIntDist = thrust::uniform_int_distribution<T>;
};
#endif
template <typename T>
HOSTDEVICE inline void StridedMemcpy(const T* x, const size_t* x_dims, T* out,
const size_t* out_dims, int i, int rank,
size_t prod_x_remain,
size_t prod_out_remain,
const size_t* offsets) {
size_t x_dim_i = x_dims[i];
size_t out_dim_i = out_dims[i];
size_t x_stride = prod_x_remain / x_dim_i;
size_t out_stride = prod_out_remain / out_dim_i;
size_t offset_i = offsets[i];
if (i == rank - 1) {
PADDLE_ASSERT(x_stride == 1 && out_stride == 1);
x += offset_i;
for (size_t j = 0; j < out_dim_i; ++j) {
*out++ = *x++;
}
} else {
x += offset_i * x_stride;
for (size_t j = 0; j < out_dim_i; ++j) {
StridedMemcpy<T>(x, x_dims, out, out_dims, i + 1, rank, x_stride,
out_stride, offsets);
x += x_stride;
out += out_stride;
}
}
}
template <typename DeviceContext, typename T>
struct RandomCropFunctor {
const T* x_;
T* out_;
size_t x_dims_[9];
size_t out_dims_[9];
int num_batchsize_dims_;
int rank_;
int64_t seed_;
size_t prod_batchsize_dims_;
size_t prod_x_ins_dims_;
size_t prod_out_ins_dims_;
RandomCropFunctor(const T* x, T* out, const framework::DDim& x_dims,
const framework::DDim& out_dims, int num_batchsize_dims,
int64_t seed)
: x_(x),
out_(out),
num_batchsize_dims_(num_batchsize_dims),
rank_(x_dims.size()),
seed_(seed) {
PADDLE_ENFORCE_EQ(x_dims.size(), out_dims.size());
PADDLE_ENFORCE_GT(rank_, num_batchsize_dims_);
prod_batchsize_dims_ = 1;
prod_x_ins_dims_ = 1;
prod_out_ins_dims_ = 1;
for (size_t i = 0; i < static_cast<size_t>(rank_); ++i) {
size_t x_dim_i = x_dims[i];
size_t out_dim_i = out_dims[i];
x_dims_[i] = x_dim_i;
out_dims_[i] = out_dim_i;
if (i < static_cast<size_t>(num_batchsize_dims_)) {
PADDLE_ENFORCE_EQ(x_dim_i, out_dim_i);
prod_batchsize_dims_ *= x_dim_i;
} else {
prod_x_ins_dims_ *= x_dim_i;
prod_out_ins_dims_ *= out_dim_i;
}
}
}
HOSTDEVICE void operator()(size_t ins_idx) {
typename Random<DeviceContext>::Engine engine(seed_);
engine.discard(ins_idx * (rank_ - num_batchsize_dims_));
size_t offsets[9];
for (int i = num_batchsize_dims_; i < rank_; ++i) {
typename Random<DeviceContext>::template UniformIntDist<size_t> dist(
0, x_dims_[i] - out_dims_[i]);
offsets[i - num_batchsize_dims_] = dist(engine);
}
const T* x = x_ + ins_idx * prod_x_ins_dims_;
T* out = out_ + ins_idx * prod_out_ins_dims_;
StridedMemcpy<T>(x, x_dims_ + num_batchsize_dims_, out,
out_dims_ + num_batchsize_dims_, 0,
rank_ - num_batchsize_dims_, prod_x_ins_dims_,
prod_out_ins_dims_, offsets);
}
};
template <typename DeviceContext, typename T>
class RandomCropKernel : public framework::OpKernel<T> {
public:
virtual void Compute(const framework::ExecutionContext& ctx) const {
auto& seed_tensor = detail::Ref(ctx.Input<framework::LoDTensor>("Seed"));
int64_t seed = 0;
if (platform::is_cpu_place(seed_tensor.place())) {
seed = *seed_tensor.data<int64_t>();
} else {
LOG(WARNING) << "It is slow to place seed in GPU memory. Please verify "
"your program";
framework::LoDTensor cpu_seed;
framework::TensorCopySync(seed_tensor, platform::CPUPlace(), &cpu_seed);
seed = *cpu_seed.data<int64_t>();
}
auto shape = ctx.Attr<std::vector<int>>("shape");
auto& x = detail::Ref(ctx.Input<framework::LoDTensor>("X"));
auto& out = detail::Ref(ctx.Output<framework::LoDTensor>("Out"));
int num_batchsize_dims = x.dims().size() - shape.size();
RandomCropFunctor<DeviceContext, T> functor(
x.data<T>(), out.mutable_data<T>(ctx.GetPlace()), x.dims(), out.dims(),
num_batchsize_dims, seed);
platform::ForRange<DeviceContext> for_range(
ctx.template device_context<DeviceContext>(),
functor.prod_batchsize_dims_);
for_range(functor);
Random<platform::CPUDeviceContext>::Engine engine(seed);
engine.discard(functor.prod_batchsize_dims_ *
(functor.rank_ - functor.num_batchsize_dims_));
*ctx.Output<framework::LoDTensor>("SeedOut")->mutable_data<int64_t>(
platform::CPUPlace()) = engine();
}
};
// TODO(fengjiayi): Backward of random crop op
} // namespace operators
} // namespace paddle
......@@ -23,13 +23,12 @@ namespace reader {
class CustomReader : public framework::DecoratedReader {
public:
CustomReader(ReaderBase* reader, const framework::BlockDesc& sub_block,
const platform::Place& dev_place,
const std::vector<std::string>& source_var_names,
const std::vector<std::string>& sink_var_names)
: DecoratedReader(reader),
program_(*sub_block.Program()),
sub_block_id_(sub_block.ID()),
exe_(framework::Executor(dev_place)),
exe_(framework::Executor(platform::CPUPlace())),
source_var_names_(source_var_names),
sink_var_names_(sink_var_names) {}
......@@ -60,7 +59,7 @@ class CreateCustomReaderOp : public framework::OperatorBase {
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
->Get<framework::ReaderHolder>();
out->Reset(
new CustomReader(underlying_reader.Get(), *sub_block, dev_place,
new CustomReader(underlying_reader.Get(), *sub_block,
Attr<std::vector<std::string>>("source_var_names"),
Attr<std::vector<std::string>>("sink_var_names")));
}
......@@ -85,9 +84,10 @@ class CreateCustomReaderOpMaker : public DecoratedReaderMakerBase {
CreateCustomReader Operator
A custom reader can be used for input data preprocessing.
A custom reader holds its own sub-block, which will be executed in its
'ReadNext()' function. Users can configurate their own preprocessing
pipelines by inserting operators into custom reader's sub-block.
A custom reader holds its own sub-block, which will be executed in CPU
in its 'ReadNext()' function. Users can configurate their own
preprocessing pipelines by inserting operators into custom reader's
sub-block.
)DOC");
}
};
......
......@@ -46,6 +46,8 @@ class SendBarrierOp : public framework::OperatorBase {
auto rpc_client = detail::RPCClient::GetInstance();
VLOG(3) << "SendBarrierOp sync_mode:" << sync_mode;
// need to wait before sending send_barrier message
PADDLE_ENFORCE(rpc_client->Wait());
if (sync_mode) {
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/operators/tensorrt_engine_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/utils/singleton.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
void paddle::operators::TensorRTEngineKernel<DeviceContext, T>::Prepare(
const framework::ExecutionContext &context) const {
// Get the ProgramDesc and pass to convert.
const auto &block = context.Attr<framework::proto::BlockDesc>("subgraph");
max_batch_ = context.Attr<int>("max_batch");
auto max_workspace = context.Attr<int>("max_workspace");
engine_.reset(new inference::tensorrt::TensorRTEngine(
max_batch_, max_workspace, nullptr));
inference::Singleton<inference::tensorrt::OpConverter>::Global().ConvertBlock(
block, engine_.get());
engine_->FreezeNetwork();
}
class TensorRTEngineOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Xs", "A list of inputs.").AsDuplicable();
AddOutput("Ys", "A list of outputs").AsDuplicable();
AddAttr<std::string>("subgraph", "the subgraph");
AddComment("TensorRT engine operator.");
}
};
class TensorRTEngineInferVarType : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override {}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(tensorrt_engine, ops::TensorRTEngineOp,
ops::TensorRTEngineOpMaker, ops::TensorRTEngineOpMaker);
REGISTER_OP_CPU_KERNEL(
tensorrt_engine,
ops::TensorRTEngineKernel<paddle::platform::CPUDeviceContext, float>,
ops::TensorRTEngineKernel<paddle::platform::CPUDeviceContext, double>,
ops::TensorRTEngineKernel<paddle::platform::CPUDeviceContext, int>,
ops::TensorRTEngineKernel<paddle::platform::CPUDeviceContext, int64_t>);
#endif // PADDLE_WITH_CUDA
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/inference/analysis/helper.h"
#include "paddle/fluid/inference/tensorrt/engine.h"
namespace paddle {
namespace operators {
class TensorRTEngineOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
framework::OpKernelType kt = framework::OpKernelType(
framework::ToDataType(
ctx.Input<framework::LoDTensor>("pre_ids")->type()),
platform::CPUPlace());
return kt;
}
};
template <typename DeviceContext, typename T>
class TensorRTEngineKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
if (!engine_) {
Prepare(context);
}
auto input_names = context.op().Inputs("Xs");
PADDLE_ENFORCE(!input_names.empty(), "should pass more than one inputs");
// Try to determine a batch_size
auto* tensor0 = context.Input<framework::LoDTensor>(input_names.front());
PADDLE_ENFORCE_NOT_NULL(tensor0);
int batch_size = tensor0->dims()[0];
PADDLE_ENFORCE_LE(batch_size, max_batch_);
// Convert input tensor from fluid to engine.
for (const auto& x : context.Inputs("Xs")) {
// convert input and copy to TRT engine's buffer
auto* v = context.scope().FindVar(x);
PADDLE_ENFORCE_NOT_NULL(v, "no variable called %s", x);
auto& t = v->Get<framework::LoDTensor>();
if (platform::is_cpu_place(t.place())) {
engine_->SetInputFromCPU(x, static_cast<const void*>(t.data<void>()),
t.memory_size());
} else {
engine_->SetInputFromGPU(x, static_cast<const void*>(t.data<void>()),
t.memory_size());
}
}
// Execute the engine.
PADDLE_ENFORCE_GT(batch_size, 0);
engine_->Execute(batch_size);
// Convert output tensor from engine to fluid
for (const auto& y : context.Outputs("Ys")) {
// convert output and copy to fluid.
nvinfer1::ITensor* trt_t = engine_->GetITensor(y);
auto dims = trt_t->getDimensions();
// Use the output ITensor's dims to reshape the Fluid Tensor.
std::vector<int> ddim(dims.d, dims.d + dims.nbDims);
auto* fluid_v = context.scope().FindVar(y);
PADDLE_ENFORCE_NOT_NULL(fluid_v, "no output variable called %s", y);
auto* fluid_t = fluid_v->GetMutable<framework::LoDTensor>();
fluid_t->Resize(framework::make_ddim(ddim));
auto size = inference::analysis::AccuDims(dims.d, dims.nbDims);
if (platform::is_cpu_place(fluid_t->place())) {
engine_->GetOutputInCPU(
y, fluid_t->mutable_data<float>(platform::CPUPlace()), size);
} else {
engine_->GetOutputInGPU(
y, fluid_t->mutable_data<float>(platform::CUDAPlace()), size);
}
}
}
protected:
// Build the engine.
void Prepare(const framework::ExecutionContext& context) const;
private:
mutable std::unique_ptr<inference::tensorrt::TensorRTEngine> engine_;
mutable int max_batch_{0};
};
} // namespace operators
} // namespace paddle
#endif // PADDLE_WITH_CUDA
......@@ -21,6 +21,8 @@ limitations under the License. */
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/operators/detail/grpc_server.h"
#include "paddle/fluid/operators/detail/request_handler_impl.h"
#include "paddle/fluid/operators/listen_and_serv_op.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
......@@ -35,42 +37,44 @@ namespace m = paddle::operators::math;
namespace detail = paddle::operators::detail;
namespace string = paddle::string;
std::unique_ptr<detail::AsyncGRPCServer> rpc_service;
std::unique_ptr<detail::AsyncGRPCServer> g_rpc_service;
std::unique_ptr<detail::RequestHandler> g_req_handler;
void StartServer(std::atomic<bool>* initialized) {
void StartServer() {
f::Scope scope;
p::CPUPlace place;
scope.Var(NCCL_ID_VARNAME);
p::DeviceContextPool& pool = p::DeviceContextPool::Instance();
auto& dev_ctx = *pool.Get(p::CPUPlace());
rpc_service.reset(new detail::AsyncGRPCServer("127.0.0.1:0", true));
f::ProgramDesc empty_program;
f::Executor executor(dev_ctx.GetPlace());
rpc_service->SetScope(&scope);
rpc_service->SetDevCtx(&dev_ctx);
rpc_service->SetProgram(&empty_program);
rpc_service->SetExecutor(&executor);
g_req_handler->SetScope(&scope);
g_req_handler->SetDevCtx(&dev_ctx);
g_req_handler->SetProgram(&empty_program);
g_req_handler->SetExecutor(&executor);
g_rpc_service->RegisterRPC(detail::kRequestSend, g_req_handler.get());
g_req_handler->SetRPCServer(g_rpc_service.get());
std::thread server_thread(
std::bind(&detail::AsyncGRPCServer::RunSyncUpdate, rpc_service.get()));
*initialized = true;
rpc_service->SetCond(0);
auto recv = rpc_service->Get();
std::bind(&detail::AsyncGRPCServer::StartServer, g_rpc_service.get()));
g_rpc_service->SetCond(detail::kRequestSend);
std::cout << "before WaitFanInOfSend" << std::endl;
g_rpc_service->WaitBarrier(detail::kRequestSend);
LOG(INFO) << "got nccl id and stop server...";
rpc_service->ShutDown();
g_rpc_service->ShutDown();
server_thread.join();
}
TEST(SendNcclId, DISABLED_Normal) {
std::atomic<bool> initialized{false};
std::thread server_thread(StartServer, &initialized);
while (!initialized) {
}
// wait server to start
// sleep(2);
rpc_service->WaitServerReady();
TEST(SendNcclId, GrpcServer) {
g_req_handler.reset(new detail::RequestSendHandler(true));
g_rpc_service.reset(new detail::AsyncGRPCServer("127.0.0.1:0", 1));
std::thread server_thread(StartServer);
g_rpc_service->WaitServerReady();
f::Scope scope;
p::CPUPlace place;
......@@ -78,17 +82,20 @@ TEST(SendNcclId, DISABLED_Normal) {
auto& dev_ctx = *pool.Get(p::CPUPlace());
auto var = scope.Var(NCCL_ID_VARNAME);
// var->SetType(f::proto::VarType_Type_RAW);
auto id = var->GetMutable<ncclUniqueId>();
p::dynload::ncclGetUniqueId(id);
int port = rpc_service->GetSelectedPort();
int port = g_rpc_service->GetSelectedPort();
std::string ep = string::Sprintf("127.0.0.1:%d", port);
detail::RPCClient client;
LOG(INFO) << "connect to server" << ep;
client.AsyncSendVariable(ep, dev_ctx, scope, NCCL_ID_VARNAME);
client.Wait();
client.AsyncSendBatchBarrier(ep);
client.Wait();
server_thread.join();
auto* ptr = rpc_service.release();
delete ptr;
g_rpc_service.reset(nullptr);
g_req_handler.reset(nullptr);
}
......@@ -55,6 +55,9 @@ class TopkKernel : public framework::OpKernel<T> {
// NOTE: eigen shape doesn't affect paddle tensor.
eg_input.reshape(flat2dims);
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
for (size_t i = 0; i < row; i++) {
std::vector<std::pair<T, size_t>> vec;
for (size_t j = 0; j < col; j++) {
......
......@@ -15,6 +15,7 @@
#pragma once
#include <stdio.h>
#include <string>
#include <thread> // NOLINT
#include <typeindex>
#include <vector>
......
......@@ -38,6 +38,7 @@ struct EventList;
static int64_t profiler_lister_id = 0;
static bool should_send_profile_state = false;
std::mutex profiler_mu;
// The profiler state, the initial value is ProfilerState::kDisabled
static ProfilerState g_state = ProfilerState::kDisabled;
......@@ -228,6 +229,8 @@ void EnableProfiler(ProfilerState state) {
PADDLE_ENFORCE(state != ProfilerState::kDisabled,
"Can't enbale profling, since the input state is ",
"ProfilerState::kDisabled");
std::lock_guard<std::mutex> l(profiler_mu);
if (state == g_state) {
return;
}
......@@ -295,7 +298,7 @@ void PrintProfiler(const std::vector<std::vector<EventItem>>& events_table,
} else if (g_state == ProfilerState::kAll) {
place = "All";
} else {
PADDLE_THROW("Invalid profiler state");
PADDLE_THROW("Invalid profiler state", g_state);
}
std::cout << "Place: " << place << std::endl;
......@@ -443,6 +446,7 @@ void ParseEvents(const std::vector<std::vector<Event>>& events,
void DisableProfiler(EventSortingKey sorted_key,
const std::string& profile_path) {
std::lock_guard<std::mutex> l(profiler_mu);
if (g_state == ProfilerState::kDisabled) return;
// Mark the profiling stop.
Mark("_stop_profiler_", nullptr);
......@@ -466,7 +470,7 @@ void SetProfileListener() {
std::mt19937 rng;
rng.seed(std::random_device()());
std::uniform_int_distribution<std::mt19937::result_type> dist6(
1, std::numeric_limits<std::mt19937::result_type>::max());
1, std::numeric_limits<int>::max());
profiler_lister_id = dist6(rng);
}
int64_t ListenerId() { return profiler_lister_id; }
......
......@@ -117,6 +117,7 @@ PYBIND11_PLUGIN(core) {
.def("set", PyCPUTensorSetFromArray<int64_t>)
.def("set", PyCPUTensorSetFromArray<bool>)
.def("set", PyCPUTensorSetFromArray<uint16_t>)
.def("set", PyCPUTensorSetFromArray<uint8_t>)
#ifdef PADDLE_WITH_CUDA
.def("set", PyCUDATensorSetFromArray<float>)
.def("set", PyCUDATensorSetFromArray<int>)
......@@ -124,12 +125,14 @@ PYBIND11_PLUGIN(core) {
.def("set", PyCUDATensorSetFromArray<int64_t>)
.def("set", PyCUDATensorSetFromArray<bool>)
.def("set", PyCUDATensorSetFromArray<uint16_t>)
.def("set", PyCUDATensorSetFromArray<uint8_t>)
.def("set", PyCUDAPinnedTensorSetFromArray<float>)
.def("set", PyCUDAPinnedTensorSetFromArray<int>)
.def("set", PyCUDAPinnedTensorSetFromArray<double>)
.def("set", PyCUDAPinnedTensorSetFromArray<int64_t>)
.def("set", PyCUDAPinnedTensorSetFromArray<bool>)
.def("set", PyCUDAPinnedTensorSetFromArray<uint16_t>)
.def("set", PyCUDAPinnedTensorSetFromArray<uint8_t>)
#endif
.def("shape", [](Tensor &self) { return vectorize(self.dims()); })
.def("set_float_element", TensorSetElement<float>)
......@@ -492,6 +495,7 @@ All parameter, weight, gradient are variables in Paddle.
m.def("enable_profiler", platform::EnableProfiler);
m.def("disable_profiler", platform::DisableProfiler);
m.def("is_profiler_enabled", platform::IsProfileEnabled);
m.def("reset_profiler", platform::ResetProfiler);
// -- python binds for parallel executor.
......
......@@ -66,18 +66,18 @@ class NeonDepthwiseConvFunction : public ConvFunctionBase {
float* inputPadding = inputData;
int padInputHeight = inputHeight + 2 * paddingH();
int padInputWidth = inputWidth + 2 * paddingW();
if (paddingH() > 0 || paddingW() > 0) {
int newSize = batchSize * inputChannels * padInputHeight * padInputWidth;
resizeBuffer<Device>(newSize);
inputPadding = reinterpret_cast<float*>(memory_->getBuf());
neon::Padding<float>::run(inputData,
inputPadding,
batchSize * inputChannels,
inputHeight,
inputWidth,
padInputHeight,
padInputWidth);
}
int newSize =
batchSize * (inputChannels + 1) * padInputHeight * padInputWidth;
resizeBuffer<Device>(newSize);
inputPadding = reinterpret_cast<float*>(memory_->getBuf());
neon::Padding<float>::run(inputData,
inputPadding,
batchSize * inputChannels,
inputHeight,
inputWidth,
padInputHeight,
padInputWidth);
std::function<void(
const float*, const float*, int, int, int, int, int, int, float*)>
......
......@@ -183,7 +183,7 @@ function build() {
============================================
EOF
make clean
make -j `nproc`
make install -j `nproc`
}
function build_android() {
......
......@@ -36,9 +36,11 @@ class DataToLoDTensorConverter(object):
self.dtype = 'float64'
elif dtype == core.VarDesc.VarType.INT32:
self.dtype = 'int32'
elif dtype == core.VarDesc.VarType.UINT8:
self.dtype = 'uint8'
else:
raise ValueError("dtype must be any of [int32, float32, int64, "
"float64]")
"float64, uint8]")
self.data = []
self.lod = []
......
......@@ -82,6 +82,7 @@ __all__ = [
'roi_pool',
'dice_loss',
'upsampling_bilinear2d',
'random_crop',
]
......@@ -154,7 +155,8 @@ def fc(input,
Examples:
.. code-block:: python
data = fluid.layers.data(name="data", shape=[32, 32], dtype="float32")
data = fluid.layers.data(
name="data", shape=[32, 32], dtype="float32")
fc = fluid.layers.fc(input=data, size=1000, act="tanh")
"""
......@@ -177,11 +179,8 @@ def fc(input,
inputs={"X": input_var,
"Y": w},
outputs={"Out": tmp},
attrs={
"x_num_col_dims": num_flatten_dims,
"y_num_col_dims": 1,
"use_mkldnn": use_mkldnn
})
attrs={"x_num_col_dims": num_flatten_dims,
"y_num_col_dims": 1})
mul_results.append(tmp)
if len(mul_results) == 1:
......@@ -349,7 +348,8 @@ def dynamic_lstm(input,
cell_activation(str): The activation for cell output. Choices = ["sigmoid",
"tanh", "relu", "identity"], default "tanh".
candidate_activation(str): The activation for candidate hidden state.
Choices = ["sigmoid", "tanh", "relu", "identity"],
Choices = ["sigmoid", "tanh",
"relu", "identity"],
default "tanh".
dtype(str): Data type. Choices = ["float32", "float64"], default "float32".
name(str|None): A name for this layer(optional). If set None, the layer
......@@ -516,10 +516,12 @@ def dynamic_lstmp(input,
cell_activation(str): The activation for cell output. Choices = ["sigmoid",
"tanh", "relu", "identity"], default "tanh".
candidate_activation(str): The activation for candidate hidden state.
Choices = ["sigmoid", "tanh", "relu", "identity"],
Choices = ["sigmoid", "tanh",
"relu", "identity"],
default "tanh".
proj_activation(str): The activation for projection output.
Choices = ["sigmoid", "tanh", "relu", "identity"],
Choices = ["sigmoid", "tanh",
"relu", "identity"],
default "tanh".
dtype(str): Data type. Choices = ["float32", "float64"], default "float32".
name(str|None): A name for this layer(optional). If set None, the layer
......@@ -855,7 +857,7 @@ def cos_sim(X, Y):
return out
def dropout(x, dropout_prob, is_test=False, seed=None):
def dropout(x, dropout_prob, is_test=False, seed=None, name=None):
"""
Computes dropout.
......@@ -873,6 +875,8 @@ def dropout(x, dropout_prob, is_test=False, seed=None):
parameter is set to None, a random seed is used.
NOTE: If an integer seed is given, always the same output
units will be dropped. DO NOT use a fixed seed in training.
name(str|None): A name for this layer(optional). If set None, the layer
will be named automatically.
Returns:
Variable: A tensor variable.
......@@ -1117,7 +1121,7 @@ def sequence_softmax(input, param_attr=None, bias_attr=None, use_cudnn=True):
return softmax_out
def softmax(input, param_attr=None, bias_attr=None, use_cudnn=True):
def softmax(input, param_attr=None, bias_attr=None, use_cudnn=True, name=None):
helper = LayerHelper('softmax', **locals())
dtype = helper.input_dtype()
softmax_out = helper.create_tmp_variable(dtype)
......@@ -2172,7 +2176,8 @@ def reduce_mean(input, dim=None, keep_dim=False, name=None):
fluid.layers.reduce_mean(x) # [0.4375]
fluid.layers.reduce_mean(x, dim=0) # [0.15, 0.25, 0.55, 0.8]
fluid.layers.reduce_mean(x, dim=-1) # [0.475, 0.4]
fluid.layers.reduce_mean(x, dim=1, keep_dim=True) # [[0.475], [0.4]]
fluid.layers.reduce_mean(
x, dim=1, keep_dim=True) # [[0.475], [0.4]]
# x is a Tensor variable with shape [2, 2, 2] and elements as below:
# [[[1.0, 2.0], [3.0, 4.0]],
......@@ -2391,7 +2396,8 @@ def split(input, num_or_sections, dim=-1, name=None):
x0.shape # [3, 3, 5]
x1.shape # [3, 3, 5]
x2.shape # [3, 3, 5]
x0, x1, x2 = fluid.layers.split(x, num_or_sections=[2, 3, 4], dim=1)
x0, x1, x2 = fluid.layers.split(
x, num_or_sections=[2, 3, 4], dim=1)
x0.shape # [3, 2, 5]
x1.shape # [3, 3, 5]
x2.shape # [3, 4, 5]
......@@ -2610,7 +2616,7 @@ def matmul(x, y, transpose_x=False, transpose_y=False, name=None):
return out
def topk(input, k):
def topk(input, k, name=None):
"""
This operator is used to find values and indices of the k largest entries
for the last dimension.
......@@ -2626,6 +2632,8 @@ def topk(input, k):
input(Variable): The input variable which can be a vector or Tensor with
higher rank.
k(int): An integer value to specify the top k largest elements.
name(str|None): A name for this layer(optional). If set None, the layer
will be named automatically.
Returns:
values(Variable): The k largest elements along each last dimensional
......@@ -3301,7 +3309,8 @@ def softmax_with_cross_entropy(logits, label, soft_label=False):
data = fluid.layers.data(name='data', shape=[128], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
fc = fluid.layers.fc(input=data, size=100)
out = fluid.layers.softmax_with_cross_entropy(logits=fc, label=label)
out = fluid.layers.softmax_with_cross_entropy(
logits=fc, label=label)
"""
helper = LayerHelper('softmax_with_cross_entropy', **locals())
softmax = helper.create_tmp_variable(dtype=logits.dtype)
......@@ -3348,7 +3357,8 @@ def smooth_l1(x, y, inside_weight=None, outside_weight=None, sigma=None):
.. code-block:: python
data = fluid.layers.data(name='data', shape=[128], dtype='float32')
label = fluid.layers.data(name='label', shape=[100], dtype='float32')
label = fluid.layers.data(
name='label', shape=[100], dtype='float32')
fc = fluid.layers.fc(input=data, size=100)
out = fluid.layers.smooth_l1(x=fc, y=label)
"""
......@@ -3670,7 +3680,8 @@ def lrn(input, n=5, k=1.0, alpha=1e-4, beta=0.75, name=None):
Examples:
.. code-block:: python
data = fluid.layers.data(name="data", shape=[3, 112, 112], dtype="float32")
data = fluid.layers.data(
name="data", shape=[3, 112, 112], dtype="float32")
lrn = fluid.layers.lrn(input=data)
"""
helper = LayerHelper('lrn', **locals())
......@@ -3925,10 +3936,10 @@ def upsampling_bilinear2d(input, out_shape=None, scale=None, name=None):
Bilinear interpolation is an extension of linear interpolation for
interpolating functions of two variables (e.g. H-direction and
W-direction in this layer) on a rectilinear 2D grid.
For details, please refer to Wikipedia:
https://en.wikipedia.org/wiki/Bilinear_interpolation
Args:
input (Variable): The input tensor of bilinear interpolation,
This is a 4-D tensor of the shape
......@@ -3946,7 +3957,7 @@ def upsampling_bilinear2d(input, out_shape=None, scale=None, name=None):
Returns:
out (Variable): The output is a 4-D tensor of the shape
(num_batches, channls, out_h, out_w).
Examples:
.. code-block:: python
......@@ -3979,3 +3990,33 @@ def upsampling_bilinear2d(input, out_shape=None, scale=None, name=None):
attrs={"out_h": out_h,
"out_w": out_w})
return out
def random_crop(input, shape, seed=1):
helper = LayerHelper("random_crop", **locals())
dtype = helper.input_dtype()
out = helper.create_tmp_variable(dtype)
if isinstance(seed, int):
seed_value = seed
seed = helper.create_tmp_variable(dtype="int64")
helper.append_op(
type="fill_constant",
inputs={},
outputs={"Out": seed},
attrs={
"dtype": seed.dtype,
"shape": [1],
"value": float(seed_value),
"force_cpu": True
})
elif not isinstance(seed, Variable):
raise ValueError("'seed' must be a Variable or an int.")
seed_out = helper.create_tmp_variable(dtype="int64")
helper.append_op(
type="random_crop",
inputs={"X": input,
"Seed": seed},
outputs={"Out": out,
"SeedOut": seed_out},
attrs={"shape": shape})
return out
......@@ -112,7 +112,7 @@ def cast(x, dtype):
return out
def concat(input, axis=0):
def concat(input, axis=0, name=None):
"""
**Concat**
......@@ -122,6 +122,8 @@ def concat(input, axis=0):
Args:
input(list): List of tensors to be concatenated
axis(int): Integer axis along which the tensors will be concatenated
name(str|None): A name for this layer(optional). If set None, the layer
will be named automatically.
Returns:
Variable: Output variable of the concatenation
......
......@@ -16,7 +16,10 @@ import core
from contextlib import contextmanager
import os
__all__ = ['cuda_profiler', 'reset_profiler', 'profiler']
__all__ = [
'cuda_profiler', 'reset_profiler', 'profiler', 'start_profiler',
'stop_profiler'
]
NVPROF_CONFIG = [
"gpustarttimestamp",
......@@ -72,20 +75,31 @@ def reset_profiler():
core.reset_profiler()
@contextmanager
def profiler(state, sorted_key=None, profile_path='/tmp/profile'):
"""The profiler interface.
Different from cuda_profiler, this profiler can be used to profile both CPU
and GPU program. By defalut, it records the CPU and GPU operator kernels,
if you want to profile other program, you can refer the profiling tutorial
to add more records.
def start_profiler(state):
"""Enable the profiler.
Args:
state (string) : The profiling state, which should be 'CPU', 'GPU'
or 'All'. 'CPU' means only profile CPU. 'GPU' means profiling
GPU as well. 'All' also generates timeline.
"""
if core.is_profiler_enabled():
return
if state not in ['CPU', 'GPU', "All"]:
raise ValueError("The state must be 'CPU' or 'GPU' or 'All'.")
if state == "GPU":
prof_state = core.ProfilerState.kCUDA
elif state == "CPU":
prof_state = core.ProfilerState.kCPU
else:
prof_state = core.ProfilerState.kAll
core.enable_profiler(prof_state)
def stop_profiler(sorted_key=None, profile_path='/tmp/profile'):
"""Stop the profiler.
Args:
state (string) : The profiling state, which should be 'CPU' or 'GPU',
telling the profiler to use CPU timer or GPU timer for profiling.
Although users may have already specified the execution place
(CPUPlace/CUDAPlace) in the begining, for flexibility the profiler
would not inherit this place.
sorted_key (string) : If None, the profiling results will be printed
in the order of first end time of events. Otherwise, the profiling
results will be sorted by the this flag. This flag should be one
......@@ -98,17 +112,8 @@ def profiler(state, sorted_key=None, profile_path='/tmp/profile'):
profile_path (string) : If state == 'All', it will write a profile
proto output file.
"""
if state not in ['CPU', 'GPU', "All"]:
raise ValueError("The state must be 'CPU' or 'GPU' or 'All'.")
if state == "GPU":
prof_state = core.ProfilerState.kCUDA
elif state == "CPU":
prof_state = core.ProfilerState.kCPU
else:
prof_state = core.ProfilerState.kAll
core.enable_profiler(prof_state)
yield
if not core.is_profiler_enabled():
return
sorted_key = 'default' if sorted_key is None else sorted_key
if sorted_key not in ['default', 'calls', 'total', 'max', 'min', 'ave']:
raise ValueError("The sorted_key must be None or in 'calls', 'total', "
......@@ -124,3 +129,34 @@ def profiler(state, sorted_key=None, profile_path='/tmp/profile'):
# TODO(qingqing) : redirect C++ ostream to Python stream.
# with core.ostream_redirect(stdout=True, stderr=True):
core.disable_profiler(key_map[sorted_key], profile_path)
@contextmanager
def profiler(state, sorted_key=None, profile_path='/tmp/profile'):
"""The profiler interface.
Different from cuda_profiler, this profiler can be used to profile both CPU
and GPU program. By defalut, it records the CPU and GPU operator kernels,
if you want to profile other program, you can refer the profiling tutorial
to add more records.
Args:
state (string) : The profiling state, which should be 'CPU' or 'GPU',
telling the profiler to use CPU timer or GPU timer for profiling.
Although users may have already specified the execution place
(CPUPlace/CUDAPlace) in the begining, for flexibility the profiler
would not inherit this place.
sorted_key (string) : If None, the profiling results will be printed
in the order of first end time of events. Otherwise, the profiling
results will be sorted by the this flag. This flag should be one
of 'calls', 'total', 'max', 'min' or 'ave'.
The `calls` means sorting by the number of calls.
The `total` means sorting by the total execution time.
The `max` means sorting by the maximum execution time.
The `min` means sorting by the minimum execution time.
The `ave` means sorting by the average execution time.
profile_path (string) : If state == 'All', it will write a profile
proto output file.
"""
start_profiler(state)
yield
stop_profiler(sorted_key, profile_path)
......@@ -217,8 +217,6 @@ def infer(use_cuda, inference_program, params_dirname):
# The range of random integers is [low, high]
word = fluid.create_random_int_lodtensor(
lod, base_shape, place, low=0, high=WORD_DICT_LEN - 1)
pred = fluid.create_random_int_lodtensor(
lod, base_shape, place, low=0, high=PRED_DICT_LEN - 1)
ctx_n2 = fluid.create_random_int_lodtensor(
lod, base_shape, place, low=0, high=WORD_DICT_LEN - 1)
ctx_n1 = fluid.create_random_int_lodtensor(
......@@ -229,18 +227,20 @@ def infer(use_cuda, inference_program, params_dirname):
lod, base_shape, place, low=0, high=WORD_DICT_LEN - 1)
ctx_p2 = fluid.create_random_int_lodtensor(
lod, base_shape, place, low=0, high=WORD_DICT_LEN - 1)
pred = fluid.create_random_int_lodtensor(
lod, base_shape, place, low=0, high=PRED_DICT_LEN - 1)
mark = fluid.create_random_int_lodtensor(
lod, base_shape, place, low=0, high=MARK_DICT_LEN - 1)
results = inferencer.infer(
{
'word_data': word,
'verb_data': pred,
'ctx_n2_data': ctx_n2,
'ctx_n1_data': ctx_n1,
'ctx_0_data': ctx_0,
'ctx_p1_data': ctx_p1,
'ctx_p2_data': ctx_p2,
'verb_data': pred,
'mark_data': mark
},
return_numpy=False)
......
......@@ -53,7 +53,7 @@ def encoder(is_sparse):
return encoder_out
def decoder_train(context, is_sparse):
def train_decoder(context, is_sparse):
# decoder
trg_language_word = pd.data(
name="target_language_word", shape=[1], dtype='int64', lod_level=1)
......@@ -81,7 +81,7 @@ def decoder_train(context, is_sparse):
return rnn()
def decoder_decode(context, is_sparse):
def decode(context, is_sparse):
init_state = context
array_len = pd.fill_constant(shape=[1], dtype='int64', value=max_length)
counter = pd.zeros(shape=[1], dtype='int64', force_cpu=True)
......@@ -148,31 +148,9 @@ def decoder_decode(context, is_sparse):
return translation_ids, translation_scores
def set_init_lod(data, lod, place):
res = fluid.LoDTensor()
res.set(data, place)
res.set_lod(lod)
return res
def to_lodtensor(data, place):
seq_lens = [len(seq) for seq in data]
cur_len = 0
lod = [cur_len]
for l in seq_lens:
cur_len += l
lod.append(cur_len)
flattened_data = np.concatenate(data, axis=0).astype("int64")
flattened_data = flattened_data.reshape([len(flattened_data), 1])
res = fluid.LoDTensor()
res.set(flattened_data, place)
res.set_lod([lod])
return res
def train_program(is_sparse):
context = encoder(is_sparse)
rnn_out = decoder_train(context, is_sparse)
rnn_out = train_decoder(context, is_sparse)
label = pd.data(
name="target_language_next_word", shape=[1], dtype='int64', lod_level=1)
cost = pd.cross_entropy(input=rnn_out, label=label)
......@@ -218,13 +196,12 @@ def train(use_cuda, is_sparse, is_local=True):
def decode_main(use_cuda, is_sparse):
if use_cuda and not fluid.core.is_compiled_with_cuda():
return
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
context = encoder(is_sparse)
translation_ids, translation_scores = decoder_decode(context, is_sparse)
translation_ids, translation_scores = decode(context, is_sparse)
exe = Executor(place)
exe.run(framework.default_startup_program())
......@@ -234,26 +211,32 @@ def decode_main(use_cuda, is_sparse):
[1. for _ in range(batch_size)], dtype='float32')
init_ids_data = init_ids_data.reshape((batch_size, 1))
init_scores_data = init_scores_data.reshape((batch_size, 1))
init_lod = [i for i in range(batch_size)] + [batch_size]
init_lod = [1] * batch_size
init_lod = [init_lod, init_lod]
init_ids = fluid.create_lod_tensor(init_ids_data, init_lod, place)
init_scores = fluid.create_lod_tensor(init_scores_data, init_lod, place)
train_data = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.wmt14.train(dict_size), buf_size=1000),
batch_size=batch_size)
for _, data in enumerate(train_data()):
init_ids = set_init_lod(init_ids_data, init_lod, place)
init_scores = set_init_lod(init_scores_data, init_lod, place)
src_word_data = to_lodtensor(map(lambda x: x[0], data), place)
feed_order = ['src_word_id']
feed_list = [
framework.default_main_program().global_block().var(var_name)
for var_name in feed_order
]
feeder = fluid.DataFeeder(feed_list, place)
for data in train_data():
feed_dict = feeder.feed(map(lambda x: [x[0]], data))
feed_dict['init_ids'] = init_ids
feed_dict['init_scores'] = init_scores
result_ids, result_scores = exe.run(
framework.default_main_program(),
feed={
'src_word_id': src_word_data,
'init_ids': init_ids,
'init_scores': init_scores
},
feed=feed_dict,
fetch_list=[translation_ids, translation_scores],
return_numpy=False)
print result_ids.lod()
......
......@@ -147,28 +147,6 @@ def decoder_decode(context, is_sparse):
return translation_ids, translation_scores
def set_init_lod(data, lod, place):
res = fluid.LoDTensor()
res.set(data, place)
res.set_lod(lod)
return res
def to_lodtensor(data, place):
seq_lens = [len(seq) for seq in data]
cur_len = 0
lod = [cur_len]
for l in seq_lens:
cur_len += l
lod.append(cur_len)
flattened_data = np.concatenate(data, axis=0).astype("int64")
flattened_data = flattened_data.reshape([len(flattened_data), 1])
res = fluid.LoDTensor()
res.set(flattened_data, place)
res.set_lod([lod])
return res
def train_main(use_cuda, is_sparse, is_local=True):
if use_cuda and not fluid.core.is_compiled_with_cuda():
return
......@@ -192,23 +170,25 @@ def train_main(use_cuda, is_sparse, is_local=True):
paddle.dataset.wmt14.train(dict_size), buf_size=1000),
batch_size=batch_size)
feed_order = [
'src_word_id', 'target_language_word', 'target_language_next_word'
]
exe = Executor(place)
def train_loop(main_program):
exe.run(framework.default_startup_program())
feed_list = [
main_program.global_block().var(var_name) for var_name in feed_order
]
feeder = fluid.DataFeeder(feed_list, place)
batch_id = 0
for pass_id in xrange(1):
for data in train_data():
word_data = to_lodtensor(map(lambda x: x[0], data), place)
trg_word = to_lodtensor(map(lambda x: x[1], data), place)
trg_word_next = to_lodtensor(map(lambda x: x[2], data), place)
outs = exe.run(main_program,
feed={
'src_word_id': word_data,
'target_language_word': trg_word,
'target_language_next_word': trg_word_next
},
feed=feeder.feed(data),
fetch_list=[avg_cost])
avg_cost_val = np.array(outs[0])
print('pass_id=' + str(pass_id) + ' batch=' + str(batch_id) +
......@@ -258,26 +238,32 @@ def decode_main(use_cuda, is_sparse):
[1. for _ in range(batch_size)], dtype='float32')
init_ids_data = init_ids_data.reshape((batch_size, 1))
init_scores_data = init_scores_data.reshape((batch_size, 1))
init_lod = [i for i in range(batch_size)] + [batch_size]
init_lod = [1] * batch_size
init_lod = [init_lod, init_lod]
init_ids = fluid.create_lod_tensor(init_ids_data, init_lod, place)
init_scores = fluid.create_lod_tensor(init_scores_data, init_lod, place)
train_data = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.wmt14.train(dict_size), buf_size=1000),
batch_size=batch_size)
for _, data in enumerate(train_data()):
init_ids = set_init_lod(init_ids_data, init_lod, place)
init_scores = set_init_lod(init_scores_data, init_lod, place)
src_word_data = to_lodtensor(map(lambda x: x[0], data), place)
feed_order = ['src_word_id']
feed_list = [
framework.default_main_program().global_block().var(var_name)
for var_name in feed_order
]
feeder = fluid.DataFeeder(feed_list, place)
for data in train_data():
feed_dict = feeder.feed(map(lambda x: [x[0]], data))
feed_dict['init_ids'] = init_ids
feed_dict['init_scores'] = init_scores
result_ids, result_scores = exe.run(
framework.default_main_program(),
feed={
'src_word_id': src_word_data,
'init_ids': init_ids,
'init_scores': init_scores
},
feed=feed_dict,
fetch_list=[translation_ids, translation_scores],
return_numpy=False)
print result_ids.lod()
......
......@@ -152,29 +152,6 @@ def seq_to_seq_net():
return avg_cost, prediction
def to_lodtensor(data, place):
seq_lens = [len(seq) for seq in data]
cur_len = 0
lod = [cur_len]
for l in seq_lens:
cur_len += l
lod.append(cur_len)
flattened_data = np.concatenate(data, axis=0).astype("int64")
flattened_data = flattened_data.reshape([len(flattened_data), 1])
res = core.LoDTensor()
res.set(flattened_data, place)
res.set_lod([lod])
return res
def create_random_lodtensor(lod, place, low, high):
data = np.random.random_integers(low, high, [lod[-1], 1]).astype("int64")
res = fluid.LoDTensor()
res.set(data, place)
res.set_lod([lod])
return res
def train(use_cuda, save_dirname=None):
[avg_cost, prediction] = seq_to_seq_net()
......@@ -188,22 +165,20 @@ def train(use_cuda, save_dirname=None):
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = Executor(place)
exe.run(framework.default_startup_program())
feed_order = ['source_sequence', 'target_sequence', 'label_sequence']
feed_list = [
framework.default_main_program().global_block().var(var_name)
for var_name in feed_order
]
feeder = fluid.DataFeeder(feed_list, place)
batch_id = 0
for pass_id in xrange(2):
for data in train_data():
word_data = to_lodtensor(map(lambda x: x[0], data), place)
trg_word = to_lodtensor(map(lambda x: x[1], data), place)
trg_word_next = to_lodtensor(map(lambda x: x[2], data), place)
outs = exe.run(framework.default_main_program(),
feed={
'source_sequence': word_data,
'target_sequence': trg_word,
'label_sequence': trg_word_next
},
feed=feeder.feed(data),
fetch_list=[avg_cost])
avg_cost_val = np.array(outs[0])
......@@ -237,9 +212,23 @@ def infer(use_cuda, save_dirname=None):
[inference_program, feed_target_names,
fetch_targets] = fluid.io.load_inference_model(save_dirname, exe)
lod = [0, 4, 10]
word_data = create_random_lodtensor(lod, place, low=0, high=1)
trg_word = create_random_lodtensor(lod, place, low=0, high=1)
# Setup input by creating LoDTensor to represent sequence of words.
# Here each word is the basic element of the LoDTensor and the shape of
# each word (base_shape) should be [1] since it is simply an index to
# look up for the corresponding word vector.
# Suppose the length_based level of detail (lod) info is set to [[4, 6]],
# which has only one lod level. Then the created LoDTensor will have only
# one higher level structure (sequence of words, or sentence) than the basic
# element (word). Hence the LoDTensor will hold data for two sentences of
# length 4 and 6, respectively.
# Note that lod info should be a list of lists.
lod = [[4, 6]]
base_shape = [1]
# The range of random integers is [low, high]
word_data = fluid.create_random_int_lodtensor(
lod, base_shape, place, low=0, high=1)
trg_word = fluid.create_random_int_lodtensor(
lod, base_shape, place, low=0, high=1)
# Construct feed as a dictionary of {feed_target_name: feed_target_data}
# and results will contain a list of data corresponding to fetch_targets.
......
......@@ -15,7 +15,7 @@
import unittest
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid import debuger
from paddle.fluid import debugger
from paddle.fluid.framework import Program
......@@ -51,9 +51,9 @@ class TestDebugger(unittest.TestCase):
outputs={"Out": mul_out},
attrs={"x_num_col_dims": 1})
print(debuger.pprint_program_codes(p))
print(debugger.pprint_program_codes(p))
debuger.draw_block_graphviz(p.block(0), path="./test.dot")
debugger.draw_block_graphviz(p.block(0), path="./test.dot")
if __name__ == '__main__':
......
......@@ -13,31 +13,47 @@
# limitations under the License.
import unittest
from test_mul_op import TestMulOp, TestMulOp2, TestFP16MulOp1, TestFP16MulOp2
import numpy as np
import math
from op_test import OpTest
class TestMKLDNNMulOp(TestMulOp):
def init_op_test(self):
super(TestMKLDNNMulOp, self).setUp()
self.attrs = {"use_mkldnn": True}
def quantize_max_abs(x, num_bits):
range = math.pow(2, num_bits) - 1
scale = np.max(np.abs(x).flatten())
y = np.round(x / scale * range)
return y, scale
class TestMKLDNNMulOp2(TestMulOp2):
def init_op_test(self):
super(TestMKLDNNMulOp2, self).setUp()
self.attrs = {"use_mkldnn": True}
def dequantize_max_abs(x, num_bits, scale):
range = math.pow(2, num_bits) - 1
y = (scale / range) * x
return y
class TestMKLDNNFP16MulOp1(TestFP16MulOp1):
def init_op_test(self):
super(TestMKLDNNFP16MulOp1, self).setUp()
self.attrs = {"use_mkldnn": True}
class TestFakeDequantizeMaxAbsOp(OpTest):
def set_args(self):
self.num_bits = 8
def setUp(self):
self.set_args()
self.op_type = "fake_dequantize_max_abs"
x = np.random.randn(31, 65).astype("float32")
yq, scale = quantize_max_abs(x, self.num_bits)
print 'scale ', scale
ydq = dequantize_max_abs(yq, self.num_bits, scale)
class TestMKLDNNFP16MulOp2(TestFP16MulOp2):
def init_op_test(self):
super(TestMKLDNNFP16MulOp2, self).setUp()
self.attrs = {"use_mkldnn": True}
self.inputs = {'X': yq}
self.attrs = {'num_bits': self.num_bits, 'scale': float(scale)}
self.outputs = {'Out': ydq}
def test_check_output(self):
self.check_output()
class TestFakeDequantizeMaxAbsOp5Bits(OpTest):
def set_args(self):
self.num_bits = 5
if __name__ == "__main__":
......
......@@ -21,12 +21,10 @@ from op_test import OpTest
class TestMulOp(OpTest):
def setUp(self):
self.op_type = "mul"
self.use_mkldnn = False
self.inputs = {
'X': np.random.random((32, 84)).astype("float32"),
'Y': np.random.random((84, 100)).astype("float32")
}
self.attrs = {'use_mkldnn': self.use_mkldnn}
self.outputs = {'Out': np.dot(self.inputs['X'], self.inputs['Y'])}
def test_check_output(self):
......@@ -47,16 +45,11 @@ class TestMulOp(OpTest):
class TestMulOp2(OpTest):
def setUp(self):
self.op_type = "mul"
self.use_mkldnn = False
self.inputs = {
'X': np.random.random((15, 4, 12, 10)).astype("float32"),
'Y': np.random.random((4, 30, 8, 2, 9)).astype("float32")
}
self.attrs = {
'x_num_col_dims': 2,
'y_num_col_dims': 2,
'use_mkldnn': self.use_mkldnn
}
self.attrs = {'x_num_col_dims': 2, 'y_num_col_dims': 2}
result = np.dot(self.inputs['X'].reshape(15 * 4, 12 * 10),
self.inputs['Y'].reshape(4 * 30, 8 * 2 * 9))
result = result.reshape(15, 4, 8, 2, 9)
......@@ -80,11 +73,9 @@ class TestMulOp2(OpTest):
class TestFP16MulOp1(OpTest):
def setUp(self):
self.op_type = "mul"
self.use_mkldnn = False
x = np.random.random((32, 84)).astype("float16")
y = np.random.random((84, 100)).astype("float16")
self.inputs = {'X': x.view(np.uint16), 'Y': y.view(np.uint16)}
self.attrs = {'use_mkldnn': self.use_mkldnn}
self.outputs = {'Out': np.dot(x, y)}
def test_check_output(self):
......@@ -97,15 +88,10 @@ class TestFP16MulOp1(OpTest):
class TestFP16MulOp2(OpTest):
def setUp(self):
self.op_type = "mul"
self.use_mkldnn = False
x = np.random.random((15, 4, 12, 10)).astype("float16")
y = np.random.random((4, 30, 8, 2, 9)).astype("float16")
self.inputs = {'X': x.view(np.uint16), 'Y': y.view(np.uint16)}
self.attrs = {
'x_num_col_dims': 2,
'y_num_col_dims': 2,
'use_mkldnn': self.use_mkldnn
}
self.attrs = {'x_num_col_dims': 2, 'y_num_col_dims': 2}
result = np.dot(
x.reshape(15 * 4, 12 * 10), y.reshape(4 * 30, 8 * 2 * 9))
result = result.reshape(15, 4, 8, 2, 9)
......
......@@ -63,10 +63,7 @@ class TestOperator(unittest.TestCase):
self.assertEqual(mul_op.output("Out"), ["mul.out"])
self.assertEqual(
set(mul_op.attr_names),
set([
"x_num_col_dims", "y_num_col_dims", "use_mkldnn", "op_role",
"op_role_var"
]))
set(["x_num_col_dims", "y_num_col_dims", "op_role", "op_role_var"]))
self.assertEqual(mul_op.has_attr("x_num_col_dims"), True)
self.assertEqual(mul_op.attr_type("x_num_col_dims"), core.AttrType.INT)
self.assertEqual(mul_op.attr("x_num_col_dims"), 1)
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle.fluid.core as core
from op_test import OpTest
class TestRandomCropOp(OpTest):
def setUp(self):
to_crop = np.array([[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]] *
5).astype("float32")
self.possible_res = [
np.array([[1, 2, 3], [5, 6, 7]]), np.array([[2, 3, 4], [6, 7, 8]]),
np.array([[5, 6, 7], [9, 10, 11]]),
np.array([[6, 7, 8], [10, 11, 12]])
]
self.op_type = "random_crop"
self.inputs = {'X': to_crop, 'Seed': np.array([10])}
self.outputs = {'Out': np.array([]), 'SeedOut': np.array([])}
self.attrs = {'shape': [2, 3]}
def test_check_output(self):
self.check_output_customized(self.verify_output)
def verify_output(self, outs):
out = np.array(outs[1])
for ins in out[:]:
is_equal = [(ins == res).all() for res in self.possible_res]
self.assertIn(True, is_equal)
if __name__ == "__main__":
unittest.main()
......@@ -14,7 +14,7 @@
import math
import unittest
from paddle.fluid.transpiler.distribute_transpiler import split_dense_variable
from paddle.fluid.transpiler.distribute_transpiler import split_variable
import paddle.fluid as fluid
import paddle.fluid.core as core
import random
......@@ -31,7 +31,7 @@ class TestSplitVar(unittest.TestCase):
# dtype=core.VarDesc.VarType.LOD_TENSOR,
shape=shape)
var_list.append(var)
blocks = split_dense_variable(var_list, 10, min_size)
blocks = split_variable(var_list, 10, min_size)
all_sizes = []
for s in expected_sizes:
for s2 in s:
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
......@@ -69,7 +69,8 @@ packages=['paddle',
'paddle.fluid.proto',
'paddle.fluid.proto.profiler',
'paddle.fluid.layers',
'paddle.fluid.transpiler']
'paddle.fluid.transpiler',
'paddle.fluid.transpiler.details']
if '${WITH_FLUID_ONLY}'== 'OFF':
packages+=['paddle.proto',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册