提交 03363041 编写于 作者: K Kavya Srinet

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into rmsprop

......@@ -49,11 +49,12 @@ if(NOT WITH_GOLANG)
endif(NOT WITH_GOLANG)
if(NOT WITH_GPU)
add_definitions(-DPADDLE_ONLY_CPU)
add_definitions(-DHPPL_STUB_FUNC)
list(APPEND CMAKE_CXX_SOURCE_FILE_EXTENSIONS cu)
else()
add_definitions(-DPADDLE_WITH_CUDA)
FIND_PACKAGE(CUDA REQUIRED)
if(${CUDA_VERSION_MAJOR} VERSION_LESS 7)
......
# Design Doc: Session
## Abstract
The *session* object encapsulates the environment in which the
computation graph is executed.
We will have the *local* session and *remote* session, they offer the
same [interface](#interface). The local session encapsulates the local
runtime environment and the remote session encapsulates the cluster
runtime environment.
The local runtime environment contains:
1. computation devices (i.e., CPU, GPU) handles, and
1. the [scope](../scope.md) which holds all variables.
The remote runtime environment contains:
1. computation devices (i.e., CPU and GPU on node 0, 1) in a cluster,
and
1. the distributed [scope](../scope.md) in a cluster which holds all
variables.
The user can create a remote session on Paddle Cloud and evaluate the
computation graph with it. In this way, the user can control the
remote computation resource in a cluster from his local computer.
## Background
The current design has an implicit global session in which
`paddle.eval()` is executed. The pain point is:
Since the user is not able to explicitly switch between runtime
environments, the user cannot run a topology in two independent
environments.
For example, in reinforcement learning, the user may want to have a
stale model for inference and a fresh model for training, and only
replace the stale model with the fresh model periodically.
Furthermore, we have no concept that encapsulates a remote environment
that executes a computation graph.
We need the session object to address above issues.
## Session
A session is an object that owns the runtime environment. All
computations are executed through `session.eval()`.
### Interface
```python
eval(
targets,
feed_dict=None,
)
```
Evaluates the target Operations or Variables in `targets`.
- *targets*: the evaluation targets. Can be a single Operation or
Variable, or a list with the Operations or Variables as
elements. The value returned by `eval()` has the same shape as the
`target` argument.
The PaddlePaddle program is represented by
the [ProgramDesc](../design/program.md), `eval()` will infer the
ProgramDesc from the given targets and run the PaddlePaddle
program. Please
see
[this graph](./distributed_architecture.md#local-training-architecture) for
the detailed illustration for the local session
and
[this graph](./distributed_architecture.md#distributed-training-architecture) for
the detailed illustration for the remote session.
- *feed_dict*: a dictionary that contains the tensors which override
the edges of the computation graph.
feed_dict not only can provide the input data, it can override any
OP's input as well:
```python
a = pd.constant(2.0, name="a")
b = pd.variable(name="b")
c = pd.mul(a,b)
sess.eval(targets=c, feed_dict={"b":3.0}) # returns 6.0
```
```python
close()
```
Closes the session and releases the scope that the session owns.
### Create a Local Session
```python
session(
devices=None
)
```
Creates a new session. One session owns one global scope, so creating
multiple sessions will create different scopes.
- *devices*: a single `string` or a list of `string` of device names,
the corresponding devices will be the computation devices for
`eval()`. If not specified, all available devices (e.g., all GPUs)
will be used. The user doesn't need to specify the CPU device since
it will be always used. Multiple sessions can use the same device.
#### Example
```Python
a = paddle.constant(1.0)
b = paddle.constant(2.0)
c = a + b
sess = paddle.session(devices=["gpu:0", "gpu:1", "fpga:0"])
sess.eval(c)
sess.close()
```
### Create a Remote Session
```python
create_cloud_job(
name,
num_trainer,
mem_per_trainer,
gpu_per_trainer,
cpu_per_trainer,
num_ps,
mem_per_ps,
cpu_per_ps,
)
```
Creates a Paddle Cloud job. Fails if the job name exists.
```python
get_cloud_job(
name
)
```
Gets a Paddle Cloud job.
```python
remote_session(
job
)
```
- *job*: the Paddle Cloud job.
#### Example
```Python
reader = paddle.reader.recordio("/pfs/home/peter/mnist-train-*") # data stored on Paddle Cloud
image = reader.column(0)
label = reader.column(1)
fc1 = paddle.op.fc(image, size=256, act="sigmoid")
fc2 = paddle.op.fc(fc1, size=10, act="softmax")
cost = paddle.op.cross_entropy(fc2, label)
opt = paddle.optimizer.sgd(cost)
job = paddle.create_cloud_job("test", 3, "1G", 1, 1, 2, "1G", 1)
sess = paddle.remote_ession(job)
for i in range(1000):
sess.eval(opt)
sess.close()
```
......@@ -33,22 +33,45 @@ The mapping relationship between an operator and its gradient operators is a fun
```cpp
// (OpDesc) --> vector<OpDesc>
using GradOpDescMaker = std::function<std::vector<OpDesc>(const OpDesc&)>;
std::function<std::vector<OpDescBind>(const OpDescBind&)>;
```
The function take a `OpDesc` of the forward operator and return one or many gradient operator descriptions.
The function takes an `OpDescBind` of the forward operator and returns one or many gradient operator descriptions. `OpDescBind` is a C++ wrapper for protobuf message `OpDesc` to manipulate `OpDesc` fast.
The `GradOpDescMaker` will be registered in `OpInfo`, to replace `grad_op_type_` field. The `OpInfo` should be
```cpp
struct OpInfo {
GradOpDescMaker grad_op_maker_;
std::function<std::vector<std::unique_ptr<OpDescBind>>(const OpDescBind&)> grad_op_maker_;
...
};
```
The `grad_op_maker_ ` is `nullptr` if the operator does not have associated gradient operators.
We propose a base class called `GradOpDescMakerBase` to let operator developers generate `Gradient Operators` easily. The public interface of that class is
```cpp
class GradOpDescMakerBase {
public:
GradOpDescMakerBase(const OpDescBind& );
virtual std::vector<std::unique_ptr<OpDescBind>> operator()()const = 0;
};
```
We can convert `GradOpDescMakerBase` to `std::function<std::vector<std::unique_ptr<OpDescBind>>(const OpDescBind&)>` by
```cpp
using GradOpMaker = ...;
std::function<std::vector<OpDescBind>(const OpDescBind&)> func;
func = [] (const OpDescBind& fwd_op) {
GradOpMaker maker(fwd_op);
return maker();
};
```
We can write many helper functions since the `GradOpDescMakerBase` is a class now. The basic helper functions get the variables of `Input`, `Output`, `InputGradient` and `OutputGradient` in the forwarding operator.
We should chagne register macros at the same time. In the current solution, there is no difference between forwarding operators and backward operators. So `REGISTER_OP` just register one operator. If the `REGISTER_OPERATOR ` contains `OpProtoAndCheckerMaker` and `GradOpDescMaker`, we just list them in the same macro. It can be done by a macro contains `__VA_ARGS__`.
The user interface should be
......
......@@ -47,7 +47,7 @@ bool isUsingGpu() { return FLAGS_use_gpu; }
void setUseGpu(bool useGpu) { FLAGS_use_gpu = useGpu; }
bool isGpuVersion() {
#ifdef PADDLE_ONLY_CPU
#ifndef PADDLE_WITH_CUDA
return false;
#else
return true;
......
......@@ -46,7 +46,7 @@ paddle_error paddle_matrix_set_row(paddle_matrix mat,
if (rowID >= ptr->mat->getHeight()) return kPD_OUT_OF_RANGE;
paddle::real* buf = ptr->mat->getRowBuf(rowID);
size_t width = ptr->mat->getWidth();
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_CUDA
hl_memcpy(buf, rowArray, sizeof(paddle::real) * width);
#else
std::copy(rowArray, rowArray + width, buf);
......
......@@ -141,9 +141,26 @@ static std::unique_ptr<OperatorBase> BackwardRecursive(
net->ops_[op_offset]->Rename(name, dup_outputs.back());
}
// collect all the offset to append `add` op for each alias
insert_position.push_back(
{dup_op.back(), OpRegistry::CreateOp("add", {{"X", {dup_outputs}}},
{{"Out", {name}}}, {})});
//
// one variable is shared between multiple operators.
// insert add operator one by one, then add it to output
for (size_t output_idx = 0; output_idx < dup_outputs.size() - 1;
++output_idx) {
auto insert_add_x = dup_outputs[output_idx];
auto insert_add_y = dup_outputs[output_idx + 1];
auto insert_add_out = name + "@SHARED@" + std::to_string(output_idx);
// first add op inserted
if (output_idx == dup_outputs.size() - 2) {
insert_add_out = name;
}
if (output_idx != 0) {
insert_add_y = name + "@SHARED@" + std::to_string(output_idx - 1);
}
insert_position.push_back(
{dup_op.back(),
OpRegistry::CreateOp("sum", {{"X", {insert_add_x, insert_add_y}}},
{{"Out", {insert_add_out}}}, {})});
}
}
// make sure the inserted `add` ops follow the BFS order.
......@@ -182,7 +199,8 @@ static std::unique_ptr<OperatorBase> BackwardRecursive(
// process recurrent gradient op as a special operator.
if (forwardOp.Type() == "recurrent") {
// NOTE clean up cycle call somewhere (RNN's stepnet constains itself), or
// NOTE clean up cycle call somewhere (RNN's stepnet constains itself),
// or
// this will result in infinite loop.
const auto& rnnop =
*static_cast<const operators::RecurrentOp*>(&forwardOp);
......
......@@ -133,15 +133,18 @@ class FillZeroOpMaker : public OpProtoAndCheckerMaker {
}
};
class AddOpMaker : public OpProtoAndCheckerMaker {
class SumOpMaker : public framework::OpProtoAndCheckerMaker {
public:
AddOpMaker(OpProto *proto, OpAttrChecker *op_checker)
SumOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "x").AsDuplicable();
AddOutput("Out", "out");
AddInput("X", "the input tensors of sum operator.")
.AsDuplicable()
.NotInGradient();
AddOutput("Out", "the output tensor of sum operator.").NotInGradient();
AddComment("");
}
};
} // namespace framework
} // namespace paddle
......@@ -154,7 +157,7 @@ REGISTER_OP(mul, f::NOP, f::MulOpMaker, mul_grad, f::NOP);
REGISTER_OP(sigmoid, f::NOP, f::SigmoidOpMaker, sigmoid_grad, f::NOP);
REGISTER_OP_WITHOUT_GRADIENT(nograd, f::NOP, f::NoGradOpMaker);
REGISTER_OP_WITHOUT_GRADIENT(fill_zeros_like, f::NOP, f::FillZeroOpMaker);
REGISTER_OP(add, f::NOP, f::AddOpMaker, add_grad, f::NOP);
REGISTER_OP(sum, f::NOP, f::SumOpMaker, sum_grad, f::NOP);
REGISTER_OP_WITHOUT_GRADIENT(fc, f::FcOp, f::FcOpMaker);
REGISTER_OP(many_output_op, f::NOP, f::ManyOutputOpMaker, many_output_op_grad,
f::NOP);
......@@ -283,7 +286,7 @@ TEST(Backward, net_shared_weight) {
ASSERT_TRUE(bwd->IsNetOp());
auto bwd_net = static_cast<ops::NetOp *>(bwd.get());
ASSERT_EQ(3UL, bwd_net->ops_.size());
ASSERT_EQ("add", bwd_net->ops_[2]->Type());
ASSERT_EQ("sum", bwd_net->ops_[2]->Type());
}
TEST(Backward, op_register_grad_not_for_network) {
......
......@@ -19,6 +19,7 @@ limitations under the License. */
#include <vector>
#include "paddle/framework/op_desc.h"
#include "paddle/framework/var_desc.h"
#include "paddle/platform/macros.h"
namespace paddle {
namespace framework {
......@@ -34,9 +35,6 @@ class BlockDescBind {
BlockDescBind(ProgramDescBind *prog, BlockDesc *desc)
: prog_(prog), desc_(desc), need_update_(false) {}
BlockDescBind(const BlockDescBind &o) = delete;
BlockDescBind &operator=(const BlockDescBind &o) = delete;
int32_t ID() const { return desc_->idx(); }
int32_t Parent() const { return desc_->parent_idx(); }
......@@ -66,6 +64,8 @@ class BlockDescBind {
std::deque<std::unique_ptr<OpDescBind>> ops_;
std::unordered_map<std::string, std::unique_ptr<VarDescBind>> vars_;
DISABLE_COPY_AND_ASSIGN(BlockDescBind);
};
} // namespace framework
} // namespace paddle
......@@ -14,6 +14,7 @@
#pragma once
#include "paddle/framework/grad_op_desc_maker.h"
#include "paddle/framework/op_info.h"
#include "paddle/framework/op_proto_maker.h"
#include "paddle/framework/operator.h"
......@@ -96,7 +97,10 @@ struct OpInfoFiller<T, kOpProtoAndCheckerMaker> {
template <typename T>
struct OpInfoFiller<T, kGradOpDescMaker> {
void operator()(const char* op_type, OpInfo* info) const {
info->grad_op_maker_ = new T();
info->grad_op_maker_ = [](const OpDescBind& fwd_op) {
T maker(fwd_op);
return maker();
};
}
};
} // namespace details
......
......@@ -39,28 +39,6 @@ class IOIgnoredOpMaker : public OpProtoAndCheckerMaker {
namespace f = paddle::framework;
TEST(GradOpBuilder, AddTwo) {
std::shared_ptr<f::OperatorBase> add_op(f::OpRegistry::CreateOp(
"sum", {{"X", {"x", "y"}}}, {{"Out", {"out"}}}, {}));
std::shared_ptr<f::OperatorBase> grad_add_op =
f::OpRegistry::CreateGradOp(*add_op);
EXPECT_EQ(grad_add_op->Inputs().size(), 1UL);
EXPECT_EQ(grad_add_op->Outputs().size(), 1UL);
EXPECT_EQ(grad_add_op->Input(f::GradVarName("Out")), f::GradVarName("out"));
auto &outputs = grad_add_op->Outputs(f::GradVarName("X"));
EXPECT_EQ(2UL, outputs.size());
auto in_output = [&outputs](const std::string &name) {
for (auto &output_name : outputs) {
if (output_name == name) return true;
}
return false;
};
EXPECT_TRUE(in_output(f::GradVarName("x")));
EXPECT_TRUE(in_output(f::GradVarName("y")));
}
REGISTER_OP(mult_io, f::NOP, f::MutiInOutOpMaker, mult_io_grad, f::NOP);
REGISTER_OP(io_ignored, f::NOP, f::IOIgnoredOpMaker, io_ignored_grad, f::NOP);
......@@ -205,4 +183,4 @@ TEST(GradOpDescBuilder, IOIgnoredInGradient) {
{f::GradVarName("in3_1"), f::GradVarName("in3_2")}));
delete forw_op;
delete grad_op;
}
\ No newline at end of file
}
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/op_desc.h"
#include "paddle/framework/operator.h"
namespace paddle {
namespace framework {
class GradOpDescMakerBase {
public:
explicit GradOpDescMakerBase(const OpDescBind& fwd_op) : fwd_op_(fwd_op) {}
virtual ~GradOpDescMakerBase() = default;
virtual std::vector<std::unique_ptr<OpDescBind>> operator()() const = 0;
protected:
static std::vector<std::string> ToGradNames(
const std::vector<std::string>& var_names) {
std::vector<std::string> ret_val;
ret_val.reserve(var_names.size());
std::transform(var_names.begin(), var_names.end(),
std::back_inserter(ret_val), GradVarName);
return ret_val;
}
std::vector<std::string> InputGrad(const std::string& name) const {
return ToGradNames(fwd_op_.Input(name));
}
std::vector<std::string> OutputGrad(const std::string& name) const {
return ToGradNames(fwd_op_.Output(name));
}
std::vector<std::string> InputNames() const {
return this->fwd_op_.InputNames();
}
std::vector<std::string> OutputNames() const {
return this->fwd_op_.OutputNames();
}
std::vector<std::string> Input(const std::string& name) const {
return fwd_op_.Input(name);
}
std::vector<std::string> Output(const std::string& name) const {
return fwd_op_.Output(name);
}
const std::unordered_map<std::string, Attribute>& Attrs() const {
return fwd_op_.GetAttrMap();
}
const Attribute& GetAttr(const std::string& name) const {
auto& map = fwd_op_.GetAttrMap();
auto it = map.find(name);
PADDLE_ENFORCE(it != map.end(), "Cannot find attribute %s", name);
return it->second;
}
std::string ForwardOpType() const { return this->fwd_op_.Type(); }
private:
const OpDescBind& fwd_op_;
};
class SingleGradOpDescMaker : public GradOpDescMakerBase {
public:
using GradOpDescMakerBase::GradOpDescMakerBase;
std::vector<std::unique_ptr<OpDescBind>> operator()() const {
std::vector<std::unique_ptr<OpDescBind>> retv;
retv.emplace_back(this->Apply());
return retv;
}
protected:
virtual std::unique_ptr<OpDescBind> Apply() const = 0;
};
class DefaultGradOpDescMaker : public SingleGradOpDescMaker {
public:
using SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
virtual std::unique_ptr<OpDescBind> Apply() const {
auto* grad = new OpDescBind();
grad->SetType(this->GradOpType());
for (auto& input_param : this->InputNames()) {
grad->SetInput(input_param, this->Input(input_param));
grad->SetOutput(GradVarName(input_param), this->InputGrad(input_param));
}
for (auto& output_param : this->OutputNames()) {
grad->SetInput(output_param, this->Output(output_param));
grad->SetInput(GradVarName(output_param), this->OutputGrad(output_param));
}
grad->SetAttrMap(this->Attrs());
return std::unique_ptr<OpDescBind>(grad);
}
virtual std::string GradOpType() const {
return this->ForwardOpType() + "_grad";
}
};
} // namespace framework
} // namespace paddle
......@@ -15,7 +15,7 @@
#pragma once
#include <memory>
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_CUDA
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
#include <thrust/system/cuda/experimental/pinned_allocator.h>
......@@ -29,7 +29,7 @@
namespace paddle {
namespace framework {
#ifdef PADDLE_ONLY_CPU
#ifndef PADDLE_WITH_CUDA
template <typename T>
using Vector = std::vector<T>;
#else
......
# Design Doc: LoD (Level-of-Detail) Tensor
PaddlePaddle's RNN doesn't require that all instances have the same length. To do so, we introduce an extension to Tensor, namely, LoD Tensor.
Like other deep learning systems, PaddlePaddle supports training models from sequence data. Also, like other systems, PaddlePaddle represent a mini-batch of sequences as a Tensor. What is different is that PaddlePaddle doesn't require all sequences in a mini-batch to be of the same length. Thus no need for padding zeros.
## Challenge of Variable-length Inputs
| | TensorFlow | PaddlePaddle |
|-----------------------|------------|--------------|
| RNN | Support | Support |
| recursive RNN | Support | Support |
| padding zeros | Must | No need |
| blob data type | Tensor | LoDTensor |
People usually represent a mini-batch by a Tensor. For example, a mini-batch of 10 images, each of size 32x32, is a 10x32x32 Tensor. So a transformation, T, of all images can be a matrix multiplication of the 10xOx32-dimensional tensor T and the 10x32x32 Tensor.
PaddlePaddle achieves this flexibility by passing through a new data type, *LoD Tensor*, which is a Tensor attached with segmentation index known as *LoD*, between operators. The LoD index doesn't only segment a tensor, but also recursively segments sub-sequences. This document presents the design of LoD and LoDTensor.
Another example is that each mini-batch contains 32 sentences, where each word is a D-dimensional one-hot vector. If all sentences have the same length L, we can represent this mini-batch by a 32xLxD tensor. However, in most cases, sentences have variable lengths, and we will need an index data structure to record these variable lengths.
## LoD as a Solution
## The Challenge: Variable-length Sequences
### Mini-Batch of variable-length sentences
Most deep learning systems represent a mini-batch as a Tensor. For example, a mini-batch of 10 images, each of size 32x32, is a 10x32x32 Tensor. Another example is that each mini-batch contains N sentences, where each word is a D-dimensional one-hot vector. Suppose that all sentences have the same length L, we can represent this mini-batch by a NxLxD tensor.
Let's imagine a mini-batch of 3 variable lengths sentences, containing 3, 1, and 2 words respectively. We can represent it by a (3+1+2)xD tensor plus some index information:
Both examples show that the elements of sequences are usually of the same size. In the first example, all images are 32x32, and in the second one, all words are D-dimensional vectors. It doesn't make sense to allow variable-sized images, as that would require transformations like convolution to handle variable-sized Tensors.
The real challenge is that in most cases, sentences have variable lengths, and we will need an index data structure to segment the tensor into sequences. Also, sequences might consist of sub-sequences.
## A Solution: The LoD Index
To understand our solution, it is best to look at some examples.
### A Mini-Batch of Sentences
Let's imagine a mini-batch of 3 variable lengths sentences composed of 3, 1, and 2 words, respectively. We can represent the mini-batch by a (3+1+2)xD tensor plus some index information:
```
3
3 1 2
||| | ||
```
Each `|` represents a D-dimensional word vectors. The number 3 on top indicate 3 sentences, and numbers 3, 1, and 2 on the second level represent the number of words in each sentence.
where each `|` represents a D-dimensional word vector. The numbers, 3, 1, and 2, form a 1-level LoD.
### Recursive Sequences
Let check another example of a 2-level LoD Tensor. Consider a mini-batch of three articles with 3, 1, and 2 sentences, and each sentence consists of a variable number of words:
```
3 1 2
3 2 4 1 2 3
||| || |||| | || |||
```
### Mini-Batch of variable-length videos
### A Mini-Batch of Videos
This approach generalizes to the case where elements are not words, but higher dimensional objects, like images. Suppose that a mini-batch contains videos of the same frame size 640x480. If a mini-batch contains 3 videos of 3, 1, and 2 frames respectively. The underlying tensor is of size (3+1+2)x640x480. The index information illustrates as:
LoD tensors generalize to the case where elements are higher dimensional objects, like images. Suppose that a mini-batch contains videos of the same frame size 640x480. Here is a mini-batch of 3 videos with 3, 1, and 2 frames, respectively.
```
3
3 1 2
口口口 口 口口
```
where each `口` represents an image.
The underlying tensor is of size (3+1+2)x640x480, and each `口` represents a 640x480 image.
### Mini-Batch of fixed-size images
### A Mini-Batch of Images
Let's get back to a typical example, image classification, where each mini-batch has M fixed-sized images. The LoD Tensor representation is
In traditional cases like a mini-batch with N fixed-sized images, the LoD Tensor representation is as
```
M
1 1 1 1 1
口口口口 ... 口
```
The many 1's on the second level seem duplicated. For this particular case of 2 levels and the second level always have length 1, we can ignore the LoD index.
### Design and summarization
In this case, we don't lose any information by ignoring the many 1's in the index and simply considering this LoD Tensor as a usual Tensor:
In summary, as long as that the essential elements (words or images) have the same size, we can represent mini-batches by a LoD Tensor:
```
口口口口 ... 口
```
- The underlying tensor has size LxD1xD2x..., where D1xD2... is the size of the essential elements, and
- The first dimension size L has an additonal property -- a LoD index as a nested vector:
### Model Parameters
```c++
typedef std::vector<std::<vector>> LoD;
```
A model parameter is just a usual Tensor, which, just like the above example, is a **0-level LoD Tensor**.
- The LoD index is not necessary when there are only two levels and all elements of the second level have length 1.
## Slicing of LoD Tensor
## The LoD Tensor
Consider that we have a network with three levels of RNN: the top level one handles articles, the second level one handles sentences, and the basic level one handles words. This network requires that mini-batches represented by 3 level LoD Tensor, for example,
Let us revisit above example of the 2-level LoD Tensor
```
3
3 1 2
3 2 4 1 2 3
||| || |||| | || |||
```
To allow each level of RNN to handle its input, we define **the slicing of a LoD Tensor is defined as getting the j-th sequence on level i, or the <i,j>-slice**
It is indeed a tree, where leaves are elementary sequences identified by **branches**.
For example, the third sentence in above example is identified by branch <0,2>, where 0 indicates the first article with length 3, and 2 indicates the third sentence in this article with length 4.
### The LoD Index
For example, the <2,1>-slice of above slice is
We can save the LoD index in the above example
```
2
||
3 1 2
3 2 4 1 2 3
```
and the <1,2>-slice of above example is
in a not-full 2D matrix:
```c++
typedef std::vector<std::vector<int> > LoD;
```
2
2 3
|| |||
```
Let's go on slicing this slice. Its <1,1>-slice is
where
- `LoD.size()` is the number of levels, or the maximum length of branches,
- `LoD[i][j]` is the length of the j-th segment at the i-th level.
## The Offset Representation
To quickly access elementary sequences, we adopt an offset representation -- instead of saving the lengths, we save the beginning and ending elements of sequences.
In the above example, we accumulate the length of elementary sequences:
```
1
1
|
3 2 4 1 2 3
```
### The Slicing Algorithm
into offsets
The algorithm, with over-simplified data structure, is defined as
```
0 3 5 9 10 12 15
= = = = = =
3 2+3 4+5 1+9 2+10 3+12
```
```c++
typedef std::vector<std::vector<int>> LoD;
so we know that the first sentence is from word 0 to word 3, and the second sentence from work 3 to word 5.
struct LoDTensor {
LoD lod_;
float* tensor_;
};
Similarly, the lengths in the top level LoD
LoDTensor Slice(const LoDTensor& lodt, int level, int sequence);
```
3 1 2
```
Let us revisit the example above
are transformed into offsets of elements/words as follows:
```
3
3 1 2
3 2 4 1 2 3
||| || |||| | || |||
0 9 10 15
= = =
3+2+4 1+9 2+3+10
```
Suppose that we want to retrieve the <1,2>-slice
so we can tell that the first article is from word 0 to word 9, and the second article is from word 9 to word 10.
The complete offset representation is as follows:
```
2
2 3
|| |||
0 9 10 15
0 3 5 9 10 12 15
||| || |||| | || |||
```
we will need to find out the starting position of this slice by summing over all leaf nodes in `LoD` to the left of the slice, i.e., 3 + 2 + 4 + 1 = 10.
## Slicing of LoD Tensors
When we use the above 2-level LoD Tensor as the input to a nested-RNN, we need to retrieve certain sequences. Here we define the sequence identified by branch <i,j,...> as the **<i,j,...>-slice**.
To avoid the traversal of the LoD tree at slicing time, we can do it at the construction time -- instead of saving the lengths of the next level in the LoD tree, we can save the starting offset of the next level. For example, above LoD Tensor can be transformed into
For example, the <2>-slice of above example is
```
0
0 9 10
0 3 5 9 10 12
||| || |||| | || |||
10 15
10 12 15
|| |||
```
We don't really need the 0 on top, so the LoD Tensor could be
and the <2,0>-slice of above slice is
```
0 9 10
0 3 5 9 10 12
||| || |||| | || |||
10 12
||
```
......@@ -31,15 +31,6 @@ const std::vector<std::string> &OpDescBind::Input(
return it->second;
}
std::vector<std::string> OpDescBind::InputNames() const {
std::vector<std::string> retv;
retv.reserve(this->inputs_.size());
for (auto &ipt : this->inputs_) {
retv.push_back(ipt.first);
}
return retv;
}
void OpDescBind::SetInput(const std::string &param_name,
const std::vector<std::string> &args) {
need_update_ = true;
......@@ -54,15 +45,6 @@ const std::vector<std::string> &OpDescBind::Output(
return it->second;
}
std::vector<std::string> OpDescBind::OutputNames() const {
std::vector<std::string> retv;
retv.reserve(this->outputs_.size());
for (auto &ipt : this->outputs_) {
retv.push_back(ipt.first);
}
return retv;
}
void OpDescBind::SetOutput(const std::string &param_name,
const std::vector<std::string> &args) {
need_update_ = true;
......
......@@ -35,15 +35,11 @@ class OpDescBind {
const std::vector<std::string> &Input(const std::string &name) const;
std::vector<std::string> InputNames() const;
void SetInput(const std::string &param_name,
const std::vector<std::string> &args);
const std::vector<std::string> &Output(const std::string &name) const;
std::vector<std::string> OutputNames() const;
void SetOutput(const std::string &param_name,
const std::vector<std::string> &args);
......@@ -61,9 +57,6 @@ class OpDescBind {
void SetBlockAttr(const std::string &name, BlockDescBind &block);
// Only be used in C++
void SetAttrMap(const AttributeMap &attr_map);
Attribute GetAttr(const std::string &name) const;
int GetBlockAttr(const std::string &name) const;
......@@ -71,7 +64,23 @@ class OpDescBind {
// Only be used in C++
const AttributeMap &GetAttrMap() const;
// Only be used in C++
void SetAttrMap(const AttributeMap &attr_map);
std::vector<std::string> InputNames() const { return MapKeys(inputs_); }
std::vector<std::string> OutputNames() const { return MapKeys(outputs_); }
private:
template <typename MapType>
static std::vector<typename MapType::key_type> MapKeys(const MapType &map) {
std::vector<typename MapType::key_type> ret_val;
ret_val.reserve(map.size());
std::transform(
map.begin(), map.end(), std::back_inserter(ret_val),
[](const typename MapType::value_type &pair) { return pair.first; });
return ret_val;
}
void Sync();
OpDesc op_desc_;
......
......@@ -20,20 +20,15 @@
#include "paddle/framework/attribute.h"
#include "paddle/framework/op_desc.h"
#include "paddle/framework/type_defs.h"
#include "paddle/platform/macros.h"
namespace paddle {
namespace framework {
class GradOpDescMakerBase {
public:
virtual ~GradOpDescMakerBase() = default;
virtual std::vector<OpDescBind> operator()(const OpDescBind&) const = 0;
};
struct OpInfo {
OpCreator creator_;
std::string grad_op_type_;
GradOpDescMakerBase* grad_op_maker_{nullptr};
GradOpMakerFN grad_op_maker_;
OpProto* proto_{nullptr};
OpAttrChecker* checker_{nullptr};
......@@ -67,11 +62,6 @@ class OpInfoMap {
public:
static OpInfoMap& Instance();
OpInfoMap(const OpInfoMap& o) = delete;
OpInfoMap(OpInfoMap&& o) = delete;
OpInfoMap& operator=(const OpInfoMap& o) = delete;
OpInfoMap& operator=(OpInfoMap&& o) = delete;
bool Has(const std::string& op_type) const {
return map_.find(op_type) != map_.end();
}
......@@ -107,6 +97,8 @@ class OpInfoMap {
private:
OpInfoMap() = default;
std::unordered_map<std::string, const OpInfo> map_;
DISABLE_COPY_AND_ASSIGN(OpInfoMap);
};
} // namespace framework
......
......@@ -48,4 +48,4 @@ TEST(ProtoMaker, DuplicatedInOut) {
paddle::framework::OpAttrChecker op_checker;
auto proto_maker = TestInOutProtoMaker(&op_proto, &op_checker);
ASSERT_THROW(proto_maker.Validate(), paddle::platform::EnforceNotMet);
}
\ No newline at end of file
}
......@@ -211,7 +211,7 @@ class OpKernelRegistrar : public Registrar {
// TODO(fengjiayi): The following macros
// seems ugly, do we have better method?
#ifdef PADDLE_ONLY_CPU
#ifndef PADDLE_WITH_CUDA
#define USE_OP_KERNEL(op_type) USE_OP_DEVICE_KERNEL(op_type, CPU)
#else
#define USE_OP_KERNEL(op_type) \
......
......@@ -183,4 +183,4 @@ class CosineOpComplete : public paddle::framework::CosineOp {
TEST(OperatorRegistrar, Test) {
using namespace paddle::framework;
OperatorRegistrar<CosineOpComplete, CosineOpProtoAndCheckerMaker> reg("cos");
}
\ No newline at end of file
}
......@@ -25,7 +25,7 @@ Eigen::DefaultDevice& ExecutionContext::GetEigenDevice<
return *device_context_.GetEigenDevice<platform::CPUPlace>();
}
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_CUDA
template <>
Eigen::GpuDevice&
ExecutionContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const {
......
......@@ -16,6 +16,7 @@ limitations under the License. */
#include <vector>
#include "paddle/framework/framework.pb.h"
#include "paddle/platform/macros.h"
namespace paddle {
namespace framework {
......@@ -26,9 +27,6 @@ class ProgramDescBind {
public:
static ProgramDescBind &Instance(ProgramDesc *prog);
ProgramDescBind(const ProgramDescBind &o) = delete;
ProgramDescBind &operator=(const ProgramDescBind &o) = delete;
BlockDescBind *AppendBlock(const BlockDescBind &parent);
BlockDescBind *Block(size_t idx) { return blocks_[idx].get(); }
......@@ -46,6 +44,8 @@ class ProgramDescBind {
ProgramDesc *prog_;
std::vector<std::unique_ptr<BlockDescBind>> blocks_;
DISABLE_COPY_AND_ASSIGN(ProgramDescBind);
};
} // namespace framework
} // namespace paddle
......@@ -19,6 +19,7 @@ limitations under the License. */
#include <unordered_map>
#include "paddle/framework/variable.h"
#include "paddle/platform/macros.h"
namespace paddle {
namespace framework {
......@@ -38,11 +39,6 @@ class Scope {
Scope() {}
~Scope();
// Disable Copy, Assign, Move.
Scope(const Scope& other) = delete;
Scope& operator=(const Scope& other) = delete;
Scope(Scope&& other) = delete;
/// Create a sub-scope. Returns a reference other than a pointer so
/// to prevent from manual deletion.
/// Mark it to const because that new kid scope cannot change parent scope.
......@@ -73,6 +69,8 @@ class Scope {
std::unordered_map<std::string, Variable*> vars_;
mutable std::list<Scope*> kids_;
Scope const* parent_{nullptr};
DISABLE_COPY_AND_ASSIGN(Scope);
};
} // namespace framework
......
......@@ -47,13 +47,6 @@ class TensorArray {
// max number of values allowed to store.
const size_t MAX_SIZE{100000};
/*
* Inputs:
* - value_shared: share memory between tensors.
*/
explicit TensorArray(bool values_shared = true)
: values_shared_(values_shared) {}
/*
* Read the value at location `index` in the `TensorArray`.
*/
......@@ -111,7 +104,6 @@ class TensorArray {
private:
mutable std::vector<LoDTensor> values_;
bool values_shared_;
}; // class TensorArray
} // namespace framework
......
......@@ -65,7 +65,7 @@ inline T* Tensor::mutable_data(platform::Place place) {
holder_.reset(new PlaceholderImpl<T, platform::CPUPlace>(
boost::get<platform::CPUPlace>(place), size));
} else if (platform::is_gpu_place(place)) {
#ifdef PADDLE_ONLY_CPU
#ifndef PADDLE_WITH_CUDA
PADDLE_THROW("'GPUPlace' is not supported in CPU only device.");
}
#else
......@@ -103,7 +103,7 @@ inline void Tensor::CopyFrom(const Tensor& src,
memory::Copy(boost::get<platform::CPUPlace>(dst_place), dst_ptr,
boost::get<platform::CPUPlace>(src_place), src_ptr, size);
}
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_CUDA
else if (platform::is_gpu_place(src_place) &&
platform::is_cpu_place(dst_place)) {
memory::Copy(boost::get<platform::CPUPlace>(dst_place), dst_ptr,
......
......@@ -74,7 +74,7 @@ TEST(Tensor, MutableData) {
EXPECT_EQ(p1, p2);
}
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_CUDA
{
Tensor src_tensor;
float* p1 = nullptr;
......@@ -126,7 +126,7 @@ TEST(Tensor, ShareDataWith) {
ASSERT_EQ(src_tensor.data<int>(), dst_tensor.data<int>());
}
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_CUDA
{
Tensor src_tensor;
Tensor dst_tensor;
......@@ -163,7 +163,7 @@ TEST(Tensor, Slice) {
EXPECT_EQ(src_data_address + 3 * 4 * 1 * sizeof(int), slice_data_address);
}
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_CUDA
{
Tensor src_tensor;
src_tensor.mutable_data<double>(make_ddim({6, 9}), GPUPlace());
......@@ -218,7 +218,7 @@ TEST(Tensor, CopyFrom) {
EXPECT_EQ(dst_ptr[i], slice_ptr[i]);
}
}
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_CUDA
{
Tensor src_tensor;
Tensor gpu_tensor;
......
......@@ -20,6 +20,7 @@
namespace paddle {
namespace framework {
class OperatorBase;
class OpDescBind;
using VariableNameMap = std::map<std::string, std::vector<std::string>>;
// The order should be as same as framework.proto
......@@ -34,5 +35,8 @@ using OpCreator = std::function<OperatorBase*(
const std::string& /*type*/, const VariableNameMap& /*inputs*/,
const VariableNameMap& /*outputs*/, const AttributeMap& /*attrs*/)>;
using GradOpMakerFN =
std::function<std::vector<std::unique_ptr<OpDescBind>>(const OpDescBind&)>;
} // namespace framework
} // namespace paddle
......@@ -194,7 +194,7 @@ public:
REGISTER_TYPED_FUNC(BlockExpand, CPU, BlockExpandForward);
REGISTER_TYPED_FUNC(BlockExpandGrad, CPU, BlockExpandBackward);
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_CUDA
REGISTER_TYPED_FUNC(BlockExpand, GPU, BlockExpandForward);
REGISTER_TYPED_FUNC(BlockExpandGrad, GPU, BlockExpandBackward);
#endif
......
......@@ -395,7 +395,7 @@ REGISTER_TYPED_FUNC(ContextProjectionForward,
REGISTER_TYPED_FUNC(ContextProjectionBackward,
CPU,
ContextProjectionBackwardFunc);
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_CUDA
REGISTER_TYPED_FUNC(ContextProjectionForward,
GPU,
ContextProjectionForwardFunc);
......
......@@ -233,7 +233,7 @@ private:
REGISTER_TYPED_FUNC(CosSimForward, CPU, CosSimForwardFunc);
REGISTER_TYPED_FUNC(CosSimBackward, CPU, CosSimBackwardFunc);
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_CUDA
REGISTER_TYPED_FUNC(CosSimForward, GPU, CosSimForwardFunc);
REGISTER_TYPED_FUNC(CosSimBackward, GPU, CosSimBackwardFunc);
#endif
......
......@@ -169,7 +169,7 @@ private:
REGISTER_TYPED_FUNC(Crop, CPU, CropFunc);
REGISTER_TYPED_FUNC(CropGrad, CPU, CropGradFunc);
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_CUDA
REGISTER_TYPED_FUNC(Crop, GPU, CropFunc);
REGISTER_TYPED_FUNC(CropGrad, GPU, CropGradFunc);
#endif
......
......@@ -336,7 +336,7 @@ private:
REGISTER_TYPED_FUNC(CrossMapNormal, CPU, CrossMapNormalFunc);
REGISTER_TYPED_FUNC(CrossMapNormalGrad, CPU, CrossMapNormalGradFunc);
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_CUDA
REGISTER_TYPED_FUNC(CrossMapNormal, GPU, CrossMapNormalFunc);
REGISTER_TYPED_FUNC(CrossMapNormalGrad, GPU, CrossMapNormalGradFunc);
#endif
......
......@@ -292,7 +292,7 @@ REGISTER_TYPED_FUNC(DepthwiseConvGradInput,
REGISTER_TYPED_FUNC(DepthwiseConvGradFilter,
CPU,
DepthwiseConvGradFilterFunction);
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_CUDA
REGISTER_TYPED_FUNC(DepthwiseConv, GPU, DepthwiseConvFunction);
REGISTER_TYPED_FUNC(DepthwiseConvGradInput,
GPU,
......
......@@ -17,7 +17,7 @@ limitations under the License. */
namespace paddle {
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_CUDA
TEST(DepthwiseConv, Forward) {
DepthwiseConvolution<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU>(
"GemmConv-CPU", "DepthwiseConv-GPU", forward);
......
......@@ -340,7 +340,7 @@ public:
REGISTER_TYPED_FUNC(GemmConv, CPU, GemmConvFunction);
REGISTER_TYPED_FUNC(GemmConvGradInput, CPU, GemmConvGradInputFunction);
REGISTER_TYPED_FUNC(GemmConvGradFilter, CPU, GemmConvGradFilterFunction);
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_CUDA
REGISTER_TYPED_FUNC(GemmConv, GPU, GemmConvFunction);
REGISTER_TYPED_FUNC(GemmConvGradInput, GPU, GemmConvGradInputFunction);
REGISTER_TYPED_FUNC(GemmConvGradFilter, GPU, GemmConvGradFilterFunction);
......
......@@ -24,7 +24,7 @@ TEST(GemmConv, NaiveConv) {
"NaiveConv-CPU", "GemmConv-CPU", forward);
}
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_CUDA
TEST(GemmConv, Forward) {
Convolution<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU>(
"GemmConv-CPU", "GemmConv-GPU", forward);
......
......@@ -116,7 +116,7 @@ void TestIm2ColFunctor() {
TEST(Im2ColFunctor, CPU) { TestIm2ColFunctor<DEVICE_TYPE_CPU, float>(); }
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_CUDA
TEST(Im2ColFunctor, GPU) { TestIm2ColFunctor<DEVICE_TYPE_GPU, float>(); }
......
......@@ -341,7 +341,7 @@ private:
};
REGISTER_TYPED_FUNC(MulOp, CPU, MulFunc);
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_CUDA
REGISTER_TYPED_FUNC(MulOp, GPU, MulFunc);
#endif
} // namespace paddle
......@@ -207,7 +207,7 @@ private:
REGISTER_TYPED_FUNC(Pad, CPU, PadFunc);
REGISTER_TYPED_FUNC(PadGrad, CPU, PadGradFunc);
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_CUDA
REGISTER_TYPED_FUNC(Pad, GPU, PadFunc);
REGISTER_TYPED_FUNC(PadGrad, GPU, PadGradFunc);
#endif
......
......@@ -217,7 +217,7 @@ public:
REGISTER_TYPED_FUNC(RowConv, CPU, RowConvFunc);
REGISTER_TYPED_FUNC(RowConvGrad, CPU, RowConvGradFunc);
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_CUDA
REGISTER_TYPED_FUNC(RowConv, GPU, RowConvFunc);
REGISTER_TYPED_FUNC(RowConvGrad, GPU, RowConvGradFunc);
#endif
......
......@@ -132,7 +132,7 @@ public:
REGISTER_TYPED_FUNC(NCHW2NHWC, CPU, NCHW2NHWCFunc);
REGISTER_TYPED_FUNC(NHWC2NCHW, CPU, NHWC2NCHWFunc);
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_CUDA
REGISTER_TYPED_FUNC(NCHW2NHWC, GPU, NCHW2NHWCFunc);
REGISTER_TYPED_FUNC(NHWC2NCHW, GPU, NHWC2NCHWFunc);
#endif
......
......@@ -16,7 +16,7 @@ limitations under the License. */
#include "BatchNormalizationLayer.h"
#include "Layer.h"
#include "paddle/utils/Stat.h"
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_CUDA
#include "CudnnBatchNormLayer.h"
#endif
......
......@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/utils/Stat.h"
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_CUDA
#include "hl_batch_transpose.h"
#endif
#include "BatchNormalizationLayer.h"
......@@ -90,7 +90,7 @@ void BatchNormalizationLayer::expandMat(const MatrixPtr& in, MatrixPtr& out) {
size_t batchSize = in->getHeight();
CHECK_EQ(out->getHeight(), batchSize * imgPixels_);
if (useGpu_) {
#ifdef PADDLE_ONLY_CPU
#ifndef PADDLE_WITH_CUDA
LOG(FATAL) << "paddle is compiled only for cpu";
#else
batchTranspose(
......@@ -127,7 +127,7 @@ void BatchNormalizationLayer::shrinkMat(const MatrixPtr& in, MatrixPtr& out) {
}
CHECK_EQ(in->getHeight(), static_cast<size_t>(batchSize * imgPixels_));
if (useGpu_) {
#ifdef PADDLE_ONLY_CPU
#ifndef PADDLE_WITH_CUDA
LOG(FATAL) << "paddle is compiled only for cpu";
#else
batchTranspose(
......
......@@ -15,7 +15,7 @@ limitations under the License. */
#include "PoolLayer.h"
#include "PoolProjectionLayer.h"
#include "paddle/utils/Logging.h"
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_CUDA
#include "CudnnPoolLayer.h"
#endif
namespace paddle {
......@@ -53,7 +53,7 @@ Layer* PoolLayer::create(const LayerConfig& config) {
const std::string& pool = config.inputs(0).pool_conf().pool_type();
if (pool == "max-projection" || pool == "avg-projection") {
return new PoolProjectionLayer(config);
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_CUDA
} else if (CudnnPoolLayer::typeCheck(pool)) {
return new CudnnPoolLayer(config);
#endif
......
......@@ -674,7 +674,7 @@ void testLayerGradKernel(TestConfig testConf,
bool useGpu,
bool useWeight,
float epsilon) {
#ifdef PADDLE_ONLY_CPU
#ifndef PADDLE_WITH_CUDA
if (useGpu) return;
#endif
FLAGS_use_gpu = useGpu;
......
......@@ -119,7 +119,7 @@ TEST(Layer, batchNorm) {
CHECK_EQ(static_cast<int>(convLayer->getOutputValue()->getWidth()), 576);
}
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_CUDA
void batchNormInference(int n, int c, int h, int w) {
MatrixPtr input = std::make_shared<GpuMatrix>(n, c * h * w);
MatrixPtr cudnnOut = std::make_shared<GpuMatrix>(n, c * h * w);
......
......@@ -117,7 +117,7 @@ MatrixPtr doOneConvTest(size_t imgSize,
}
TEST(Layer, convParaUnified) {
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_CUDA
MatrixPtr input, resultCpu, resultGpu;
/// TEST1 for conv ///
......
......@@ -150,7 +150,7 @@ TEST(Layer, detectionOutputLayerFwd) {
useGpu,
result2);
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_CUDA
// GPU case 1.
useGpu = true;
inputLoc = Matrix::create(1, 16, false, useGpu);
......
......@@ -51,7 +51,7 @@ void testEvaluator(TestConfig testConf,
string testEvaluatorName,
size_t batchSize,
bool useGpu) {
#ifdef PADDLE_ONLY_CPU
#ifndef PADDLE_WITH_CUDA
if (useGpu) return;
#endif
FLAGS_use_gpu = useGpu;
......
......@@ -97,7 +97,7 @@ TEST(Layer, kmaxSeqScoreLayer) {
Matrix::create(subSeqStartPosition.back(), 1, false, false);
std::vector<bool> mode = {false};
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_CUDA
mode.push_back(true);
#endif
......
......@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_CUDA
#include <cudnn.h>
#endif
#include <gtest/gtest.h>
......@@ -258,7 +258,7 @@ void testProjectionConv(size_t groups, bool isDeconv) {
true);
}
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_CUDA
TEST(Projection, conv) {
/// test ConvProjection
testProjectionConv(1, false);
......@@ -422,7 +422,7 @@ TEST(Layer, depthwiseConvLayer) {
// 'depthwise_conv' is a sepecial case of 'exconv' whose
// groups size equals to the input channels size.
testDepthwiseConvLayer("exconv", /* useGpu= */ false);
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_CUDA
testDepthwiseConvLayer("exconv", /* useGpu= */ true);
#endif
}
......@@ -480,7 +480,7 @@ void testConvLayer(const string& type, bool trans, bool useGpu) {
TEST(Layer, convLayer) {
testConvLayer("exconv", /* trans= */ false, /* useGpu= */ false);
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_CUDA
testConvLayer("exconv", /* trans= */ false, /* useGpu= */ true);
testConvLayer("cudnn_conv", /* trans= */ false, /* useGpu= */ true);
#endif
......@@ -525,7 +525,7 @@ TEST(Layer, convTransLayer) {
for (auto useGpu : {false, true}) {
testConvTransLayer("exconvt", /* trans= */ false, /* useGpu= */ useGpu);
}
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_CUDA
testConvTransLayer("cudnn_convt", /* trans= */ false, /* useGpu= */ true);
#endif
}
......@@ -638,7 +638,7 @@ TEST(Layer, SelectiveFullyConnectedLayer) {
/* trans= */ false,
/* useGup= */ false,
false);
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_CUDA
testLayerGrad(config,
"selective_fc",
100,
......@@ -1210,7 +1210,7 @@ void testPoolLayer(const string& poolType, bool trans, bool useGpu) {
testLayerGrad(config, "pool", 100, trans, useGpu);
}
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_CUDA
void testPoolLayer2(const string& poolType, bool trans, bool useGpu) {
TestConfig config;
config.inputDefs.push_back({INPUT_DATA, "layer_0", 3200, 0});
......@@ -1236,7 +1236,7 @@ TEST(Layer, PoolLayer) {
testPoolLayer("avg-projection", /* trans= */ false, /* useGpu= */ false);
testPoolLayer("max-projection", /* trans= */ false, /* useGpu= */ false);
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_CUDA
testPoolLayer("avg-projection", /* trans= */ false, /* useGpu= */ true);
testPoolLayer("max-projection", /* trans= */ false, /* useGpu= */ true);
testPoolLayer("cudnn-max-pool", /* trans= */ false, /* useGpu= */ true);
......@@ -1309,7 +1309,7 @@ void testPool3DLayer(const string& poolType, bool trans, bool useGpu) {
TEST(Layer, Pool3DLayer) {
testPool3DLayer("avg", /* trans= */ false, /* useGpu= */ false);
testPool3DLayer("max", /* trans= */ false, /* useGpu= */ false);
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_CUDA
testPool3DLayer("avg", /* trans= */ false, /* useGpu= */ true);
testPool3DLayer("max", /* trans= */ false, /* useGpu= */ true);
#endif
......@@ -1695,7 +1695,7 @@ void testBatchNormLayer(const string& type, bool trans, bool useGpu) {
TEST(Layer, BatchNormalizationLayer) {
testBatchNormLayer("batch_norm", false, false);
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_CUDA
testBatchNormLayer("batch_norm", false, true);
if (hl_get_cudnn_lib_version() >= int(4000)) {
testBatchNormLayer("cudnn_batch_norm", false, true);
......@@ -1744,7 +1744,7 @@ void testBatchNorm3DLayer(const string& type, bool trans, bool useGpu) {
TEST(Layer, testBatchNorm3DLayer) {
testBatchNorm3DLayer("batch_norm", false, false);
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_CUDA
testBatchNorm3DLayer("batch_norm", false, true);
if (hl_get_cudnn_lib_version() >= int(4000)) {
testBatchNorm3DLayer("cudnn_batch_norm", false, true);
......@@ -2262,7 +2262,7 @@ void test3DConvLayer(const string& type, bool trans, bool useGpu) {
TEST(Layer, test3DConvLayer) {
test3DConvLayer("conv3d", /* trans= */ false, /* useGpu= */ false);
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_CUDA
test3DConvLayer("conv3d", /* trans= */ false, /* useGpu= */ true);
#endif
}
......@@ -2339,7 +2339,7 @@ void test3DDeConvLayer(const string& type, bool trans, bool useGpu) {
TEST(Layer, test3DDeConvLayer) {
test3DDeConvLayer("deconv3d", /* trans= */ false, /* useGpu= */ false);
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_CUDA
test3DDeConvLayer("deconv3d", /* trans= */ false, /* useGpu= */ true);
#endif
}
......
......@@ -243,7 +243,7 @@ TEST(Compare, concat_slice) {
compareNetwork(config_file_a, config_file_b);
}
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_CUDA
TEST(Compare, img_pool) {
std::string config_file_a = "./gserver/tests/img_pool_a.conf";
std::string config_file_b = "./gserver/tests/img_pool_b.conf";
......
......@@ -151,7 +151,7 @@ TEST(Layer, priorBoxLayerFwd) {
useGpu,
result);
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_CUDA
// reset the input parameters
variance[1] = 0.1;
variance[3] = 0.2;
......
......@@ -485,7 +485,7 @@ TEST(ProtoDataProvider, test) {
// Currently in async mode, useGpu is not supported
continue;
}
#ifdef PADDLE_ONLY_CPU
#ifndef PADDLE_WITH_CUDA
if (useGpu) {
continue;
}
......@@ -525,7 +525,7 @@ TEST(ProtoDataProvider, constant_slots) {
for (int numConstantSlots : {1, 2}) {
for (int useGpu : numTwoArray) {
for (int dataCompression : numTwoArray) {
#ifdef PADDLE_ONLY_CPU
#ifndef PADDLE_WITH_CUDA
if (useGpu) {
continue;
}
......@@ -708,7 +708,7 @@ TEST(ProtoSequenceDataProvider, test) {
// Currently in async mode, useGpu is not supported
continue;
}
#ifdef PADDLE_ONLY_CPU
#ifndef PADDLE_WITH_CUDA
if (useGpu) {
continue;
}
......
......@@ -37,7 +37,7 @@ TEST(PyDataProvider, py_fill_slots) {
config.clear_files();
std::string dataFile = "gserver/tests/pyDataProvider/pyDataProviderList";
config.set_files(dataFile);
#ifdef PADDLE_ONLY_CPU
#ifndef PADDLE_WITH_CUDA
bool useGpu = false;
#else
bool useGpu = true;
......@@ -71,7 +71,7 @@ TEST(PyDataProvider, py_fill_nest_slots) {
std::string dataFile = "gserver/tests/pyDataProvider/pyDataProviderList";
config.set_files(dataFile);
EXPECT_EQ(config.IsInitialized(), true);
#ifdef PADDLE_ONLY_CPU
#ifndef PADDLE_WITH_CUDA
bool useGpu = false;
#else
bool useGpu = true;
......
......@@ -321,7 +321,7 @@ TEST(Layer, SelectiveFcLayer_train_dense_mul) {
"filelist=gserver/tests/SelectiveFcTest/dense_mul_list";
for (auto useGpu : {false, true}) {
#ifdef PADDLE_ONLY_CPU
#ifndef PADDLE_WITH_CUDA
if (useGpu) {
break;
}
......@@ -388,7 +388,7 @@ void testSelectiveFcLayerTrainSparseMul(const LayerConfig& config,
outMatSelfc->getWidth(),
outMatSelfc->getElementCnt()));
cpuOutMatSelfc->copyFrom(*outMatSelfc, HPPL_STREAM_DEFAULT);
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_CUDA
if (useGpu) {
hl_stream_synchronize(HPPL_STREAM_DEFAULT);
}
......@@ -418,7 +418,7 @@ void testSelectiveFcLayerTrainSparseMul(const LayerConfig& config,
MatrixPtr cpuOutMatFc(
new CpuMatrix(outMatFc->getHeight(), outMatFc->getWidth()));
cpuOutMatFc->copyFrom(*outMatFc, HPPL_STREAM_DEFAULT);
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_CUDA
if (useGpu) {
hl_stream_synchronize(HPPL_STREAM_DEFAULT);
}
......@@ -443,7 +443,7 @@ TEST(Layer, SelectiveFcLayer_train_sparse_mul) {
selLayerConfig.set_size(fcLayerWidth);
testSelectiveFcLayerTrainSparseMul(selLayerConfig, false);
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_CUDA
testSelectiveFcLayerTrainSparseMul(selLayerConfig, true);
#endif
}
......
......@@ -195,7 +195,7 @@ TEST(Layer, SeqSliceLayer) {
vector<vector<real>> ends;
std::vector<bool> mode = {false};
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_CUDA
mode.push_back(true);
#endif
genSeqInfo(seqStartPos, subSeqStartPos);
......
......@@ -199,7 +199,7 @@ TEST(Layer, WarpCTCLayer) {
for (auto batchSize : {1, 10, 32}) {
for (auto normByTimes : {false, true}) {
for (auto useGpu : {false, true}) {
#ifdef PADDLE_ONLY_CPU
#ifndef PADDLE_WITH_CUDA
if (useGpu) continue;
#endif
LOG(INFO) << "layerSize=" << layerSize << " batchSize=" << batchSize
......
......@@ -670,7 +670,7 @@ void GpuMatrix::leftMul(Matrix& a, real scaleAB, real scaleT) {
}
void GpuMatrix::selectRows(Matrix& table, IVector& ids) {
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_CUDA
CHECK(dynamic_cast<GpuMatrix*>(&table));
CHECK(table.useGpu());
CHECK(ids.useGpu());
......@@ -694,7 +694,7 @@ void GpuMatrix::selectRows(Matrix& table, IVector& ids) {
}
void GpuMatrix::addToRows(Matrix& table, IVector& ids) {
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_CUDA
CHECK(dynamic_cast<GpuMatrix*>(&table));
CHECK(table.useGpu());
CHECK(ids.useGpu());
......@@ -741,7 +741,7 @@ void GpuMatrix::rowMax(Matrix& max) {
}
void GpuMatrix::rowMax(IVector& maxIds, Matrix& maxVal) {
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_CUDA
CHECK(maxIds.useGpu() && maxVal.useGpu()) << "Matrix type are not equal";
size_t numSamples = getHeight();
size_t beam = maxVal.getWidth();
......
......@@ -836,7 +836,7 @@ void GpuSparseMatrix::zeroMem() {
}
void GpuSparseMatrix::rowMax(IVector& maxIds, Matrix& maxVal) {
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_CUDA
CHECK(maxIds.useGpu() && maxVal.useGpu()) << "Matrix type are not equal";
size_t numSamples = getHeight();
size_t beam = maxVal.getWidth();
......
......@@ -172,7 +172,7 @@ void GpuVectorT<T>::isEqualTo(const VectorT<T>& b, const T& value) {
template <class T>
void GpuVectorT<T>::selectFrom(const VectorT<T>& src, const VectorT<int>& ids) {
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_CUDA
hl_vector_select_from<T>(this->getData(),
this->getSize(),
src.getData(),
......@@ -850,7 +850,7 @@ CpuGpuVectorT<T>::CpuGpuVectorT(CpuGpuVectorT<T>& src,
size_t size)
: sync_(nullptr) {
CHECK_LE(offset + size, static_cast<size_t>(src.getSize()));
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_CUDA
SyncedFlag* flag = src.getSync();
if (*flag == DATA_AT_CPU) {
src.copyToGpu(); // will set synchronous data between CPU and GPU
......@@ -861,7 +861,7 @@ CpuGpuVectorT<T>::CpuGpuVectorT(CpuGpuVectorT<T>& src,
auto cMemHandle = (src.getVector(false))->getMemoryHandle();
cpuVectorT_ = std::make_shared<CpuVectorT<T>>(
size, std::dynamic_pointer_cast<CpuMemoryHandle>(cMemHandle), offset);
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_CUDA
auto gMemHandle = (src.getVector(true))->getMemoryHandle();
gpuVectorT_ = std::make_shared<GpuVectorT<T>>(
size, std::dynamic_pointer_cast<GpuMemoryHandle>(gMemHandle), offset);
......
......@@ -68,7 +68,7 @@ void testPoolAllocator() {
TEST(Allocator, Pool) {
testPoolAllocator<CpuAllocator>();
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_CUDA
testPoolAllocator<GpuAllocator>();
#endif
}
......@@ -92,7 +92,7 @@ TEST(MemoryHandle, Cpu) {
EXPECT_EQ(ptr1, ptr2);
}
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_CUDA
TEST(MemoryHandle, Gpu) {
int numGpu = hl_get_device_count();
......
......@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_CUDA
/**
* This test file use autotest::AutoCompare and cmpWithoutArg to compares the
* implementation of CPU and GPU member function in
......
......@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_CUDA
#include <gtest/gtest.h>
#include "paddle/math/Vector.h"
......
......@@ -94,7 +94,7 @@ void testWrapper(F&& f) {
}
}
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_CUDA
TEST(ExecViaCpu, test1) {
testWrapper(f);
testWrapper(&f);
......
......@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_CUDA
#include <gtest/gtest.h>
#include "paddle/math/Matrix.h"
......
......@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_CUDA
/**
* This test file use autotest::AutoCompare and cmpWithArg to compares the
* implementation of CPU and GPU member function in Matrix.cpp.
......
......@@ -47,7 +47,7 @@ struct MatrixPara {
SparseFormat format;
};
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_CUDA
void test_sparse_matrix_mul(MatrixPara paraA,
MatrixPara paraB,
MatrixPara paraC) {
......@@ -452,7 +452,7 @@ TEST(Matrix, SparseMatrixCSRFormatTrimFrom) {
matB->trimFrom(*mat);
checkSMatrixEqual2(matA, matB);
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_CUDA
GpuSparseMatrixPtr matC = std::make_shared<GpuSparseMatrix>(
height, trimedWidth, height, FLOAT_VALUE, SPARSE_CSR, true);
matC->trimFrom(*mat);
......@@ -546,7 +546,7 @@ TEST(Matrix, SparseMatrixCSCFormatTrimFrom) {
matB->trimFrom(*mat);
checkSMatrixEqual2(matA, matB);
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_CUDA
GpuSparseMatrixPtr matC = std::make_shared<GpuSparseMatrix>(
height, trimedWidth, height, FLOAT_VALUE, SPARSE_CSC, true);
matC->trimFrom(*mat);
......
......@@ -270,7 +270,7 @@ TEST(Unary, BaseOp) {
TestUnaryVectorT<CpuIVector, int> testCpuIVector(
testUnaryBaseOpInt<CpuIVector>);
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_GPU
TestUnaryMatrix<GpuMatrix> testGpuMatrix(testUnaryBaseOp<GpuMatrix>);
TestUnaryVectorT<GpuVector, real> testGpuVector(testUnaryBaseOp<GpuVector>);
TestUnaryVectorT<GpuIVector, int> testGpuIVector(
......@@ -317,7 +317,7 @@ void testUnayrMathOp(Tensor& A1, Tensor& A2) {
TEST(Unary, MathOp) {
TestUnaryMatrix<CpuMatrix> testCpu(testUnayrMathOp<CpuMatrix>);
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_GPU
TestUnaryMatrix<GpuMatrix> testGpu(testUnayrMathOp<GpuMatrix>);
#endif
}
......@@ -374,7 +374,7 @@ void testUnayrCompareOp(Tensor& A1, Tensor& A2) {
TEST(Unary, CompareOp) {
TestUnaryMatrix<CpuMatrix> testCpu(testUnayrCompareOp<CpuMatrix>);
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_GPU
TestUnaryMatrix<GpuMatrix> testGpu(testUnayrCompareOp<GpuMatrix>);
#endif
}
......@@ -536,7 +536,7 @@ void testBinaryBaseOp(Tensor& A1, Tensor& A2, Tensor& B) {
TEST(Binary, BaseOp) {
TestBinaryMatrix<CpuMatrix> testCpu(testBinaryBaseOp<CpuMatrix>);
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_GPU
TestBinaryMatrix<GpuMatrix> testGpu(testBinaryBaseOp<GpuMatrix>);
#endif
}
......@@ -710,7 +710,7 @@ void testBinaryMathOp(Tensor& A1, Tensor& A2, Tensor& B) {
TEST(Binary, MathOp) {
TestBinaryMatrix<CpuMatrix> testCpu(testBinaryMathOp<CpuMatrix>);
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_GPU
TestBinaryMatrix<GpuMatrix> testGpu(testBinaryMathOp<GpuMatrix>);
#endif
}
......@@ -810,7 +810,7 @@ void testBinaryCompareOp(Tensor& A1, Tensor& A2, Tensor& B) {
TEST(Binary, CompareOp) {
TestBinaryMatrix<CpuMatrix> testCpu(testBinaryCompareOp<CpuMatrix>);
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_GPU
TestBinaryMatrix<GpuMatrix> testGpu(testBinaryCompareOp<GpuMatrix>);
#endif
}
......@@ -955,7 +955,7 @@ void testTernaryBaseOp(Tensor& A1, Tensor& A2, Tensor& B, Tensor& C) {
TEST(Ternary, BaseOp) {
TestTernaryMatrix<CpuMatrix> testCpu(testTernaryBaseOp<CpuMatrix>);
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_GPU
TestTernaryMatrix<GpuMatrix> testGpu(testTernaryBaseOp<GpuMatrix>);
#endif
}
......@@ -1058,7 +1058,7 @@ void testTernaryCompareOp(Tensor& A1, Tensor& A2, Tensor& B, Tensor& C) {
TEST(Ternary, CompareOp) {
TestTernaryMatrix<CpuMatrix> testCpu(testTernaryCompareOp<CpuMatrix>);
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_GPU
TestTernaryMatrix<GpuMatrix> testGpu(testTernaryCompareOp<GpuMatrix>);
#endif
}
......@@ -1086,7 +1086,7 @@ void testQuaternaryAdd(
TEST(Quaternary, BaseOp) {
TestQuaternaryMatrix<CpuMatrix> testCpu(testQuaternaryAdd<CpuMatrix>);
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_GPU
TestQuaternaryMatrix<GpuMatrix> testGpu(testQuaternaryAdd<GpuMatrix>);
#endif
}
......@@ -1156,7 +1156,7 @@ void testQuaternaryCompareOp(
TEST(Quaternary, CompareOp) {
TestQuaternaryMatrix<CpuMatrix> testCpu(testQuaternaryCompareOp<CpuMatrix>);
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_GPU
TestQuaternaryMatrix<GpuMatrix> testGpu(testQuaternaryCompareOp<GpuMatrix>);
#endif
}
......@@ -91,7 +91,7 @@ int VectorCheckErr(const VectorPtr& vector1, const VectorPtr& vector2) {
typedef std::function<void(size_t size, bool useGpu)> testMatrixFunc;
void testCase(testMatrixFunc matrixFunc) {
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_CUDA
for (auto useGpu : {false, true}) {
#else
for (auto useGpu : {false}) {
......
......@@ -17,7 +17,7 @@ limitations under the License. */
using namespace paddle; // NOLINT
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_CUDA
TEST(MatrixBatchTransTest, test_batch_matrix_transpose) {
const int nx = 100;
const int ny = 50;
......
......@@ -72,7 +72,7 @@ void testLazyAssign(int height, int width) {
TEST(lazyAssign, CPU) { testMatrixCase(testLazyAssign<CpuMatrix>); }
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_GPU
TEST(lazyAssign, GPU) { testMatrixCase(testLazyAssign<GpuMatrix>); }
#endif
......@@ -142,6 +142,6 @@ void testSgdUpdate(int height, int width) {
TEST(sgdUpdate, CPU) { testMatrixCase(testSgdUpdate<CpuMatrix>); }
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_GPU
TEST(sgdUpdate, GPU) { testMatrixCase(testSgdUpdate<GpuMatrix>); }
#endif
......@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_CUDA
/// This unittest checks GpuMatrix/CpuMatrix get same result, so disable when
/// only cpu version.
......
......@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_CUDA
#include <cuda_runtime.h>
#include <gtest/gtest.h>
......
......@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_CUDA
/// This unittest checks GpuSparseMatrix/CpuSparseMatrix get same result,
// so disable when
/// only cpu version.
......
......@@ -175,7 +175,7 @@ void* BuddyAllocator::SystemAlloc(size_t size) {
}
BuddyAllocator::PoolSet::iterator BuddyAllocator::RefillPool() {
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_CUDA
if (system_allocator_->UseGpu()) {
if ((total_used_ + total_free_) == 0) {
// Compute the maximum allocation size for the first allocation.
......
......@@ -62,7 +62,7 @@ void CPUAllocator::Free(void* p, size_t size, size_t index) {
bool CPUAllocator::UseGpu() const { return false; }
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_CUDA
void* GPUAllocator::Alloc(size_t& index, size_t size) {
// CUDA documentation doesn't explain if cudaMalloc returns nullptr
......
......@@ -40,7 +40,7 @@ class CPUAllocator : public SystemAllocator {
virtual bool UseGpu() const;
};
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_CUDA
class GPUAllocator : public SystemAllocator {
public:
virtual void* Alloc(size_t& index, size_t size);
......
......@@ -56,7 +56,7 @@ TEST(CPUAllocator, LockMem) {
TestAllocator(a, 0);
}
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_CUDA
TEST(GPUAllocator, Alloc) {
paddle::memory::detail::GPUAllocator a;
TestAllocator(a, 2048);
......
......@@ -26,7 +26,7 @@ void Copy<platform::CPUPlace, platform::CPUPlace>(platform::CPUPlace, void* dst,
std::memcpy(dst, src, num);
}
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_CUDA
template <>
void Copy<platform::CPUPlace, platform::GPUPlace>(platform::CPUPlace dst_place,
void* dst,
......
......@@ -33,7 +33,7 @@ namespace memory {
template <typename DstPlace, typename SrcPlace>
void Copy(DstPlace, void* dst, SrcPlace, const void* src, size_t num);
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_CUDA
/**
* \brief Copy memory from one place to another place.
......
......@@ -62,7 +62,7 @@ size_t Used<platform::CPUPlace>(platform::CPUPlace place) {
return GetCPUBuddyAllocator()->Used();
}
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_CUDA
BuddyAllocator* GetGPUBuddyAllocator(int gpu_id) {
using BuddyAllocVec = std::vector<BuddyAllocator*>;
......@@ -77,7 +77,7 @@ BuddyAllocator* GetGPUBuddyAllocator(int gpu_id) {
// GPU buddy allocator initialization
std::call_once(gpu_allocator_flag, [&]() {
int gpu_num = platform::GetDeviceCount();
int gpu_num = platform::GetCUDADeviceCount();
allocators.reserve(gpu_num);
for (int gpu = 0; gpu < gpu_num; gpu++) {
platform::SetDeviceId(gpu);
......
......@@ -80,7 +80,7 @@ TEST(BuddyAllocator, CPUMultAlloc) {
}
}
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_CUDA
size_t align(size_t size, paddle::platform::GPUPlace place) {
size += sizeof(paddle::memory::detail::Metadata);
......
......@@ -103,12 +103,16 @@ set(DEPS_OPS
recurrent_op
cond_op
cross_entropy_op
softmax_with_cross_entropy_op)
softmax_with_cross_entropy_op
sum_op)
op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc
DEPS framework_proto tensor net_op)
op_library(cond_op SRCS cond_op.cc DEPS framework_proto tensor operator net_op)
op_library(cross_entropy_op DEPS cross_entropy)
op_library(softmax_with_cross_entropy_op DEPS cross_entropy softmax)
op_library(sum_op DEPS net_op)
list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS})
foreach(src ${GENERAL_OPS})
......
......@@ -69,6 +69,22 @@ class ReluOpMaker : public framework::OpProtoAndCheckerMaker {
}
};
template <typename AttrType>
class LeakyReluOpMaker : public framework::OpProtoAndCheckerMaker {
public:
LeakyReluOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of LeakyRelu operator");
AddOutput("Y", "Output of LeakyRelu operator");
AddComment(
"LeakyRelu activation operator, "
"leaky_relu = max(x, alpha * x)");
AddAttr<AttrType>("alpha", "The small negative slope")
.SetDefault(static_cast<AttrType>(0.02f));
}
};
class TanhOpMaker : public framework::OpProtoAndCheckerMaker {
public:
TanhOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
......@@ -240,6 +256,9 @@ REGISTER_OP(softsign, ops::ActivationOp, ops::SoftsignOpMaker, softsign_grad,
REGISTER_OP(brelu, ops::ActivationOp, ops::BReluOpMaker<float>, brelu_grad,
ops::ActivationOpGrad);
REGISTER_OP(leaky_relu, ops::ActivationOp, ops::LeakyReluOpMaker<float>,
leaky_relu_grad, ops::ActivationOpGrad);
REGISTER_OP(soft_relu, ops::ActivationOp, ops::SoftReluOpMaker<float>,
soft_relu_grad, ops::ActivationOpGrad);
......
......@@ -309,6 +309,33 @@ struct SoftReluGradFunctor : public BaseActivationFunctor<T> {
}
};
template <typename T>
struct LeakyReluFunctor : public BaseActivationFunctor<T> {
float alpha;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"alpha", &alpha}};
}
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) const {
y.device(d) = x.cwiseMax(alpha * x);
}
};
template <typename T>
struct LeakyReluGradFunctor : public BaseActivationFunctor<T> {
float alpha;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"alpha", &alpha}};
}
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
auto temp1 = alpha * (x < static_cast<T>(0)).template cast<T>().eval();
auto temp2 = (x >= static_cast<T>(0)).template cast<T>().eval();
dx.device(d) = dy * (temp1 + temp2).template cast<T>();
}
};
template <typename T>
struct PowFunctor : public BaseActivationFunctor<T> {
float factor;
......@@ -379,4 +406,5 @@ struct STanhGradFunctor : public BaseActivationFunctor<T> {
__macro(soft_relu, SoftReluFunctor, SoftReluGradFunctor); \
__macro(pow, PowFunctor, PowGradFunctor); \
__macro(stanh, STanhFunctor, STanhGradFunctor); \
__macro(softsign, SoftsignFunctor, SoftsignGradFunctor)
__macro(softsign, SoftsignFunctor, SoftsignGradFunctor); \
__macro(leaky_relu, LeakyReluFunctor, LeakyReluGradFunctor)
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/adadelta_op.h"
namespace paddle {
namespace operators {
class AdadeltaOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Param"),
"Input(Param) of AdadeltaOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Grad"),
"Input(Grad) of AdadeltaOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("AvgSquaredGrad"),
"Input(AvgSquaredGrad) of AdadeltaOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("AvgSquaredUpdate"),
"Input(AvgSquaredUpdate) of AdadeltaOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("ParamOut"),
"Output(ParamOut) of AdadeltaOp should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("AvgSquaredGradOut"),
"Output(AvgSquaredGradOut) of AdadeltaOp should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("AvgSquaredUpdateOut"),
"Output(AvgSquaredUpdateOut) of AdadeltaOp should not be null.");
auto param_dim = ctx->GetInputDim("Param");
PADDLE_ENFORCE_EQ(
param_dim, ctx->GetInputDim("Grad"),
"param and grad input of AdadeltaOp should have same dimension");
PADDLE_ENFORCE_EQ(param_dim, ctx->GetInputDim("AvgSquaredGrad"),
"Param and AvgSquaredGrad input of AdadeltaOp "
"should have same dimension");
PADDLE_ENFORCE_EQ(param_dim, ctx->GetInputDim("AvgSquaredUpdate"),
"Param and AvgSquaredUpdate input of AdadeltaOp "
"should have same dimension");
ctx->SetOutputDim("ParamOut", param_dim);
ctx->SetOutputDim("AvgSquaredGradOut", param_dim);
ctx->SetOutputDim("AvgSquaredUpdateOut", param_dim);
}
};
class AdadeltaOpMaker : public framework::OpProtoAndCheckerMaker {
public:
AdadeltaOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Param", "(Tensor) Input parameter");
AddInput("Grad", "(Tensor) Input gradient");
AddInput("AvgSquaredGrad",
"(Tensor) Input expectation of squared gradient");
AddInput("AvgSquaredUpdate",
"(Tensor) Input expectation of squared parameter updates");
AddOutput("ParamOut", "(Tensor) Output parameter");
AddOutput("AvgSquaredGradOut",
"(Tensor) Output expectation of squared gradient");
AddOutput("AvgSquaredUpdateOut",
"(Tensor) Output expectation of squared parameter updates");
AddAttr<float>("rho",
"(float, default 0.95) Exponential decay rate "
"for squared gradients.")
.SetDefault(0.95f);
AddAttr<float>("epsilon",
"(float, default 1.0e-6) Constant for "
"numerical stability")
.SetDefault(1.0e-6f);
AddComment(R"DOC(
Adadelta Updates Operator.
This implements the Adadelta optimizer[1]. Adadelta is a per-dimension
adaptive learning rate method for gradient descent.
Adadelta updates:
avg_squared_grad_out = rho * avg_squared_grad + (1 - rho) * grad * grad
param_update = - sqrt((avg_squared_update + epsilon) /
(avg_squared_grad_out + epsilon)) * grad
avg_squared_update_out = rho * avg_squared_update + (1 - rho) * param_update**2
param_out = param + param_update
References:
[1] ADADELTA: An Adaptive Learning Rate Method
https://arxiv.org/abs/1212.5701
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(adadelta, ops::AdadeltaOp, ops::AdadeltaOpMaker);
REGISTER_OP_CPU_KERNEL(
adadelta, ops::AdadeltaOpKernel<paddle::platform::CPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/operators/adadelta_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(
adadelta, ops::AdadeltaOpKernel<paddle::platform::GPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
template <typename Place, typename T>
class AdadeltaOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto param_out_tensor = ctx.Output<framework::Tensor>("ParamOut");
auto avg_squared_grad_out_tensor =
ctx.Output<framework::Tensor>("AvgSquaredGradOut");
auto avg_squared_update_out_tensor =
ctx.Output<framework::Tensor>("AvgSquaredUpdateOut");
param_out_tensor->mutable_data<T>(ctx.GetPlace());
avg_squared_grad_out_tensor->mutable_data<T>(ctx.GetPlace());
avg_squared_update_out_tensor->mutable_data<T>(ctx.GetPlace());
float rho = ctx.Attr<float>("rho");
float epsilon = ctx.Attr<float>("epsilon");
auto param = framework::EigenVector<T>::Flatten(
*ctx.Input<framework::Tensor>("Param"));
auto grad = framework::EigenVector<T>::Flatten(
*ctx.Input<framework::Tensor>("Grad"));
// Squared gradient accumulator
auto avg_squared_grad = framework::EigenVector<T>::Flatten(
*ctx.Input<framework::Tensor>("AvgSquaredGrad"));
// Squared updates accumulator
auto avg_squared_update = framework::EigenVector<T>::Flatten(
*ctx.Input<framework::Tensor>("AvgSquaredUpdate"));
auto param_out = framework::EigenVector<T>::Flatten(*param_out_tensor);
auto avg_squared_grad_out =
framework::EigenVector<T>::Flatten(*avg_squared_grad_out_tensor);
auto avg_squared_update_out =
framework::EigenVector<T>::Flatten(*avg_squared_update_out_tensor);
auto place = ctx.GetEigenDevice<Place>();
avg_squared_grad_out.device(place) =
rho * avg_squared_grad + (1 - rho) * grad.square();
auto update =
-((avg_squared_update + epsilon) / (avg_squared_grad_out + epsilon))
.sqrt() *
grad;
avg_squared_update_out.device(place) =
rho * avg_squared_update + (1 - rho) * update.square();
param_out.device(place) = param + update;
}
};
} // namespace operators
} // namespace paddle
......@@ -126,8 +126,7 @@ void CondOp::PrepareDataForSubnet(
dim[0] = index_tensors[i].dims()[0];
tensor_child->mutable_data<float>(dim, platform::CPUPlace());
Gather<float>(dev_ctx.GetPlace(), tensor_parent, &index_tensors[i],
tensor_child);
CPUGather<float>(dev_ctx, *tensor_parent, index_tensors[i], tensor_child);
}
}
......@@ -188,7 +187,7 @@ void CondOp::MergeDataFromSubnet(const framework::Scope& scope,
Variable* var_child = sub_scopes[i]->FindVar(output);
PADDLE_ENFORCE_NOT_NULL(var_child);
auto* tensor_child = &var_child->Get<LoDTensor>();
ScatterUpdate<float>(dev_ctx.GetPlace(), tensor_child, &index_tensors[i],
ScatterAssign<float>(dev_ctx, *tensor_child, index_tensors[i],
tensor_parent);
}
}
......
......@@ -34,7 +34,7 @@ struct StridedMemcpyFunctor<T, 1> {
auto& cpu_place = boost::get<platform::CPUPlace>(place);
memory::Copy(cpu_place, dst, cpu_place, src, sizeof(T) * dst_dim.head);
} else {
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_CUDA
auto& gpu_place = boost::get<platform::GPUPlace>(place);
auto& cuda_ctx =
reinterpret_cast<const platform::CUDADeviceContext&>(dev_ctx);
......
/* Copyright (c) 2016 PaddlePaddle Authors All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/tensor.h"
#include "paddle/platform/place.h"
namespace paddle {
namespace operators {
using framework::Tensor;
using platform::Place;
#define CUDA_1D_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
i += blockDim.x * gridDim.x)
template <typename T>
__global__ void GatherCUDAKernel(const T* params, const int* indices, T* output,
size_t index_size, size_t slice_size) {
CUDA_1D_KERNEL_LOOP(i, index_size * slice_size) {
int indices_i = i / slice_size;
int slice_i = i - indices_i * slice_size; // offset inside the slice
int gather_i = indices[indices_i];
int params_i = gather_i * slice_size + slice_i;
*(output + i) = *(params + params_i);
}
}
/**
* A thin wrapper on gpu tensor
* Return a new tensor from source tensor, gathered according to index
* input[src]: type-T source Tensor
* input[index]: type-int index Tensor (1-D)
* return: output tensor
*/
template <typename T>
void GPUGather(const platform::DeviceContext& ctx, const Tensor& src,
const Tensor& index, Tensor* output) {
// PADDLE_ENFORCE(platform::is_gpu_place(place));
// check index of shape 1-D
PADDLE_ENFORCE(index.dims().size() == 1);
int index_size = index.dims()[0];
auto src_dims = src.dims();
framework::DDim output_dims(src_dims);
output_dims[0] = index_size;
// slice size
int slice_size = 1;
for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i];
const T* p_src = src.data<T>();
const int* p_index = index.data<int>();
T* p_output = output->data<T>();
int block = 512;
int n = slice_size * index_size;
int grid = (n + block - 1) / block;
GatherCUDAKernel<T><<<
grid, block, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream()>>>(
p_src, p_index, p_output, index_size, slice_size);
}
} // namespace operators
} // namespace paddle
......@@ -24,49 +24,40 @@ limitations under the License. */
namespace paddle {
namespace operators {
// Implementation of CPU copy
template <typename T>
void CPUGather(const T* src, const int* indices, const int slice_size,
const int index_size, T* output) {
const size_t slice_bytes = slice_size * sizeof(T);
for (int i = 0; i < index_size; ++i) {
int index_ = indices[i];
memcpy(output + i * slice_size, src + index_ * slice_size, slice_bytes);
}
}
// Implementation of GPU copy:
template <typename T>
void GPUGather(const T* src, const int* index, const int slice_size,
const int index_size, T* output);
using framework::Tensor;
/**
* A thin wrapper for gathering on cpu tensor
* Return a new tensor from source tensor, gathered according to index
* input[src]: type-T source Tensor
* input[index]: type-int index Tensor (1-D)
* return: output tensor
*/
template <typename T>
void Gather(const platform::Place& place, const paddle::framework::Tensor* src,
const paddle::framework::Tensor* index,
paddle::framework::Tensor* output) {
void CPUGather(const platform::DeviceContext& ctx, const Tensor& src,
const Tensor& index, Tensor* output) {
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()));
// check index of shape 1-D
PADDLE_ENFORCE(index->dims().size() == 1);
int index_size = index->dims()[0];
PADDLE_ENFORCE(index.dims().size() == 1);
int index_size = index.dims()[0];
auto src_dims = src->dims();
auto src_dims = src.dims();
framework::DDim output_dims(src_dims);
output_dims[0] = index_size;
const T* p_src = src.data<T>();
const int* p_index = index.data<int>();
T* p_output = output->data<T>();
// slice size
int slice_size = 1;
for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i];
// Gathering
if (platform::is_cpu_place(place)) {
CPUGather<T>(src->data<T>(), index->data<int>(), slice_size, index_size,
output->data<T>());
const size_t slice_bytes = slice_size * sizeof(T);
for (int i = 0; i < index_size; ++i) {
int index_ = p_index[i];
memcpy(p_output + i * slice_size, p_src + index_ * slice_size, slice_bytes);
}
}
......
......@@ -31,6 +31,8 @@ class GatherOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of GatherOp should not be null.");
auto index_dims = ctx->GetInputDim("Index");
PADDLE_ENFORCE(index_dims.size() == 1);
int batch_size = ctx->GetInputDim("Index")[0];
PADDLE_ENFORCE_GE(batch_size, 0, "Batch size must be >0");
framework::DDim output_dims(ctx->GetInputDim("X"));
......@@ -79,8 +81,5 @@ Out = X[Index]
namespace ops = paddle::operators;
REGISTER_OP(gather, ops::GatherOp, ops::GatherOpMaker, gather_grad,
ops::GatherGradOp);
REGISTER_OP_CPU_KERNEL(gather,
ops::GatherOpKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
gather_grad,
ops::GatherGradientOpKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(gather, ops::GatherOpKernel<float>);
REGISTER_OP_CPU_KERNEL(gather_grad, ops::GatherGradientOpKernel<float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "gather.cu.h"
#include "paddle/framework/eigen.h"
#include "paddle/operators/gather_op.h"
#include "scatter.cu.h"
namespace paddle {
namespace operators {
template <typename T>
class GatherOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"This kernel only runs on GPU device.");
auto *x = ctx.Input<Tensor>("X");
auto *index = ctx.Input<Tensor>("Index");
auto *output = ctx.Output<Tensor>("Out");
output->mutable_data<T>(ctx.GetPlace());
GPUGather<T>(ctx.device_context(), *x, *index, output);
}
};
template <typename T>
class GatherGradOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"This kernel only runs on GPU device.");
auto *Index = ctx.Input<Tensor>("Index");
auto *dX = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *dO = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto *x = ctx.Input<Tensor>("X");
dX->mutable_data<T>(ctx.GetPlace());
auto dxt = framework::EigenVector<T>::Flatten(*dX);
auto place = ctx.GetEigenDevice<platform::GPUPlace>();
dxt.device(place) = dxt.constant(static_cast<T>(0));
GPUScatterAssign<T>(ctx.device_context(), *dO, *Index, dX);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(gather, ops::GatherOpCUDAKernel<float>);
REGISTER_OP_GPU_KERNEL(gather_grad, ops::GatherGradOpCUDAKernel<float>);
......@@ -23,29 +23,40 @@ namespace operators {
using Tensor = framework::Tensor;
template <typename Place, typename T>
template <typename T>
class GatherOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto *X = ctx.Input<Tensor>("X");
auto *Index = ctx.Input<Tensor>("Index");
auto *Y = ctx.Output<Tensor>("Out");
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
"This kernel only runs on CPU.");
auto *x = ctx.Input<Tensor>("X");
auto *index = ctx.Input<Tensor>("Index");
auto *output = ctx.Output<Tensor>("Out");
output->mutable_data<T>(ctx.GetPlace());
Y->mutable_data<T>(ctx.GetPlace());
Gather<T>(ctx.GetPlace(), X, Index, Y);
CPUGather<T>(ctx.device_context(), *x, *index, output);
}
};
template <typename Place, typename T>
template <typename T>
class GatherGradientOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
"This kernel only runs on CPU.");
auto *Index = ctx.Input<Tensor>("Index");
auto *dX = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *dO = ctx.Input<Tensor>(framework::GradVarName("Out"));
dX->mutable_data<T>(ctx.GetPlace());
ScatterUpdate<T>(ctx.GetPlace(), dO, Index, dX);
auto dxt = framework::EigenVector<T>::Flatten(*dX);
auto place = ctx.GetEigenDevice<platform::CPUPlace>();
dxt.device(place) = dxt.constant(static_cast<T>(0));
ScatterAssign<T>(ctx.device_context(), *dO, *Index, dX);
}
};
......
......@@ -41,7 +41,9 @@ TEST(Gather, GatherData) {
int* p_output = output->mutable_data<int>(make_ddim({2, 4}), CPUPlace());
Gather<int>(CPUPlace(), src, index, output);
auto* cpu_place = new paddle::platform::CPUPlace();
paddle::platform::CPUDeviceContext ctx(*cpu_place);
CPUGather<int>(ctx, *src, *index, output);
for (int i = 0; i < 4; ++i) EXPECT_EQ(p_output[i], i + 4);
for (int i = 4; i < 8; ++i) EXPECT_EQ(p_output[i], i - 4);
......
......@@ -71,7 +71,7 @@ void testIm2col() {
context =
new paddle::platform::CPUDeviceContext(paddle::platform::CPUPlace());
} else {
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_CUDA
context =
new paddle::platform::CUDADeviceContext(paddle::platform::GPUPlace());
#else
......@@ -116,7 +116,7 @@ void testIm2col() {
TEST(math, im2col) {
testIm2col<paddle::platform::CPUPlace>();
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_CUDA
testIm2col<paddle::platform::GPUPlace>();
#endif
}
#include "paddle/operators/math/math_function.h"
#include "gtest/gtest.h"
#ifndef PADDLE_ONLY_CPU
#ifdef PADDLE_WITH_CUDA
TEST(math_function, notrans_mul_trans) {
paddle::framework::Tensor input1;
paddle::framework::Tensor input1_gpu;
......
......@@ -30,36 +30,39 @@ using LoDTensor = framework::LoDTensor;
void RecurrentAlgorithm::Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const {
auto step_scopes = GetStepScopes(scope);
rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_,
false /*infer_shape_mode*/);
InitMemories(step_scopes[0], false /*infer_shape_mode*/);
auto* input0 = scope.FindVar(arg_->inlinks[0]);
PADDLE_ENFORCE_NOT_NULL(input0);
size_t seq_len = input0->GetMutable<LoDTensor>()->dims()[0];
PADDLE_ENFORCE_GT(seq_len, 0);
for (size_t step_id = 0; step_id < seq_len_; step_id++) {
// create output alias variables
CreateScopes(scope, seq_len);
auto& step_scopes = GetStepScopes(scope);
rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len);
InitMemories(step_scopes[0]);
for (size_t step_id = 0; step_id < seq_len; step_id++) {
if (step_id > 0) {
rnn::LinkMemories(step_scopes, arg_->memories, step_id, -1,
false /*infer_shape_mode*/);
rnn::LinkMemories(step_scopes, arg_->memories, step_id, -1);
}
(*stepnet_)->Run(*step_scopes[step_id], dev_ctx);
}
rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_,
false /*infer_shape_mode*/);
rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len);
}
void RecurrentAlgorithm::CreateScopes(const Scope& scope) const {
void RecurrentAlgorithm::CreateScopes(const Scope& scope,
size_t seq_len) const {
// TODO(superjom) Only two scopes are needed for inference, this case will be
// supported later.
auto step_scopes_var = scope.FindVar(arg_->step_scopes);
auto* step_scopes_var = scope.FindVar(arg_->step_scopes);
PADDLE_ENFORCE(step_scopes_var != nullptr, "");
auto step_scopes = step_scopes_var->GetMutable<std::vector<Scope*>>();
auto* step_scopes = step_scopes_var->GetMutable<std::vector<Scope*>>();
// Now all variables in scope must be created outside of op.
PADDLE_ENFORCE_NOT_NULL(stepnet_);
PADDLE_ENFORCE(!(*stepnet_)->Outputs().empty(), "stepnet_ op has no outputs");
if (seq_len_ > step_scopes->size()) {
for (size_t i = step_scopes->size(); i < seq_len_; ++i) {
if (seq_len > step_scopes->size()) {
for (size_t i = step_scopes->size(); i < seq_len; ++i) {
auto& step_scope = scope.NewScope();
// create step net's temp inputs
......@@ -82,8 +85,7 @@ void RecurrentAlgorithm::CreateScopes(const Scope& scope) const {
}
}
void RecurrentAlgorithm::InitMemories(Scope* step_scope,
bool infer_shape_mode) const {
void RecurrentAlgorithm::InitMemories(Scope* step_scope) const {
for (auto& attr : arg_->memories) {
auto* pre_mem = step_scope->NewVar(attr.pre_var)->GetMutable<LoDTensor>();
PADDLE_ENFORCE(step_scope->FindVar(attr.boot_var) != nullptr,
......@@ -91,12 +93,9 @@ void RecurrentAlgorithm::InitMemories(Scope* step_scope,
attr.boot_var);
auto* boot_mem =
step_scope->FindVar(attr.boot_var)->GetMutable<LoDTensor>();
if (infer_shape_mode) {
pre_mem->Resize(boot_mem->dims());
PADDLE_ENFORCE_EQ(pre_mem->dims().size(), 2);
} else {
pre_mem->ShareDataWith<float>(*boot_mem);
}
pre_mem->Resize(boot_mem->dims());
PADDLE_ENFORCE_EQ(pre_mem->dims().size(), 2);
pre_mem->ShareDataWith<float>(*boot_mem);
}
}
......@@ -146,23 +145,23 @@ class RecurrentAlgorithmProtoAndCheckerMaker
void RecurrentGradientAlgorithm::Run(
const Scope& scope, const platform::DeviceContext& dev_ctx) const {
auto step_scopes = GetStepScopes(scope);
rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_,
false /*infer_shape_mode*/);
for (int step_id = seq_len_ - 1; step_id >= 0; --step_id) {
if (static_cast<size_t>(step_id) != seq_len_ - 1) {
rnn::LinkMemories(step_scopes, arg_->memories, step_id, 1,
false /*infer_shape_mode*/);
auto* input0 = scope.FindVar(arg_->inlinks[0]);
PADDLE_ENFORCE_NOT_NULL(input0);
size_t seq_len = input0->GetMutable<LoDTensor>()->dims()[0];
auto& step_scopes = GetStepScopes(scope);
rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len);
for (int step_id = seq_len - 1; step_id >= 0; --step_id) {
if (step_id != seq_len - 1) {
rnn::LinkMemories(step_scopes, arg_->memories, step_id, 1);
}
(*stepnet_)->Run(*step_scopes[step_id], dev_ctx);
}
LinkBootMemoryGradients(step_scopes[0], false);
rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_,
false /*infer_shape_mode*/);
rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len);
LinkBootMemoryGradients(step_scopes[0]);
}
void RecurrentGradientAlgorithm::LinkBootMemoryGradients(
Scope* step_scope, bool infer_shape_mode) const {
Scope* step_scope) const {
for (auto& attr : arg_->memories) {
PADDLE_ENFORCE(step_scope->FindVar(attr.var) != nullptr,
"memory variable [%s] does not exists", attr.var);
......@@ -171,11 +170,8 @@ void RecurrentGradientAlgorithm::LinkBootMemoryGradients(
auto* mem_grad = step_scope->NewVar(attr.var)->GetMutable<LoDTensor>();
auto* boot_mem_grad =
step_scope->NewVar(attr.boot_var)->GetMutable<LoDTensor>();
if (infer_shape_mode) {
boot_mem_grad->Resize(mem_grad->dims());
} else {
boot_mem_grad->ShareDataWith<float>(*mem_grad);
}
boot_mem_grad->Resize(mem_grad->dims());
boot_mem_grad->ShareDataWith<float>(*mem_grad);
}
}
......
......@@ -48,7 +48,7 @@ class RecurrentAlgorithm {
* NOTE the scopes are reused in both the forward and backward, so just
* create once and expand its size if more steps need.
*/
void CreateScopes(const framework::Scope& scope) const;
void CreateScopes(const framework::Scope& scope, size_t seq_len) const;
const std::vector<framework::Scope*>& GetStepScopes(
const framework::Scope& scope) const {
......@@ -56,12 +56,11 @@ class RecurrentAlgorithm {
->GetMutable<std::vector<framework::Scope*>>();
}
void InitMemories(framework::Scope* step_scopes, bool infer_shape_mode) const;
void InitMemories(framework::Scope* step_scopes) const;
private:
std::unique_ptr<framework::OperatorBase>* stepnet_;
rnn::Argument* arg_;
mutable size_t seq_len_;
};
class RecurrentGradientAlgorithm {
......@@ -86,8 +85,7 @@ class RecurrentGradientAlgorithm {
void Run(const framework::Scope& scope,
const platform::DeviceContext& dev_ctx) const;
void LinkBootMemoryGradients(framework::Scope* step_scopes,
bool infer_shape_mode) const;
void LinkBootMemoryGradients(framework::Scope* step_scopes) const;
protected:
inline const std::vector<framework::Scope*>& GetStepScopes(
......@@ -98,7 +96,6 @@ class RecurrentGradientAlgorithm {
private:
rnn::Argument* arg_;
mutable size_t seq_len_;
std::unique_ptr<framework::OperatorBase>* stepnet_;
};
......@@ -123,6 +120,7 @@ class RecurrentOp : public framework::OperatorBase {
void set_stepnet(std::unique_ptr<OperatorBase> net) {
stepnet_ = std::move(net);
}
const OperatorBase& stepnet() const { return *stepnet_; }
static const rnn::ArgumentName kArgName;
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册