提交 097f533b 编写于 作者: Y Yi Wang

Resolve conflict

...@@ -49,11 +49,12 @@ if(NOT WITH_GOLANG) ...@@ -49,11 +49,12 @@ if(NOT WITH_GOLANG)
endif(NOT WITH_GOLANG) endif(NOT WITH_GOLANG)
if(NOT WITH_GPU) if(NOT WITH_GPU)
add_definitions(-DPADDLE_ONLY_CPU)
add_definitions(-DHPPL_STUB_FUNC) add_definitions(-DHPPL_STUB_FUNC)
list(APPEND CMAKE_CXX_SOURCE_FILE_EXTENSIONS cu) list(APPEND CMAKE_CXX_SOURCE_FILE_EXTENSIONS cu)
else() else()
add_definitions(-DPADDLE_WITH_CUDA)
FIND_PACKAGE(CUDA REQUIRED) FIND_PACKAGE(CUDA REQUIRED)
if(${CUDA_VERSION_MAJOR} VERSION_LESS 7) if(${CUDA_VERSION_MAJOR} VERSION_LESS 7)
......
...@@ -15,9 +15,9 @@ Please be aware that these Python classes need to maintain some construction-tim ...@@ -15,9 +15,9 @@ Please be aware that these Python classes need to maintain some construction-tim
### Program ### Program
A `ProgramDesc` describes a [DL program](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/program.md), which is composed of an array of `BlockDesc`s. A `BlockDesc` refers to its parent block by its index in the array. For example, operators in the step block of an RNN operator needs to be able to access variables in its ancessor blocks. A `ProgramDesc` describes a [DL program](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/program.md), which is composed of an array of `BlockDesc`s. The `BlockDesc`s in a `ProgramDesc` can have a tree-like hierarchical structure. However, the `ProgramDesc` onlys stores a flattened array of `BlockDesc`s. A `BlockDesc` refers to its parent block by its index in the array. For example, operators in the step block of an RNN operator need to be able to access variables in its ancestor blocks.
Whenever we create a block, we need set its parent block to the current block, so the Python class `Program` needs to maintain a data member `current_block`. Whenever we create a block, we need to set its parent block to the current block, hence the Python class `Program` needs to maintain a data member `current_block`.
```python ```python
class Program(objects): class Program(objects):
...@@ -81,13 +81,13 @@ class Block(objects): ...@@ -81,13 +81,13 @@ class Block(objects):
self.ops.prepend(Operator(self, ...)) self.ops.prepend(Operator(self, ...))
``` ```
`create_parameter` is necessary because parameters are global variables, those defined in the global block, but can be created in some sub-blocks, e.g., an FC layer in the step block of an RNN operator. `create_parameter` is necessary because parameters are global variables, defined in the global block, but can be created in some sub-blocks. For example, an FC layer in the step block of an RNN operator.
`prepand_operator` is necessary because the constructor of `Parameter` needs to create the initialize (or load) operator of the parameter, and would like to put it in the *preamble* of the global block. `prepend_operator` is necessary because the constructor of `Parameter` needs to create the initialize (or load) operator of the parameter, and would like to put it in the *preamble* of the global block.
### Operator ### Operator
The `Operator` class fills in the `OpDesc` message and calls the C++ function `InferShape` to infer output shape from input shape. The `Operator` class fills in the `OpDesc` message and calls the C++ function `InferShape` to infer the output shapes from the input shapes.
```python ```python
class Operator(object): class Operator(object):
...@@ -105,7 +105,7 @@ class Operator(object): ...@@ -105,7 +105,7 @@ class Operator(object):
return self.proto.type() return self.proto.type()
``` ```
`Operator` creates the `OpDesc` message in C++ space, so could it call the `InferShape` function, which is in C++. `Operator` creates the `OpDesc` message in C++ space, so that it can call the `InferShape` function, which is in C++.
### Variable ### Variable
...@@ -128,7 +128,7 @@ class Variable(object): ...@@ -128,7 +128,7 @@ class Variable(object):
self.writer = None self.writer = None
``` ```
Please be aware of `self.writer`, that tracks operator who creates the variable. It possible that there are more than one operators who write a variable, but in Python space, each writes to a variable is represented by a Variable class. This is guaranteed by the fact that **`core.NewVarDesc` must NOT create a new `VarDesc` message if its name already exists in the specified block**. Please be aware of `self.writer`, that tracks operator who creates the variable. It possible that there are more than one operators who write a variable, but in Python space, each write to a variable is represented by a Variable class. This is guaranteed by the fact that **`core.NewVarDesc` must NOT create a new `VarDesc` message if its name already exists in the specified block**.
### Parameter ### Parameter
...@@ -155,7 +155,7 @@ class Parameter(Variable): ...@@ -155,7 +155,7 @@ class Parameter(Variable):
initialize_op_attrs) initialize_op_attrs)
``` ```
When users create a parameter, s/he can call When users create a parameter, they can call
```python ```python
program.create_parameter( program.create_parameter(
......
# Design Doc: Session
## Abstract
The *session* object encapsulates the environment in which the
computation graph is executed.
We will have the *local* session and *remote* session, they offer the
same [interface](#interface). The local session encapsulates the local
runtime environment and the remote session encapsulates the cluster
runtime environment.
The local runtime environment contains:
1. computation devices (i.e., CPU, GPU) handles, and
1. the [scope](../scope.md) which holds all variables.
The remote runtime environment contains:
1. computation devices (i.e., CPU and GPU on node 0, 1) in a cluster,
and
1. the distributed [scope](../scope.md) in a cluster which holds all
variables.
The user can create a remote session on Paddle Cloud and evaluate the
computation graph with it. In this way, the user can control the
remote computation resource in a cluster from his local computer.
## Background
The current design has an implicit global session in which
`paddle.eval()` is executed. The pain point is:
Since the user is not able to explicitly switch between runtime
environments, the user cannot run a topology in two independent
environments.
For example, in reinforcement learning, the user may want to have a
stale model for inference and a fresh model for training, and only
replace the stale model with the fresh model periodically.
Furthermore, we have no concept that encapsulates a remote environment
that executes a computation graph.
We need the session object to address above issues.
## Session
A session is an object that owns the runtime environment. All
computations are executed through `session.eval()`.
### Interface
```python
eval(
targets,
feed_dict=None,
)
```
Evaluates the target Operations or Variables in `targets`.
- *targets*: the evaluation targets. Can be a single Operation or
Variable, or a list with the Operations or Variables as
elements. The value returned by `eval()` has the same shape as the
`target` argument.
The PaddlePaddle program is represented by
the [ProgramDesc](../design/program.md), `eval()` will infer the
ProgramDesc from the given targets and run the PaddlePaddle
program. Please
see
[this graph](./distributed_architecture.md#local-training-architecture) for
the detailed illustration for the local session
and
[this graph](./distributed_architecture.md#distributed-training-architecture) for
the detailed illustration for the remote session.
- *feed_dict*: a dictionary that contains the tensors which override
the edges of the computation graph.
feed_dict not only can provide the input data, it can override any
OP's input as well:
```python
a = pd.constant(2.0, name="a")
b = pd.variable(name="b")
c = pd.mul(a,b)
sess.eval(targets=c, feed_dict={"b":3.0}) # returns 6.0
```
```python
close()
```
Closes the session and releases the scope that the session owns.
### Create a Local Session
```python
session(
devices=None
)
```
Creates a new session. One session owns one global scope, so creating
multiple sessions will create different scopes.
- *devices*: a single `string` or a list of `string` of device names,
the corresponding devices will be the computation devices for
`eval()`. If not specified, all available devices (e.g., all GPUs)
will be used. The user doesn't need to specify the CPU device since
it will be always used. Multiple sessions can use the same device.
#### Example
```Python
a = paddle.constant(1.0)
b = paddle.constant(2.0)
c = a + b
sess = paddle.session(devices=["gpu:0", "gpu:1", "fpga:0"])
sess.eval(c)
sess.close()
```
### Create a Remote Session
```python
create_cloud_job(
name,
num_trainer,
mem_per_trainer,
gpu_per_trainer,
cpu_per_trainer,
num_ps,
mem_per_ps,
cpu_per_ps,
)
```
Creates a Paddle Cloud job. Fails if the job name exists.
```python
get_cloud_job(
name
)
```
Gets a Paddle Cloud job.
```python
remote_session(
job
)
```
- *job*: the Paddle Cloud job.
#### Example
```Python
reader = paddle.reader.recordio("/pfs/home/peter/mnist-train-*") # data stored on Paddle Cloud
image = reader.column(0)
label = reader.column(1)
fc1 = paddle.op.fc(image, size=256, act="sigmoid")
fc2 = paddle.op.fc(fc1, size=10, act="softmax")
cost = paddle.op.cross_entropy(fc2, label)
opt = paddle.optimizer.sgd(cost)
job = paddle.create_cloud_job("test", 3, "1G", 1, 1, 2, "1G", 1)
sess = paddle.remote_ession(job)
for i in range(1000):
sess.eval(opt)
sess.close()
```
...@@ -47,7 +47,7 @@ bool isUsingGpu() { return FLAGS_use_gpu; } ...@@ -47,7 +47,7 @@ bool isUsingGpu() { return FLAGS_use_gpu; }
void setUseGpu(bool useGpu) { FLAGS_use_gpu = useGpu; } void setUseGpu(bool useGpu) { FLAGS_use_gpu = useGpu; }
bool isGpuVersion() { bool isGpuVersion() {
#ifdef PADDLE_ONLY_CPU #ifndef PADDLE_WITH_CUDA
return false; return false;
#else #else
return true; return true;
......
...@@ -46,7 +46,7 @@ paddle_error paddle_matrix_set_row(paddle_matrix mat, ...@@ -46,7 +46,7 @@ paddle_error paddle_matrix_set_row(paddle_matrix mat,
if (rowID >= ptr->mat->getHeight()) return kPD_OUT_OF_RANGE; if (rowID >= ptr->mat->getHeight()) return kPD_OUT_OF_RANGE;
paddle::real* buf = ptr->mat->getRowBuf(rowID); paddle::real* buf = ptr->mat->getRowBuf(rowID);
size_t width = ptr->mat->getWidth(); size_t width = ptr->mat->getWidth();
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
hl_memcpy(buf, rowArray, sizeof(paddle::real) * width); hl_memcpy(buf, rowArray, sizeof(paddle::real) * width);
#else #else
std::copy(rowArray, rowArray + width, buf); std::copy(rowArray, rowArray + width, buf);
......
...@@ -177,7 +177,7 @@ static std::unique_ptr<OperatorBase> BackwardRecursive( ...@@ -177,7 +177,7 @@ static std::unique_ptr<OperatorBase> BackwardRecursive(
for (size_t output_idx = 0; output_idx < dup_outputs.size() - 1; for (size_t output_idx = 0; output_idx < dup_outputs.size() - 1;
++output_idx) { ++output_idx) {
auto insert_add_x = dup_outputs[output_idx]; auto insert_add_x = dup_outputs[output_idx];
auto insert_add_y = dup_outputs[output_idx]; auto insert_add_y = dup_outputs[output_idx + 1];
auto insert_add_out = name + "@SHARED@" + std::to_string(output_idx); auto insert_add_out = name + "@SHARED@" + std::to_string(output_idx);
// first add op inserted // first add op inserted
if (output_idx == dup_outputs.size() - 2) { if (output_idx == dup_outputs.size() - 2) {
...@@ -188,9 +188,8 @@ static std::unique_ptr<OperatorBase> BackwardRecursive( ...@@ -188,9 +188,8 @@ static std::unique_ptr<OperatorBase> BackwardRecursive(
} }
insert_position.push_back( insert_position.push_back(
{dup_op.back(), {dup_op.back(),
OpRegistry::CreateOp( OpRegistry::CreateOp("sum", {{"X", {insert_add_x, insert_add_y}}},
"sum", {{"X", {insert_add_x}}, {"X", {insert_add_y}}}, {{"Out", {insert_add_out}}}, {})});
{{"Out", {insert_add_out}}}, {})});
} }
} }
...@@ -230,7 +229,8 @@ static std::unique_ptr<OperatorBase> BackwardRecursive( ...@@ -230,7 +229,8 @@ static std::unique_ptr<OperatorBase> BackwardRecursive(
// process recurrent gradient op as a special operator. // process recurrent gradient op as a special operator.
if (forwardOp.Type() == "recurrent") { if (forwardOp.Type() == "recurrent") {
// NOTE clean up cycle call somewhere (RNN's stepnet constains itself), or // NOTE clean up cycle call somewhere (RNN's stepnet constains itself),
// or
// this will result in infinite loop. // this will result in infinite loop.
const auto& rnnop = const auto& rnnop =
*static_cast<const operators::RecurrentOp*>(&forwardOp); *static_cast<const operators::RecurrentOp*>(&forwardOp);
......
...@@ -105,6 +105,7 @@ message LoDTensorDesc { ...@@ -105,6 +105,7 @@ message LoDTensorDesc {
message VarDesc { message VarDesc {
required string name = 1; required string name = 1;
optional LoDTensorDesc lod_tensor = 2; optional LoDTensorDesc lod_tensor = 2;
optional bool persistable = 3 [ default = false ];
} }
message BlockDesc { message BlockDesc {
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
#pragma once #pragma once
#include <memory> #include <memory>
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
#include <thrust/device_vector.h> #include <thrust/device_vector.h>
#include <thrust/host_vector.h> #include <thrust/host_vector.h>
#include <thrust/system/cuda/experimental/pinned_allocator.h> #include <thrust/system/cuda/experimental/pinned_allocator.h>
...@@ -29,7 +29,7 @@ ...@@ -29,7 +29,7 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
#ifdef PADDLE_ONLY_CPU #ifndef PADDLE_WITH_CUDA
template <typename T> template <typename T>
using Vector = std::vector<T>; using Vector = std::vector<T>;
#else #else
......
# Design Doc: LoD (Level-of-Detail) Tensor # Design Doc: LoD (Level-of-Detail) Tensor
PaddlePaddle's RNN doesn't require that all instances have the same length. To do so, we introduce an extension to Tensor, namely, LoD Tensor. Like other deep learning systems, PaddlePaddle supports training models from sequence data. Also, like other systems, PaddlePaddle represent a mini-batch of sequences as a Tensor. What is different is that PaddlePaddle doesn't require all sequences in a mini-batch to be of the same length. Thus no need for padding zeros.
## Challenge of Variable-length Inputs | | TensorFlow | PaddlePaddle |
|-----------------------|------------|--------------|
| RNN | Support | Support |
| recursive RNN | Support | Support |
| padding zeros | Must | No need |
| blob data type | Tensor | LoDTensor |
People usually represent a mini-batch by a Tensor. For example, a mini-batch of 10 images, each of size 32x32, is a 10x32x32 Tensor. So a transformation, T, of all images can be a matrix multiplication of the 10xOx32-dimensional tensor T and the 10x32x32 Tensor. PaddlePaddle achieves this flexibility by passing through a new data type, *LoD Tensor*, which is a Tensor attached with segmentation index known as *LoD*, between operators. The LoD index doesn't only segment a tensor, but also recursively segments sub-sequences. This document presents the design of LoD and LoDTensor.
Another example is that each mini-batch contains 32 sentences, where each word is a D-dimensional one-hot vector. If all sentences have the same length L, we can represent this mini-batch by a 32xLxD tensor. However, in most cases, sentences have variable lengths, and we will need an index data structure to record these variable lengths.
## LoD as a Solution ## The Challenge: Variable-length Sequences
### Mini-Batch of variable-length sentences Most deep learning systems represent a mini-batch as a Tensor. For example, a mini-batch of 10 images, each of size 32x32, is a 10x32x32 Tensor. Another example is that each mini-batch contains N sentences, where each word is a D-dimensional one-hot vector. Suppose that all sentences have the same length L, we can represent this mini-batch by a NxLxD tensor.
Let's imagine a mini-batch of 3 variable lengths sentences, containing 3, 1, and 2 words respectively. We can represent it by a (3+1+2)xD tensor plus some index information: Both examples show that the elements of sequences are usually of the same size. In the first example, all images are 32x32, and in the second one, all words are D-dimensional vectors. It doesn't make sense to allow variable-sized images, as that would require transformations like convolution to handle variable-sized Tensors.
The real challenge is that in most cases, sentences have variable lengths, and we will need an index data structure to segment the tensor into sequences. Also, sequences might consist of sub-sequences.
## A Solution: The LoD Index
To understand our solution, it is best to look at some examples.
### A Mini-Batch of Sentences
Let's imagine a mini-batch of 3 variable lengths sentences composed of 3, 1, and 2 words, respectively. We can represent the mini-batch by a (3+1+2)xD tensor plus some index information:
``` ```
3
3 1 2 3 1 2
||| | || ||| | ||
``` ```
Each `|` represents a D-dimensional word vectors. The number 3 on top indicate 3 sentences, and numbers 3, 1, and 2 on the second level represent the number of words in each sentence. where each `|` represents a D-dimensional word vector. The numbers, 3, 1, and 2, form a 1-level LoD.
### Recursive Sequences
Let check another example of a 2-level LoD Tensor. Consider a mini-batch of three articles with 3, 1, and 2 sentences, and each sentence consists of a variable number of words:
```
3 1 2
3 2 4 1 2 3
||| || |||| | || |||
```
### Mini-Batch of variable-length videos ### A Mini-Batch of Videos
This approach generalizes to the case where elements are not words, but higher dimensional objects, like images. Suppose that a mini-batch contains videos of the same frame size 640x480. If a mini-batch contains 3 videos of 3, 1, and 2 frames respectively. The underlying tensor is of size (3+1+2)x640x480. The index information illustrates as: LoD tensors generalize to the case where elements are higher dimensional objects, like images. Suppose that a mini-batch contains videos of the same frame size 640x480. Here is a mini-batch of 3 videos with 3, 1, and 2 frames, respectively.
``` ```
3
3 1 2 3 1 2
口口口 口 口口 口口口 口 口口
``` ```
where each `口` represents an image. The underlying tensor is of size (3+1+2)x640x480, and each `口` represents a 640x480 image.
### Mini-Batch of fixed-size images ### A Mini-Batch of Images
Let's get back to a typical example, image classification, where each mini-batch has M fixed-sized images. The LoD Tensor representation is In traditional cases like a mini-batch with N fixed-sized images, the LoD Tensor representation is as
``` ```
M
1 1 1 1 1 1 1 1 1 1
口口口口 ... 口 口口口口 ... 口
``` ```
The many 1's on the second level seem duplicated. For this particular case of 2 levels and the second level always have length 1, we can ignore the LoD index. In this case, we don't lose any information by ignoring the many 1's in the index and simply considering this LoD Tensor as a usual Tensor:
### Design and summarization
In summary, as long as that the essential elements (words or images) have the same size, we can represent mini-batches by a LoD Tensor: ```
口口口口 ... 口
```
- The underlying tensor has size LxD1xD2x..., where D1xD2... is the size of the essential elements, and ### Model Parameters
- The first dimension size L has an additonal property -- a LoD index as a nested vector:
```c++ A model parameter is just a usual Tensor, which, just like the above example, is a **0-level LoD Tensor**.
typedef std::vector<std::<vector>> LoD;
```
- The LoD index is not necessary when there are only two levels and all elements of the second level have length 1.
## Slicing of LoD Tensor ## The LoD Tensor
Consider that we have a network with three levels of RNN: the top level one handles articles, the second level one handles sentences, and the basic level one handles words. This network requires that mini-batches represented by 3 level LoD Tensor, for example, Let us revisit above example of the 2-level LoD Tensor
``` ```
3
3 1 2 3 1 2
3 2 4 1 2 3 3 2 4 1 2 3
||| || |||| | || ||| ||| || |||| | || |||
``` ```
To allow each level of RNN to handle its input, we define **the slicing of a LoD Tensor is defined as getting the j-th sequence on level i, or the <i,j>-slice** It is indeed a tree, where leaves are elementary sequences identified by **branches**.
For example, the third sentence in above example is identified by branch <0,2>, where 0 indicates the first article with length 3, and 2 indicates the third sentence in this article with length 4.
### The LoD Index
For example, the <2,1>-slice of above slice is We can save the LoD index in the above example
``` ```
2 3 1 2
|| 3 2 4 1 2 3
``` ```
and the <1,2>-slice of above example is in a not-full 2D matrix:
```c++
typedef std::vector<std::vector<int> > LoD;
``` ```
2
2 3
|| |||
```
Let's go on slicing this slice. Its <1,1>-slice is where
- `LoD.size()` is the number of levels, or the maximum length of branches,
- `LoD[i][j]` is the length of the j-th segment at the i-th level.
## The Offset Representation
To quickly access elementary sequences, we adopt an offset representation -- instead of saving the lengths, we save the beginning and ending elements of sequences.
In the above example, we accumulate the length of elementary sequences:
``` ```
1 3 2 4 1 2 3
1
|
``` ```
### The Slicing Algorithm into offsets
The algorithm, with over-simplified data structure, is defined as ```
0 3 5 9 10 12 15
= = = = = =
3 2+3 4+5 1+9 2+10 3+12
```
```c++ so we know that the first sentence is from word 0 to word 3, and the second sentence from work 3 to word 5.
typedef std::vector<std::vector<int>> LoD;
struct LoDTensor { Similarly, the lengths in the top level LoD
LoD lod_;
float* tensor_;
};
LoDTensor Slice(const LoDTensor& lodt, int level, int sequence); ```
3 1 2
``` ```
Let us revisit the example above are transformed into offsets of elements/words as follows:
``` ```
3 0 9 10 15
3 1 2 = = =
3 2 4 1 2 3 3+2+4 1+9 2+3+10
||| || |||| | || |||
``` ```
Suppose that we want to retrieve the <1,2>-slice so we can tell that the first article is from word 0 to word 9, and the second article is from word 9 to word 10.
The complete offset representation is as follows:
``` ```
2 0 9 10 15
2 3 0 3 5 9 10 12 15
|| ||| ||| || |||| | || |||
``` ```
we will need to find out the starting position of this slice by summing over all leaf nodes in `LoD` to the left of the slice, i.e., 3 + 2 + 4 + 1 = 10. ## Slicing of LoD Tensors
When we use the above 2-level LoD Tensor as the input to a nested-RNN, we need to retrieve certain sequences. Here we define the sequence identified by branch <i,j,...> as the **<i,j,...>-slice**.
To avoid the traversal of the LoD tree at slicing time, we can do it at the construction time -- instead of saving the lengths of the next level in the LoD tree, we can save the starting offset of the next level. For example, above LoD Tensor can be transformed into For example, the <2>-slice of above example is
``` ```
0 10 15
0 9 10 10 12 15
0 3 5 9 10 12 || |||
||| || |||| | || |||
``` ```
We don't really need the 0 on top, so the LoD Tensor could be and the <2,0>-slice of above slice is
``` ```
0 9 10 10 12
0 3 5 9 10 12 ||
||| || |||| | || |||
``` ```
...@@ -48,4 +48,4 @@ TEST(ProtoMaker, DuplicatedInOut) { ...@@ -48,4 +48,4 @@ TEST(ProtoMaker, DuplicatedInOut) {
paddle::framework::OpAttrChecker op_checker; paddle::framework::OpAttrChecker op_checker;
auto proto_maker = TestInOutProtoMaker(&op_proto, &op_checker); auto proto_maker = TestInOutProtoMaker(&op_proto, &op_checker);
ASSERT_THROW(proto_maker.Validate(), paddle::platform::EnforceNotMet); ASSERT_THROW(proto_maker.Validate(), paddle::platform::EnforceNotMet);
} }
\ No newline at end of file
...@@ -218,7 +218,7 @@ class OpKernelRegistrar : public Registrar { ...@@ -218,7 +218,7 @@ class OpKernelRegistrar : public Registrar {
// TODO(fengjiayi): The following macros // TODO(fengjiayi): The following macros
// seems ugly, do we have better method? // seems ugly, do we have better method?
#ifdef PADDLE_ONLY_CPU #ifndef PADDLE_WITH_CUDA
#define USE_OP_KERNEL(op_type) USE_OP_DEVICE_KERNEL(op_type, CPU) #define USE_OP_KERNEL(op_type) USE_OP_DEVICE_KERNEL(op_type, CPU)
#else #else
#define USE_OP_KERNEL(op_type) \ #define USE_OP_KERNEL(op_type) \
......
...@@ -183,4 +183,4 @@ class CosineOpComplete : public paddle::framework::CosineOp { ...@@ -183,4 +183,4 @@ class CosineOpComplete : public paddle::framework::CosineOp {
TEST(OperatorRegistrar, Test) { TEST(OperatorRegistrar, Test) {
using namespace paddle::framework; using namespace paddle::framework;
OperatorRegistrar<CosineOpComplete, CosineOpProtoAndCheckerMaker> reg("cos"); OperatorRegistrar<CosineOpComplete, CosineOpProtoAndCheckerMaker> reg("cos");
} }
\ No newline at end of file
...@@ -25,7 +25,7 @@ Eigen::DefaultDevice& ExecutionContext::GetEigenDevice< ...@@ -25,7 +25,7 @@ Eigen::DefaultDevice& ExecutionContext::GetEigenDevice<
return *device_context_.GetEigenDevice<platform::CPUPlace>(); return *device_context_.GetEigenDevice<platform::CPUPlace>();
} }
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
template <> template <>
Eigen::GpuDevice& Eigen::GpuDevice&
ExecutionContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const { ExecutionContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const {
......
...@@ -217,12 +217,11 @@ LoDTensor DynamicBatchUnpacker::GetBatch(size_t index) { ...@@ -217,12 +217,11 @@ LoDTensor DynamicBatchUnpacker::GetBatch(size_t index) {
// collect indice need to copy to the batch // collect indice need to copy to the batch
std::vector<size_t> indice; std::vector<size_t> indice;
for (size_t seq_id = 0; seq_id < meta.size(); seq_id++) { for (const auto& seq : meta) {
const auto& seq_meta = meta[seq_id]; size_t id = seq.begin + index;
if (index >= seq_meta.end) break; if (id >= seq.end) break;
indice.push_back(seq_meta.begin + index); indice.push_back(id);
} }
PADDLE_ENFORCE(!indice.empty(), "invalid batch at %d", index); PADDLE_ENFORCE(!indice.empty(), "invalid batch at %d", index);
// copy the indice of records in LoDTensor // copy the indice of records in LoDTensor
...@@ -232,16 +231,18 @@ LoDTensor DynamicBatchUnpacker::GetBatch(size_t index) { ...@@ -232,16 +231,18 @@ LoDTensor DynamicBatchUnpacker::GetBatch(size_t index) {
result.Resize(make_ddim(record_dims_vec)); result.Resize(make_ddim(record_dims_vec));
result.mutable_data<value_type>(platform::CPUPlace()); result.mutable_data<value_type>(platform::CPUPlace());
for (size_t i = 0; i < indice.size() - 1; i++) { for (size_t i = 0; i < indice.size(); i++) {
auto index = indice[i]; auto index = indice[i];
auto target = result.Slice<value_type>(i, i + 1); auto target = result.Slice<value_type>(i, i + 1);
auto source_ = source->Slice<value_type>(index, index + 1); auto source_ = source->Slice<value_type>(index, index + 1);
target.CopyFrom<value_type>(source_, platform::CPUPlace()); target.CopyFrom<value_type>(source_, platform::CPUPlace());
} }
return result; return result;
} }
// TODO(supejom) to cache lod if reasonable
LoDTensor PackDynamicBatch(const std::vector<LoDTensor>& source, LoDTensor PackDynamicBatch(const std::vector<LoDTensor>& source,
const std::vector<DySeqMeta>& meta, const LoD& lod, const std::vector<DySeqMeta>& meta, const LoD& lod,
size_t level) { size_t level) {
...@@ -273,7 +274,6 @@ LoDTensor PackDynamicBatch(const std::vector<LoDTensor>& source, ...@@ -273,7 +274,6 @@ LoDTensor PackDynamicBatch(const std::vector<LoDTensor>& source,
} }
result.set_lod(lod); result.set_lod(lod);
return result; return result;
} }
......
...@@ -65,7 +65,7 @@ inline T* Tensor::mutable_data(platform::Place place) { ...@@ -65,7 +65,7 @@ inline T* Tensor::mutable_data(platform::Place place) {
holder_.reset(new PlaceholderImpl<T, platform::CPUPlace>( holder_.reset(new PlaceholderImpl<T, platform::CPUPlace>(
boost::get<platform::CPUPlace>(place), size)); boost::get<platform::CPUPlace>(place), size));
} else if (platform::is_gpu_place(place)) { } else if (platform::is_gpu_place(place)) {
#ifdef PADDLE_ONLY_CPU #ifndef PADDLE_WITH_CUDA
PADDLE_THROW("'GPUPlace' is not supported in CPU only device."); PADDLE_THROW("'GPUPlace' is not supported in CPU only device.");
} }
#else #else
...@@ -103,7 +103,7 @@ inline void Tensor::CopyFrom(const Tensor& src, ...@@ -103,7 +103,7 @@ inline void Tensor::CopyFrom(const Tensor& src,
memory::Copy(boost::get<platform::CPUPlace>(dst_place), dst_ptr, memory::Copy(boost::get<platform::CPUPlace>(dst_place), dst_ptr,
boost::get<platform::CPUPlace>(src_place), src_ptr, size); boost::get<platform::CPUPlace>(src_place), src_ptr, size);
} }
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
else if (platform::is_gpu_place(src_place) && else if (platform::is_gpu_place(src_place) &&
platform::is_cpu_place(dst_place)) { platform::is_cpu_place(dst_place)) {
memory::Copy(boost::get<platform::CPUPlace>(dst_place), dst_ptr, memory::Copy(boost::get<platform::CPUPlace>(dst_place), dst_ptr,
......
...@@ -74,7 +74,7 @@ TEST(Tensor, MutableData) { ...@@ -74,7 +74,7 @@ TEST(Tensor, MutableData) {
EXPECT_EQ(p1, p2); EXPECT_EQ(p1, p2);
} }
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
{ {
Tensor src_tensor; Tensor src_tensor;
float* p1 = nullptr; float* p1 = nullptr;
...@@ -126,7 +126,7 @@ TEST(Tensor, ShareDataWith) { ...@@ -126,7 +126,7 @@ TEST(Tensor, ShareDataWith) {
ASSERT_EQ(src_tensor.data<int>(), dst_tensor.data<int>()); ASSERT_EQ(src_tensor.data<int>(), dst_tensor.data<int>());
} }
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
{ {
Tensor src_tensor; Tensor src_tensor;
Tensor dst_tensor; Tensor dst_tensor;
...@@ -163,7 +163,7 @@ TEST(Tensor, Slice) { ...@@ -163,7 +163,7 @@ TEST(Tensor, Slice) {
EXPECT_EQ(src_data_address + 3 * 4 * 1 * sizeof(int), slice_data_address); EXPECT_EQ(src_data_address + 3 * 4 * 1 * sizeof(int), slice_data_address);
} }
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
{ {
Tensor src_tensor; Tensor src_tensor;
src_tensor.mutable_data<double>(make_ddim({6, 9}), GPUPlace()); src_tensor.mutable_data<double>(make_ddim({6, 9}), GPUPlace());
...@@ -218,7 +218,7 @@ TEST(Tensor, CopyFrom) { ...@@ -218,7 +218,7 @@ TEST(Tensor, CopyFrom) {
EXPECT_EQ(dst_ptr[i], slice_ptr[i]); EXPECT_EQ(dst_ptr[i], slice_ptr[i]);
} }
} }
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
{ {
Tensor src_tensor; Tensor src_tensor;
Tensor gpu_tensor; Tensor gpu_tensor;
......
...@@ -194,7 +194,7 @@ public: ...@@ -194,7 +194,7 @@ public:
REGISTER_TYPED_FUNC(BlockExpand, CPU, BlockExpandForward); REGISTER_TYPED_FUNC(BlockExpand, CPU, BlockExpandForward);
REGISTER_TYPED_FUNC(BlockExpandGrad, CPU, BlockExpandBackward); REGISTER_TYPED_FUNC(BlockExpandGrad, CPU, BlockExpandBackward);
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
REGISTER_TYPED_FUNC(BlockExpand, GPU, BlockExpandForward); REGISTER_TYPED_FUNC(BlockExpand, GPU, BlockExpandForward);
REGISTER_TYPED_FUNC(BlockExpandGrad, GPU, BlockExpandBackward); REGISTER_TYPED_FUNC(BlockExpandGrad, GPU, BlockExpandBackward);
#endif #endif
......
...@@ -395,7 +395,7 @@ REGISTER_TYPED_FUNC(ContextProjectionForward, ...@@ -395,7 +395,7 @@ REGISTER_TYPED_FUNC(ContextProjectionForward,
REGISTER_TYPED_FUNC(ContextProjectionBackward, REGISTER_TYPED_FUNC(ContextProjectionBackward,
CPU, CPU,
ContextProjectionBackwardFunc); ContextProjectionBackwardFunc);
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
REGISTER_TYPED_FUNC(ContextProjectionForward, REGISTER_TYPED_FUNC(ContextProjectionForward,
GPU, GPU,
ContextProjectionForwardFunc); ContextProjectionForwardFunc);
......
...@@ -233,7 +233,7 @@ private: ...@@ -233,7 +233,7 @@ private:
REGISTER_TYPED_FUNC(CosSimForward, CPU, CosSimForwardFunc); REGISTER_TYPED_FUNC(CosSimForward, CPU, CosSimForwardFunc);
REGISTER_TYPED_FUNC(CosSimBackward, CPU, CosSimBackwardFunc); REGISTER_TYPED_FUNC(CosSimBackward, CPU, CosSimBackwardFunc);
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
REGISTER_TYPED_FUNC(CosSimForward, GPU, CosSimForwardFunc); REGISTER_TYPED_FUNC(CosSimForward, GPU, CosSimForwardFunc);
REGISTER_TYPED_FUNC(CosSimBackward, GPU, CosSimBackwardFunc); REGISTER_TYPED_FUNC(CosSimBackward, GPU, CosSimBackwardFunc);
#endif #endif
......
...@@ -169,7 +169,7 @@ private: ...@@ -169,7 +169,7 @@ private:
REGISTER_TYPED_FUNC(Crop, CPU, CropFunc); REGISTER_TYPED_FUNC(Crop, CPU, CropFunc);
REGISTER_TYPED_FUNC(CropGrad, CPU, CropGradFunc); REGISTER_TYPED_FUNC(CropGrad, CPU, CropGradFunc);
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
REGISTER_TYPED_FUNC(Crop, GPU, CropFunc); REGISTER_TYPED_FUNC(Crop, GPU, CropFunc);
REGISTER_TYPED_FUNC(CropGrad, GPU, CropGradFunc); REGISTER_TYPED_FUNC(CropGrad, GPU, CropGradFunc);
#endif #endif
......
...@@ -336,7 +336,7 @@ private: ...@@ -336,7 +336,7 @@ private:
REGISTER_TYPED_FUNC(CrossMapNormal, CPU, CrossMapNormalFunc); REGISTER_TYPED_FUNC(CrossMapNormal, CPU, CrossMapNormalFunc);
REGISTER_TYPED_FUNC(CrossMapNormalGrad, CPU, CrossMapNormalGradFunc); REGISTER_TYPED_FUNC(CrossMapNormalGrad, CPU, CrossMapNormalGradFunc);
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
REGISTER_TYPED_FUNC(CrossMapNormal, GPU, CrossMapNormalFunc); REGISTER_TYPED_FUNC(CrossMapNormal, GPU, CrossMapNormalFunc);
REGISTER_TYPED_FUNC(CrossMapNormalGrad, GPU, CrossMapNormalGradFunc); REGISTER_TYPED_FUNC(CrossMapNormalGrad, GPU, CrossMapNormalGradFunc);
#endif #endif
......
...@@ -292,7 +292,7 @@ REGISTER_TYPED_FUNC(DepthwiseConvGradInput, ...@@ -292,7 +292,7 @@ REGISTER_TYPED_FUNC(DepthwiseConvGradInput,
REGISTER_TYPED_FUNC(DepthwiseConvGradFilter, REGISTER_TYPED_FUNC(DepthwiseConvGradFilter,
CPU, CPU,
DepthwiseConvGradFilterFunction); DepthwiseConvGradFilterFunction);
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
REGISTER_TYPED_FUNC(DepthwiseConv, GPU, DepthwiseConvFunction); REGISTER_TYPED_FUNC(DepthwiseConv, GPU, DepthwiseConvFunction);
REGISTER_TYPED_FUNC(DepthwiseConvGradInput, REGISTER_TYPED_FUNC(DepthwiseConvGradInput,
GPU, GPU,
......
...@@ -17,7 +17,7 @@ limitations under the License. */ ...@@ -17,7 +17,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
TEST(DepthwiseConv, Forward) { TEST(DepthwiseConv, Forward) {
DepthwiseConvolution<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU>( DepthwiseConvolution<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU>(
"GemmConv-CPU", "DepthwiseConv-GPU", forward); "GemmConv-CPU", "DepthwiseConv-GPU", forward);
......
...@@ -340,7 +340,7 @@ public: ...@@ -340,7 +340,7 @@ public:
REGISTER_TYPED_FUNC(GemmConv, CPU, GemmConvFunction); REGISTER_TYPED_FUNC(GemmConv, CPU, GemmConvFunction);
REGISTER_TYPED_FUNC(GemmConvGradInput, CPU, GemmConvGradInputFunction); REGISTER_TYPED_FUNC(GemmConvGradInput, CPU, GemmConvGradInputFunction);
REGISTER_TYPED_FUNC(GemmConvGradFilter, CPU, GemmConvGradFilterFunction); REGISTER_TYPED_FUNC(GemmConvGradFilter, CPU, GemmConvGradFilterFunction);
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
REGISTER_TYPED_FUNC(GemmConv, GPU, GemmConvFunction); REGISTER_TYPED_FUNC(GemmConv, GPU, GemmConvFunction);
REGISTER_TYPED_FUNC(GemmConvGradInput, GPU, GemmConvGradInputFunction); REGISTER_TYPED_FUNC(GemmConvGradInput, GPU, GemmConvGradInputFunction);
REGISTER_TYPED_FUNC(GemmConvGradFilter, GPU, GemmConvGradFilterFunction); REGISTER_TYPED_FUNC(GemmConvGradFilter, GPU, GemmConvGradFilterFunction);
......
...@@ -24,7 +24,7 @@ TEST(GemmConv, NaiveConv) { ...@@ -24,7 +24,7 @@ TEST(GemmConv, NaiveConv) {
"NaiveConv-CPU", "GemmConv-CPU", forward); "NaiveConv-CPU", "GemmConv-CPU", forward);
} }
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
TEST(GemmConv, Forward) { TEST(GemmConv, Forward) {
Convolution<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU>( Convolution<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU>(
"GemmConv-CPU", "GemmConv-GPU", forward); "GemmConv-CPU", "GemmConv-GPU", forward);
......
...@@ -116,7 +116,7 @@ void TestIm2ColFunctor() { ...@@ -116,7 +116,7 @@ void TestIm2ColFunctor() {
TEST(Im2ColFunctor, CPU) { TestIm2ColFunctor<DEVICE_TYPE_CPU, float>(); } TEST(Im2ColFunctor, CPU) { TestIm2ColFunctor<DEVICE_TYPE_CPU, float>(); }
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
TEST(Im2ColFunctor, GPU) { TestIm2ColFunctor<DEVICE_TYPE_GPU, float>(); } TEST(Im2ColFunctor, GPU) { TestIm2ColFunctor<DEVICE_TYPE_GPU, float>(); }
......
...@@ -341,7 +341,7 @@ private: ...@@ -341,7 +341,7 @@ private:
}; };
REGISTER_TYPED_FUNC(MulOp, CPU, MulFunc); REGISTER_TYPED_FUNC(MulOp, CPU, MulFunc);
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
REGISTER_TYPED_FUNC(MulOp, GPU, MulFunc); REGISTER_TYPED_FUNC(MulOp, GPU, MulFunc);
#endif #endif
} // namespace paddle } // namespace paddle
...@@ -207,7 +207,7 @@ private: ...@@ -207,7 +207,7 @@ private:
REGISTER_TYPED_FUNC(Pad, CPU, PadFunc); REGISTER_TYPED_FUNC(Pad, CPU, PadFunc);
REGISTER_TYPED_FUNC(PadGrad, CPU, PadGradFunc); REGISTER_TYPED_FUNC(PadGrad, CPU, PadGradFunc);
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
REGISTER_TYPED_FUNC(Pad, GPU, PadFunc); REGISTER_TYPED_FUNC(Pad, GPU, PadFunc);
REGISTER_TYPED_FUNC(PadGrad, GPU, PadGradFunc); REGISTER_TYPED_FUNC(PadGrad, GPU, PadGradFunc);
#endif #endif
......
...@@ -217,7 +217,7 @@ public: ...@@ -217,7 +217,7 @@ public:
REGISTER_TYPED_FUNC(RowConv, CPU, RowConvFunc); REGISTER_TYPED_FUNC(RowConv, CPU, RowConvFunc);
REGISTER_TYPED_FUNC(RowConvGrad, CPU, RowConvGradFunc); REGISTER_TYPED_FUNC(RowConvGrad, CPU, RowConvGradFunc);
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
REGISTER_TYPED_FUNC(RowConv, GPU, RowConvFunc); REGISTER_TYPED_FUNC(RowConv, GPU, RowConvFunc);
REGISTER_TYPED_FUNC(RowConvGrad, GPU, RowConvGradFunc); REGISTER_TYPED_FUNC(RowConvGrad, GPU, RowConvGradFunc);
#endif #endif
......
...@@ -132,7 +132,7 @@ public: ...@@ -132,7 +132,7 @@ public:
REGISTER_TYPED_FUNC(NCHW2NHWC, CPU, NCHW2NHWCFunc); REGISTER_TYPED_FUNC(NCHW2NHWC, CPU, NCHW2NHWCFunc);
REGISTER_TYPED_FUNC(NHWC2NCHW, CPU, NHWC2NCHWFunc); REGISTER_TYPED_FUNC(NHWC2NCHW, CPU, NHWC2NCHWFunc);
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
REGISTER_TYPED_FUNC(NCHW2NHWC, GPU, NCHW2NHWCFunc); REGISTER_TYPED_FUNC(NCHW2NHWC, GPU, NCHW2NHWCFunc);
REGISTER_TYPED_FUNC(NHWC2NCHW, GPU, NHWC2NCHWFunc); REGISTER_TYPED_FUNC(NHWC2NCHW, GPU, NHWC2NCHWFunc);
#endif #endif
......
...@@ -16,7 +16,7 @@ limitations under the License. */ ...@@ -16,7 +16,7 @@ limitations under the License. */
#include "BatchNormalizationLayer.h" #include "BatchNormalizationLayer.h"
#include "Layer.h" #include "Layer.h"
#include "paddle/utils/Stat.h" #include "paddle/utils/Stat.h"
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
#include "CudnnBatchNormLayer.h" #include "CudnnBatchNormLayer.h"
#endif #endif
......
...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/utils/Stat.h" #include "paddle/utils/Stat.h"
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
#include "hl_batch_transpose.h" #include "hl_batch_transpose.h"
#endif #endif
#include "BatchNormalizationLayer.h" #include "BatchNormalizationLayer.h"
...@@ -90,7 +90,7 @@ void BatchNormalizationLayer::expandMat(const MatrixPtr& in, MatrixPtr& out) { ...@@ -90,7 +90,7 @@ void BatchNormalizationLayer::expandMat(const MatrixPtr& in, MatrixPtr& out) {
size_t batchSize = in->getHeight(); size_t batchSize = in->getHeight();
CHECK_EQ(out->getHeight(), batchSize * imgPixels_); CHECK_EQ(out->getHeight(), batchSize * imgPixels_);
if (useGpu_) { if (useGpu_) {
#ifdef PADDLE_ONLY_CPU #ifndef PADDLE_WITH_CUDA
LOG(FATAL) << "paddle is compiled only for cpu"; LOG(FATAL) << "paddle is compiled only for cpu";
#else #else
batchTranspose( batchTranspose(
...@@ -127,7 +127,7 @@ void BatchNormalizationLayer::shrinkMat(const MatrixPtr& in, MatrixPtr& out) { ...@@ -127,7 +127,7 @@ void BatchNormalizationLayer::shrinkMat(const MatrixPtr& in, MatrixPtr& out) {
} }
CHECK_EQ(in->getHeight(), static_cast<size_t>(batchSize * imgPixels_)); CHECK_EQ(in->getHeight(), static_cast<size_t>(batchSize * imgPixels_));
if (useGpu_) { if (useGpu_) {
#ifdef PADDLE_ONLY_CPU #ifndef PADDLE_WITH_CUDA
LOG(FATAL) << "paddle is compiled only for cpu"; LOG(FATAL) << "paddle is compiled only for cpu";
#else #else
batchTranspose( batchTranspose(
......
...@@ -15,7 +15,7 @@ limitations under the License. */ ...@@ -15,7 +15,7 @@ limitations under the License. */
#include "PoolLayer.h" #include "PoolLayer.h"
#include "PoolProjectionLayer.h" #include "PoolProjectionLayer.h"
#include "paddle/utils/Logging.h" #include "paddle/utils/Logging.h"
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
#include "CudnnPoolLayer.h" #include "CudnnPoolLayer.h"
#endif #endif
namespace paddle { namespace paddle {
...@@ -53,7 +53,7 @@ Layer* PoolLayer::create(const LayerConfig& config) { ...@@ -53,7 +53,7 @@ Layer* PoolLayer::create(const LayerConfig& config) {
const std::string& pool = config.inputs(0).pool_conf().pool_type(); const std::string& pool = config.inputs(0).pool_conf().pool_type();
if (pool == "max-projection" || pool == "avg-projection") { if (pool == "max-projection" || pool == "avg-projection") {
return new PoolProjectionLayer(config); return new PoolProjectionLayer(config);
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
} else if (CudnnPoolLayer::typeCheck(pool)) { } else if (CudnnPoolLayer::typeCheck(pool)) {
return new CudnnPoolLayer(config); return new CudnnPoolLayer(config);
#endif #endif
......
...@@ -674,7 +674,7 @@ void testLayerGradKernel(TestConfig testConf, ...@@ -674,7 +674,7 @@ void testLayerGradKernel(TestConfig testConf,
bool useGpu, bool useGpu,
bool useWeight, bool useWeight,
float epsilon) { float epsilon) {
#ifdef PADDLE_ONLY_CPU #ifndef PADDLE_WITH_CUDA
if (useGpu) return; if (useGpu) return;
#endif #endif
FLAGS_use_gpu = useGpu; FLAGS_use_gpu = useGpu;
......
...@@ -119,7 +119,7 @@ TEST(Layer, batchNorm) { ...@@ -119,7 +119,7 @@ TEST(Layer, batchNorm) {
CHECK_EQ(static_cast<int>(convLayer->getOutputValue()->getWidth()), 576); CHECK_EQ(static_cast<int>(convLayer->getOutputValue()->getWidth()), 576);
} }
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
void batchNormInference(int n, int c, int h, int w) { void batchNormInference(int n, int c, int h, int w) {
MatrixPtr input = std::make_shared<GpuMatrix>(n, c * h * w); MatrixPtr input = std::make_shared<GpuMatrix>(n, c * h * w);
MatrixPtr cudnnOut = std::make_shared<GpuMatrix>(n, c * h * w); MatrixPtr cudnnOut = std::make_shared<GpuMatrix>(n, c * h * w);
......
...@@ -117,7 +117,7 @@ MatrixPtr doOneConvTest(size_t imgSize, ...@@ -117,7 +117,7 @@ MatrixPtr doOneConvTest(size_t imgSize,
} }
TEST(Layer, convParaUnified) { TEST(Layer, convParaUnified) {
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
MatrixPtr input, resultCpu, resultGpu; MatrixPtr input, resultCpu, resultGpu;
/// TEST1 for conv /// /// TEST1 for conv ///
......
...@@ -150,7 +150,7 @@ TEST(Layer, detectionOutputLayerFwd) { ...@@ -150,7 +150,7 @@ TEST(Layer, detectionOutputLayerFwd) {
useGpu, useGpu,
result2); result2);
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
// GPU case 1. // GPU case 1.
useGpu = true; useGpu = true;
inputLoc = Matrix::create(1, 16, false, useGpu); inputLoc = Matrix::create(1, 16, false, useGpu);
......
...@@ -51,7 +51,7 @@ void testEvaluator(TestConfig testConf, ...@@ -51,7 +51,7 @@ void testEvaluator(TestConfig testConf,
string testEvaluatorName, string testEvaluatorName,
size_t batchSize, size_t batchSize,
bool useGpu) { bool useGpu) {
#ifdef PADDLE_ONLY_CPU #ifndef PADDLE_WITH_CUDA
if (useGpu) return; if (useGpu) return;
#endif #endif
FLAGS_use_gpu = useGpu; FLAGS_use_gpu = useGpu;
......
...@@ -97,7 +97,7 @@ TEST(Layer, kmaxSeqScoreLayer) { ...@@ -97,7 +97,7 @@ TEST(Layer, kmaxSeqScoreLayer) {
Matrix::create(subSeqStartPosition.back(), 1, false, false); Matrix::create(subSeqStartPosition.back(), 1, false, false);
std::vector<bool> mode = {false}; std::vector<bool> mode = {false};
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
mode.push_back(true); mode.push_back(true);
#endif #endif
......
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
#include <cudnn.h> #include <cudnn.h>
#endif #endif
#include <gtest/gtest.h> #include <gtest/gtest.h>
...@@ -258,7 +258,7 @@ void testProjectionConv(size_t groups, bool isDeconv) { ...@@ -258,7 +258,7 @@ void testProjectionConv(size_t groups, bool isDeconv) {
true); true);
} }
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
TEST(Projection, conv) { TEST(Projection, conv) {
/// test ConvProjection /// test ConvProjection
testProjectionConv(1, false); testProjectionConv(1, false);
...@@ -422,7 +422,7 @@ TEST(Layer, depthwiseConvLayer) { ...@@ -422,7 +422,7 @@ TEST(Layer, depthwiseConvLayer) {
// 'depthwise_conv' is a sepecial case of 'exconv' whose // 'depthwise_conv' is a sepecial case of 'exconv' whose
// groups size equals to the input channels size. // groups size equals to the input channels size.
testDepthwiseConvLayer("exconv", /* useGpu= */ false); testDepthwiseConvLayer("exconv", /* useGpu= */ false);
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
testDepthwiseConvLayer("exconv", /* useGpu= */ true); testDepthwiseConvLayer("exconv", /* useGpu= */ true);
#endif #endif
} }
...@@ -480,7 +480,7 @@ void testConvLayer(const string& type, bool trans, bool useGpu) { ...@@ -480,7 +480,7 @@ void testConvLayer(const string& type, bool trans, bool useGpu) {
TEST(Layer, convLayer) { TEST(Layer, convLayer) {
testConvLayer("exconv", /* trans= */ false, /* useGpu= */ false); testConvLayer("exconv", /* trans= */ false, /* useGpu= */ false);
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
testConvLayer("exconv", /* trans= */ false, /* useGpu= */ true); testConvLayer("exconv", /* trans= */ false, /* useGpu= */ true);
testConvLayer("cudnn_conv", /* trans= */ false, /* useGpu= */ true); testConvLayer("cudnn_conv", /* trans= */ false, /* useGpu= */ true);
#endif #endif
...@@ -525,7 +525,7 @@ TEST(Layer, convTransLayer) { ...@@ -525,7 +525,7 @@ TEST(Layer, convTransLayer) {
for (auto useGpu : {false, true}) { for (auto useGpu : {false, true}) {
testConvTransLayer("exconvt", /* trans= */ false, /* useGpu= */ useGpu); testConvTransLayer("exconvt", /* trans= */ false, /* useGpu= */ useGpu);
} }
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
testConvTransLayer("cudnn_convt", /* trans= */ false, /* useGpu= */ true); testConvTransLayer("cudnn_convt", /* trans= */ false, /* useGpu= */ true);
#endif #endif
} }
...@@ -638,7 +638,7 @@ TEST(Layer, SelectiveFullyConnectedLayer) { ...@@ -638,7 +638,7 @@ TEST(Layer, SelectiveFullyConnectedLayer) {
/* trans= */ false, /* trans= */ false,
/* useGup= */ false, /* useGup= */ false,
false); false);
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
testLayerGrad(config, testLayerGrad(config,
"selective_fc", "selective_fc",
100, 100,
...@@ -1210,7 +1210,7 @@ void testPoolLayer(const string& poolType, bool trans, bool useGpu) { ...@@ -1210,7 +1210,7 @@ void testPoolLayer(const string& poolType, bool trans, bool useGpu) {
testLayerGrad(config, "pool", 100, trans, useGpu); testLayerGrad(config, "pool", 100, trans, useGpu);
} }
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
void testPoolLayer2(const string& poolType, bool trans, bool useGpu) { void testPoolLayer2(const string& poolType, bool trans, bool useGpu) {
TestConfig config; TestConfig config;
config.inputDefs.push_back({INPUT_DATA, "layer_0", 3200, 0}); config.inputDefs.push_back({INPUT_DATA, "layer_0", 3200, 0});
...@@ -1236,7 +1236,7 @@ TEST(Layer, PoolLayer) { ...@@ -1236,7 +1236,7 @@ TEST(Layer, PoolLayer) {
testPoolLayer("avg-projection", /* trans= */ false, /* useGpu= */ false); testPoolLayer("avg-projection", /* trans= */ false, /* useGpu= */ false);
testPoolLayer("max-projection", /* trans= */ false, /* useGpu= */ false); testPoolLayer("max-projection", /* trans= */ false, /* useGpu= */ false);
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
testPoolLayer("avg-projection", /* trans= */ false, /* useGpu= */ true); testPoolLayer("avg-projection", /* trans= */ false, /* useGpu= */ true);
testPoolLayer("max-projection", /* trans= */ false, /* useGpu= */ true); testPoolLayer("max-projection", /* trans= */ false, /* useGpu= */ true);
testPoolLayer("cudnn-max-pool", /* trans= */ false, /* useGpu= */ true); testPoolLayer("cudnn-max-pool", /* trans= */ false, /* useGpu= */ true);
...@@ -1309,7 +1309,7 @@ void testPool3DLayer(const string& poolType, bool trans, bool useGpu) { ...@@ -1309,7 +1309,7 @@ void testPool3DLayer(const string& poolType, bool trans, bool useGpu) {
TEST(Layer, Pool3DLayer) { TEST(Layer, Pool3DLayer) {
testPool3DLayer("avg", /* trans= */ false, /* useGpu= */ false); testPool3DLayer("avg", /* trans= */ false, /* useGpu= */ false);
testPool3DLayer("max", /* trans= */ false, /* useGpu= */ false); testPool3DLayer("max", /* trans= */ false, /* useGpu= */ false);
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
testPool3DLayer("avg", /* trans= */ false, /* useGpu= */ true); testPool3DLayer("avg", /* trans= */ false, /* useGpu= */ true);
testPool3DLayer("max", /* trans= */ false, /* useGpu= */ true); testPool3DLayer("max", /* trans= */ false, /* useGpu= */ true);
#endif #endif
...@@ -1695,7 +1695,7 @@ void testBatchNormLayer(const string& type, bool trans, bool useGpu) { ...@@ -1695,7 +1695,7 @@ void testBatchNormLayer(const string& type, bool trans, bool useGpu) {
TEST(Layer, BatchNormalizationLayer) { TEST(Layer, BatchNormalizationLayer) {
testBatchNormLayer("batch_norm", false, false); testBatchNormLayer("batch_norm", false, false);
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
testBatchNormLayer("batch_norm", false, true); testBatchNormLayer("batch_norm", false, true);
if (hl_get_cudnn_lib_version() >= int(4000)) { if (hl_get_cudnn_lib_version() >= int(4000)) {
testBatchNormLayer("cudnn_batch_norm", false, true); testBatchNormLayer("cudnn_batch_norm", false, true);
...@@ -1744,7 +1744,7 @@ void testBatchNorm3DLayer(const string& type, bool trans, bool useGpu) { ...@@ -1744,7 +1744,7 @@ void testBatchNorm3DLayer(const string& type, bool trans, bool useGpu) {
TEST(Layer, testBatchNorm3DLayer) { TEST(Layer, testBatchNorm3DLayer) {
testBatchNorm3DLayer("batch_norm", false, false); testBatchNorm3DLayer("batch_norm", false, false);
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
testBatchNorm3DLayer("batch_norm", false, true); testBatchNorm3DLayer("batch_norm", false, true);
if (hl_get_cudnn_lib_version() >= int(4000)) { if (hl_get_cudnn_lib_version() >= int(4000)) {
testBatchNorm3DLayer("cudnn_batch_norm", false, true); testBatchNorm3DLayer("cudnn_batch_norm", false, true);
...@@ -2262,7 +2262,7 @@ void test3DConvLayer(const string& type, bool trans, bool useGpu) { ...@@ -2262,7 +2262,7 @@ void test3DConvLayer(const string& type, bool trans, bool useGpu) {
TEST(Layer, test3DConvLayer) { TEST(Layer, test3DConvLayer) {
test3DConvLayer("conv3d", /* trans= */ false, /* useGpu= */ false); test3DConvLayer("conv3d", /* trans= */ false, /* useGpu= */ false);
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
test3DConvLayer("conv3d", /* trans= */ false, /* useGpu= */ true); test3DConvLayer("conv3d", /* trans= */ false, /* useGpu= */ true);
#endif #endif
} }
...@@ -2339,7 +2339,7 @@ void test3DDeConvLayer(const string& type, bool trans, bool useGpu) { ...@@ -2339,7 +2339,7 @@ void test3DDeConvLayer(const string& type, bool trans, bool useGpu) {
TEST(Layer, test3DDeConvLayer) { TEST(Layer, test3DDeConvLayer) {
test3DDeConvLayer("deconv3d", /* trans= */ false, /* useGpu= */ false); test3DDeConvLayer("deconv3d", /* trans= */ false, /* useGpu= */ false);
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
test3DDeConvLayer("deconv3d", /* trans= */ false, /* useGpu= */ true); test3DDeConvLayer("deconv3d", /* trans= */ false, /* useGpu= */ true);
#endif #endif
} }
......
...@@ -243,7 +243,7 @@ TEST(Compare, concat_slice) { ...@@ -243,7 +243,7 @@ TEST(Compare, concat_slice) {
compareNetwork(config_file_a, config_file_b); compareNetwork(config_file_a, config_file_b);
} }
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
TEST(Compare, img_pool) { TEST(Compare, img_pool) {
std::string config_file_a = "./gserver/tests/img_pool_a.conf"; std::string config_file_a = "./gserver/tests/img_pool_a.conf";
std::string config_file_b = "./gserver/tests/img_pool_b.conf"; std::string config_file_b = "./gserver/tests/img_pool_b.conf";
......
...@@ -151,7 +151,7 @@ TEST(Layer, priorBoxLayerFwd) { ...@@ -151,7 +151,7 @@ TEST(Layer, priorBoxLayerFwd) {
useGpu, useGpu,
result); result);
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
// reset the input parameters // reset the input parameters
variance[1] = 0.1; variance[1] = 0.1;
variance[3] = 0.2; variance[3] = 0.2;
......
...@@ -485,7 +485,7 @@ TEST(ProtoDataProvider, test) { ...@@ -485,7 +485,7 @@ TEST(ProtoDataProvider, test) {
// Currently in async mode, useGpu is not supported // Currently in async mode, useGpu is not supported
continue; continue;
} }
#ifdef PADDLE_ONLY_CPU #ifndef PADDLE_WITH_CUDA
if (useGpu) { if (useGpu) {
continue; continue;
} }
...@@ -525,7 +525,7 @@ TEST(ProtoDataProvider, constant_slots) { ...@@ -525,7 +525,7 @@ TEST(ProtoDataProvider, constant_slots) {
for (int numConstantSlots : {1, 2}) { for (int numConstantSlots : {1, 2}) {
for (int useGpu : numTwoArray) { for (int useGpu : numTwoArray) {
for (int dataCompression : numTwoArray) { for (int dataCompression : numTwoArray) {
#ifdef PADDLE_ONLY_CPU #ifndef PADDLE_WITH_CUDA
if (useGpu) { if (useGpu) {
continue; continue;
} }
...@@ -708,7 +708,7 @@ TEST(ProtoSequenceDataProvider, test) { ...@@ -708,7 +708,7 @@ TEST(ProtoSequenceDataProvider, test) {
// Currently in async mode, useGpu is not supported // Currently in async mode, useGpu is not supported
continue; continue;
} }
#ifdef PADDLE_ONLY_CPU #ifndef PADDLE_WITH_CUDA
if (useGpu) { if (useGpu) {
continue; continue;
} }
......
...@@ -37,7 +37,7 @@ TEST(PyDataProvider, py_fill_slots) { ...@@ -37,7 +37,7 @@ TEST(PyDataProvider, py_fill_slots) {
config.clear_files(); config.clear_files();
std::string dataFile = "gserver/tests/pyDataProvider/pyDataProviderList"; std::string dataFile = "gserver/tests/pyDataProvider/pyDataProviderList";
config.set_files(dataFile); config.set_files(dataFile);
#ifdef PADDLE_ONLY_CPU #ifndef PADDLE_WITH_CUDA
bool useGpu = false; bool useGpu = false;
#else #else
bool useGpu = true; bool useGpu = true;
...@@ -71,7 +71,7 @@ TEST(PyDataProvider, py_fill_nest_slots) { ...@@ -71,7 +71,7 @@ TEST(PyDataProvider, py_fill_nest_slots) {
std::string dataFile = "gserver/tests/pyDataProvider/pyDataProviderList"; std::string dataFile = "gserver/tests/pyDataProvider/pyDataProviderList";
config.set_files(dataFile); config.set_files(dataFile);
EXPECT_EQ(config.IsInitialized(), true); EXPECT_EQ(config.IsInitialized(), true);
#ifdef PADDLE_ONLY_CPU #ifndef PADDLE_WITH_CUDA
bool useGpu = false; bool useGpu = false;
#else #else
bool useGpu = true; bool useGpu = true;
......
...@@ -321,7 +321,7 @@ TEST(Layer, SelectiveFcLayer_train_dense_mul) { ...@@ -321,7 +321,7 @@ TEST(Layer, SelectiveFcLayer_train_dense_mul) {
"filelist=gserver/tests/SelectiveFcTest/dense_mul_list"; "filelist=gserver/tests/SelectiveFcTest/dense_mul_list";
for (auto useGpu : {false, true}) { for (auto useGpu : {false, true}) {
#ifdef PADDLE_ONLY_CPU #ifndef PADDLE_WITH_CUDA
if (useGpu) { if (useGpu) {
break; break;
} }
...@@ -388,7 +388,7 @@ void testSelectiveFcLayerTrainSparseMul(const LayerConfig& config, ...@@ -388,7 +388,7 @@ void testSelectiveFcLayerTrainSparseMul(const LayerConfig& config,
outMatSelfc->getWidth(), outMatSelfc->getWidth(),
outMatSelfc->getElementCnt())); outMatSelfc->getElementCnt()));
cpuOutMatSelfc->copyFrom(*outMatSelfc, HPPL_STREAM_DEFAULT); cpuOutMatSelfc->copyFrom(*outMatSelfc, HPPL_STREAM_DEFAULT);
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
if (useGpu) { if (useGpu) {
hl_stream_synchronize(HPPL_STREAM_DEFAULT); hl_stream_synchronize(HPPL_STREAM_DEFAULT);
} }
...@@ -418,7 +418,7 @@ void testSelectiveFcLayerTrainSparseMul(const LayerConfig& config, ...@@ -418,7 +418,7 @@ void testSelectiveFcLayerTrainSparseMul(const LayerConfig& config,
MatrixPtr cpuOutMatFc( MatrixPtr cpuOutMatFc(
new CpuMatrix(outMatFc->getHeight(), outMatFc->getWidth())); new CpuMatrix(outMatFc->getHeight(), outMatFc->getWidth()));
cpuOutMatFc->copyFrom(*outMatFc, HPPL_STREAM_DEFAULT); cpuOutMatFc->copyFrom(*outMatFc, HPPL_STREAM_DEFAULT);
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
if (useGpu) { if (useGpu) {
hl_stream_synchronize(HPPL_STREAM_DEFAULT); hl_stream_synchronize(HPPL_STREAM_DEFAULT);
} }
...@@ -443,7 +443,7 @@ TEST(Layer, SelectiveFcLayer_train_sparse_mul) { ...@@ -443,7 +443,7 @@ TEST(Layer, SelectiveFcLayer_train_sparse_mul) {
selLayerConfig.set_size(fcLayerWidth); selLayerConfig.set_size(fcLayerWidth);
testSelectiveFcLayerTrainSparseMul(selLayerConfig, false); testSelectiveFcLayerTrainSparseMul(selLayerConfig, false);
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
testSelectiveFcLayerTrainSparseMul(selLayerConfig, true); testSelectiveFcLayerTrainSparseMul(selLayerConfig, true);
#endif #endif
} }
......
...@@ -195,7 +195,7 @@ TEST(Layer, SeqSliceLayer) { ...@@ -195,7 +195,7 @@ TEST(Layer, SeqSliceLayer) {
vector<vector<real>> ends; vector<vector<real>> ends;
std::vector<bool> mode = {false}; std::vector<bool> mode = {false};
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
mode.push_back(true); mode.push_back(true);
#endif #endif
genSeqInfo(seqStartPos, subSeqStartPos); genSeqInfo(seqStartPos, subSeqStartPos);
......
...@@ -199,7 +199,7 @@ TEST(Layer, WarpCTCLayer) { ...@@ -199,7 +199,7 @@ TEST(Layer, WarpCTCLayer) {
for (auto batchSize : {1, 10, 32}) { for (auto batchSize : {1, 10, 32}) {
for (auto normByTimes : {false, true}) { for (auto normByTimes : {false, true}) {
for (auto useGpu : {false, true}) { for (auto useGpu : {false, true}) {
#ifdef PADDLE_ONLY_CPU #ifndef PADDLE_WITH_CUDA
if (useGpu) continue; if (useGpu) continue;
#endif #endif
LOG(INFO) << "layerSize=" << layerSize << " batchSize=" << batchSize LOG(INFO) << "layerSize=" << layerSize << " batchSize=" << batchSize
......
...@@ -670,7 +670,7 @@ void GpuMatrix::leftMul(Matrix& a, real scaleAB, real scaleT) { ...@@ -670,7 +670,7 @@ void GpuMatrix::leftMul(Matrix& a, real scaleAB, real scaleT) {
} }
void GpuMatrix::selectRows(Matrix& table, IVector& ids) { void GpuMatrix::selectRows(Matrix& table, IVector& ids) {
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
CHECK(dynamic_cast<GpuMatrix*>(&table)); CHECK(dynamic_cast<GpuMatrix*>(&table));
CHECK(table.useGpu()); CHECK(table.useGpu());
CHECK(ids.useGpu()); CHECK(ids.useGpu());
...@@ -694,7 +694,7 @@ void GpuMatrix::selectRows(Matrix& table, IVector& ids) { ...@@ -694,7 +694,7 @@ void GpuMatrix::selectRows(Matrix& table, IVector& ids) {
} }
void GpuMatrix::addToRows(Matrix& table, IVector& ids) { void GpuMatrix::addToRows(Matrix& table, IVector& ids) {
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
CHECK(dynamic_cast<GpuMatrix*>(&table)); CHECK(dynamic_cast<GpuMatrix*>(&table));
CHECK(table.useGpu()); CHECK(table.useGpu());
CHECK(ids.useGpu()); CHECK(ids.useGpu());
...@@ -741,7 +741,7 @@ void GpuMatrix::rowMax(Matrix& max) { ...@@ -741,7 +741,7 @@ void GpuMatrix::rowMax(Matrix& max) {
} }
void GpuMatrix::rowMax(IVector& maxIds, Matrix& maxVal) { void GpuMatrix::rowMax(IVector& maxIds, Matrix& maxVal) {
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
CHECK(maxIds.useGpu() && maxVal.useGpu()) << "Matrix type are not equal"; CHECK(maxIds.useGpu() && maxVal.useGpu()) << "Matrix type are not equal";
size_t numSamples = getHeight(); size_t numSamples = getHeight();
size_t beam = maxVal.getWidth(); size_t beam = maxVal.getWidth();
......
...@@ -836,7 +836,7 @@ void GpuSparseMatrix::zeroMem() { ...@@ -836,7 +836,7 @@ void GpuSparseMatrix::zeroMem() {
} }
void GpuSparseMatrix::rowMax(IVector& maxIds, Matrix& maxVal) { void GpuSparseMatrix::rowMax(IVector& maxIds, Matrix& maxVal) {
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
CHECK(maxIds.useGpu() && maxVal.useGpu()) << "Matrix type are not equal"; CHECK(maxIds.useGpu() && maxVal.useGpu()) << "Matrix type are not equal";
size_t numSamples = getHeight(); size_t numSamples = getHeight();
size_t beam = maxVal.getWidth(); size_t beam = maxVal.getWidth();
......
...@@ -172,7 +172,7 @@ void GpuVectorT<T>::isEqualTo(const VectorT<T>& b, const T& value) { ...@@ -172,7 +172,7 @@ void GpuVectorT<T>::isEqualTo(const VectorT<T>& b, const T& value) {
template <class T> template <class T>
void GpuVectorT<T>::selectFrom(const VectorT<T>& src, const VectorT<int>& ids) { void GpuVectorT<T>::selectFrom(const VectorT<T>& src, const VectorT<int>& ids) {
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
hl_vector_select_from<T>(this->getData(), hl_vector_select_from<T>(this->getData(),
this->getSize(), this->getSize(),
src.getData(), src.getData(),
...@@ -850,7 +850,7 @@ CpuGpuVectorT<T>::CpuGpuVectorT(CpuGpuVectorT<T>& src, ...@@ -850,7 +850,7 @@ CpuGpuVectorT<T>::CpuGpuVectorT(CpuGpuVectorT<T>& src,
size_t size) size_t size)
: sync_(nullptr) { : sync_(nullptr) {
CHECK_LE(offset + size, static_cast<size_t>(src.getSize())); CHECK_LE(offset + size, static_cast<size_t>(src.getSize()));
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
SyncedFlag* flag = src.getSync(); SyncedFlag* flag = src.getSync();
if (*flag == DATA_AT_CPU) { if (*flag == DATA_AT_CPU) {
src.copyToGpu(); // will set synchronous data between CPU and GPU src.copyToGpu(); // will set synchronous data between CPU and GPU
...@@ -861,7 +861,7 @@ CpuGpuVectorT<T>::CpuGpuVectorT(CpuGpuVectorT<T>& src, ...@@ -861,7 +861,7 @@ CpuGpuVectorT<T>::CpuGpuVectorT(CpuGpuVectorT<T>& src,
auto cMemHandle = (src.getVector(false))->getMemoryHandle(); auto cMemHandle = (src.getVector(false))->getMemoryHandle();
cpuVectorT_ = std::make_shared<CpuVectorT<T>>( cpuVectorT_ = std::make_shared<CpuVectorT<T>>(
size, std::dynamic_pointer_cast<CpuMemoryHandle>(cMemHandle), offset); size, std::dynamic_pointer_cast<CpuMemoryHandle>(cMemHandle), offset);
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
auto gMemHandle = (src.getVector(true))->getMemoryHandle(); auto gMemHandle = (src.getVector(true))->getMemoryHandle();
gpuVectorT_ = std::make_shared<GpuVectorT<T>>( gpuVectorT_ = std::make_shared<GpuVectorT<T>>(
size, std::dynamic_pointer_cast<GpuMemoryHandle>(gMemHandle), offset); size, std::dynamic_pointer_cast<GpuMemoryHandle>(gMemHandle), offset);
......
...@@ -68,7 +68,7 @@ void testPoolAllocator() { ...@@ -68,7 +68,7 @@ void testPoolAllocator() {
TEST(Allocator, Pool) { TEST(Allocator, Pool) {
testPoolAllocator<CpuAllocator>(); testPoolAllocator<CpuAllocator>();
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
testPoolAllocator<GpuAllocator>(); testPoolAllocator<GpuAllocator>();
#endif #endif
} }
...@@ -92,7 +92,7 @@ TEST(MemoryHandle, Cpu) { ...@@ -92,7 +92,7 @@ TEST(MemoryHandle, Cpu) {
EXPECT_EQ(ptr1, ptr2); EXPECT_EQ(ptr1, ptr2);
} }
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
TEST(MemoryHandle, Gpu) { TEST(MemoryHandle, Gpu) {
int numGpu = hl_get_device_count(); int numGpu = hl_get_device_count();
......
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
/** /**
* This test file use autotest::AutoCompare and cmpWithoutArg to compares the * This test file use autotest::AutoCompare and cmpWithoutArg to compares the
* implementation of CPU and GPU member function in * implementation of CPU and GPU member function in
......
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "paddle/math/Vector.h" #include "paddle/math/Vector.h"
......
...@@ -94,7 +94,7 @@ void testWrapper(F&& f) { ...@@ -94,7 +94,7 @@ void testWrapper(F&& f) {
} }
} }
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
TEST(ExecViaCpu, test1) { TEST(ExecViaCpu, test1) {
testWrapper(f); testWrapper(f);
testWrapper(&f); testWrapper(&f);
......
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "paddle/math/Matrix.h" #include "paddle/math/Matrix.h"
......
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
/** /**
* This test file use autotest::AutoCompare and cmpWithArg to compares the * This test file use autotest::AutoCompare and cmpWithArg to compares the
* implementation of CPU and GPU member function in Matrix.cpp. * implementation of CPU and GPU member function in Matrix.cpp.
......
...@@ -47,7 +47,7 @@ struct MatrixPara { ...@@ -47,7 +47,7 @@ struct MatrixPara {
SparseFormat format; SparseFormat format;
}; };
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
void test_sparse_matrix_mul(MatrixPara paraA, void test_sparse_matrix_mul(MatrixPara paraA,
MatrixPara paraB, MatrixPara paraB,
MatrixPara paraC) { MatrixPara paraC) {
...@@ -452,7 +452,7 @@ TEST(Matrix, SparseMatrixCSRFormatTrimFrom) { ...@@ -452,7 +452,7 @@ TEST(Matrix, SparseMatrixCSRFormatTrimFrom) {
matB->trimFrom(*mat); matB->trimFrom(*mat);
checkSMatrixEqual2(matA, matB); checkSMatrixEqual2(matA, matB);
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
GpuSparseMatrixPtr matC = std::make_shared<GpuSparseMatrix>( GpuSparseMatrixPtr matC = std::make_shared<GpuSparseMatrix>(
height, trimedWidth, height, FLOAT_VALUE, SPARSE_CSR, true); height, trimedWidth, height, FLOAT_VALUE, SPARSE_CSR, true);
matC->trimFrom(*mat); matC->trimFrom(*mat);
...@@ -546,7 +546,7 @@ TEST(Matrix, SparseMatrixCSCFormatTrimFrom) { ...@@ -546,7 +546,7 @@ TEST(Matrix, SparseMatrixCSCFormatTrimFrom) {
matB->trimFrom(*mat); matB->trimFrom(*mat);
checkSMatrixEqual2(matA, matB); checkSMatrixEqual2(matA, matB);
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
GpuSparseMatrixPtr matC = std::make_shared<GpuSparseMatrix>( GpuSparseMatrixPtr matC = std::make_shared<GpuSparseMatrix>(
height, trimedWidth, height, FLOAT_VALUE, SPARSE_CSC, true); height, trimedWidth, height, FLOAT_VALUE, SPARSE_CSC, true);
matC->trimFrom(*mat); matC->trimFrom(*mat);
......
...@@ -270,7 +270,7 @@ TEST(Unary, BaseOp) { ...@@ -270,7 +270,7 @@ TEST(Unary, BaseOp) {
TestUnaryVectorT<CpuIVector, int> testCpuIVector( TestUnaryVectorT<CpuIVector, int> testCpuIVector(
testUnaryBaseOpInt<CpuIVector>); testUnaryBaseOpInt<CpuIVector>);
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_GPU
TestUnaryMatrix<GpuMatrix> testGpuMatrix(testUnaryBaseOp<GpuMatrix>); TestUnaryMatrix<GpuMatrix> testGpuMatrix(testUnaryBaseOp<GpuMatrix>);
TestUnaryVectorT<GpuVector, real> testGpuVector(testUnaryBaseOp<GpuVector>); TestUnaryVectorT<GpuVector, real> testGpuVector(testUnaryBaseOp<GpuVector>);
TestUnaryVectorT<GpuIVector, int> testGpuIVector( TestUnaryVectorT<GpuIVector, int> testGpuIVector(
...@@ -317,7 +317,7 @@ void testUnayrMathOp(Tensor& A1, Tensor& A2) { ...@@ -317,7 +317,7 @@ void testUnayrMathOp(Tensor& A1, Tensor& A2) {
TEST(Unary, MathOp) { TEST(Unary, MathOp) {
TestUnaryMatrix<CpuMatrix> testCpu(testUnayrMathOp<CpuMatrix>); TestUnaryMatrix<CpuMatrix> testCpu(testUnayrMathOp<CpuMatrix>);
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_GPU
TestUnaryMatrix<GpuMatrix> testGpu(testUnayrMathOp<GpuMatrix>); TestUnaryMatrix<GpuMatrix> testGpu(testUnayrMathOp<GpuMatrix>);
#endif #endif
} }
...@@ -374,7 +374,7 @@ void testUnayrCompareOp(Tensor& A1, Tensor& A2) { ...@@ -374,7 +374,7 @@ void testUnayrCompareOp(Tensor& A1, Tensor& A2) {
TEST(Unary, CompareOp) { TEST(Unary, CompareOp) {
TestUnaryMatrix<CpuMatrix> testCpu(testUnayrCompareOp<CpuMatrix>); TestUnaryMatrix<CpuMatrix> testCpu(testUnayrCompareOp<CpuMatrix>);
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_GPU
TestUnaryMatrix<GpuMatrix> testGpu(testUnayrCompareOp<GpuMatrix>); TestUnaryMatrix<GpuMatrix> testGpu(testUnayrCompareOp<GpuMatrix>);
#endif #endif
} }
...@@ -536,7 +536,7 @@ void testBinaryBaseOp(Tensor& A1, Tensor& A2, Tensor& B) { ...@@ -536,7 +536,7 @@ void testBinaryBaseOp(Tensor& A1, Tensor& A2, Tensor& B) {
TEST(Binary, BaseOp) { TEST(Binary, BaseOp) {
TestBinaryMatrix<CpuMatrix> testCpu(testBinaryBaseOp<CpuMatrix>); TestBinaryMatrix<CpuMatrix> testCpu(testBinaryBaseOp<CpuMatrix>);
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_GPU
TestBinaryMatrix<GpuMatrix> testGpu(testBinaryBaseOp<GpuMatrix>); TestBinaryMatrix<GpuMatrix> testGpu(testBinaryBaseOp<GpuMatrix>);
#endif #endif
} }
...@@ -710,7 +710,7 @@ void testBinaryMathOp(Tensor& A1, Tensor& A2, Tensor& B) { ...@@ -710,7 +710,7 @@ void testBinaryMathOp(Tensor& A1, Tensor& A2, Tensor& B) {
TEST(Binary, MathOp) { TEST(Binary, MathOp) {
TestBinaryMatrix<CpuMatrix> testCpu(testBinaryMathOp<CpuMatrix>); TestBinaryMatrix<CpuMatrix> testCpu(testBinaryMathOp<CpuMatrix>);
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_GPU
TestBinaryMatrix<GpuMatrix> testGpu(testBinaryMathOp<GpuMatrix>); TestBinaryMatrix<GpuMatrix> testGpu(testBinaryMathOp<GpuMatrix>);
#endif #endif
} }
...@@ -810,7 +810,7 @@ void testBinaryCompareOp(Tensor& A1, Tensor& A2, Tensor& B) { ...@@ -810,7 +810,7 @@ void testBinaryCompareOp(Tensor& A1, Tensor& A2, Tensor& B) {
TEST(Binary, CompareOp) { TEST(Binary, CompareOp) {
TestBinaryMatrix<CpuMatrix> testCpu(testBinaryCompareOp<CpuMatrix>); TestBinaryMatrix<CpuMatrix> testCpu(testBinaryCompareOp<CpuMatrix>);
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_GPU
TestBinaryMatrix<GpuMatrix> testGpu(testBinaryCompareOp<GpuMatrix>); TestBinaryMatrix<GpuMatrix> testGpu(testBinaryCompareOp<GpuMatrix>);
#endif #endif
} }
...@@ -955,7 +955,7 @@ void testTernaryBaseOp(Tensor& A1, Tensor& A2, Tensor& B, Tensor& C) { ...@@ -955,7 +955,7 @@ void testTernaryBaseOp(Tensor& A1, Tensor& A2, Tensor& B, Tensor& C) {
TEST(Ternary, BaseOp) { TEST(Ternary, BaseOp) {
TestTernaryMatrix<CpuMatrix> testCpu(testTernaryBaseOp<CpuMatrix>); TestTernaryMatrix<CpuMatrix> testCpu(testTernaryBaseOp<CpuMatrix>);
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_GPU
TestTernaryMatrix<GpuMatrix> testGpu(testTernaryBaseOp<GpuMatrix>); TestTernaryMatrix<GpuMatrix> testGpu(testTernaryBaseOp<GpuMatrix>);
#endif #endif
} }
...@@ -1058,7 +1058,7 @@ void testTernaryCompareOp(Tensor& A1, Tensor& A2, Tensor& B, Tensor& C) { ...@@ -1058,7 +1058,7 @@ void testTernaryCompareOp(Tensor& A1, Tensor& A2, Tensor& B, Tensor& C) {
TEST(Ternary, CompareOp) { TEST(Ternary, CompareOp) {
TestTernaryMatrix<CpuMatrix> testCpu(testTernaryCompareOp<CpuMatrix>); TestTernaryMatrix<CpuMatrix> testCpu(testTernaryCompareOp<CpuMatrix>);
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_GPU
TestTernaryMatrix<GpuMatrix> testGpu(testTernaryCompareOp<GpuMatrix>); TestTernaryMatrix<GpuMatrix> testGpu(testTernaryCompareOp<GpuMatrix>);
#endif #endif
} }
...@@ -1086,7 +1086,7 @@ void testQuaternaryAdd( ...@@ -1086,7 +1086,7 @@ void testQuaternaryAdd(
TEST(Quaternary, BaseOp) { TEST(Quaternary, BaseOp) {
TestQuaternaryMatrix<CpuMatrix> testCpu(testQuaternaryAdd<CpuMatrix>); TestQuaternaryMatrix<CpuMatrix> testCpu(testQuaternaryAdd<CpuMatrix>);
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_GPU
TestQuaternaryMatrix<GpuMatrix> testGpu(testQuaternaryAdd<GpuMatrix>); TestQuaternaryMatrix<GpuMatrix> testGpu(testQuaternaryAdd<GpuMatrix>);
#endif #endif
} }
...@@ -1156,7 +1156,7 @@ void testQuaternaryCompareOp( ...@@ -1156,7 +1156,7 @@ void testQuaternaryCompareOp(
TEST(Quaternary, CompareOp) { TEST(Quaternary, CompareOp) {
TestQuaternaryMatrix<CpuMatrix> testCpu(testQuaternaryCompareOp<CpuMatrix>); TestQuaternaryMatrix<CpuMatrix> testCpu(testQuaternaryCompareOp<CpuMatrix>);
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_GPU
TestQuaternaryMatrix<GpuMatrix> testGpu(testQuaternaryCompareOp<GpuMatrix>); TestQuaternaryMatrix<GpuMatrix> testGpu(testQuaternaryCompareOp<GpuMatrix>);
#endif #endif
} }
...@@ -91,7 +91,7 @@ int VectorCheckErr(const VectorPtr& vector1, const VectorPtr& vector2) { ...@@ -91,7 +91,7 @@ int VectorCheckErr(const VectorPtr& vector1, const VectorPtr& vector2) {
typedef std::function<void(size_t size, bool useGpu)> testMatrixFunc; typedef std::function<void(size_t size, bool useGpu)> testMatrixFunc;
void testCase(testMatrixFunc matrixFunc) { void testCase(testMatrixFunc matrixFunc) {
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
for (auto useGpu : {false, true}) { for (auto useGpu : {false, true}) {
#else #else
for (auto useGpu : {false}) { for (auto useGpu : {false}) {
......
...@@ -17,7 +17,7 @@ limitations under the License. */ ...@@ -17,7 +17,7 @@ limitations under the License. */
using namespace paddle; // NOLINT using namespace paddle; // NOLINT
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
TEST(MatrixBatchTransTest, test_batch_matrix_transpose) { TEST(MatrixBatchTransTest, test_batch_matrix_transpose) {
const int nx = 100; const int nx = 100;
const int ny = 50; const int ny = 50;
......
...@@ -72,7 +72,7 @@ void testLazyAssign(int height, int width) { ...@@ -72,7 +72,7 @@ void testLazyAssign(int height, int width) {
TEST(lazyAssign, CPU) { testMatrixCase(testLazyAssign<CpuMatrix>); } TEST(lazyAssign, CPU) { testMatrixCase(testLazyAssign<CpuMatrix>); }
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_GPU
TEST(lazyAssign, GPU) { testMatrixCase(testLazyAssign<GpuMatrix>); } TEST(lazyAssign, GPU) { testMatrixCase(testLazyAssign<GpuMatrix>); }
#endif #endif
...@@ -142,6 +142,6 @@ void testSgdUpdate(int height, int width) { ...@@ -142,6 +142,6 @@ void testSgdUpdate(int height, int width) {
TEST(sgdUpdate, CPU) { testMatrixCase(testSgdUpdate<CpuMatrix>); } TEST(sgdUpdate, CPU) { testMatrixCase(testSgdUpdate<CpuMatrix>); }
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_GPU
TEST(sgdUpdate, GPU) { testMatrixCase(testSgdUpdate<GpuMatrix>); } TEST(sgdUpdate, GPU) { testMatrixCase(testSgdUpdate<GpuMatrix>); }
#endif #endif
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
/// This unittest checks GpuMatrix/CpuMatrix get same result, so disable when /// This unittest checks GpuMatrix/CpuMatrix get same result, so disable when
/// only cpu version. /// only cpu version.
......
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
......
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
/// This unittest checks GpuSparseMatrix/CpuSparseMatrix get same result, /// This unittest checks GpuSparseMatrix/CpuSparseMatrix get same result,
// so disable when // so disable when
/// only cpu version. /// only cpu version.
......
...@@ -175,7 +175,7 @@ void* BuddyAllocator::SystemAlloc(size_t size) { ...@@ -175,7 +175,7 @@ void* BuddyAllocator::SystemAlloc(size_t size) {
} }
BuddyAllocator::PoolSet::iterator BuddyAllocator::RefillPool() { BuddyAllocator::PoolSet::iterator BuddyAllocator::RefillPool() {
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
if (system_allocator_->UseGpu()) { if (system_allocator_->UseGpu()) {
if ((total_used_ + total_free_) == 0) { if ((total_used_ + total_free_) == 0) {
// Compute the maximum allocation size for the first allocation. // Compute the maximum allocation size for the first allocation.
......
...@@ -62,7 +62,7 @@ void CPUAllocator::Free(void* p, size_t size, size_t index) { ...@@ -62,7 +62,7 @@ void CPUAllocator::Free(void* p, size_t size, size_t index) {
bool CPUAllocator::UseGpu() const { return false; } bool CPUAllocator::UseGpu() const { return false; }
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
void* GPUAllocator::Alloc(size_t& index, size_t size) { void* GPUAllocator::Alloc(size_t& index, size_t size) {
// CUDA documentation doesn't explain if cudaMalloc returns nullptr // CUDA documentation doesn't explain if cudaMalloc returns nullptr
......
...@@ -40,7 +40,7 @@ class CPUAllocator : public SystemAllocator { ...@@ -40,7 +40,7 @@ class CPUAllocator : public SystemAllocator {
virtual bool UseGpu() const; virtual bool UseGpu() const;
}; };
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
class GPUAllocator : public SystemAllocator { class GPUAllocator : public SystemAllocator {
public: public:
virtual void* Alloc(size_t& index, size_t size); virtual void* Alloc(size_t& index, size_t size);
......
...@@ -56,7 +56,7 @@ TEST(CPUAllocator, LockMem) { ...@@ -56,7 +56,7 @@ TEST(CPUAllocator, LockMem) {
TestAllocator(a, 0); TestAllocator(a, 0);
} }
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
TEST(GPUAllocator, Alloc) { TEST(GPUAllocator, Alloc) {
paddle::memory::detail::GPUAllocator a; paddle::memory::detail::GPUAllocator a;
TestAllocator(a, 2048); TestAllocator(a, 2048);
......
...@@ -26,7 +26,7 @@ void Copy<platform::CPUPlace, platform::CPUPlace>(platform::CPUPlace, void* dst, ...@@ -26,7 +26,7 @@ void Copy<platform::CPUPlace, platform::CPUPlace>(platform::CPUPlace, void* dst,
std::memcpy(dst, src, num); std::memcpy(dst, src, num);
} }
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
template <> template <>
void Copy<platform::CPUPlace, platform::GPUPlace>(platform::CPUPlace dst_place, void Copy<platform::CPUPlace, platform::GPUPlace>(platform::CPUPlace dst_place,
void* dst, void* dst,
......
...@@ -33,7 +33,7 @@ namespace memory { ...@@ -33,7 +33,7 @@ namespace memory {
template <typename DstPlace, typename SrcPlace> template <typename DstPlace, typename SrcPlace>
void Copy(DstPlace, void* dst, SrcPlace, const void* src, size_t num); void Copy(DstPlace, void* dst, SrcPlace, const void* src, size_t num);
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
/** /**
* \brief Copy memory from one place to another place. * \brief Copy memory from one place to another place.
......
...@@ -62,7 +62,7 @@ size_t Used<platform::CPUPlace>(platform::CPUPlace place) { ...@@ -62,7 +62,7 @@ size_t Used<platform::CPUPlace>(platform::CPUPlace place) {
return GetCPUBuddyAllocator()->Used(); return GetCPUBuddyAllocator()->Used();
} }
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
BuddyAllocator* GetGPUBuddyAllocator(int gpu_id) { BuddyAllocator* GetGPUBuddyAllocator(int gpu_id) {
using BuddyAllocVec = std::vector<BuddyAllocator*>; using BuddyAllocVec = std::vector<BuddyAllocator*>;
...@@ -77,7 +77,7 @@ BuddyAllocator* GetGPUBuddyAllocator(int gpu_id) { ...@@ -77,7 +77,7 @@ BuddyAllocator* GetGPUBuddyAllocator(int gpu_id) {
// GPU buddy allocator initialization // GPU buddy allocator initialization
std::call_once(gpu_allocator_flag, [&]() { std::call_once(gpu_allocator_flag, [&]() {
int gpu_num = platform::GetDeviceCount(); int gpu_num = platform::GetCUDADeviceCount();
allocators.reserve(gpu_num); allocators.reserve(gpu_num);
for (int gpu = 0; gpu < gpu_num; gpu++) { for (int gpu = 0; gpu < gpu_num; gpu++) {
platform::SetDeviceId(gpu); platform::SetDeviceId(gpu);
......
...@@ -80,7 +80,7 @@ TEST(BuddyAllocator, CPUMultAlloc) { ...@@ -80,7 +80,7 @@ TEST(BuddyAllocator, CPUMultAlloc) {
} }
} }
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
size_t align(size_t size, paddle::platform::GPUPlace place) { size_t align(size_t size, paddle::platform::GPUPlace place) {
size += sizeof(paddle::memory::detail::Metadata); size += sizeof(paddle::memory::detail::Metadata);
......
...@@ -69,6 +69,22 @@ class ReluOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -69,6 +69,22 @@ class ReluOpMaker : public framework::OpProtoAndCheckerMaker {
} }
}; };
template <typename AttrType>
class LeakyReluOpMaker : public framework::OpProtoAndCheckerMaker {
public:
LeakyReluOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of LeakyRelu operator");
AddOutput("Y", "Output of LeakyRelu operator");
AddComment(
"LeakyRelu activation operator, "
"leaky_relu = max(x, alpha * x)");
AddAttr<AttrType>("alpha", "The small negative slope")
.SetDefault(static_cast<AttrType>(0.02f));
}
};
class TanhOpMaker : public framework::OpProtoAndCheckerMaker { class TanhOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
TanhOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) TanhOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
...@@ -240,6 +256,9 @@ REGISTER_OP(softsign, ops::ActivationOp, ops::SoftsignOpMaker, softsign_grad, ...@@ -240,6 +256,9 @@ REGISTER_OP(softsign, ops::ActivationOp, ops::SoftsignOpMaker, softsign_grad,
REGISTER_OP(brelu, ops::ActivationOp, ops::BReluOpMaker<float>, brelu_grad, REGISTER_OP(brelu, ops::ActivationOp, ops::BReluOpMaker<float>, brelu_grad,
ops::ActivationOpGrad); ops::ActivationOpGrad);
REGISTER_OP(leaky_relu, ops::ActivationOp, ops::LeakyReluOpMaker<float>,
leaky_relu_grad, ops::ActivationOpGrad);
REGISTER_OP(soft_relu, ops::ActivationOp, ops::SoftReluOpMaker<float>, REGISTER_OP(soft_relu, ops::ActivationOp, ops::SoftReluOpMaker<float>,
soft_relu_grad, ops::ActivationOpGrad); soft_relu_grad, ops::ActivationOpGrad);
......
...@@ -309,6 +309,33 @@ struct SoftReluGradFunctor : public BaseActivationFunctor<T> { ...@@ -309,6 +309,33 @@ struct SoftReluGradFunctor : public BaseActivationFunctor<T> {
} }
}; };
template <typename T>
struct LeakyReluFunctor : public BaseActivationFunctor<T> {
float alpha;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"alpha", &alpha}};
}
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) const {
y.device(d) = x.cwiseMax(alpha * x);
}
};
template <typename T>
struct LeakyReluGradFunctor : public BaseActivationFunctor<T> {
float alpha;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"alpha", &alpha}};
}
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
auto temp1 = alpha * (x < static_cast<T>(0)).template cast<T>().eval();
auto temp2 = (x >= static_cast<T>(0)).template cast<T>().eval();
dx.device(d) = dy * (temp1 + temp2).template cast<T>();
}
};
template <typename T> template <typename T>
struct PowFunctor : public BaseActivationFunctor<T> { struct PowFunctor : public BaseActivationFunctor<T> {
float factor; float factor;
...@@ -379,4 +406,5 @@ struct STanhGradFunctor : public BaseActivationFunctor<T> { ...@@ -379,4 +406,5 @@ struct STanhGradFunctor : public BaseActivationFunctor<T> {
__macro(soft_relu, SoftReluFunctor, SoftReluGradFunctor); \ __macro(soft_relu, SoftReluFunctor, SoftReluGradFunctor); \
__macro(pow, PowFunctor, PowGradFunctor); \ __macro(pow, PowFunctor, PowGradFunctor); \
__macro(stanh, STanhFunctor, STanhGradFunctor); \ __macro(stanh, STanhFunctor, STanhGradFunctor); \
__macro(softsign, SoftsignFunctor, SoftsignGradFunctor) __macro(softsign, SoftsignFunctor, SoftsignGradFunctor); \
__macro(leaky_relu, LeakyReluFunctor, LeakyReluGradFunctor)
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/adadelta_op.h"
namespace paddle {
namespace operators {
class AdadeltaOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Param"),
"Input(Param) of AdadeltaOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Grad"),
"Input(Grad) of AdadeltaOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("AvgSquaredGrad"),
"Input(AvgSquaredGrad) of AdadeltaOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("AvgSquaredUpdate"),
"Input(AvgSquaredUpdate) of AdadeltaOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("ParamOut"),
"Output(ParamOut) of AdadeltaOp should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("AvgSquaredGradOut"),
"Output(AvgSquaredGradOut) of AdadeltaOp should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("AvgSquaredUpdateOut"),
"Output(AvgSquaredUpdateOut) of AdadeltaOp should not be null.");
auto param_dim = ctx->GetInputDim("Param");
PADDLE_ENFORCE_EQ(
param_dim, ctx->GetInputDim("Grad"),
"param and grad input of AdadeltaOp should have same dimension");
PADDLE_ENFORCE_EQ(param_dim, ctx->GetInputDim("AvgSquaredGrad"),
"Param and AvgSquaredGrad input of AdadeltaOp "
"should have same dimension");
PADDLE_ENFORCE_EQ(param_dim, ctx->GetInputDim("AvgSquaredUpdate"),
"Param and AvgSquaredUpdate input of AdadeltaOp "
"should have same dimension");
ctx->SetOutputDim("ParamOut", param_dim);
ctx->SetOutputDim("AvgSquaredGradOut", param_dim);
ctx->SetOutputDim("AvgSquaredUpdateOut", param_dim);
}
};
class AdadeltaOpMaker : public framework::OpProtoAndCheckerMaker {
public:
AdadeltaOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Param", "(Tensor) Input parameter");
AddInput("Grad", "(Tensor) Input gradient");
AddInput("AvgSquaredGrad",
"(Tensor) Input expectation of squared gradient");
AddInput("AvgSquaredUpdate",
"(Tensor) Input expectation of squared parameter updates");
AddOutput("ParamOut", "(Tensor) Output parameter");
AddOutput("AvgSquaredGradOut",
"(Tensor) Output expectation of squared gradient");
AddOutput("AvgSquaredUpdateOut",
"(Tensor) Output expectation of squared parameter updates");
AddAttr<float>("rho",
"(float, default 0.95) Exponential decay rate "
"for squared gradients.")
.SetDefault(0.95f);
AddAttr<float>("epsilon",
"(float, default 1.0e-6) Constant for "
"numerical stability")
.SetDefault(1.0e-6f);
AddComment(R"DOC(
Adadelta Updates Operator.
This implements the Adadelta optimizer[1]. Adadelta is a per-dimension
adaptive learning rate method for gradient descent.
Adadelta updates:
avg_squared_grad_out = rho * avg_squared_grad + (1 - rho) * grad * grad
param_update = - sqrt((avg_squared_update + epsilon) /
(avg_squared_grad_out + epsilon)) * grad
avg_squared_update_out = rho * avg_squared_update + (1 - rho) * param_update**2
param_out = param + param_update
References:
[1] ADADELTA: An Adaptive Learning Rate Method
https://arxiv.org/abs/1212.5701
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(adadelta, ops::AdadeltaOp, ops::AdadeltaOpMaker);
REGISTER_OP_CPU_KERNEL(
adadelta, ops::AdadeltaOpKernel<paddle::platform::CPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/operators/adadelta_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(
adadelta, ops::AdadeltaOpKernel<paddle::platform::GPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
template <typename Place, typename T>
class AdadeltaOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto param_out_tensor = ctx.Output<framework::Tensor>("ParamOut");
auto avg_squared_grad_out_tensor =
ctx.Output<framework::Tensor>("AvgSquaredGradOut");
auto avg_squared_update_out_tensor =
ctx.Output<framework::Tensor>("AvgSquaredUpdateOut");
param_out_tensor->mutable_data<T>(ctx.GetPlace());
avg_squared_grad_out_tensor->mutable_data<T>(ctx.GetPlace());
avg_squared_update_out_tensor->mutable_data<T>(ctx.GetPlace());
float rho = ctx.Attr<float>("rho");
float epsilon = ctx.Attr<float>("epsilon");
auto param = framework::EigenVector<T>::Flatten(
*ctx.Input<framework::Tensor>("Param"));
auto grad = framework::EigenVector<T>::Flatten(
*ctx.Input<framework::Tensor>("Grad"));
// Squared gradient accumulator
auto avg_squared_grad = framework::EigenVector<T>::Flatten(
*ctx.Input<framework::Tensor>("AvgSquaredGrad"));
// Squared updates accumulator
auto avg_squared_update = framework::EigenVector<T>::Flatten(
*ctx.Input<framework::Tensor>("AvgSquaredUpdate"));
auto param_out = framework::EigenVector<T>::Flatten(*param_out_tensor);
auto avg_squared_grad_out =
framework::EigenVector<T>::Flatten(*avg_squared_grad_out_tensor);
auto avg_squared_update_out =
framework::EigenVector<T>::Flatten(*avg_squared_update_out_tensor);
auto place = ctx.GetEigenDevice<Place>();
avg_squared_grad_out.device(place) =
rho * avg_squared_grad + (1 - rho) * grad.square();
auto update =
-((avg_squared_update + epsilon) / (avg_squared_grad_out + epsilon))
.sqrt() *
grad;
avg_squared_update_out.device(place) =
rho * avg_squared_update + (1 - rho) * update.square();
param_out.device(place) = param + update;
}
};
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/adagrad_op.h"
namespace paddle {
namespace operators {
class AdagradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Param"),
"Input(Param) of AdagradOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Grad"),
"Input(Grad) of AdagradOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Moment"),
"Input(Moment) of AdagradOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("LearningRate"),
"Input(LearningRate) of AdagradOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("ParamOut"),
"Output(ParamOut) of AdagradOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("MomentOut"),
"Output(MomentOut) of AdagradOp should not be null.");
auto lr_dims = ctx->GetInputDim("LearningRate");
PADDLE_ENFORCE_EQ(framework::product(lr_dims), 1,
"LearningRate should have one element");
auto param_dims = ctx->GetInputDim("Param");
PADDLE_ENFORCE_EQ(
param_dims, ctx->GetInputDim("Grad"),
"Param and Grad input of AdagradOp should have the same dimension.");
PADDLE_ENFORCE_EQ(
param_dims, ctx->GetInputDim("Moment"),
"Param and Moment input of AdagradOp should have the same dimension.");
ctx->SetOutputDim("ParamOut", param_dims);
ctx->SetOutputDim("MomentOut", param_dims);
}
};
class AdagradOpMaker : public framework::OpProtoAndCheckerMaker {
public:
AdagradOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Param", "(Tensor) Input parameter");
AddInput("Grad", "(Tensor) Input gradient");
AddInput("Moment", "(Tensor) Second moment");
AddInput("LearningRate", "(Tensor) Learning rate");
AddOutput("ParamOut", "(Tensor) Output parameter");
AddOutput("MomentOut", "(Tensor) Output second moment");
AddAttr<float>("epsilon",
"(float, default 1.0e-6) "
"Constant for numerical stability")
.SetDefault(1.0e-6f);
AddComment(R"DOC(
Adaptive Gradient Algorithm (Adagrad).
moment_out = moment + grad * grad
param_out = param - learning_rate * grad / (sqrt(moment_out) + epsilon)
The original paper(http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf)
does not have the epsilon attribute. It is added here for numerical stability
by avoiding division by zero.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(adagrad, ops::AdagradOp, ops::AdagradOpMaker);
REGISTER_OP_CPU_KERNEL(adagrad,
ops::AdagradOpKernel<paddle::platform::CPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/operators/adagrad_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(adagrad,
ops::AdagradOpKernel<paddle::platform::GPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
template <typename Place, typename T>
class AdagradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto param_out_tensor = ctx.Output<framework::Tensor>("ParamOut");
auto moment_out_tensor = ctx.Output<framework::Tensor>("MomentOut");
param_out_tensor->mutable_data<T>(ctx.GetPlace());
moment_out_tensor->mutable_data<T>(ctx.GetPlace());
float epsilon = ctx.Attr<float>("epsilon");
auto param = framework::EigenVector<T>::Flatten(
*ctx.Input<framework::Tensor>("Param"));
auto grad = framework::EigenVector<T>::Flatten(
*ctx.Input<framework::Tensor>("Grad"));
auto moment = framework::EigenVector<T>::Flatten(
*ctx.Input<framework::Tensor>("Moment"));
auto lr = framework::EigenVector<T>::Flatten(
*ctx.Input<framework::Tensor>("LearningRate"));
auto param_out = framework::EigenVector<T>::Flatten(*param_out_tensor);
auto moment_out = framework::EigenVector<T>::Flatten(*moment_out_tensor);
auto place = ctx.GetEigenDevice<Place>();
moment_out.device(place) = moment + grad * grad;
Eigen::DSizes<int, 1> m_dsize(moment_out_tensor->numel());
param_out.device(place) =
param - lr.broadcast(m_dsize) * grad / (moment_out.sqrt() + epsilon);
}
};
} // namespace operators
} // namespace paddle
...@@ -34,7 +34,7 @@ struct StridedMemcpyFunctor<T, 1> { ...@@ -34,7 +34,7 @@ struct StridedMemcpyFunctor<T, 1> {
auto& cpu_place = boost::get<platform::CPUPlace>(place); auto& cpu_place = boost::get<platform::CPUPlace>(place);
memory::Copy(cpu_place, dst, cpu_place, src, sizeof(T) * dst_dim.head); memory::Copy(cpu_place, dst, cpu_place, src, sizeof(T) * dst_dim.head);
} else { } else {
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
auto& gpu_place = boost::get<platform::GPUPlace>(place); auto& gpu_place = boost::get<platform::GPUPlace>(place);
auto& cuda_ctx = auto& cuda_ctx =
reinterpret_cast<const platform::CUDADeviceContext&>(dev_ctx); reinterpret_cast<const platform::CUDADeviceContext&>(dev_ctx);
......
...@@ -71,7 +71,7 @@ void testIm2col() { ...@@ -71,7 +71,7 @@ void testIm2col() {
context = context =
new paddle::platform::CPUDeviceContext(paddle::platform::CPUPlace()); new paddle::platform::CPUDeviceContext(paddle::platform::CPUPlace());
} else { } else {
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
context = context =
new paddle::platform::CUDADeviceContext(paddle::platform::GPUPlace()); new paddle::platform::CUDADeviceContext(paddle::platform::GPUPlace());
#else #else
...@@ -116,7 +116,7 @@ void testIm2col() { ...@@ -116,7 +116,7 @@ void testIm2col() {
TEST(math, im2col) { TEST(math, im2col) {
testIm2col<paddle::platform::CPUPlace>(); testIm2col<paddle::platform::CPUPlace>();
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
testIm2col<paddle::platform::GPUPlace>(); testIm2col<paddle::platform::GPUPlace>();
#endif #endif
} }
#include "paddle/operators/math/math_function.h" #include "paddle/operators/math/math_function.h"
#include "gtest/gtest.h" #include "gtest/gtest.h"
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
TEST(math_function, notrans_mul_trans) { TEST(math_function, notrans_mul_trans) {
paddle::framework::Tensor input1; paddle::framework::Tensor input1;
paddle::framework::Tensor input1_gpu; paddle::framework::Tensor input1_gpu;
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/rmsprop_op.h"
namespace paddle {
namespace operators {
class RmspropOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Param"),
"Input(Param) of RmspropOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("MeanSquare"),
"Input(MeanSquare) of RmspropOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("LearningRate"),
"Input(LearningRate) of RmspropOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Grad"),
"Input(Grad) of RmspropOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Moment"),
"Input(Moment) of RmspropOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("ParamOut"),
"Output(param_out) of RmspropOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("MomentOut"),
"Output(Momentum_out) of RmspropOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("MeanSquareOut"),
"Output(MeanSquareOut) of RmspropOp should not be null.");
auto param_dim = ctx->GetInputDim("Param");
PADDLE_ENFORCE_EQ(
param_dim, ctx->GetInputDim("Grad"),
"Param and grad input of RmspropOp should have the same dimension.");
PADDLE_ENFORCE_EQ(param_dim, ctx->GetInputDim("Moment"),
"Param and Momentum input of RmspropOp "
"should have the same dimension.");
PADDLE_ENFORCE_EQ(param_dim, ctx->GetInputDim("MeanSquare"),
"Param and Momentum input of RmspropOp "
"should have the same dimension.");
auto lr_dim = ctx->GetInputDim("LearningRate");
PADDLE_ENFORCE_EQ(framework::product(lr_dim), 1,
"Learning Rate should be a scalar.");
ctx->SetOutputDim("ParamOut", param_dim);
ctx->SetOutputDim("MomentOut", param_dim);
ctx->SetOutputDim("MeanSquareOut", param_dim);
}
};
class RmspropOpMaker : public framework::OpProtoAndCheckerMaker {
public:
RmspropOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Param",
"(Tensor, default Tensor<float>) "
"Input parameter value that has to be updated");
AddInput("MeanSquare",
"(Tensor, default Tensor<float>)"
" The mean square value that gets updated");
AddInput("LearningRate",
"(Tensor, default Tensor<float>) "
"The learning rate should be a tensor of size 1");
AddInput("Grad",
"(Tensor, default Tensor<float>) "
"Input gradient of the parameter");
AddInput("Moment",
"(Tensor, default Tensor<float>) The moment that gets updated");
AddOutput("ParamOut", "(Tensor) Output updated parameter value");
AddOutput("MomentOut", "(Tensor) Output updated moment");
AddOutput("MeanSquareOut", "(Tensor) Output Mean squared updated value");
AddAttr<float>("epsilon",
"(float, default 1e-10) Constant "
"for numerical stability.")
.SetDefault(1.0e-10f);
AddAttr<float>("decay",
"(float, default 0.9) "
"Discounting factor for coming gradient.")
.SetDefault(0.9f);
AddAttr<float>("momentum", "(float, default 0.0) Constant value")
.SetDefault(0.0f);
AddComment(R"DOC(
RMSprop
MeanSquareOut = decay * MeanSquare + (1 - decay) * Grad * Grad
MomentOut = momentum * Moment +
LearningRate * Grad / sqrt(MeanSquareOut + epsilon)
ParamOut = Param - MomentOut
The original slides that proposed RMSprop: Slide 29 of
http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(rmsprop, ops::RmspropOp, ops::RmspropOpMaker);
REGISTER_OP_CPU_KERNEL(rmsprop,
ops::RmspropOpKernel<paddle::platform::CPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/operators/rmsprop_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(rmsprop,
ops::RmspropOpKernel<paddle::platform::GPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename Place, typename T>
class RmspropOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* param_out = ctx.Output<Tensor>("ParamOut");
auto* moment_out = ctx.Output<Tensor>("MomentOut");
auto* mean_square_out = ctx.Output<Tensor>("MeanSquareOut");
auto grad = ctx.Input<Tensor>("Grad");
param_out->mutable_data<T>(ctx.GetPlace());
moment_out->mutable_data<T>(ctx.GetPlace());
mean_square_out->mutable_data<T>(ctx.GetPlace());
float epsilon = ctx.Attr<float>("epsilon");
float rho = ctx.Attr<float>("decay");
float momentum = ctx.Attr<float>("momentum");
auto p = EigenVector<T>::Flatten(*ctx.Input<Tensor>("Param"));
auto ms = EigenVector<T>::Flatten(*ctx.Input<Tensor>("MeanSquare"));
auto lr = EigenVector<T>::Flatten(*ctx.Input<Tensor>("LearningRate"));
auto g = EigenVector<T>::Flatten(*grad);
auto mom = EigenVector<T>::Flatten(*ctx.Input<Tensor>("Moment"));
auto p_out = EigenVector<T>::Flatten(*param_out);
auto mom_out = EigenVector<T>::Flatten(*moment_out);
auto ms_out = EigenVector<T>::Flatten(*mean_square_out);
auto place = ctx.GetEigenDevice<Place>();
Eigen::DSizes<int, 1> grad_dsize(grad->numel());
ms_out.device(place) = rho * ms + (1 - rho) * g * g;
mom_out.device(place) =
momentum * mom +
lr.broadcast(grad_dsize) * g / (ms_out + epsilon).sqrt();
p_out.device(place) = p - mom_out;
}
};
} // namespace operators
} // namespace paddle
...@@ -23,19 +23,22 @@ class SGDOp : public framework::OperatorWithKernel { ...@@ -23,19 +23,22 @@ class SGDOp : public framework::OperatorWithKernel {
protected: protected:
void InferShape(framework::InferShapeContextBase *ctx) const override { void InferShape(framework::InferShapeContextBase *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("param"), PADDLE_ENFORCE(ctx->HasInput("Param"),
"Input(param) of SGDOp should not be null."); "Input(Param) of SGDOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("grad"), PADDLE_ENFORCE(ctx->HasInput("Grad"),
"Input(grad) of SGDOp should not be null."); "Input(Grad) of SGDOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("learning_rate"), PADDLE_ENFORCE(ctx->HasInput("LearningRate"),
"Input(learning_rate) of SGDOp should not be null."); "Input(LearningRate) of SGDOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("param_out"), PADDLE_ENFORCE(ctx->HasOutput("ParamOut"),
"Output(param_out) of SGDOp should not be null."); "Output(ParamOut) of SGDOp should not be null.");
auto param_dim = ctx->GetInputDim("param"); auto lr_dims = ctx->GetInputDim("LearningRate");
PADDLE_ENFORCE_EQ(param_dim, ctx->GetInputDim("grad"), PADDLE_ENFORCE_EQ(framework::product(lr_dims), 1,
"Learning rate should have 1 element");
auto param_dim = ctx->GetInputDim("Param");
PADDLE_ENFORCE_EQ(param_dim, ctx->GetInputDim("Grad"),
"Two input of SGD Op's dimension must be same."); "Two input of SGD Op's dimension must be same.");
ctx->SetOutputDim("param_out", param_dim); ctx->SetOutputDim("ParamOut", param_dim);
} }
}; };
...@@ -43,10 +46,10 @@ class SGDOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -43,10 +46,10 @@ class SGDOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
SGDOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) SGDOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("param", "input parameter"); AddInput("Param", "Input parameter");
AddInput("learning_rate", "learning rate of sgd"); AddInput("LearningRate", "Learning rate of SGD");
AddInput("grad", "input gradient"); AddInput("Grad", "Input gradient");
AddOutput("param_out", "output parameter"); AddOutput("ParamOut", "output parameter");
AddComment(R"DOC( AddComment(R"DOC(
Simplest sgd algorithm. Simplest sgd algorithm.
......
...@@ -19,28 +19,25 @@ limitations under the License. */ ...@@ -19,28 +19,25 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename Place, typename T> template <typename Place, typename T>
class SGDOpKernel : public framework::OpKernel<T> { class SGDOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto param = ctx.Input<Tensor>("param"); auto param = ctx.Input<framework::Tensor>("Param");
auto grad = ctx.Input<Tensor>("grad"); auto grad = ctx.Input<framework::Tensor>("Grad");
auto param_out = ctx.Output<Tensor>("param_out"); auto param_out = ctx.Output<framework::Tensor>("ParamOut");
float lr = *ctx.Input<float>("learning_rate"); auto learning_rate = ctx.Input<framework::Tensor>("LearningRate");
param_out->mutable_data<T>(ctx.GetPlace()); param_out->mutable_data<T>(ctx.GetPlace());
auto p = EigenVector<T>::Flatten(*param); auto p = framework::EigenVector<T>::Flatten(*param);
auto g = EigenVector<T>::Flatten(*grad); auto g = framework::EigenVector<T>::Flatten(*grad);
auto o = EigenVector<T>::Flatten(*param_out); auto o = framework::EigenVector<T>::Flatten(*param_out);
auto lr = framework::EigenVector<T>::Flatten(*learning_rate);
auto place = ctx.GetEigenDevice<Place>(); auto place = ctx.GetEigenDevice<Place>();
o.device(place) = p - lr * g; Eigen::DSizes<int, 1> grad_dsize(grad->numel());
o.device(place) = p - lr.broadcast(grad_dsize) * g;
} }
}; };
......
...@@ -72,7 +72,7 @@ TEST(StridedMemcpy, CPUConcat) { ...@@ -72,7 +72,7 @@ TEST(StridedMemcpy, CPUConcat) {
} }
} }
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
TEST(StridedMemcpy, GPUCrop) { TEST(StridedMemcpy, GPUCrop) {
// clang-format off // clang-format off
int src[] = { int src[] = {
...@@ -157,4 +157,4 @@ TEST(StridedMemcpy, GPUConcat) { ...@@ -157,4 +157,4 @@ TEST(StridedMemcpy, GPUConcat) {
#endif #endif
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
\ No newline at end of file
...@@ -35,7 +35,7 @@ Eigen::DefaultDevice* CPUDeviceContext::eigen_device() const { ...@@ -35,7 +35,7 @@ Eigen::DefaultDevice* CPUDeviceContext::eigen_device() const {
Place CPUDeviceContext::GetPlace() const { return CPUPlace(); } Place CPUDeviceContext::GetPlace() const { return CPUPlace(); }
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
template <> template <>
Eigen::GpuDevice* Eigen::GpuDevice*
......
...@@ -14,7 +14,7 @@ limitations under the License. */ ...@@ -14,7 +14,7 @@ limitations under the License. */
#include "paddle/platform/enforce.h" #include "paddle/platform/enforce.h"
#include "paddle/platform/place.h" #include "paddle/platform/place.h"
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
#include "paddle/platform/dynload/cublas.h" #include "paddle/platform/dynload/cublas.h"
#include "paddle/platform/dynload/cudnn.h" #include "paddle/platform/dynload/cudnn.h"
#include "paddle/platform/gpu_info.h" #include "paddle/platform/gpu_info.h"
...@@ -61,7 +61,7 @@ class CPUDeviceContext : public DeviceContext { ...@@ -61,7 +61,7 @@ class CPUDeviceContext : public DeviceContext {
std::unique_ptr<Eigen::DefaultDevice> eigen_device_; std::unique_ptr<Eigen::DefaultDevice> eigen_device_;
}; };
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
template <> template <>
struct EigenDeviceConverter<platform::GPUPlace> { struct EigenDeviceConverter<platform::GPUPlace> {
using EigenDeviceType = Eigen::GpuDevice; using EigenDeviceType = Eigen::GpuDevice;
......
...@@ -20,7 +20,7 @@ TEST(Device, Init) { ...@@ -20,7 +20,7 @@ TEST(Device, Init) {
using paddle::platform::CUDADeviceContext; using paddle::platform::CUDADeviceContext;
using paddle::platform::GPUPlace; using paddle::platform::GPUPlace;
int count = paddle::platform::GetDeviceCount(); int count = paddle::platform::GetCUDADeviceCount();
for (int i = 0; i < count; i++) { for (int i = 0; i < count; i++) {
DeviceContext* device_context = new CUDADeviceContext(GPUPlace(i)); DeviceContext* device_context = new CUDADeviceContext(GPUPlace(i));
Eigen::GpuDevice* gpu_device = Eigen::GpuDevice* gpu_device =
...@@ -34,7 +34,7 @@ TEST(Device, CUDADeviceContext) { ...@@ -34,7 +34,7 @@ TEST(Device, CUDADeviceContext) {
using paddle::platform::CUDADeviceContext; using paddle::platform::CUDADeviceContext;
using paddle::platform::GPUPlace; using paddle::platform::GPUPlace;
int count = paddle::platform::GetDeviceCount(); int count = paddle::platform::GetCUDADeviceCount();
for (int i = 0; i < count; i++) { for (int i = 0; i < count; i++) {
CUDADeviceContext* device_context = new CUDADeviceContext(GPUPlace(i)); CUDADeviceContext* device_context = new CUDADeviceContext(GPUPlace(i));
Eigen::GpuDevice* gpu_device = device_context->eigen_device(); Eigen::GpuDevice* gpu_device = device_context->eigen_device();
......
...@@ -29,7 +29,7 @@ limitations under the License. */ ...@@ -29,7 +29,7 @@ limitations under the License. */
#include <cxxabi.h> // for __cxa_demangle #include <cxxabi.h> // for __cxa_demangle
#endif #endif
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
#include "paddle/platform/dynload/cublas.h" #include "paddle/platform/dynload/cublas.h"
#include "paddle/platform/dynload/cudnn.h" #include "paddle/platform/dynload/cudnn.h"
...@@ -113,7 +113,7 @@ inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error( ...@@ -113,7 +113,7 @@ inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error(
} }
} }
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
template <typename... Args> template <typename... Args>
inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error( inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error(
......
...@@ -213,4 +213,4 @@ TEST(ENFORCE_USER_DEFINED_CLASS, EQ) { ...@@ -213,4 +213,4 @@ TEST(ENFORCE_USER_DEFINED_CLASS, EQ) {
TEST(ENFORCE_USER_DEFINED_CLASS, NE) { TEST(ENFORCE_USER_DEFINED_CLASS, NE) {
Dims a{{1, 2, 3, 4}}, b{{5, 6, 7, 8}}; Dims a{{1, 2, 3, 4}}, b{{5, 6, 7, 8}};
ASSERT_THROW(PADDLE_ENFORCE_EQ(a, b), paddle::platform::EnforceNotMet); ASSERT_THROW(PADDLE_ENFORCE_EQ(a, b), paddle::platform::EnforceNotMet);
} }
\ No newline at end of file
...@@ -26,11 +26,11 @@ DEFINE_double(fraction_of_gpu_memory_to_use, 0.95, ...@@ -26,11 +26,11 @@ DEFINE_double(fraction_of_gpu_memory_to_use, 0.95,
namespace paddle { namespace paddle {
namespace platform { namespace platform {
int GetDeviceCount() { int GetCUDADeviceCount() {
int count; int count;
PADDLE_ENFORCE( PADDLE_ENFORCE(
cudaGetDeviceCount(&count), cudaGetDeviceCount(&count),
"cudaGetDeviceCount failed in paddle::platform::GetDeviceCount"); "cudaGetDeviceCount failed in paddle::platform::GetCUDADeviceCount");
return count; return count;
} }
......
...@@ -14,7 +14,7 @@ limitations under the License. */ ...@@ -14,7 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <stddef.h> #include <stddef.h>
...@@ -28,7 +28,7 @@ const std::string kEnvFractionGpuMemoryToUse = ...@@ -28,7 +28,7 @@ const std::string kEnvFractionGpuMemoryToUse =
"PADDLE_FRACTION_GPU_MEMORY_TO_USE"; "PADDLE_FRACTION_GPU_MEMORY_TO_USE";
//! Get the total number of GPU devices in system. //! Get the total number of GPU devices in system.
int GetDeviceCount(); int GetCUDADeviceCount();
//! Get the current GPU device id in system. //! Get the current GPU device id in system.
int GetCurrentDeviceId(); int GetCurrentDeviceId();
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#include <boost/config.hpp> #include <boost/config.hpp>
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
// Because boost's variadic templates has bug on nvcc, boost will disable // Because boost's variadic templates has bug on nvcc, boost will disable
// variadic template support when GPU enabled on nvcc. // variadic template support when GPU enabled on nvcc.
......
...@@ -215,7 +215,7 @@ int main(int argc, char** argv) { ...@@ -215,7 +215,7 @@ int main(int argc, char** argv) {
uint64_t dataSize = FLAGS_dim * sizeof(real); uint64_t dataSize = FLAGS_dim * sizeof(real);
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
GpuVector gpuParam(FLAGS_dim); GpuVector gpuParam(FLAGS_dim);
GpuVector gpuGrad(FLAGS_dim); GpuVector gpuGrad(FLAGS_dim);
#else #else
......
...@@ -99,7 +99,7 @@ TEST(ProtoServer, regular) { ...@@ -99,7 +99,7 @@ TEST(ProtoServer, regular) {
} }
TEST(ProtoServer, extended) { TEST(ProtoServer, extended) {
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
ProtoClient* client; ProtoClient* client;
if (FLAGS_rdma_tcp == "rdma") if (FLAGS_rdma_tcp == "rdma")
client = new ProtoClient(FLAGS_server_addr, FLAGS_port, F_RDMA); client = new ProtoClient(FLAGS_server_addr, FLAGS_port, F_RDMA);
......
...@@ -34,7 +34,7 @@ static size_t UniqueIntegerGenerator() { ...@@ -34,7 +34,7 @@ static size_t UniqueIntegerGenerator() {
} }
bool IsCompileGPU() { bool IsCompileGPU() {
#ifdef PADDLE_ONLY_CPU #ifndef PADDLE_WITH_CUDA
return false; return false;
#else #else
return true; return true;
...@@ -78,7 +78,7 @@ PYBIND11_PLUGIN(core) { ...@@ -78,7 +78,7 @@ PYBIND11_PLUGIN(core) {
.def("set", PyCPUTensorSetFromArray<float>) .def("set", PyCPUTensorSetFromArray<float>)
.def("set", PyCPUTensorSetFromArray<int>) .def("set", PyCPUTensorSetFromArray<int>)
.def("set", PyCPUTensorSetFromArray<double>) .def("set", PyCPUTensorSetFromArray<double>)
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
.def("set", PyCUDATensorSetFromArray<float>) .def("set", PyCUDATensorSetFromArray<float>)
.def("set", PyCUDATensorSetFromArray<int>) .def("set", PyCUDATensorSetFromArray<int>)
.def("set", PyCUDATensorSetFromArray<double>) .def("set", PyCUDATensorSetFromArray<double>)
...@@ -96,7 +96,7 @@ PYBIND11_PLUGIN(core) { ...@@ -96,7 +96,7 @@ PYBIND11_PLUGIN(core) {
.def( .def(
"__init__", "__init__",
[](LoDTensor &instance, const std::vector<std::vector<size_t>> &lod) { [](LoDTensor &instance, const std::vector<std::vector<size_t>> &lod) {
#ifdef PADDLE_ONLY_CPU #ifndef PADDLE_WITH_CUDA
new (&instance) LoDTensor(lod); new (&instance) LoDTensor(lod);
#else #else
LoD new_lod; LoD new_lod;
...@@ -107,7 +107,7 @@ PYBIND11_PLUGIN(core) { ...@@ -107,7 +107,7 @@ PYBIND11_PLUGIN(core) {
}) })
.def("set_lod", .def("set_lod",
[](LoDTensor &self, const std::vector<std::vector<size_t>> &lod) { [](LoDTensor &self, const std::vector<std::vector<size_t>> &lod) {
#ifdef PADDLE_ONLY_CPU #ifndef PADDLE_WITH_CUDA
self.set_lod(lod); self.set_lod(lod);
#else #else
LoD new_lod; LoD new_lod;
...@@ -117,7 +117,7 @@ PYBIND11_PLUGIN(core) { ...@@ -117,7 +117,7 @@ PYBIND11_PLUGIN(core) {
#endif #endif
}) })
.def("lod", [](LoDTensor &self) -> std::vector<std::vector<size_t>> { .def("lod", [](LoDTensor &self) -> std::vector<std::vector<size_t>> {
#ifdef PADDLE_ONLY_CPU #ifndef PADDLE_WITH_CUDA
return self.lod(); return self.lod();
#else #else
auto lod = self.lod(); auto lod = self.lod();
...@@ -203,7 +203,7 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -203,7 +203,7 @@ All parameter, weight, gradient are variables in Paddle.
.def_static("create", .def_static("create",
[](paddle::platform::GPUPlace& place) [](paddle::platform::GPUPlace& place)
-> paddle::platform::DeviceContext* { -> paddle::platform::DeviceContext* {
#ifdef PADDLE_ONLY_CPU #ifndef PADDLE_WITH_CUDA
PADDLE_THROW("GPUPlace is not supported in CPU device."); PADDLE_THROW("GPUPlace is not supported in CPU device.");
#else #else
return new paddle::platform::CUDADeviceContext(place); return new paddle::platform::CUDADeviceContext(place);
......
...@@ -106,7 +106,7 @@ void PyCPUTensorSetFromArray( ...@@ -106,7 +106,7 @@ void PyCPUTensorSetFromArray(
std::memcpy(dst, array.data(), sizeof(T) * array.size()); std::memcpy(dst, array.data(), sizeof(T) * array.size());
} }
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
template <typename T> template <typename T>
void PyCUDATensorSetFromArray( void PyCUDATensorSetFromArray(
framework::Tensor &self, framework::Tensor &self,
......
...@@ -36,4 +36,4 @@ TEST(to_string, user_defined) { ...@@ -36,4 +36,4 @@ TEST(to_string, user_defined) {
using namespace paddle::string; using namespace paddle::string;
UserDefinedClass instance; UserDefinedClass instance;
ASSERT_EQ(kOutputString, to_string(instance)); ASSERT_EQ(kOutputString, to_string(instance));
} }
\ No newline at end of file
...@@ -29,7 +29,7 @@ int main(int argc, char** argv) { ...@@ -29,7 +29,7 @@ int main(int argc, char** argv) {
initMain(argc, argv); initMain(argc, argv);
initPython(argc, argv); initPython(argc, argv);
string confFile = TrainerConfigHelper::getConfigNameFromPath(FLAGS_model_dir); string confFile = TrainerConfigHelper::getConfigNameFromPath(FLAGS_model_dir);
#ifdef PADDLE_ONLY_CPU #ifndef PADDLE_WITH_CUDA
FLAGS_use_gpu = false; FLAGS_use_gpu = false;
#endif #endif
auto config = std::make_shared<TrainerConfigHelper>(confFile); auto config = std::make_shared<TrainerConfigHelper>(confFile);
......
...@@ -146,7 +146,7 @@ void compareGradient(comData& comDataCpu, comData& comDataGpu) { ...@@ -146,7 +146,7 @@ void compareGradient(comData& comDataCpu, comData& comDataGpu) {
} }
int main(int argc, char** argv) { int main(int argc, char** argv) {
#ifdef PADDLE_ONLY_CPU #ifndef PADDLE_WITH_CUDA
exit(0); exit(0);
#endif #endif
paddle::initMain(argc, argv); paddle::initMain(argc, argv);
......
...@@ -174,7 +174,7 @@ TEST(compareSparse, multiGradientMachine) { ...@@ -174,7 +174,7 @@ TEST(compareSparse, multiGradientMachine) {
FLAGS_local = local; FLAGS_local = local;
FLAGS_ports_num_for_sparse = 5; FLAGS_ports_num_for_sparse = 5;
for (bool useGpu : {false, true}) { for (bool useGpu : {false, true}) {
#ifdef PADDLE_ONLY_CPU #ifndef PADDLE_WITH_CUDA
if (useGpu) continue; if (useGpu) continue;
#endif #endif
FLAGS_parallel_nn = useGpu; FLAGS_parallel_nn = useGpu;
...@@ -198,7 +198,7 @@ TEST(compareSparse, NeuralNetwork) { ...@@ -198,7 +198,7 @@ TEST(compareSparse, NeuralNetwork) {
FLAGS_local = local; FLAGS_local = local;
FLAGS_ports_num_for_sparse = 5; FLAGS_ports_num_for_sparse = 5;
for (bool useGpu : {false, true}) { for (bool useGpu : {false, true}) {
#ifdef PADDLE_ONLY_CPU #ifndef PADDLE_WITH_CUDA
if (useGpu) continue; if (useGpu) continue;
#endif #endif
FLAGS_parallel_nn = useGpu; FLAGS_parallel_nn = useGpu;
......
...@@ -51,7 +51,7 @@ void checkGradientTest(const string& configFile, ...@@ -51,7 +51,7 @@ void checkGradientTest(const string& configFile,
TEST(checkGradient, cpu) { checkGradientTest(configFile1, false, false); } TEST(checkGradient, cpu) { checkGradientTest(configFile1, false, false); }
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
TEST(checkGradient, gpu) { checkGradientTest(configFile1, true, false); } TEST(checkGradient, gpu) { checkGradientTest(configFile1, true, false); }
TEST(checkGradient, multiGpu) { TEST(checkGradient, multiGpu) {
...@@ -97,7 +97,7 @@ TEST(checkGradient, hsigmoid) { checkGradientTest(configFile2, false, false); } ...@@ -97,7 +97,7 @@ TEST(checkGradient, hsigmoid) { checkGradientTest(configFile2, false, false); }
TEST(checkGradient, chunk) { TEST(checkGradient, chunk) {
checkGradientTest(configFile3, false, false); checkGradientTest(configFile3, false, false);
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
checkGradientTest(configFile3, true, true); checkGradientTest(configFile3, true, true);
#endif #endif
} }
......
...@@ -79,7 +79,7 @@ void trainerOnePassTest(const string& configFile, ...@@ -79,7 +79,7 @@ void trainerOnePassTest(const string& configFile,
// 1. test trainer (cpu, gpu). // 1. test trainer (cpu, gpu).
TEST(trainerOnePass, cpu) { trainerOnePassTest(configFile1, false, false); } TEST(trainerOnePass, cpu) { trainerOnePassTest(configFile1, false, false); }
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
TEST(trainerOnePass, gpu) { trainerOnePassTest(configFile1, true, false); } TEST(trainerOnePass, gpu) { trainerOnePassTest(configFile1, true, false); }
TEST(trainerOnePass, gpu2) { trainerOnePassTest(configFile1, true, false, 2); } TEST(trainerOnePass, gpu2) { trainerOnePassTest(configFile1, true, false, 2); }
...@@ -94,7 +94,7 @@ TEST(trainerOnePass, parallel) { ...@@ -94,7 +94,7 @@ TEST(trainerOnePass, parallel) {
#endif #endif
// 2. test average_window. // 2. test average_window.
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
TEST(average_window, gpu) { TEST(average_window, gpu) {
trainerOnePassTest(configFile1, true, false, 4, 0.01); trainerOnePassTest(configFile1, true, false, 4, 0.01);
} }
...@@ -266,7 +266,7 @@ TEST(checkRemoteUpdater, cpuTrainerOldUpdater) { ...@@ -266,7 +266,7 @@ TEST(checkRemoteUpdater, cpuTrainerOldUpdater) {
checkRemoteParameterUpdaterTest(configFile1, false, false, 1, true); checkRemoteParameterUpdaterTest(configFile1, false, false, 1, true);
} }
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
TEST(checkRemoteUpdater, gpuTrainer) { TEST(checkRemoteUpdater, gpuTrainer) {
checkRemoteParameterUpdaterTest(configFile1, true, false); checkRemoteParameterUpdaterTest(configFile1, true, false);
} }
......
...@@ -113,7 +113,7 @@ void testGeneration(const string& configFile, ...@@ -113,7 +113,7 @@ void testGeneration(const string& configFile,
#ifndef PADDLE_TYPE_DOUBLE #ifndef PADDLE_TYPE_DOUBLE
TEST(RecurrentGradientMachine, test_generation) { TEST(RecurrentGradientMachine, test_generation) {
#ifdef PADDLE_ONLY_CPU #ifndef PADDLE_WITH_CUDA
const auto useGpuConfs = {false}; const auto useGpuConfs = {false};
#else #else
const auto useGpuConfs = {true, false}; const auto useGpuConfs = {true, false};
......
...@@ -14,7 +14,7 @@ limitations under the License. */ ...@@ -14,7 +14,7 @@ limitations under the License. */
#include "Flags.h" #include "Flags.h"
#ifdef PADDLE_ONLY_CPU #ifndef PADDLE_WITH_CUDA
DEFINE_bool(use_gpu, false, "Only support CPU training"); DEFINE_bool(use_gpu, false, "Only support CPU training");
#else #else
DEFINE_bool(use_gpu, true, "Whether to use GPU for training"); DEFINE_bool(use_gpu, true, "Whether to use GPU for training");
......
...@@ -218,7 +218,7 @@ protected: ...@@ -218,7 +218,7 @@ protected:
* *d2* is peer device to enable direct access to by the d1 device. * *d2* is peer device to enable direct access to by the d1 device.
*/ */
inline void enablePeerAccess(int d1, int d2) { inline void enablePeerAccess(int d1, int d2) {
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
if (hl_device_can_access_peer(d1, d2)) { if (hl_device_can_access_peer(d1, d2)) {
SetDevice dev(d1); SetDevice dev(d1);
hl_device_enable_peer_access(d2); hl_device_enable_peer_access(d2);
......
...@@ -48,7 +48,7 @@ void printVersion(std::ostream& os); ...@@ -48,7 +48,7 @@ void printVersion(std::ostream& os);
* @return return true if paddle compiled with GPU * @return return true if paddle compiled with GPU
*/ */
constexpr bool isWithGpu() { constexpr bool isWithGpu() {
#ifdef PADDLE_ONLY_CPU #ifndef PADDLE_WITH_CUDA
return false; return false;
#else #else
return true; return true;
......
...@@ -122,6 +122,23 @@ class TestBRelu(OpTest): ...@@ -122,6 +122,23 @@ class TestBRelu(OpTest):
self.check_grad(['X'], 'Y', max_relative_error=0.02) self.check_grad(['X'], 'Y', max_relative_error=0.02)
class TestLeakyRelu(OpTest):
def setUp(self):
self.op_type = "leaky_relu"
alpha = 0.02
self.attrs = {'alpha': alpha}
self.inputs = {'X': np.random.uniform(-3, 3, [4, 4]).astype("float32")}
self.outputs = {
'Y': np.maximum(self.inputs['X'], alpha * self.inputs['X'])
}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.007)
class TestSoftRelu(OpTest): class TestSoftRelu(OpTest):
def setUp(self): def setUp(self):
self.op_type = "soft_relu" self.op_type = "soft_relu"
......
import unittest
import numpy as np
from op_test import OpTest
class TestAdadeltaOp1(OpTest):
def setUp(self):
self.op_type = "adadelta"
param = np.random.uniform(-1, 1, (102, 105)).astype("float32")
grad = np.random.uniform(-1, 1, (102, 105)).astype("float32")
# The squared gradient is positive
avg_squared_grad = np.random.random((102, 105)).astype("float32")
# The squared update is positive
avg_squared_update = np.random.random((102, 105)).astype("float32")
rho = 0.95
epsilon = 1e-6
self.inputs = {
'Param': param,
'Grad': grad,
'AvgSquaredGrad': avg_squared_grad,
'AvgSquaredUpdate': avg_squared_update
}
self.attrs = {'rho': rho, 'epsilon': epsilon}
avg_squared_grad_out = rho * avg_squared_grad + \
(1 - rho) * np.square(grad)
update = -np.multiply(
np.sqrt(
np.divide(avg_squared_update + epsilon, avg_squared_grad_out +
epsilon)), grad)
avg_squared_update_out = rho * avg_squared_update + \
(1 - rho) * np.square(update)
param_out = param + update
self.outputs = {
'ParamOut': param_out,
'AvgSquaredGradOut': avg_squared_grad_out,
'AvgSquaredUpdateOut': avg_squared_update_out
}
def test_check_output(self):
self.check_output()
class TestAdadeltaOp2(OpTest):
'''Test Adadelta op with default attribute values
'''
def setUp(self):
self.op_type = "adadelta"
param = np.random.uniform(-1, 1, (102, 105)).astype("float32")
grad = np.random.uniform(-1, 1, (102, 105)).astype("float32")
# The squared gradient is positive
avg_squared_grad = np.random.random((102, 105)).astype("float32")
# The squared update is positive
avg_squared_update = np.random.random((102, 105)).astype("float32")
rho = 0.95
epsilon = 1e-6
self.inputs = {
'Param': param,
'Grad': grad,
'AvgSquaredGrad': avg_squared_grad,
'AvgSquaredUpdate': avg_squared_update
}
avg_squared_grad_out = rho * avg_squared_grad + \
(1 - rho) * np.square(grad)
update = -np.multiply(
np.sqrt(
np.divide(avg_squared_update + epsilon, avg_squared_grad_out +
epsilon)), grad)
avg_squared_update_out = rho * avg_squared_update + \
(1 - rho) * np.square(update)
param_out = param + update
self.outputs = {
'ParamOut': param_out,
'AvgSquaredGradOut': avg_squared_grad_out,
'AvgSquaredUpdateOut': avg_squared_update_out
}
def test_check_output(self):
self.check_output()
if __name__ == "__main__":
unittest.main()
import unittest
import numpy as np
from op_test import OpTest
class TestAdagradOp1(OpTest):
''' Test Adagrad operator with explicit attributes
'''
def setUp(self):
self.op_type = "adagrad"
param = np.random.random((123, 321)).astype("float32")
grad = np.random.random((123, 321)).astype("float32")
moment = np.zeros((123, 321)).astype("float32")
lr = 0.01
epsilon = 1e-8
self.inputs = {
'Param': param,
'Grad': grad,
'Moment': moment,
'LearningRate': np.array([lr]).astype("float32")
}
self.attrs = {'epsilon': epsilon}
moment_out = moment + grad * grad
param_out = param - lr * grad / (np.sqrt(moment_out) + epsilon)
self.outputs = {'ParamOut': param_out, 'MomentOut': moment_out}
def test_check_output(self):
self.check_output()
class TestAdagradOp2(OpTest):
''' Test Adagrad operator with default attributes
'''
def setUp(self):
self.op_type = "adagrad"
param = np.random.random((123, 321)).astype("float32")
grad = np.random.random((123, 321)).astype("float32")
moment = np.zeros((123, 321)).astype("float32")
lr = 0.01
epsilon = 1e-6
self.inputs = {
'Param': param,
'Grad': grad,
'Moment': moment,
'LearningRate': np.array([lr]).astype("float32")
}
self.attrs = {'epsilon': epsilon}
moment_out = moment + grad * grad
param_out = param - lr * grad / (np.sqrt(moment_out) + epsilon)
self.outputs = {'ParamOut': param_out, 'MomentOut': moment_out}
def test_check_output(self):
self.check_output()
if __name__ == "__main__":
unittest.main()
import unittest
import numpy as np
from op_test import OpTest
class TestRmspropOp1(OpTest):
''' Test RMSProp with explicit inputs
'''
def setUp(self):
self.op_type = "rmsprop"
param = np.random.random((123, 321)).astype("float32")
mean_square = np.random.random((123, 321)).astype("float32")
learning_rate = np.array([0.01]).astype("float32")
grad = np.random.random((123, 321)).astype("float32")
moment = np.zeros((123, 321)).astype("float32")
epsilon = 1e-6
decay = 0.9
momentum = 0.0
self.inputs = {
'Param': param,
'MeanSquare': mean_square,
'LearningRate': learning_rate,
'Grad': grad,
'Moment': moment,
}
self.attrs = {'epsilon': epsilon, 'decay': decay, 'momentum': momentum}
ms_out = decay * mean_square + (1 - decay) * grad * grad
moment_out = momentum * moment + \
learning_rate * grad / np.sqrt(ms_out + epsilon)
param_out = param - moment_out
self.outputs = {
'ParamOut': param_out,
'MomentOut': moment_out,
'MeanSquareOut': ms_out
}
def test_check_output(self):
self.check_output()
class TestRmspropOp2(OpTest):
'''Test RMSProp with defaukt values for attributes
'''
def setUp(self):
self.op_type = "rmsprop"
param = np.random.random((123, 321)).astype("float32")
mean_square = np.random.random((123, 321)).astype("float32")
learning_rate = np.array([0.01]).astype("float32")
grad = np.random.random((123, 321)).astype("float32")
moment = np.zeros((123, 321)).astype("float32")
epsilon = 1.0e-10
decay = 0.9
momentum = 0.0
self.inputs = {
'Param': param,
'MeanSquare': mean_square,
'LearningRate': learning_rate,
'Grad': grad,
'Moment': moment,
}
ms_out = decay * mean_square + (1 - decay) * grad * grad
moment_out = momentum * moment + \
learning_rate * grad / np.sqrt(ms_out + epsilon)
param_out = param - moment_out
self.outputs = {
'ParamOut': param_out,
'MomentOut': moment_out,
'MeanSquareOut': ms_out
}
def test_check_output(self):
self.check_output()
if __name__ == "__main__":
unittest.main()
...@@ -8,10 +8,10 @@ class TestSGDOp(OpTest): ...@@ -8,10 +8,10 @@ class TestSGDOp(OpTest):
self.op_type = "sgd" self.op_type = "sgd"
w = np.random.random((102, 105)).astype("float32") w = np.random.random((102, 105)).astype("float32")
g = np.random.random((102, 105)).astype("float32") g = np.random.random((102, 105)).astype("float32")
lr = 0.1 lr = np.array([0.1]).astype("float32")
self.inputs = {'param': w, 'grad': g, 'learning_rate': lr} self.inputs = {'Param': w, 'Grad': g, 'LearningRate': lr}
self.outputs = {'param_out': w - lr * g} self.outputs = {'ParamOut': w - lr * g}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册