diff --git a/.travis.yml b/.travis.yml
index a406841f6abf01f15826f34fe4c63b4c24486ccd..361136ac2c8d899a0d7a4d7945083fcc489551b5 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -27,15 +27,6 @@ script:
# 43min timeout
paddle/scripts/paddle_docker_build.sh ${JOB}
if [ $? -eq 0 ] || [ $? -eq 142 ]; then true; else exit 1; fi;
- - |
- if [[ "$JOB" != "doc" ]]; then exit 0; fi;
- # For document only
- if [[ "$TRAVIS_PULL_REQUEST" != "false" ]]; then exit 0; fi;
- if [[ "$TRAVIS_BRANCH" != "develop" && ! "$TRAVIS_BRANCH" =~ ^v|release/[[:digit:]]+\.[[:digit:]]+(\.[[:digit:]]+)?(-\S*)?$ ]]; then exit 0; fi;
- export DEPLOY_DOCS_SH=https://raw.githubusercontent.com/PaddlePaddle/PaddlePaddle.org/master/scripts/deploy/deploy_docs.sh
- export DOCS_DIR=`pwd`
- cd ..
- curl $DEPLOY_DOCS_SH | bash -s $CONTENT_DEC_PASSWD $TRAVIS_BRANCH $DOCS_DIR $DOCS_DIR/build/doc/
notifications:
email:
on_success: change
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 2cab76e8f01c571c931398f6492aa9aeeebf1f08..f56c5d382af8cdfb5a941ee272a0f8d22ec04d67 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -65,6 +65,7 @@ option(REPLACE_ENFORCE_GLOG "Replace PADDLE_ENFORCE with glog/CHECK for better d
option(WITH_ANAKIN "Compile with Anakin library" OFF)
option(WITH_GRPC "Use grpc as the default rpc framework" ${WITH_DISTRIBUTE})
option(WITH_BRPC_RDMA "Use brpc rdma as the rpc protocal" OFF)
+option(WITH_INFERENCE "Compile fluid inference library" ON)
option(WITH_SYSTEM_BLAS "Use system blas library" OFF)
option(PY_VERSION "Compile PaddlePaddle with python3 support" ${PY_VERSION})
@@ -159,6 +160,7 @@ endif()
########################################################################################
include(external/mklml) # download mklml package
+include(external/xbyak) # download xbyak package
include(external/libxsmm) # download, build, install libxsmm
include(external/zlib) # download, build, install zlib
include(external/gflags) # download, build, install gflags
@@ -175,6 +177,7 @@ include(external/any) # download libn::any
include(external/eigen) # download eigen3
include(external/pybind11) # download pybind11
include(external/cares)
+include(external/cub)
if(WITH_DISTRIBUTE)
if(WITH_GRPC)
diff --git a/cmake/external/cub.cmake b/cmake/external/cub.cmake
new file mode 100644
index 0000000000000000000000000000000000000000..c94849cf4b96746e6c507db2a6310c2f305dacf5
--- /dev/null
+++ b/cmake/external/cub.cmake
@@ -0,0 +1,35 @@
+if(NOT WITH_GPU)
+ return()
+endif()
+
+include(ExternalProject)
+
+set(CUB_SOURCE_DIR ${THIRD_PARTY_PATH}/cub)
+set(CUB_INCLUDE_DIR ${CUB_SOURCE_DIR}/src/extern_cub)
+
+include_directories(${CUB_INCLUDE_DIR})
+
+ExternalProject_Add(
+ extern_cub
+ ${EXTERNAL_PROJECT_LOG_ARGS}
+ GIT_REPOSITORY "https://github.com/NVlabs/cub.git"
+ GIT_TAG "v1.8.0"
+ PREFIX ${CUB_SOURCE_DIR}
+ UPDATE_COMMAND ""
+ CONFIGURE_COMMAND ""
+ BUILD_COMMAND ""
+ INSTALL_COMMAND ""
+ TEST_COMMAND ""
+)
+
+if(${CMAKE_VERSION} VERSION_LESS "3.3.0")
+ set(dummyfile ${CMAKE_CURRENT_BINARY_DIR}/cub_dummy.c)
+ file(WRITE ${dummyfile} "const char *dummy = \"${dummyfile}\";")
+ add_library(cub STATIC ${dummyfile})
+else()
+ add_library(cub INTERFACE)
+endif()
+
+add_dependencies(cub extern_cub)
+
+LIST(APPEND externl_project_dependencies cub)
diff --git a/cmake/external/xbyak.cmake b/cmake/external/xbyak.cmake
new file mode 100644
index 0000000000000000000000000000000000000000..384c2f9328296ce6a8a6293be6cc47e5063dd3c4
--- /dev/null
+++ b/cmake/external/xbyak.cmake
@@ -0,0 +1,58 @@
+# Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+set(WITH_XBYAK ON)
+if(WIN32 OR APPLE)
+ SET(WITH_XBYAK OFF CACHE STRING "Disable XBYAK in Windows and MacOS" FORCE)
+ return()
+endif()
+
+include(ExternalProject)
+
+set(XBYAK_PROJECT extern_xbyak)
+set(XBYAK_PREFIX_DIR ${THIRD_PARTY_PATH}/xbyak)
+set(XBYAK_INSTALL_ROOT ${THIRD_PARTY_PATH}/install/xbyak)
+set(XBYAK_INC_DIR ${XBYAK_INSTALL_ROOT}/include)
+
+include_directories(${XBYAK_INC_DIR})
+include_directories(${XBYAK_INC_DIR}/xbyak)
+
+add_definitions(-DPADDLE_WITH_XBYAK)
+
+# xbyak options
+add_definitions(-DXBYAK64)
+add_definitions(-DXBYAK_NO_OP_NAMES)
+
+ExternalProject_Add(
+ ${XBYAK_PROJECT}
+ ${EXTERNAL_PROJECT_LOG_ARGS}
+ DEPENDS ""
+ GIT_REPOSITORY "https://github.com/herumi/xbyak.git"
+ GIT_TAG "v5.661" # Jul 26th
+ PREFIX ${XBYAK_PREFIX_DIR}
+ UPDATE_COMMAND ""
+ CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${XBYAK_INSTALL_ROOT}
+ CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${XBYAK_INSTALL_ROOT}
+)
+
+if (${CMAKE_VERSION} VERSION_LESS "3.3.0")
+ set(dummyfile ${CMAKE_CURRENT_BINARY_DIR}/xbyak_dummy.c)
+ file(WRITE ${dummyfile} "const char *dummy_xbyak = \"${dummyfile}\";")
+ add_library(xbyak STATIC ${dummyfile})
+else()
+ add_library(xbyak INTERFACE)
+endif()
+
+add_dependencies(xbyak ${XBYAK_PROJECT})
+list(APPEND external_project_dependencies xbyak)
diff --git a/cmake/generic.cmake b/cmake/generic.cmake
index 07bab994d354df834d0667c69f307b2d7684fb22..82c958073cba92f00a341121e36ba45531b22aec 100644
--- a/cmake/generic.cmake
+++ b/cmake/generic.cmake
@@ -264,7 +264,10 @@ function(cc_test TARGET_NAME)
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
if (${cc_test_SERIAL})
set_property(TEST ${TARGET_NAME} PROPERTY RUN_SERIAL 1)
+
+ set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_cpu_deterministic=true)
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_init_allocated_mem=true)
+ set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_cudnn_deterministic=true)
endif()
endif()
endfunction(cc_test)
@@ -329,7 +332,10 @@ function(nv_test TARGET_NAME)
add_test(${TARGET_NAME} ${TARGET_NAME})
if (nv_test_SERIAL)
set_property(TEST ${TARGET_NAME} PROPERTY RUN_SERIAL 1)
+
+ set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_cpu_deterministic=true)
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_init_allocated_mem=true)
+ set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_cudnn_deterministic=true)
endif()
endif()
endfunction(nv_test)
@@ -577,7 +583,9 @@ function(py_test TARGET_NAME)
set(multiValueArgs SRCS DEPS ARGS ENVS)
cmake_parse_arguments(py_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
add_test(NAME ${TARGET_NAME}
- COMMAND env FLAGS_init_allocated_mem=true PYTHONPATH=${PADDLE_BINARY_DIR}/python ${py_test_ENVS}
+ COMMAND env FLAGS_init_allocated_mem=true FLAGS_cudnn_deterministic=true
+ FLAGS_cpu_deterministic=true
+ PYTHONPATH=${PADDLE_BINARY_DIR}/python ${py_test_ENVS}
${PYTHON_EXECUTABLE} -u ${py_test_SRCS} ${py_test_ARGS}
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
endif()
diff --git a/doc/fluid/howto/optimization/timeline_cn.md b/doc/fluid/howto/optimization/timeline_cn.md
index 5d061e1c00d2ca0194153730a39486b8357fa5b0..faf39f276dbddcd4961407ba2d082c9826051cbe 100644
--- a/doc/fluid/howto/optimization/timeline_cn.md
+++ b/doc/fluid/howto/optimization/timeline_cn.md
@@ -1,21 +1,27 @@
# 如何使用timeline工具做性能分析
-1. 在训练的主循环外加上`with profiler.profiler(...)`。运行之后,代码会在`/tmp/profile`目录下生成一个profile的记录文件。
+1. 在训练的主循环外加上`profiler.start_profiler(...)`和`profiler.stop_profiler(...)`。运行之后,代码会在`/tmp/profile`目录下生成一个profile的记录文件。
**提示:**
请不要在timeline记录信息时运行太多次迭代,因为timeline中的记录数量和迭代次数是成正比的。
```python
- with profiler.profiler('All', 'total', '/tmp/profile') as prof:
- for pass_id in range(pass_num):
- for batch_id, data in enumerate(train_reader()):
- exe.run(fluid.default_main_program(),
- feed=feeder.feed(data),
- fetch_list=[])
+ for pass_id in range(pass_num):
+ for batch_id, data in enumerate(train_reader()):
+ if pass_id == 0 and batch_id == 5:
+ profiler.start_profiler("All")
+ elif pass_id == 0 and batch_id == 10:
+ profiler.stop_profiler("total", "/tmp/profile")
+ exe.run(fluid.default_main_program(),
+ feed=feeder.feed(data),
+ fetch_list=[])
...
```
1. 运行`python paddle/tools/timeline.py`来处理`/tmp/profile`,这个程序默认会生成一个`/tmp/timeline`文件,你也可以用命令行参数来修改这个路径,请参考[timeline.py](https://github.com/PaddlePaddle/Paddle/blob/develop/tools/timeline.py)。
+```python
+python Paddle/tools/timeline.py --profile_path=/tmp/profile --timeline_path=timeline
+```
1. 打开chrome浏览器,访问,用`load`按钮来加载生成的`timeline`文件。
diff --git a/doc/fluid/howto/optimization/timeline_en.md b/doc/fluid/howto/optimization/timeline_en.md
index 96481ae2a6e4442d40803f8d5361e5f942502df3..6f963c6b4da6967fb2f493ada917a4b08917fa4c 100644
--- a/doc/fluid/howto/optimization/timeline_en.md
+++ b/doc/fluid/howto/optimization/timeline_en.md
@@ -1,15 +1,17 @@
# how to use timeline tool to do profile
-1. Add `with profiler.profiler(...)` to the main training loop. After run, the code will generate a profile record file `/tmp/profile`. **Warning**: Please do not run too many batches when use profiler to record timeline information, for the profile record will grow with the batch number.
+1. Add `profiler.start_profiler(...)`和`profiler.stop_profiler(...)` to the main training loop. After run, the code will generate a profile record file `/tmp/profile`. **Warning**: Please do not run too many batches when use profiler to record timeline information, for the profile record will grow with the batch number.
```python
- with profiler.profiler('All', 'total', '/tmp/profile') as prof:
- for pass_id in range(pass_num):
- for batch_id, data in enumerate(train_reader()):
- exe.run(fluid.default_main_program(),
- feed=feeder.feed(data),
- fetch_list=[],
- use_program_cache=True)
+ for pass_id in range(pass_num):
+ for batch_id, data in enumerate(train_reader()):
+ if pass_id == 0 and batch_id == 5:
+ profiler.start_profiler("All")
+ elif pass_id == 0 and batch_id == 10:
+ profiler.stop_profiler("total", "/tmp/profile")
+ exe.run(fluid.default_main_program(),
+ feed=feeder.feed(data),
+ fetch_list=[])
...
```
@@ -17,6 +19,10 @@
file `/tmp/timeline` by default. You can change the path by cmd parameter, please take a look at
[timeline.py](https://github.com/PaddlePaddle/Paddle/blob/develop/tools/timeline.py) for details.
+```python
+python Paddle/tools/timeline.py --profile_path=/tmp/profile --timeline_path=timeline
+```
+
1. Open chrome and visit , use `load` button to load the generated `timeline` file.
![chrome tracing](./tracing.jpeg)
diff --git a/doc/survey/op_fusion_design.md b/doc/survey/op_fusion_design.md
new file mode 100644
index 0000000000000000000000000000000000000000..d6e48f4f58269b67450cb012f6dcc59e1083abba
--- /dev/null
+++ b/doc/survey/op_fusion_design.md
@@ -0,0 +1,20 @@
+# Operator fusion
+Fusing multiple operators together is an important method to optimize the program execution, particularly for GPU or other specialized accelerators. An obvious benefit is to avoid the overhead of saving the intermediate result back into global memory.
+
+There are generally two ways to fuse operators, fusing directly connected operators and fusing non directly connected operators. The first method is mainly used by [NNVM Compiler](https://github.com/dmlc/tvm/) and [XLA](https://www.tensorflow.org/performance/xla/). The second method is mainly used by Dynet and TensorFlow Fold to do auto-batching. The principle of fusing operator is according to some rules to combine multiple operations into one, for example, `Y = X * W` and `Z = Y + B` can be fused to `Z = X * W + B`, and `Y1 = X1 * W` and `Y2 = X2 * W` can be fused to `[Y1;Y2] = [X1;X2] * W`. In order to get a short-term profit, we decided to try to manually specify these rules.
+
+## Challenge
+The challenge of fusing operators is:
+ - how to make the rules.
+ - how to implement these rules efficiently.
+
+### How to make the rules?
+
+The problem of determining the best single location for a fusion operator is an NP-hard combinatorial problem. After analysis the operators of the DL model, we found there are two group of operators can be fused explicitly, one is the simple and adjacent operations, for example, `tmp = x + y` and `z = Relu(tmp)`, and the other is the operators that have the same function, for example, a serials of `SGD` or `Momentum`. They usually appear in the model in a large number. So we should think about how to fuse them separately first.
+
+### How to implement these rules efficiently?
+#### How to fuse the adjacent operations efficiently?
+Here we use a template function to represent the fused operations. The pros of using a template function are that it is simple and efficient, and the cons are that it is not easy to expand, and it can only be used to express some simple operations. So taking into account our current needs, the template function is more appropriate.
+
+#### How to fuse the operators that have the same function efficiently?
+We take SGD operator as an example, the training model may have hundreds of parameters and correspondingly have the same number of SGD operators. The expression(`w = w - lr*w_g`) of those operators is the same, so during of training, the executor will execute this expression hundreds time in CPU or other specialized accelerators. If we can fuse them and make the address of all `w` and all `w_g` continuous respectively, we only need execute one time. For some accelerators, the time of launching kernel is not neglected, so the time of hundreds of times of launching and executing kernel may be larger than launching and executing only once. There usually are many operators that similar to `SGD` in the DL model, such as `AllReduce` and `FC`.
diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec
index 3ef317bb7a1c25c5738342f34ae7994b0184a7de..dd172ff9c97814c089ddb2e5bf729880cf0c9cdb 100644
--- a/paddle/fluid/API.spec
+++ b/paddle/fluid/API.spec
@@ -336,6 +336,7 @@ paddle.fluid.contrib.BeamSearchDecoder.decode ArgSpec(args=['self'], varargs=Non
paddle.fluid.contrib.BeamSearchDecoder.early_stop ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
paddle.fluid.contrib.BeamSearchDecoder.read_array ArgSpec(args=['self', 'init', 'is_ids', 'is_scores'], varargs=None, keywords=None, defaults=(False, False))
paddle.fluid.contrib.BeamSearchDecoder.update_array ArgSpec(args=['self', 'array', 'value'], varargs=None, keywords=None, defaults=None)
+paddle.fluid.contrib.memory_usage ArgSpec(args=['program', 'batch_size'], varargs=None, keywords=None, defaults=None)
paddle.fluid.transpiler.DistributeTranspiler.__init__ ArgSpec(args=['self', 'config'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.transpiler.DistributeTranspiler.create_splited_vars ArgSpec(args=['self', 'source_var', 'block', 'tag'], varargs=None, keywords=None, defaults=None)
paddle.fluid.transpiler.DistributeTranspiler.get_pserver_program ArgSpec(args=['self', 'endpoint'], varargs=None, keywords=None, defaults=None)
diff --git a/paddle/fluid/CMakeLists.txt b/paddle/fluid/CMakeLists.txt
index d274d96c29bdbf5973d568d783369c3975bdc436..2577e59d9cf24c26b7c04aa00cdde6cde17f7206 100644
--- a/paddle/fluid/CMakeLists.txt
+++ b/paddle/fluid/CMakeLists.txt
@@ -5,5 +5,7 @@ add_subdirectory(operators)
add_subdirectory(pybind)
add_subdirectory(string)
add_subdirectory(recordio)
-# NOTE: please add subdirectory inference at last.
-add_subdirectory(inference)
+if(WITH_INFERENCE)
+ # NOTE: please add subdirectory inference at last.
+ add_subdirectory(inference)
+endif()
diff --git a/paddle/fluid/framework/details/all_reduce_op_handle.cc b/paddle/fluid/framework/details/all_reduce_op_handle.cc
index 700c73c745bad72637d77385f5cd38c494501c86..bf493a3fa44e48deec734250d04b2a413c3ed9da 100644
--- a/paddle/fluid/framework/details/all_reduce_op_handle.cc
+++ b/paddle/fluid/framework/details/all_reduce_op_handle.cc
@@ -17,6 +17,7 @@
#include "paddle/fluid/framework/details/container_cast.h"
#include "paddle/fluid/framework/details/reduce_and_gather.h"
#include "paddle/fluid/framework/details/variable_visitor.h"
+#include "paddle/fluid/platform/profiler.h"
namespace paddle {
namespace framework {
@@ -45,6 +46,7 @@ AllReduceOpHandle::AllReduceOpHandle(ir::Node *node,
#endif
void AllReduceOpHandle::RunImpl() {
+ platform::RecordEvent r("all_reduce", nullptr);
if (NoDummyInputSize() == 1) {
return; // No need to all reduce when GPU count = 1;
} else {
diff --git a/paddle/fluid/framework/details/build_strategy.h b/paddle/fluid/framework/details/build_strategy.h
index b2e5399e2376a86c1cd310b29c768832665af87f..8714a42162bda3d5ad12e7925fe8cc4e693f51b1 100644
--- a/paddle/fluid/framework/details/build_strategy.h
+++ b/paddle/fluid/framework/details/build_strategy.h
@@ -21,6 +21,26 @@ namespace framework {
namespace details {
struct BuildStrategy {
+ // ParallelExecutor supports two modes of ReduceStrategy, kAllReduce and
+ // kReduce, for CPU and GPU. If you use kAllReduce, different threads
+ // optimize their parameters separately. If you use kReduce, the optimizations
+ // of parameters are distributed to different threads.
+ // For example, a model has 100 parameters and is running with four threads,
+ // if you choose kAllReduce, every thread is to optimize 100 parameters
+ // separately, if you choose kReduce, every thread is to optimize 25
+ // parameters.
+ // Of particular note is, if you use kReduce when using CPU training,
+ // all the parameters are shared between different threads. This feature will
+ // save memory.
+ // FIXME(zcd): The result of the two modes(kAllReduce and kReduce) maybe not
+ // equal for GPU. Because, the result of the different order of summing maybe
+ // different, for example, the result of `a+b+c+d` may be different with the
+ // result of `c+a+b+d`.
+ // For GPU, the implementation of kAllReduce and kReduce is adopted NCCL,
+ // so the result of kAllReduce and kReduce maybe not equal.
+ // For CPU, if you want to fix the order of summing to make the result
+ // of kAllReduce and kReduce no diff, you can add
+ // `FLAGS_cpu_deterministic=true` to env.
enum class ReduceStrategy { kAllReduce = 0, kReduce = 1 };
enum class GradientScaleStrategy {
diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc
index 5ca2ed8f96244a11925dfa6af8e48458cf334ecd..a4fdbcb26d1d0cfb05edebff5419d9559c336b3a 100644
--- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc
+++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc
@@ -275,7 +275,8 @@ std::unique_ptr MultiDevSSAGraphBuilder::ApplyImpl(
if (strategy_.gradient_scale_ !=
BuildStrategy::GradientScaleStrategy::kCustomized) {
// TODO(paddle-dev): Why is there no input for this op_handle?
- CreateScaleLossGradOp(&result);
+ auto loss_grad_name = node->Op()->OutputArgumentNames()[0];
+ CreateScaleLossGradOp(&result, loss_grad_name);
}
// This assumes the backward generating code will ensure IsScaleLossOp
// is true only for the op that scale the final scalar loss.
@@ -535,7 +536,8 @@ int MultiDevSSAGraphBuilder::GetVarDeviceID(const ir::Graph &graph,
return got == sharded_var_device.end() ? -1 : got->second;
}
-void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(ir::Graph *result) const {
+void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(
+ ir::Graph *result, const std::string &loss_grad_name) const {
for (size_t i = 0; i < places_.size(); ++i) {
// Insert ScaleCost OpHandle
#ifdef PADDLE_WITH_CUDA
@@ -558,10 +560,10 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(ir::Graph *result) const {
// loss->pending_ops_.emplace_back(op_handle);
// op_handle->inputs_.emplace_back(loss);
- CreateOpOutput(result, op_handle,
- result->CreateEmptyNode(GradVarName(loss_var_name_),
- ir::Node::Type::kVariable),
- places_[i], i);
+ CreateOpOutput(
+ result, op_handle,
+ result->CreateEmptyNode(loss_grad_name, ir::Node::Type::kVariable),
+ places_[i], i);
}
}
diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.h b/paddle/fluid/framework/details/multi_devices_graph_builder.h
index 099dbe5abef6458c4613c9f680440734f59cb6e2..f2cb6bb1c861e07f1034f1742ad4f3cfbb0d8837 100644
--- a/paddle/fluid/framework/details/multi_devices_graph_builder.h
+++ b/paddle/fluid/framework/details/multi_devices_graph_builder.h
@@ -75,7 +75,9 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
void CreateComputationalOps(ir::Graph *result, ir::Node *node,
size_t num_places) const;
- void CreateScaleLossGradOp(ir::Graph *result) const;
+ void CreateScaleLossGradOp(ir::Graph *result,
+ const std::string &loss_grad_name) const;
+
VarHandle *CreateReduceOp(ir::Graph *result, const std::string &og,
int dst_dev_id) const;
void CreateComputationalOp(ir::Graph *result, ir::Node *node,
diff --git a/paddle/fluid/framework/details/reduce_op_handle.cc b/paddle/fluid/framework/details/reduce_op_handle.cc
index 7160e346dad0615e2fd32b70c096880af0359e1a..6c7e5c1fb06620b1c071b00fcfcc1b4a29bf8d62 100644
--- a/paddle/fluid/framework/details/reduce_op_handle.cc
+++ b/paddle/fluid/framework/details/reduce_op_handle.cc
@@ -16,12 +16,18 @@
#include "paddle/fluid/framework/details/container_cast.h"
#include "paddle/fluid/framework/details/reduce_and_gather.h"
#include "paddle/fluid/framework/details/variable_visitor.h"
+#include "paddle/fluid/platform/profiler.h"
+
+DEFINE_bool(
+ cpu_deterministic, false,
+ "Whether to make the result of computation deterministic in CPU side.");
namespace paddle {
namespace framework {
namespace details {
void ReduceOpHandle::RunImpl() {
+ platform::RecordEvent r("reduce", nullptr);
if (places_.size() == 1) return;
// the input and output may have dummy var.
auto in_var_handles = DynamicCast(inputs_);
@@ -89,11 +95,33 @@ void ReduceOpHandle::RunImpl() {
} else {
std::vector lod_tensors =
GetInputValues(in_var_handles, var_scopes);
+
if (paddle::platform::is_cpu_place(lod_tensors[0]->place())) {
this->RunAndRecordEvent([&] {
- ReduceLoDTensor func(lod_tensors,
- out_var->GetMutable());
- VisitDataType(ToDataType(lod_tensors[0]->type()), func);
+ // FIXME(zcd): The order of summing is important,
+ // especially when the type of data is float or double.
+ // For example, the result of `a+b+c+d` may be different
+ // with the result of `c+a+b+d`, so the summing order should be fixed.
+ if (!FLAGS_cpu_deterministic) {
+ ReduceLoDTensor func(lod_tensors,
+ out_var->GetMutable());
+ VisitDataType(ToDataType(lod_tensors[0]->type()), func);
+ } else {
+ // We sum lod_tensors to reduce_sum_trg which is in local_scopes_0
+ // here, but it doesn't mean reduce_sum_trg must be in local_scopes_0.
+ auto &reduce_sum_trg = *this->local_scopes_[0]
+ ->FindVar(kLocalExecScopeName)
+ ->Get()
+ ->FindVar(out_var_handle->name_)
+ ->GetMutable();
+ ReduceLoDTensor func(lod_tensors, &reduce_sum_trg);
+ VisitDataType(ToDataType(lod_tensors[0]->type()), func);
+
+ auto trg = out_var->GetMutable();
+ if (reduce_sum_trg.data() != trg->data()) {
+ TensorCopy(reduce_sum_trg, platform::CPUPlace(), trg);
+ }
+ }
});
} else if (paddle::platform::is_gpu_place(lod_tensors[0]->place())) {
#ifdef PADDLE_WITH_CUDA
diff --git a/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc b/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc
index 1d80bab90f513139f807b57258177c6b2ac53ac0..5bd974d6b789a2f085c0a69de5e133187342f587 100644
--- a/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc
+++ b/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc
@@ -17,6 +17,7 @@
#include
#include
#include "paddle/fluid/framework/executor.h"
+#include "paddle/fluid/platform/profiler.h"
namespace paddle {
namespace framework {
@@ -62,6 +63,7 @@ FeedFetchList ScopeBufferedSSAGraphExecutor::Run(
eptr = std::current_exception();
}
+ platform::RecordEvent e("ScopeBufferedSSAGraphExecutorAfterRun", nullptr);
drop_scope_counter_ += 1;
if (!fetch_tensors.empty() ||
drop_scope_counter_ == strategy_.num_iteration_per_drop_scope_) {
diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc
index e556c84b0219eba2b92c456c205e03947171626b..0eaf9a9c951991a5775604eb8d0e7535f81a4ae2 100644
--- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc
+++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc
@@ -15,6 +15,7 @@
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/ssa_graph_builder.h"
+#include "paddle/fluid/platform/profiler.h"
namespace paddle {
namespace framework {
@@ -34,6 +35,8 @@ ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor(
FeedFetchList ThreadedSSAGraphExecutor::Run(
const std::vector &fetch_tensors) {
+ std::unique_ptr event(
+ new platform::RecordEvent("ThreadedSSAGraphExecutorPrepare", nullptr));
std::unordered_map pending_ops;
std::unordered_set pending_vars;
BlockingQueue ready_vars;
@@ -84,6 +87,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
// Clean run context
run_op_futures_.clear();
exception_holder_.Clear();
+ event.reset(nullptr);
// Step 3. Execution
while (!pending_vars.empty()) {
diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc
index c2800c972a5501859672fbfd6921499e84d09cb0..dad170ed78c64202b5c812bd8682887fe3b736d6 100644
--- a/paddle/fluid/framework/executor.cc
+++ b/paddle/fluid/framework/executor.cc
@@ -330,12 +330,7 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
}
for (auto& op : ctx->ops_) {
- VLOG(4) << place_ << " " << op->DebugStringEx(local_scope);
op->Run(*local_scope, place_);
- // NOTE! Please do not delete this line, it's usefull because the debug
- // string before and after op.run are different, after run the output
- // will have right shape which is usefull for debug.
- VLOG(3) << place_ << " " << op->DebugStringEx(local_scope);
if (FLAGS_benchmark) {
VLOG(2) << "Memory used after operator " + op->Type() + " running: "
diff --git a/paddle/fluid/framework/ir/graph.cc b/paddle/fluid/framework/ir/graph.cc
index f870fb2b9cf805aba84d6f4573b0574ff361e71c..f87d5212c0cd87a5a63cf2d54ca677516ab45816 100644
--- a/paddle/fluid/framework/ir/graph.cc
+++ b/paddle/fluid/framework/ir/graph.cc
@@ -182,9 +182,11 @@ Graph::Graph(const ProgramDesc &program) : program_(program) {
}
/**
- * We only handle write after read(WAR), since it should not have a write
- * after write in program. If there are write after write operators, we need
- * prune them.
+ * We should handle write after read(WAR) and write after write(WAW) here.
+ * Because some of the operators of the program can be executed parallelly.
+ * So, to make the program running in the right order, we should add the
+ * dependence of WAR and WAW.
+ *
*
* https://en.wikipedia.org/wiki/Hazard_(computer_architecture)#Write_after_read_(WAR)
*/
@@ -201,6 +203,19 @@ Graph::Graph(const ProgramDesc &program) : program_(program) {
(*it_new)->inputs.empty() ? nullptr : (*it_new)->inputs[0];
const auto &read_ops = (*it_old)->outputs;
+ PADDLE_ENFORCE(write_op, "The write_op should not be empty.");
+
+ // Add write after write dependence
+ ir::Node *upstream_op =
+ (*it_old)->inputs.empty() ? nullptr : (*it_old)->inputs[0];
+ if (upstream_op) {
+ ir::Node *dep_var = CreateControlDepVar();
+ write_op->inputs.push_back(dep_var);
+ upstream_op->outputs.push_back(dep_var);
+ dep_var->outputs.push_back(write_op);
+ dep_var->inputs.push_back(upstream_op);
+ }
+
for (auto *read_op : read_ops) {
// Manually add a dependency var from read_op to write_op;
if (read_op == write_op) {
diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc
index cdac00739bc48648b41751e644a953d0d310ffbf..d04f7744961b2561977f4d36d0073a97557043da 100644
--- a/paddle/fluid/framework/operator.cc
+++ b/paddle/fluid/framework/operator.cc
@@ -127,7 +127,7 @@ static LoD GetLoD(const Scope& scope, const std::string& name) {
}
void OperatorBase::Run(const Scope& scope, const platform::Place& place) {
- VLOG(10) << "- " << DebugStringEx(&scope);
+ VLOG(4) << place << " " << DebugStringEx(&scope);
if (platform::is_gpu_place(place)) {
#ifndef PADDLE_WITH_CUDA
PADDLE_THROW("Cannot run operator on place %s", place);
@@ -136,8 +136,10 @@ void OperatorBase::Run(const Scope& scope, const platform::Place& place) {
platform::SetDeviceId(dev_id);
#endif
}
+ platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
+ platform::RecordEvent record_event(Type(), pool.Get(place));
RunImpl(scope, place);
- VLOG(10) << "+ " << DebugStringEx(&scope);
+ VLOG(3) << place << " " << DebugStringEx(&scope);
}
bool OperatorBase::HasInputs(const std::string& name) const {
@@ -639,9 +641,6 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(place);
- // For profiling, don't move out of this function because that will result
- // in the failure of multi-GPU profiling.
- platform::RecordEvent record_event(Type(), dev_ctx);
// check if op[type] has kernel registered.
auto& all_op_kernels = AllOpKernels();
auto kernels_iter = all_op_kernels.find(type_);
@@ -779,6 +778,7 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType(
const ExecutionContext& ctx) const {
auto& scope = ctx.scope();
int data_type = -1;
+ std::string last_input_name;
for (auto& input : this->inputs_) {
for (auto& ipt_name : input.second) {
auto* var = scope.FindVar(ipt_name);
@@ -795,9 +795,10 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType(
int tmp = static_cast(ToDataType(t->type()));
PADDLE_ENFORCE(
tmp == data_type || data_type == -1,
- "DataType of Paddle Op %s must be the same. Get %d != %d", Type(),
- data_type, tmp);
+ "DataType of Paddle Op %s must be the same. Get %s(%d) != %s(%d)",
+ Type(), last_input_name, data_type, ipt_name, tmp);
data_type = tmp;
+ last_input_name = ipt_name;
}
}
}
diff --git a/paddle/fluid/inference/analysis/analyzer.cc b/paddle/fluid/inference/analysis/analyzer.cc
index 98bdfcc00b9f0e8f40dfc92e4021b2bd6fb19313..c4ab26a2288bb9d8f3cd54a797d2062e0606b219 100644
--- a/paddle/fluid/inference/analysis/analyzer.cc
+++ b/paddle/fluid/inference/analysis/analyzer.cc
@@ -24,7 +24,7 @@
namespace paddle {
-DEFINE_bool(inference_analysis_enable_tensorrt_subgraph_engine, false,
+DEFINE_bool(inference_analysis_enable_tensorrt_subgraph_engine, true,
"Enable subgraph to TensorRT engine for acceleration");
DEFINE_string(inference_analysis_graphviz_log_root, "./",
@@ -42,10 +42,19 @@ class DfgPassManagerImpl final : public DfgPassManager {
// TODO(Superjomn) set the key with pass reprs.
AddPass("fluid-to-data-flow-graph", new FluidToDataFlowGraphPass);
if (FLAGS_inference_analysis_enable_tensorrt_subgraph_engine) {
- auto trt_teller = [](const Node* node) {
+ auto trt_teller = [&](const Node* node) {
+ std::unordered_set teller_set(
+ {"elementwise_add", "mul", "conv2d", "pool2d", "relu"});
if (!node->IsFunction()) return false;
- return static_cast(node)->func_type() == "mul";
+
+ const auto* func = static_cast(node);
+ if (teller_set.count(func->func_type()))
+ return true;
+ else {
+ return false;
+ }
};
+
AddPass("tensorrt-subgraph-marker",
new TensorRTSubgraphNodeMarkPass(trt_teller));
AddPass("tensorrt-subgraph", new TensorRTSubGraphPass(trt_teller));
diff --git a/paddle/fluid/inference/analysis/data_flow_graph.cc b/paddle/fluid/inference/analysis/data_flow_graph.cc
index 8a3af0a8ebd5bad7be7046fa399cca4920da3d71..7f64bc75ae8ad40a268739cdc36051e76af9f49a 100644
--- a/paddle/fluid/inference/analysis/data_flow_graph.cc
+++ b/paddle/fluid/inference/analysis/data_flow_graph.cc
@@ -337,6 +337,34 @@ ExtractInputAndOutputOfSubGraph(std::vector &graph) { // NOLINT
std::vector(outputs.begin(), outputs.end()));
}
+void FilterRedundantOutputOfSubGraph(DataFlowGraph *graph) {
+ std::vector op_nodes;
+ for (auto &node : GraphTraits(graph).nodes_in_TS()) {
+ if (node.type() == Node::Type::kValue || node.deleted()) {
+ continue;
+ }
+ op_nodes.push_back(&node);
+ }
+ size_t op_num = op_nodes.size();
+ for (size_t i = 0; i < op_num; i++) {
+ if (op_nodes[i]->type() == Node::Type::kFunction) continue;
+ std::unordered_set follow_up_input_names;
+ for (size_t j = i + 1; j < op_num; j++) {
+ for (auto *in : op_nodes[j]->inlinks) {
+ follow_up_input_names.insert(in->name());
+ }
+ }
+ std::vector filtered_subgraph_outlinks;
+ for (auto *out : op_nodes[i]->outlinks) {
+ if (follow_up_input_names.count(out->name())) {
+ filtered_subgraph_outlinks.push_back(out);
+ }
+ }
+ PADDLE_ENFORCE_GE(filtered_subgraph_outlinks.size(), 1UL);
+ op_nodes[i]->outlinks = filtered_subgraph_outlinks;
+ }
+}
+
} // namespace analysis
} // namespace inference
} // namespace paddle
diff --git a/paddle/fluid/inference/analysis/data_flow_graph.h b/paddle/fluid/inference/analysis/data_flow_graph.h
index 16aeae4d35e7bd54646053190da7f47eaca69aa0..bb3ec6bbc1d9555386aba8837b019d2511653258 100644
--- a/paddle/fluid/inference/analysis/data_flow_graph.h
+++ b/paddle/fluid/inference/analysis/data_flow_graph.h
@@ -178,6 +178,7 @@ struct GraphTraits {
std::pair, std::vector>
ExtractInputAndOutputOfSubGraph(std::vector &graph); // NOLINT
+void FilterRedundantOutputOfSubGraph(DataFlowGraph *graph);
} // namespace analysis
} // namespace inference
} // namespace paddle
diff --git a/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc b/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc
index 2328d870422c5a31c22d7b09980aae35e01b2b25..18c32fa09199003f17183207828cdfe4e627ae1a 100644
--- a/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc
+++ b/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc
@@ -23,7 +23,7 @@
namespace paddle {
namespace inference {
-DEFINE_int32(tensorrt_max_batchsize, 300, "TensorRT maximum batch size");
+DEFINE_int32(tensorrt_max_batchsize, 3, "TensorRT maximum batch size");
DEFINE_int32(tensorrt_workspace_size, 2048, "TensorRT workspace size");
namespace analysis {
@@ -52,6 +52,7 @@ bool DataFlowGraphToFluidPass::Initialize(Argument *argument) {
bool DataFlowGraphToFluidPass::Finalize() { return true; }
void DataFlowGraphToFluidPass::Run(DataFlowGraph *graph) {
+ FilterRedundantOutputOfSubGraph(graph);
LOG(INFO) << "graph.inputs " << graph->inputs.size();
for (auto &node : GraphTraits(graph).nodes_in_TS()) {
if (node.deleted()) continue;
@@ -87,34 +88,113 @@ void DataFlowGraphToFluidPass::AddFluidOp(Node *node) {
}
void CreateTrtEngineOp(Node *node, const DataFlowGraph &graph,
- const framework::proto::BlockDesc &block) {
+ framework::proto::BlockDesc *block) {
static int counter{0};
PADDLE_ENFORCE(node->IsFunctionBlock());
framework::OpDesc desc;
auto *func = static_cast(node);
// collect inputs
- std::vector io;
+ std::unordered_set input_names;
for (auto *x : func->inlinks) {
- io.push_back(x->name());
+ input_names.insert(x->name());
}
- desc.SetInput("Xs", io);
+ desc.SetInput(
+ "Xs", std::vector(input_names.begin(), input_names.end()));
- // collect outputs
- io.clear();
+ std::unordered_set output_names;
for (auto *x : func->outlinks) {
- io.push_back(x->name());
+ output_names.insert(x->name());
}
- desc.SetOutput("Ys", io);
+
+ std::vector output_temp(output_names.begin(),
+ output_names.end());
+ desc.SetOutput("Ys", output_temp);
desc.SetType("tensorrt_engine");
- PADDLE_ENFORCE(!block.vars().empty(), "the block has no var-desc");
+ std::unordered_map output_name_map;
+
+ // The following procedure is used to rename all the intermediate
+ // variables and the output variables of the subgraph.
+ // Why we do this?
+ // During the transition from fluid OP to tensorrt OP, we map
+ // the input and output Tensor(fluid data structure) of fluid OP
+ // to the correspondin ITensor (trt data structure) through the
+ // Tensor name. When we set up ITensor for an variable, we must
+ // ensure that it has not been set before.
+ // If there is variable in the fluid graph, which is not only the
+ // input of a OP, but also the output of a Op, there will be problems.
+ // So we have to rename the variable in the subgraph to make sure
+ // it is either an OP's input or an OP's output.
+
+ auto subgraph_nodes = func->subgraph;
+ for (int index = 0; index < block->ops_size(); index++) {
+ framework::proto::OpDesc *op = block->mutable_ops(index);
+ auto correspond_node = subgraph_nodes[index];
+ PADDLE_ENFORCE_EQ(correspond_node->name(), op->type());
+
+ std::unordered_map var2id;
+ for (auto *in_var : correspond_node->inlinks) {
+ var2id[in_var->name()] = in_var->id();
+ }
+ // rename for the input variables of op inside subgraph
+ for (int i = 0; i < op->inputs_size(); i++) {
+ framework::proto::OpDesc_Var *in_var = op->mutable_inputs(i);
+ std::vector replaced_names;
+ for (int k = 0; k < in_var->arguments_size(); k++) {
+ std::string arg_value = in_var->arguments(k);
+ if (input_names.count(arg_value)) {
+ replaced_names.push_back(arg_value);
+ } else {
+ replaced_names.push_back(arg_value +
+ std::to_string(var2id[arg_value]));
+ }
+ }
+ in_var->clear_arguments();
+ for (size_t k = 0; k < replaced_names.size(); k++) {
+ in_var->add_arguments(replaced_names[k]);
+ }
+ }
+ var2id.clear();
+ for (auto out_var : correspond_node->outlinks) {
+ var2id[out_var->name()] = out_var->id();
+ }
+
+ // rename for the output variables of op inside subgraph
+ for (int i = 0; i < op->outputs_size(); i++) {
+ framework::proto::OpDesc_Var *out_var = op->mutable_outputs(i);
+ std::vector replaced_names;
+ for (int k = 0; k < out_var->arguments_size(); k++) {
+ std::string arg_value = out_var->arguments(k);
+ if (output_names.count(arg_value)) {
+ output_name_map[arg_value] =
+ arg_value + std::to_string(var2id[arg_value]);
+ }
+ replaced_names.push_back(arg_value + std::to_string(var2id[arg_value]));
+ }
+ out_var->clear_arguments();
+ for (size_t k = 0; k < replaced_names.size(); k++) {
+ out_var->add_arguments(replaced_names[k]);
+ }
+ }
+ }
+ // When tensorrt engine runs at the end of the operation,
+ // output_mapping help us copy the data from the renamed ITensor
+ // to Tensor.
+ std::vector output_mapping;
+ for (auto name : output_names) {
+ PADDLE_ENFORCE(output_name_map.count(name) != 0);
+ output_mapping.push_back(output_name_map[name]);
+ }
+
+ PADDLE_ENFORCE(!block->vars().empty(), "the block has no var-desc");
// Set attrs
- SetAttr(desc.Proto(), "subgraph", block.SerializeAsString());
+ SetAttr(desc.Proto(), "subgraph", block->SerializeAsString());
SetAttr(desc.Proto(), "engine_uniq_key", "trt-" + std::to_string(counter++));
SetAttr(desc.Proto(), "max_batch", FLAGS_tensorrt_max_batchsize);
SetAttr(desc.Proto(), "max_workspace", FLAGS_tensorrt_workspace_size);
SetAttr(desc.Proto(), "parameters", ExtractParameters(graph.nodes.nodes()));
+ SetAttr(desc.Proto(), "output_name_mapping", output_mapping);
node->SetPbMsg(desc.Proto()->SerializeAsString());
}
@@ -146,15 +226,17 @@ void DataFlowGraphToFluidPass::AddEngineOp(Node *node) {
LOG(INFO) << "transformed variable size: "
<< block_desc.Proto()->vars().size();
// copy ops.
+
for (auto *node : block_node->subgraph) {
auto *op = block_desc.AppendOp();
PADDLE_ENFORCE(!node->pb_msg().empty());
op->Proto()->ParseFromString(node->pb_msg());
}
+
*block_desc.Proto()->mutable_vars() =
argument_->origin_program_desc->blocks(0).vars();
PADDLE_ENFORCE(!block_desc.Proto()->vars().empty());
- CreateTrtEngineOp(node, *argument_->main_dfg, *block_desc.Proto());
+ CreateTrtEngineOp(node, *argument_->main_dfg, block_desc.Proto());
auto *main_block = desc_->mutable_blocks(framework::kRootBlockIndex);
auto *op = main_block->add_ops();
PADDLE_ENFORCE(!node->pb_msg().empty(), "failed to set desc for block");
diff --git a/paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.cc b/paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.cc
index a6f85484756417e103cbb60bcb664e8b800b9f28..c05b0e5d4690d0a447edf63a149903704bc2c9be 100644
--- a/paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.cc
+++ b/paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.cc
@@ -46,9 +46,9 @@ std::string DFG_GraphvizDrawPass::Draw(DataFlowGraph *graph) {
for (size_t i = 0; i < graph->nodes.size(); i++) {
const Node &node = graph->nodes.Get(i);
if (!config_.display_deleted_node && node.deleted()) continue;
- for (auto &in : node.inlinks) {
- if (!config_.display_deleted_node && in->deleted()) continue;
- dot.AddEdge(in->repr(), node.repr(), {});
+ for (auto &out : node.outlinks) {
+ if (!config_.display_deleted_node && out->deleted()) continue;
+ dot.AddEdge(node.repr(), out->repr(), {});
}
}
return dot.Build();
diff --git a/paddle/fluid/inference/analysis/subgraph_splitter.cc b/paddle/fluid/inference/analysis/subgraph_splitter.cc
index 389f9e1a9148a4daf0e5b751cce5cb6325252a4e..80809d4c43ca08298bad25cf614dcb4117d3f99a 100644
--- a/paddle/fluid/inference/analysis/subgraph_splitter.cc
+++ b/paddle/fluid/inference/analysis/subgraph_splitter.cc
@@ -76,7 +76,7 @@ void UnionFindCombine(const node_map_t &node_map, size_t a, size_t b) {
std::vector> SubGraphSplitter::ExtractSubGraphs() {
std::vector marked_nodes;
- for (auto &node : GraphTraits(graph_).nodes()) {
+ for (auto &node : GraphTraits(graph_).nodes_in_TS()) {
if (node.attr(kMarkerAttrName).Bool()) {
marked_nodes.push_back(&node);
}
diff --git a/paddle/fluid/inference/api/CMakeLists.txt b/paddle/fluid/inference/api/CMakeLists.txt
index 259d79bedbf664f52b1189ca71567665a6d91180..08d0f493ab30d92a121d089d9003bc575429b4dd 100644
--- a/paddle/fluid/inference/api/CMakeLists.txt
+++ b/paddle/fluid/inference/api/CMakeLists.txt
@@ -74,9 +74,10 @@ if (WITH_ANAKIN) # only needed in CI
target_link_libraries(inference_anakin_api anakin anakin_saber_common)
target_link_libraries(inference_anakin_api_shared anakin anakin_saber_common)
if (WITH_TESTING)
- cc_test(inference_anakin_test SRCS api_anakin_engine_tester.cc
- ARGS --model=${ANAKIN_INSTALL_DIR}/mobilenet_v2.anakin.bin
- DEPS inference_anakin_api_shared)
- target_compile_options(inference_anakin_test BEFORE PUBLIC ${ANAKIN_COMPILE_EXTRA_FLAGS})
+ # this test is unstable, disable it first.
+ #cc_test(inference_anakin_test SRCS api_anakin_engine_tester.cc
+ #ARGS --model=${ANAKIN_INSTALL_DIR}/mobilenet_v2.anakin.bin
+ #DEPS inference_anakin_api_shared)
+ #target_compile_options(inference_anakin_test BEFORE PUBLIC ${ANAKIN_COMPILE_EXTRA_FLAGS})
endif(WITH_TESTING)
endif()
diff --git a/paddle/fluid/inference/api/api.cc b/paddle/fluid/inference/api/api.cc
index e74f23ff969f5a8f58a71da337c16dcbc14f10c0..63c3f0d7b3f5c2b9246e2b041796caf5eb562826 100644
--- a/paddle/fluid/inference/api/api.cc
+++ b/paddle/fluid/inference/api/api.cc
@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
+#include
#include "paddle/fluid/inference/api/paddle_inference_api.h"
namespace paddle {
@@ -40,19 +41,36 @@ PaddleBuf::PaddleBuf(PaddleBuf&& other)
PaddleBuf::PaddleBuf(const PaddleBuf& other) { *this = other; }
PaddleBuf& PaddleBuf::operator=(const PaddleBuf& other) {
+ if (!other.memory_owned_) {
+ data_ = other.data_;
+ length_ = other.length_;
+ memory_owned_ = other.memory_owned_;
+ } else {
+ Resize(other.length());
+ memcpy(data_, other.data(), other.length());
+ length_ = other.length();
+ memory_owned_ = true;
+ }
+ return *this;
+}
+
+PaddleBuf& PaddleBuf::operator=(PaddleBuf&& other) {
// only the buffer with external memory can be copied
- assert(!other.memory_owned_);
data_ = other.data_;
length_ = other.length_;
memory_owned_ = other.memory_owned_;
+ other.data_ = nullptr;
+ other.length_ = 0;
+ other.memory_owned_ = false;
return *this;
}
void PaddleBuf::Resize(size_t length) {
// Only the owned memory can be reset, the external memory can't be changed.
if (length_ == length) return;
- assert(memory_owned_);
- Free();
+ if (memory_owned_) {
+ Free();
+ }
data_ = new char[length];
length_ = length;
memory_owned_ = true;
@@ -68,7 +86,7 @@ void PaddleBuf::Reset(void* data, size_t length) {
void PaddleBuf::Free() {
if (memory_owned_ && data_) {
assert(length_ > 0);
- delete static_cast(data_);
+ delete[] static_cast(data_);
data_ = nullptr;
length_ = 0;
}
diff --git a/paddle/fluid/inference/api/paddle_inference_api.h b/paddle/fluid/inference/api/paddle_inference_api.h
index 59b0df7968cce137843ba8cad38a62fdb8d3bfc1..b24414e8245b1a4d90acce4fa1ad5690e06b47dd 100644
--- a/paddle/fluid/inference/api/paddle_inference_api.h
+++ b/paddle/fluid/inference/api/paddle_inference_api.h
@@ -40,11 +40,12 @@ class PaddleBuf {
// Copy only available when memory is managed externally.
explicit PaddleBuf(const PaddleBuf&);
PaddleBuf& operator=(const PaddleBuf&);
+ PaddleBuf& operator=(PaddleBuf&&);
// Do not own the memory.
PaddleBuf(void* data, size_t length)
: data_(data), length_(length), memory_owned_{false} {}
// Own memory.
- explicit PaddleBuf(size_t length)
+ PaddleBuf(size_t length)
: data_(new char[length]), length_(length), memory_owned_(true) {}
// Resize to `length` bytes.
void Resize(size_t length);
diff --git a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
index d86c046f2e5b08a4c00cf6cad19627e6a196c798..8f42a37cd3f8978b917b42e8f45a128b8422aa57 100644
--- a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
+++ b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
@@ -1,6 +1,7 @@
# Add TRT tests
nv_library(tensorrt_converter
SRCS mul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc elementwise_op.cc
+activation_op.cc
DEPS tensorrt_engine operator scope framework_proto op_registry)
nv_test(test_op_converter SRCS test_op_converter.cc DEPS
diff --git a/paddle/fluid/inference/tensorrt/convert/op_converter.h b/paddle/fluid/inference/tensorrt/convert/op_converter.h
index 1b6a0ad82f3ceb00cec15c28c8121adc22271b7a..41faaf7212accaaec238062b1340e8da8fa6be33 100644
--- a/paddle/fluid/inference/tensorrt/convert/op_converter.h
+++ b/paddle/fluid/inference/tensorrt/convert/op_converter.h
@@ -55,7 +55,6 @@ class OpConverter {
it = Registry::Lookup("fc");
}
}
-
if (op_desc.Type().find("elementwise") != std::string::npos) {
static std::unordered_set add_tensor_op_set{
"add", "mul", "sub", "div", "max", "min", "pow"};
@@ -72,6 +71,8 @@ class OpConverter {
"Unsupported elementwise type" + op_type);
it =
Registry::Lookup("elementwise_" + op_type + "_weight");
+ PADDLE_ENFORCE_NOT_NULL(it, "no OpConverter for optype [%s]",
+ op_desc.Type());
} else {
PADDLE_ENFORCE(add_tensor_op_set.count(op_type) > 0,
"Unsupported elementwise type" + op_type);
diff --git a/paddle/fluid/operators/conv_mkldnn_op.cc b/paddle/fluid/operators/conv_mkldnn_op.cc
index 0511eb42a073ac305634110a71a35e501f062132..f07ab5a33b87d7945e5fcdf8f3644f0711ce643b 100644
--- a/paddle/fluid/operators/conv_mkldnn_op.cc
+++ b/paddle/fluid/operators/conv_mkldnn_op.cc
@@ -280,12 +280,16 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel {
* ('any') which lets a primitive (convolution in this case) choose
* the memory format preferred for best performance
*/
+ std::string data_format = ctx.Attr("data_format");
+ auto chosen_memory_format =
+ platform::data_format_to_memory_format(data_format);
+
auto src_md = platform::MKLDNNMemDesc(
- src_tz, platform::MKLDNNGetDataType(), memory::format::any);
+ src_tz, platform::MKLDNNGetDataType(), chosen_memory_format);
auto weights_md = platform::MKLDNNMemDesc(
- weights_tz, platform::MKLDNNGetDataType(), memory::format::any);
+ weights_tz, platform::MKLDNNGetDataType(), chosen_memory_format);
auto dst_md = platform::MKLDNNMemDesc(
- dst_tz, platform::MKLDNNGetDataType(), memory::format::any);
+ dst_tz, platform::MKLDNNGetDataType(), chosen_memory_format);
// create a conv primitive descriptor and save it for usage in backward
std::shared_ptr conv_pd =
@@ -423,16 +427,20 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel {
* ('any') which lets a primitive (conv backward in this case) choose
* the memory format preferred for best performance
*/
+ std::string data_format = ctx.Attr("data_format");
+ auto chosen_memory_format =
+ platform::data_format_to_memory_format(data_format);
+
auto src_md = platform::MKLDNNMemDesc(
- src_tz, platform::MKLDNNGetDataType(), memory::format::any);
+ src_tz, platform::MKLDNNGetDataType(), chosen_memory_format);
auto diff_src_md = platform::MKLDNNMemDesc(
- src_tz, platform::MKLDNNGetDataType(), memory::format::any);
+ src_tz, platform::MKLDNNGetDataType(), chosen_memory_format);
auto weights_md = platform::MKLDNNMemDesc(
- weights_tz, platform::MKLDNNGetDataType(), memory::format::any);
+ weights_tz, platform::MKLDNNGetDataType(), chosen_memory_format);
auto diff_weights_md = platform::MKLDNNMemDesc(
- weights_tz, platform::MKLDNNGetDataType(), memory::format::any);
+ weights_tz, platform::MKLDNNGetDataType(), chosen_memory_format);
auto diff_dst_md = platform::MKLDNNMemDesc(
- dst_tz, platform::MKLDNNGetDataType(), memory::format::any);
+ dst_tz, platform::MKLDNNGetDataType(), chosen_memory_format);
// Retrieve conv_pd from device context
auto conv_pd =
diff --git a/paddle/fluid/operators/elementwise_op_function.h b/paddle/fluid/operators/elementwise_op_function.h
index eb8272e90c32c3a0be2c0ce1bc679571af876317..bc3e95e904f8b6c2cdd2ae6685bf67580178e6b6 100644
--- a/paddle/fluid/operators/elementwise_op_function.h
+++ b/paddle/fluid/operators/elementwise_op_function.h
@@ -534,8 +534,8 @@ void ElemwiseGradCompute(const framework::ExecutionContext& ctx,
const framework::Tensor& dout, int axis,
framework::Tensor* dx, framework::Tensor* dy,
DX_OP dx_op, DY_OP dy_op) {
- const framework::DDim x_dim = x.dims();
- const framework::DDim y_dim = y.dims();
+ const framework::DDim& x_dim = x.dims();
+ const framework::DDim& y_dim = y.dims();
if (x.dims() == y.dims()) {
ElemwiseGradComputeNoBroadcast(
ctx, x_dim, y_dim, x, y, out, dout, axis, dx, dy, dx_op, dy_op);
@@ -558,19 +558,19 @@ void ElemwiseExplicitGradCompute(const framework::ExecutionContext& ctx,
framework::Tensor* dx, framework::Tensor* dy,
DX_OP dx_op, DY_OP dy_op) {
if (dy == nullptr) {
- const framework::DDim dx_dims = dout.dims();
+ const framework::DDim& dx_dims = dout.dims();
auto dy_dims = dx_dims;
ElemwiseGradComputeNoBroadcast(
ctx, dx_dims, dy_dims, x, y, out, dout, axis, dx, dy, dx_op, dy_op);
} else {
if (dout.dims() == dy->dims()) {
- const framework::DDim dx_dims = dout.dims();
- const framework::DDim dy_dims = dy->dims();
+ const framework::DDim& dx_dims = dout.dims();
+ const framework::DDim& dy_dims = dy->dims();
ElemwiseGradComputeNoBroadcast(
ctx, dx_dims, dy_dims, x, y, out, dout, axis, dx, dy, dx_op, dy_op);
} else { // Y is a scalar
auto dx_dims = dout.dims();
- const framework::DDim dy_dims = dy->dims();
+ const framework::DDim& dy_dims = dy->dims();
ElemwiseGradComputeWithBroadcast(
ctx, dx_dims, dy_dims, x, y, out, dout, axis, dx, dy, dx_op, dy_op);
}
diff --git a/paddle/fluid/operators/feed_op.cc b/paddle/fluid/operators/feed_op.cc
index bcb3e63ed7dbc775c1de6c4522f0548ea48a6cf0..dc7ef664958238ddbd48745bd59cc7db28e49f5b 100644
--- a/paddle/fluid/operators/feed_op.cc
+++ b/paddle/fluid/operators/feed_op.cc
@@ -31,7 +31,6 @@ class FeedOp : public framework::OperatorBase {
const platform::Place &place) const override {
// get device context from pool
auto *dev_ctx = platform::DeviceContextPool::Instance().Get(place);
- platform::RecordEvent record_event(Type(), dev_ctx);
auto feed_var_name = Input("X");
auto *feed_var = scope.FindVar(feed_var_name);
diff --git a/paddle/fluid/operators/fetch_barrier_op.cc b/paddle/fluid/operators/fetch_barrier_op.cc
index 680fde19eefe57475b7526ebc29d4ff977a16977..d9cd956dfdff3d009d38ee5088f5396080580483 100644
--- a/paddle/fluid/operators/fetch_barrier_op.cc
+++ b/paddle/fluid/operators/fetch_barrier_op.cc
@@ -36,12 +36,6 @@ class FetchBarrierOp : public framework::OperatorBase {
void RunImpl(const framework::Scope& scope,
const platform::Place& place) const override {
std::vector eps = Attr>("endpoints");
-
- platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
- auto& ctx = *pool.Get(place);
- // For profiling
- platform::RecordEvent record_event(Type(), &ctx);
-
distributed::RPCClient* rpc_client =
distributed::RPCClient::GetInstance();
diff --git a/paddle/fluid/operators/fetch_op.cc b/paddle/fluid/operators/fetch_op.cc
index 1640a2a22c69a0e3ab81a2889d6105b2cf4162b7..c197b45e8196a47def6465128e8ca39d8daefed6 100644
--- a/paddle/fluid/operators/fetch_op.cc
+++ b/paddle/fluid/operators/fetch_op.cc
@@ -30,9 +30,6 @@ class FetchOp : public framework::OperatorBase {
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
- platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
- platform::RecordEvent record_event(Type(), pool.Get(place));
-
auto fetch_var_name = Input("X");
auto *fetch_var = scope.FindVar(fetch_var_name);
PADDLE_ENFORCE(fetch_var != nullptr,
diff --git a/paddle/fluid/operators/fused_elemwise_activation_op.cc b/paddle/fluid/operators/fused_elemwise_activation_op.cc
new file mode 100644
index 0000000000000000000000000000000000000000..a6fd0aeb021dce40339c32251af130d5984dccd2
--- /dev/null
+++ b/paddle/fluid/operators/fused_elemwise_activation_op.cc
@@ -0,0 +1,221 @@
+/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include
+#include
+
+#include "paddle/fluid/operators/fused_elemwise_activation_op.h"
+
+namespace paddle {
+namespace operators {
+
+class FusedElemwiseActivationOp : public framework::OperatorWithKernel {
+ public:
+ using framework::OperatorWithKernel::OperatorWithKernel;
+
+ void InferShape(framework::InferShapeContext *ctx) const override {
+ PADDLE_ENFORCE(
+ ctx->HasInput("X"),
+ "Input(X) of FusedElemwiseActivationOp op should not be null.");
+ PADDLE_ENFORCE(
+ ctx->HasInput("Y"),
+ "Input(Y) of FusedElemwiseActivationOp op should not be null.");
+ PADDLE_ENFORCE(
+ ctx->HasOutput("Out"),
+ "Output(Out) of FusedElemwiseActivationOp op should not be null.");
+
+ auto x_dim = ctx->GetInputDim("X");
+ auto y_dim = ctx->GetInputDim("Y");
+ PADDLE_ENFORCE_GE(x_dim.size(), y_dim.size(),
+ "Rank of first input must >= rank of second input.");
+
+ ctx->SetOutputDim("Out", x_dim);
+ ctx->ShareLoD("X", /*->*/ "Out");
+ }
+
+ protected:
+ framework::OpKernelType GetExpectedKernelType(
+ const framework::ExecutionContext &ctx) const override {
+ PADDLE_ENFORCE_EQ(ctx.Input("X")->type(),
+ ctx.Input("Y")->type(),
+ "The element's type of input should be the same.");
+ auto input_data_type =
+ framework::ToDataType(ctx.Input("X")->type());
+ return framework::OpKernelType(input_data_type, ctx.GetPlace());
+ }
+};
+
+class FusedElemwiseActivationMaker : public framework::OpProtoAndCheckerMaker {
+ public:
+ void Make() override {
+ AddInput("X", "(vector)");
+ AddInput("Y", "(vector)");
+ AddOutput("Out", "vector");
+ AddAttr("axis",
+ "axis is used by elementwise_op, the default value is -1.")
+ .SetDefault(-1);
+ AddAttr("scale",
+ "scale is used by scale_op, the default value is 0.0.")
+ .SetDefault(0.0);
+ AddAttr("recomputation",
+ "Whether to recompute the Out."
+ "fused_elemwise_activation_grad has two methods to get the "
+ "dx and dy, one "
+ "is to use the 'Out', and the other is not to use it. "
+ "The former method will save the time of recomputing the "
+ "'Out', but it must occupy the memory to store the 'out'. "
+ "While, the later method can avoid occupying the memory, "
+ "but it must recompute the 'Out'. The default value is true.")
+ .SetDefault(true);
+ AddAttr>("functor_list",
+ "The functors that should be fused.")
+ .AddCustomChecker([&](const std::vector &functor_list) {
+ PADDLE_ENFORCE(ValidCheck(functor_list));
+ });
+
+ AddComment(R"DOC(
+FusedElemwiseActivation Operator.
+
+At present, FusedElemwiseActivation only supports Two kinds of compound
+operators (elementwise_op and activation_op):
+
+ Z = Binary(X, Unary(Y))
+ Z = Unary(Binary(X, Y))
+
+The attributions of activation_op can be get from fused_elemwise_activation_op's
+attributions. functor_list records the functors to be fused, for example
+"scale,elementwise_add".
+
+)DOC");
+ }
+
+ private:
+ bool ValidCheck(const std::vector &functors) {
+ std::unordered_set unary_fun = {"scale", "relu"};
+ std::unordered_set binary_fun = {"elementwise_add"};
+
+ std::string unary_fun_str;
+ if (binary_fun.count(functors[0])) {
+ unary_fun_str = functors[1];
+ } else if (binary_fun.count(functors[1])) {
+ unary_fun_str = functors[0];
+ } else {
+ PADDLE_THROW("%s and %s are not included in fused_list.", functors[0],
+ functors[1]);
+ }
+ PADDLE_ENFORCE_EQ(unary_fun.count(unary_fun_str), 1,
+ "%s is not included in fused_list.", unary_fun_str);
+ return true;
+ }
+};
+
+class FusedElemwiseActivationGradMaker
+ : public framework::SingleGradOpDescMaker {
+ public:
+ using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
+
+ protected:
+ std::unique_ptr Apply() const override {
+ auto *op_desc_ptr = new framework::OpDesc();
+ op_desc_ptr->SetType(this->ForwardOpType() + "_grad");
+
+ for (auto &input_param : this->InputNames()) {
+ op_desc_ptr->SetInput(input_param, this->Input(input_param));
+ op_desc_ptr->SetOutput(framework::GradVarName(input_param),
+ this->InputGrad(input_param, true));
+ }
+
+ for (auto &output_param : this->OutputNames()) {
+ op_desc_ptr->SetInput(output_param, this->Output(output_param));
+ op_desc_ptr->SetInput(framework::GradVarName(output_param),
+ this->OutputGrad(output_param));
+ }
+ op_desc_ptr->SetAttrMap(this->Attrs());
+
+ std::vector functor_names =
+ boost::get>(
+ op_desc_ptr->GetAttr("functor_list"));
+ functor_names[0] += "_grad";
+ functor_names[1] += "_grad";
+ op_desc_ptr->SetAttr("functor_list", functor_names);
+ return std::unique_ptr(op_desc_ptr);
+ }
+};
+
+class FusedElemwiseActivationOpGrad : public framework::OperatorWithKernel {
+ public:
+ using framework::OperatorWithKernel::OperatorWithKernel;
+
+ void InferShape(framework::InferShapeContext *ctx) const override {
+ PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
+ PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null");
+ PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
+ "Input(Out@GRAD) should not be null");
+
+ auto x_dims = ctx->GetInputDim("X");
+ auto y_dims = ctx->GetInputDim("Y");
+ auto out_dims = ctx->GetInputDim(framework::GradVarName("Out"));
+
+ PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(),
+ "Rank of first input must >= rank of second input.");
+
+ auto x_grad_name = framework::GradVarName("X");
+ auto y_grad_name = framework::GradVarName("Y");
+ if (ctx->HasOutput(x_grad_name)) {
+ ctx->SetOutputDim(x_grad_name, x_dims);
+ }
+ if (ctx->HasOutput(y_grad_name)) {
+ ctx->SetOutputDim(y_grad_name, y_dims);
+ }
+ }
+
+ protected:
+ framework::OpKernelType GetExpectedKernelType(
+ const framework::ExecutionContext &ctx) const override {
+ auto input_data_type_index = ctx.Input("X")->type();
+ PADDLE_ENFORCE_EQ(input_data_type_index,
+ ctx.Input("Y")->type(),
+ "The element's type of input should be the same.");
+ PADDLE_ENFORCE_EQ(
+ input_data_type_index,
+ ctx.Input(framework::GradVarName("Out"))->type(),
+ "The element's type of input should be the same.");
+
+ auto input_data_type = framework::ToDataType(input_data_type_index);
+ return framework::OpKernelType(input_data_type, ctx.GetPlace());
+ }
+};
+} // namespace operators
+} // namespace paddle
+
+namespace ops = paddle::operators;
+REGISTER_OPERATOR(fused_elemwise_activation, ops::FusedElemwiseActivationOp,
+ ops::FusedElemwiseActivationMaker,
+ ops::FusedElemwiseActivationGradMaker);
+REGISTER_OPERATOR(fused_elemwise_activation_grad,
+ ops::FusedElemwiseActivationOpGrad);
+
+REGISTER_OP_CPU_KERNEL(
+ fused_elemwise_activation,
+ ops::FusedElemwiseActivationKernel,
+ ops::FusedElemwiseActivationKernel);
+
+REGISTER_OP_CPU_KERNEL(
+ fused_elemwise_activation_grad,
+ ops::FusedElemwiseActivationGradKernel,
+ ops::FusedElemwiseActivationGradKernel);
diff --git a/paddle/fluid/operators/fused_elemwise_activation_op.cu b/paddle/fluid/operators/fused_elemwise_activation_op.cu
new file mode 100644
index 0000000000000000000000000000000000000000..e1d2b16b4b5e3a480777f834c2cbeb6d00a755e4
--- /dev/null
+++ b/paddle/fluid/operators/fused_elemwise_activation_op.cu
@@ -0,0 +1,30 @@
+/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include "paddle/fluid/operators/fused_elemwise_activation_op.h"
+
+namespace ops = paddle::operators;
+REGISTER_OP_CUDA_KERNEL(
+ fused_elemwise_activation,
+ ops::FusedElemwiseActivationKernel,
+ ops::FusedElemwiseActivationKernel);
+
+REGISTER_OP_CUDA_KERNEL(
+ fused_elemwise_activation_grad,
+ ops::FusedElemwiseActivationGradKernel,
+ ops::FusedElemwiseActivationGradKernel);
diff --git a/paddle/fluid/operators/fused_elemwise_activation_op.h b/paddle/fluid/operators/fused_elemwise_activation_op.h
new file mode 100644
index 0000000000000000000000000000000000000000..fe0017b824532b1210d0ae3e51983d63d081f12a
--- /dev/null
+++ b/paddle/fluid/operators/fused_elemwise_activation_op.h
@@ -0,0 +1,425 @@
+/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#pragma once
+
+#include
+#include
+#include "paddle/fluid/framework/op_desc.h"
+#include "paddle/fluid/framework/op_registry.h"
+#include "paddle/fluid/operators/detail/safe_ref.h"
+#include "paddle/fluid/operators/elementwise_op_function.h"
+#include "paddle/fluid/operators/math/functors.h"
+
+namespace math = paddle::operators::math;
+
+namespace paddle {
+namespace operators {
+
+// CompoundFunctors
+// For example: Z = Binary(X, Unary(Y))
+template
+struct BinaryCompoundFunctor {
+ BinaryCompoundFunctor(const BinaryFun &binary_fun, const UnaryFun &unary_fun)
+ : binary_fun_(binary_fun), unary_fun_(unary_fun) {}
+
+ inline HOSTDEVICE T operator()(T x, T y) {
+ return binary_fun_(x, unary_fun_(y));
+ }
+
+ private:
+ BinaryFun binary_fun_;
+ UnaryFun unary_fun_;
+};
+
+// For example: Z = Unary(Binary(X, Y))
+template
+struct UnaryCompoundFunctor {
+ UnaryCompoundFunctor(const UnaryFun &unary_fun, const BinaryFun &binary_fun)
+ : unary_fun_(unary_fun), binary_fun_(binary_fun) {}
+
+ inline HOSTDEVICE T operator()(T x, T y) {
+ return unary_fun_(binary_fun_(x, y));
+ }
+
+ private:
+ UnaryFun unary_fun_;
+ BinaryFun binary_fun_;
+};
+
+// FIXME(zcd): DBinaryFun and DUnaryFun have to method to get
+// the dx, one is to use the 'out', and the other is not to use it.
+// the former method will save the time of recomputing the
+// 'out', but it must occupy the memory to store the 'out'.
+// While the later method can avoid occupying this memory,
+// but it must recompute the 'out'.
+
+template
+struct BinaryCompoundGradDxFunctor {
+ BinaryCompoundGradDxFunctor(const DBinaryFun &d_binary_fun,
+ const UnaryFun &unary_fun)
+ : d_binary_fun_(d_binary_fun), unary_fun_(unary_fun) {}
+
+ inline HOSTDEVICE T operator()(T x, T y, T out, T dout) {
+ if (Recomputation) {
+ return dout * d_binary_fun_(x, unary_fun_(y));
+ } else {
+ return dout * d_binary_fun_(x, unary_fun_(y), out);
+ }
+ }
+
+ private:
+ DBinaryFun d_binary_fun_;
+ UnaryFun unary_fun_;
+};
+
+template
+struct BinaryCompoundGradDyFunctor {
+ BinaryCompoundGradDyFunctor(const DBinaryFun &d_binary_fun,
+ const UnaryFun &unary_fun,
+ const DUnaryFun &d_unary_fun)
+ : d_binary_fun_(d_binary_fun),
+ unary_fun_(unary_fun),
+ d_unary_fun_(d_unary_fun) {}
+
+ inline HOSTDEVICE T operator()(T x, T y, T out, T dout) {
+ if (Recomputation) {
+ return dout * d_binary_fun_(unary_fun_(y), x) * d_unary_fun_(y);
+ } else {
+ return dout * d_binary_fun_(unary_fun_(y), x, out) * d_unary_fun_(y);
+ }
+ }
+
+ private:
+ DBinaryFun d_binary_fun_;
+ UnaryFun unary_fun_;
+ DUnaryFun d_unary_fun_;
+};
+
+template
+struct UnaryCompoundGradDxFunctor {
+ UnaryCompoundGradDxFunctor(const DUnaryFun &d_unary_fun,
+ const BinaryFun &binary_fun,
+ const DBinaryFun &d_binary_fun)
+ : d_unary_fun_(d_unary_fun),
+ binary_fun_(binary_fun),
+ d_binary_fun_(d_binary_fun) {}
+
+ inline HOSTDEVICE T operator()(T x, T y, T out, T dout) {
+ T base;
+ if (Recomputation) {
+ base = dout * d_unary_fun_(binary_fun_(x, y));
+ } else {
+ base = dout * d_unary_fun_(binary_fun_(x, y), out);
+ }
+ return base * d_binary_fun_(x, y);
+ }
+
+ private:
+ DUnaryFun d_unary_fun_;
+ BinaryFun binary_fun_;
+ DBinaryFun d_binary_fun_;
+};
+
+template
+struct UnaryCompoundGradDyFunctor {
+ UnaryCompoundGradDyFunctor(const DUnaryFun &d_unary_fun,
+ const BinaryFun &binary_fun,
+ const DBinaryFun &d_binary_fun)
+ : d_unary_fun_(d_unary_fun),
+ binary_fun_(binary_fun),
+ d_binary_fun_(d_binary_fun) {}
+
+ inline HOSTDEVICE T operator()(T x, T y, T out, T dout) {
+ T base;
+ if (Recomputation) {
+ base = dout * d_unary_fun_(binary_fun_(x, y));
+ } else {
+ base = dout * d_unary_fun_(binary_fun_(x, y), out);
+ }
+ return base * d_binary_fun_(y, x);
+ }
+
+ private:
+ DUnaryFun d_unary_fun_;
+ BinaryFun binary_fun_;
+ DBinaryFun d_binary_fun_;
+};
+
+template
+static void RunBinaryCompoundFunctor(const framework::ExecutionContext &ctx,
+ const BinaryFunctor &binary_functor,
+ const UnaryFunctor &unary_functor,
+ const framework::Tensor *in_x,
+ const framework::Tensor *in_y,
+ framework::Tensor *output) {
+ int axis = ctx.Attr("axis");
+ using BinaryCompoundFunctor =
+ BinaryCompoundFunctor;
+
+ ElementwiseComputeEx(
+ ctx, in_x, in_y, axis,
+ BinaryCompoundFunctor(binary_functor, unary_functor), output);
+}
+
+template
+static void RunUnaryCompoundFunctors(const framework::ExecutionContext &ctx,
+ const UnaryFunctor &unary_functor,
+ const BinaryFunctor &binary_functor,
+ const framework::Tensor *in_x,
+ const framework::Tensor *in_y,
+ framework::Tensor *output) {
+ int axis = ctx.Attr("axis");
+
+ using UnaryCompoundFunctor =
+ UnaryCompoundFunctor;
+
+ ElementwiseComputeEx(
+ ctx, in_x, in_y, axis,
+ UnaryCompoundFunctor(unary_functor, binary_functor), output);
+}
+
+template
+static void RunBinaryCompoundGradFunctors(
+ const framework::ExecutionContext &ctx,
+ const BinaryGradFunctor &binary_grad_functor,
+ const UnaryFunctor &unary_functor,
+ const UnaryGradFunctor &unary_grad_functor, const framework::Tensor *in_x,
+ const framework::Tensor *in_y, const framework::Tensor *in_out,
+ const framework::Tensor *in_out_grad, framework::Tensor *x_grad,
+ framework::Tensor *y_grad) {
+ int axis = ctx.Attr("axis");
+
+ using BinaryCompoundDxFunctor =
+ BinaryCompoundGradDxFunctor;
+ using BinaryCompoundDyFunctor =
+ BinaryCompoundGradDyFunctor;
+
+ ElemwiseGradCompute(
+ ctx, *in_x, *in_y, *in_out, *in_out_grad, axis, x_grad, y_grad,
+ BinaryCompoundDxFunctor(binary_grad_functor, unary_functor),
+ BinaryCompoundDyFunctor(binary_grad_functor, unary_functor,
+ unary_grad_functor));
+}
+
+template
+static void RunUnaryCompoundGradFunctors(
+ const framework::ExecutionContext &ctx,
+ const UnaryGradFunctor &unary_grad_functor,
+ const BinaryFunctor &binary_functor,
+ const BinaryGradFunctor &binary_grad_functor, const framework::Tensor *in_x,
+ const framework::Tensor *in_y, const framework::Tensor *in_out,
+ const framework::Tensor *in_out_grad, framework::Tensor *x_grad,
+ framework::Tensor *y_grad) {
+ int axis = ctx.Attr("axis");
+
+ using UnaryCompoundDxFunctor =
+ UnaryCompoundGradDxFunctor;
+ using UnaryCompoundDyFunctor =
+ UnaryCompoundGradDyFunctor;
+
+ ElemwiseGradCompute(
+ ctx, *in_x, *in_y, *in_out, *in_out_grad, axis, x_grad, y_grad,
+ UnaryCompoundDxFunctor(unary_grad_functor, binary_functor,
+ binary_grad_functor),
+ UnaryCompoundDyFunctor(unary_grad_functor, binary_functor,
+ binary_grad_functor));
+}
+
+template
+static void RunFunctors(const framework::ExecutionContext &ctx,
+ const framework::Tensor *in_x,
+ const framework::Tensor *in_y,
+ framework::Tensor *output) {
+ auto &functors = ctx.Attr>("functor_list");
+ auto funcs_str = functors[0] + "," + functors[1];
+ // TODO(zcd): The following code can be refined.
+ if (funcs_str == "elementwise_add,scale") {
+ // Z = Binary(X, Unary(Y))
+ T scale = static_cast(ctx.Attr("scale"));
+ RunBinaryCompoundFunctor,
+ math::ScaleFunctor>(
+ ctx, math::AddFunctor(), math::ScaleFunctor(scale), in_x, in_y,
+ output);
+ } else if (funcs_str == "scale,elementwise_add") {
+ // Z = Unary(Binary(X, Y))
+ T scale = static_cast(ctx.Attr("scale"));
+ RunUnaryCompoundFunctors,
+ math::AddFunctor>(
+ ctx, math::ScaleFunctor(scale), math::AddFunctor(), in_x, in_y,
+ output);
+ } else if (funcs_str == "elementwise_add,relu") {
+ RunBinaryCompoundFunctor,
+ math::ReluFunctor>(
+ ctx, math::AddFunctor(), math::ReluFunctor(), in_x, in_y, output);
+ } else if (funcs_str == "relu,elementwise_add") {
+ RunUnaryCompoundFunctors,
+ math::AddFunctor>(
+ ctx, math::ReluFunctor(), math::AddFunctor(), in_x, in_y, output);
+ } else {
+ PADDLE_THROW("%s has not been implemented.", funcs_str);
+ }
+}
+
+template
+static void RunGradFunctors(const framework::ExecutionContext &ctx,
+ const framework::Tensor *in_x,
+ const framework::Tensor *in_y,
+ const framework::Tensor *in_out,
+ const framework::Tensor *in_out_grad,
+ framework::Tensor *x_grad,
+ framework::Tensor *y_grad) {
+ auto &functors = ctx.Attr>("functor_list");
+ auto funcs_str = functors[0] + "," + functors[1];
+
+ bool recomputation = ctx.Attr("recomputation");
+
+ // TODO(zcd): The following code can be refined. for example, use registion
+ if (funcs_str == "elementwise_add_grad,scale_grad") {
+ // The backward of Z = Binary(X, Unary(Y))
+ T scale = static_cast(ctx.Attr("scale"));
+ if (recomputation) {
+ RunBinaryCompoundGradFunctors,
+ math::ScaleFunctor,
+ math::ScaleGradFunctor, true>(
+ ctx, math::AddGradFunctor(), math::ScaleFunctor(scale),
+ math::ScaleGradFunctor(scale), in_x, in_y, in_out, in_out_grad,
+ x_grad, y_grad);
+ } else {
+ RunBinaryCompoundGradFunctors,
+ math::ScaleFunctor,
+ math::ScaleGradFunctor, false>(
+ ctx, math::AddGradFunctor(), math::ScaleFunctor(scale),
+ math::ScaleGradFunctor(scale), in_x, in_y, in_out, in_out_grad,
+ x_grad, y_grad);
+ }
+ } else if (funcs_str == "scale_grad,elementwise_add_grad") {
+ // The backward of Z = Unary(Binary(X, Y))
+ T scale = static_cast(ctx.Attr("scale"));
+ if (recomputation) {
+ RunUnaryCompoundGradFunctors,
+ math::AddFunctor, math::AddGradFunctor,
+ true>(ctx, math::ScaleGradFunctor(scale),
+ math::AddFunctor(),
+ math::AddGradFunctor(), in_x, in_y,
+ in_out, in_out_grad, x_grad, y_grad);
+ } else {
+ RunUnaryCompoundGradFunctors,
+ math::AddFunctor, math::AddGradFunctor,
+ false>(ctx, math::ScaleGradFunctor(scale),
+ math::AddFunctor(),
+ math::AddGradFunctor(), in_x, in_y,
+ in_out, in_out_grad, x_grad, y_grad);
+ }
+ } else if (funcs_str == "elementwise_add_grad,relu_grad") {
+ if (recomputation) {
+ RunBinaryCompoundGradFunctors,
+ math::ReluFunctor,
+ math::ReluGradFunctor, true>(
+ ctx, math::AddGradFunctor(), math::ReluFunctor(),
+ math::ReluGradFunctor(), in_x, in_y, in_out, in_out_grad, x_grad,
+ y_grad);
+ } else {
+ RunBinaryCompoundGradFunctors,
+ math::ReluFunctor,
+ math::ReluGradFunctor, false>(
+ ctx, math::AddGradFunctor(), math::ReluFunctor(),
+ math::ReluGradFunctor(), in_x, in_y, in_out, in_out_grad, x_grad,
+ y_grad);
+ }
+ } else if (funcs_str == "relu_grad,elementwise_add_grad") {
+ if (recomputation) {
+ RunUnaryCompoundGradFunctors,
+ math::AddFunctor, math::AddGradFunctor,
+ true>(ctx, math::ReluGradFunctor(),
+ math::AddFunctor(),
+ math::AddGradFunctor(), in_x, in_y,
+ in_out, in_out_grad, x_grad, y_grad);
+ } else {
+ RunUnaryCompoundGradFunctors,
+ math::AddFunctor, math::AddGradFunctor,
+ false>(ctx, math::ReluGradFunctor(),
+ math::AddFunctor(),
+ math::AddGradFunctor(), in_x, in_y,
+ in_out, in_out_grad, x_grad, y_grad);
+ }
+ } else {
+ PADDLE_THROW("%s has not been implemented.", funcs_str);
+ }
+}
+
+template
+class FusedElemwiseActivationKernel : public framework::OpKernel {
+ public:
+ void Compute(const framework::ExecutionContext &ctx) const override {
+ auto &in_x = detail::Ref(ctx.Input("X"),
+ "Cannot get input tensor %s, variable name = %s",
+ "X", ctx.op().Input("X"));
+ auto &in_y = detail::Ref(ctx.Input("Y"),
+ "Cannot get input tensor %s, variable name = %s",
+ "Y", ctx.op().Input("Y"));
+ auto &output = detail::Ref(ctx.Output("Out"),
+ "Cannot get input tensor %s, variable name = %s",
+ "Out", ctx.op().Output("Out"));
+
+ RunFunctors(ctx, &in_x, &in_y, &output);
+ }
+};
+
+template
+class FusedElemwiseActivationGradKernel : public framework::OpKernel {
+ public:
+ void Compute(const framework::ExecutionContext &ctx) const override {
+ auto &in_x = detail::Ref(ctx.Input("X"),
+ "Cannot get input tensor %s, variable name = %s",
+ "X", ctx.op().Input("X"));
+ auto &in_y = detail::Ref(ctx.Input("Y"),
+ "Cannot get input tensor %s, variable name = %s",
+ "Y", ctx.op().Input("Y"));
+ auto &in_out = detail::Ref(ctx.Input("Out"),
+ "Cannot get input tensor %s, variable name = %s",
+ "Out", ctx.op().Input("Out"));
+ auto &in_out_grad =
+ detail::Ref(ctx.Input(framework::GradVarName("Out")),
+ "Cannot get input tensor %s, variable name = %s",
+ framework::GradVarName("Out"),
+ ctx.op().Input(framework::GradVarName("Out")));
+
+ framework::Tensor *x_grad =
+ ctx.Output(framework::GradVarName("X"));
+ framework::Tensor *y_grad =
+ ctx.Output(framework::GradVarName("Y"));
+
+ RunGradFunctors(ctx, &in_x, &in_y, &in_out, &in_out_grad,
+ x_grad, y_grad);
+ }
+};
+} // namespace operators
+} // namespace paddle
diff --git a/paddle/fluid/operators/load_op.cc b/paddle/fluid/operators/load_op.cc
index ac35cf0b89bfaa0c0f8e64445f18a3bbd478e70a..27e26cb1b5c1e831f05dac299489628b92eaa58c 100644
--- a/paddle/fluid/operators/load_op.cc
+++ b/paddle/fluid/operators/load_op.cc
@@ -31,9 +31,6 @@ class LoadOp : public framework::OperatorBase {
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
- auto *dev_ctx = platform::DeviceContextPool::Instance().Get(place);
- platform::RecordEvent record_event(Type(), dev_ctx);
-
// FIXME(yuyang18): We save variable to local file now, but we should change
// it to save an output stream.
auto filename = Attr("file_path");
diff --git a/paddle/fluid/operators/lookup_table_op.cc b/paddle/fluid/operators/lookup_table_op.cc
index 3e8f3ec5c5cd683343bcbdfc2388bd37c25e00f9..d77b095c5d783a2a9fab87eb8b458117a6a3d225 100644
--- a/paddle/fluid/operators/lookup_table_op.cc
+++ b/paddle/fluid/operators/lookup_table_op.cc
@@ -32,11 +32,16 @@ class LookupTableOp : public framework::OperatorWithKernel {
auto table_dims = ctx->GetInputDim("W");
auto ids_dims = ctx->GetInputDim("Ids");
+ int ids_rank = ids_dims.size();
- PADDLE_ENFORCE_EQ(ids_dims.size(), 2);
- PADDLE_ENFORCE_EQ(ids_dims[1], 1);
+ PADDLE_ENFORCE_EQ(table_dims.size(), 2);
+ PADDLE_ENFORCE_EQ(ids_dims[ids_rank - 1], 1,
+ "The last dimension of the 'Ids' tensor must be 1.");
- ctx->SetOutputDim("Out", {ids_dims[0], table_dims[1]});
+ auto output_dims =
+ framework::vectorize(framework::slice_ddim(ids_dims, 0, ids_rank - 1));
+ output_dims.push_back(table_dims[1]);
+ ctx->SetOutputDim("Out", framework::make_ddim(output_dims));
if (ctx->GetOutputsVarType("Out")[0] ==
framework::proto::VarType::LOD_TENSOR) {
@@ -61,8 +66,7 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("Ids",
"An input with type int32 or int64 "
"contains the ids to be looked up in W. "
- "Ids must be a column vector with rank = 2. "
- "The 2nd dimension size must be 1.");
+ "The last dimension size must be 1.");
AddOutput("Out", "The lookup results, which have the same type as W.");
AddAttr("is_sparse",
"(boolean, default false) "
diff --git a/paddle/fluid/operators/lookup_table_op.cu b/paddle/fluid/operators/lookup_table_op.cu
index 27483372b93a850d313445386c7973838c4a0710..74823dab09cac358f647c074ac2f2ee2fed17e55 100644
--- a/paddle/fluid/operators/lookup_table_op.cu
+++ b/paddle/fluid/operators/lookup_table_op.cu
@@ -118,28 +118,31 @@ class LookupTableGradCUDAKernel : public framework::OpKernel {
auto *d_table = context.Output(framework::GradVarName("W"));
auto *ids_data = ids->data();
- auto ids_dim = ids->dims();
+ int64_t ids_num = ids->numel();
auto stream = dev_ctx.stream();
// copy GPU memory to CPU pinned memory
framework::Vector new_rows;
- new_rows.resize(ids_dim[0]);
+ new_rows.resize(ids_num);
auto gpu_place = boost::get(context.GetPlace());
// TODO(yuyang18): Strange code here.
memory::Copy(platform::CPUPlace(),
new_rows.CUDAMutableData(context.GetPlace()), gpu_place,
- ids_data, ids_dim[0] * sizeof(int64_t), stream);
+ ids_data, ids_num * sizeof(int64_t), stream);
d_table->set_rows(new_rows);
auto *d_table_value = d_table->mutable_value();
- d_table_value->Resize({ids_dim[0], table->dims()[1]});
+ d_table_value->Resize({ids_num, table->dims()[1]});
d_table_value->mutable_data(context.GetPlace());
auto *d_table_data = d_table_value->data();
auto *d_output_data = d_output->data();
- PADDLE_ENFORCE_EQ(d_table_value->dims(), d_output->dims());
+ auto d_output_dims = d_output->dims();
+ PADDLE_ENFORCE_EQ(
+ d_table_value->dims(),
+ framework::flatten_to_2d(d_output_dims, d_output_dims.size() - 1));
memory::Copy(gpu_place, d_table_data, gpu_place, d_output_data,
d_output->numel() * sizeof(T), stream);
diff --git a/paddle/fluid/operators/lookup_table_op.h b/paddle/fluid/operators/lookup_table_op.h
index c9f074ca0e8dafb374dc9368165df5af5053a6b8..f5c10ced8305b64c6386c5051804f8c9a8f71802 100644
--- a/paddle/fluid/operators/lookup_table_op.h
+++ b/paddle/fluid/operators/lookup_table_op.h
@@ -109,17 +109,17 @@ class LookupTableGradKernel : public framework::OpKernel {
auto *d_table = context.Output(framework::GradVarName("W"));
auto *ids_data = ids->data();
- auto ids_dim = ids->dims();
+ int64_t ids_num = ids->numel();
framework::Vector new_rows;
- new_rows.reserve(ids_dim[0]);
- for (int64_t i = 0; i < ids_dim[0]; i++) {
+ new_rows.reserve(ids_num);
+ for (int64_t i = 0; i < ids_num; i++) {
new_rows.push_back(ids_data[i]);
}
d_table->set_rows(new_rows);
auto *d_table_value = d_table->mutable_value();
- d_table_value->Resize({ids_dim[0], table_dim[1]});
+ d_table_value->Resize({ids_num, table_dim[1]});
d_table_value->mutable_data(context.GetPlace());
d_table->set_height(table_dim[0]);
@@ -127,7 +127,10 @@ class LookupTableGradKernel : public framework::OpKernel {
auto *d_output_data = d_output->data();
auto *d_table_data = d_table_value->data();
- PADDLE_ENFORCE_EQ(d_table_value->dims(), d_output->dims());
+ auto d_output_dims = d_output->dims();
+ PADDLE_ENFORCE_EQ(
+ d_table_value->dims(),
+ framework::flatten_to_2d(d_output_dims, d_output_dims.size() - 1));
memcpy(d_table_data, d_output_data, sizeof(T) * d_output->numel());
} else {
auto *ids = context.Input("Ids");
@@ -135,10 +138,9 @@ class LookupTableGradKernel : public framework::OpKernel {
auto *d_table = context.Output(framework::GradVarName("W"));
auto *ids_data = ids->data();
- auto ids_dim = ids->dims();
int N = table_dim[0];
- int D = d_output->dims()[1];
+ int D = table_dim[1];
auto *d_output_data = d_output->data();
auto *d_table_data = d_table->mutable_data