提交 58369d5c 编写于 作者: L Luo Tao

Merge branch 'develop' into op_refine

# Design Doc: Functions, Operators, and Layers
In a DL system, we can compose one or more fine grained operators into a coarse grained one. For example, the FC layer can be composed of a multiplication operator and an add operator.
Historically, some fine grained operations are known as operators, and some coarse level ones are known as layers. But we need a well-defined separation.
In general, operators are those very fine grained operations, e.g., mul and add. In the implementation, we can write them as C++ functions:
```c++
template <typename T> T add(T x, T y) { return x + y; }
template <typename T> T mul(T x, T y) { return x * y; }
```
Then we can wrap them into operators which are C++ classes and can be created from Python bindings by name. A C macro can do this. For example, the following macro invocation
```c++
#define MAKE_FUNCTION_OPERATOR(mul);
```
generates
```c++
template <typename T> class mulOp : public OperatorBase {...};
REGISTER_OP(mulOp<float32>, "mul");
```
so that in Python we can create operator mul by:
```python
X1 = Var()
X2 = Var()
Y = Var()
paddle.cpp.create_operator("mul", input=[X1, X2], output=Y)
```
Also, at the same time, we can compose a coarse level C++ operator class by composing functions `mul` and `add`:
```c++
template <typename T>
class FCOp : public OperatorBase {
public:
void Run(...) {
add(mul(Input<T>("X"), Input<T>("W")), Input<T>("b");
}
};
REGISTER_OP(FCOp, "fc");
```
We need to support such composition in Python as well. To do so, we need a higher level Python wrapping of operator creation than `paddle.cpp.create_operator`. This higher level operator API should be compatible with the layer API.
Let's explain using an example. Suppose that we are going to compose the FC using mul and add in Python, we'd like to have Python functions `mul` and `add` defined in module `operator`:
```python
def operator.mul(X1, X2):
O = Var()
paddle.cpp.create_operator("mul", input={X1, Y1], output=O)
return O
def operator.add(X1, X2):
O = Var()
paddle.cpp.create_operator("add", input={X1, X2], output=O)
return O
```
Above code snippets are automatically generated. Given them, users can define
```python
def layer.fc(X):
W = Var()
b = Var()
return operator.add(operator.mul(X, W), b)
```
If we don't have `operator.mul` and `operator.add`, the definiton of `layer.fc` would be complicated:
```python
def layer.fc(X):
W = Var()
b = Var()
O1 = Var()
paddle.cpp.create_operator("mul", input=[X, W], output=O1)
O2 = Var()
paddle.cpp.create_operator("add", input=[O1, b], output=O2)
return O2
```
We'd like to have Python bindings to operators in package `paddle.operator`, and Python compositions of operators in package `paddle.layer`. So we have the following concepts in above illustrative example:
```
| C++ functions/functors | mul | add | | |
| C++ operator class | mulOp | addOp | FCOp | |
| Python binding | operator.mul | operator.add | operator.fc | |
| Python function | | | | layer.fc |
```
This is how we differentiate layer and operators in PaddlePaddle:
- those defined in C++ and have a lightweighted Python wrapper in module `operators` are operators; whereas
- those who don't have C++ implementations but a Python implementation that compose C++ operators are known as layers.
IfOp should have only one branch. An IfOp operator takes a `cond` variable whose value must be a vector of N boolean elements. Its return value has M (M<=N) instances, each corresponds to a true element in `cond`.
```python
import paddle as pd
x = var()
y = var()
cond = var()
b = pd.create_ifop(inputs=[x], output_num=1)
with b.true_block():
x = b.inputs(0)
z = operator.add(x, y)
b.set_output(0, operator.softmax(z))
out = b(cond)
```
If we want the output still has N instances, we can use IfElseOp with a default value, whose minibatch size must be N:
```python
import paddle as pd
x = var()
y = var()
cond = var()
default_value = var()
b = pd.create_ifelseop(inputs=[x], output_num=1)
with b.true_block():
x = b.inputs(0)
z = operator.add(x, y)
b.set_output(0, operator.softmax(z))
with b.false_block():
x = b.inputs(0)
z = layer.fc(x)
b.set_output(0, operator.softmax(z))
out = b(cond)
```
If only true_block is set in an IfElseOp, we can have a default value for false as:
```python
import paddle as pd
x = var()
y = var()
cond = var()
default_value = var()
b = pd.create_ifelseop(inputs=[x], output_num=1, default_value)
with b.true_block():
x = b.inputs(0)
z = operator.add(x, y)
b.set_output(0, operator.softmax(z))
out = b(cond)
```
where default_value is a list of vars for `cond` == False.
......@@ -178,13 +178,13 @@ class MulKernel : public framework::OpKernel {
```c++
namespace ops = paddle::operators;
REGISTER_OP(mul, ops::MulOp, ops::MulOpMaker, mul_grad, ops::MulOpGrad);
REGISTER_OP(mul, ops::MulOp, ops::MulOpMaker, ops::MulOpGrad);
REGISTER_OP_CPU_KERNEL(mul, ops::MulKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(mul_grad,
ops::MulGradKernel<paddle::platform::CPUPlace, float>);
```
- `REGISTER_OP` : 注册`ops::MulOp`类,类型名为`mul`,该类的`ProtoMaker``ops::MulOpMaker`注册`ops::MulOpGrad`,类型名为`mul_grad`
- `REGISTER_OP` : 注册`ops::MulOp`类,类型名为`mul`,该类的`ProtoMaker``ops::MulOpMaker`并且注册`ops::MulOpGrad`为其反向Op。
- `REGISTER_OP_WITHOUT_GRADIENT` : 用于注册没有反向的Op。
- `REGISTER_OP_CPU_KERNEL` :注册`ops::MulKernel`类,并特化模板参数为`paddle::platform::CPUPlace``float`类型,同理,注册`ops::MulKernel`类。
......
......@@ -18,7 +18,7 @@ A backward network is built up with several backward operators. Backward operato
For example, we have got a `mul_op`, and we can register it's information and corresponding backward operator by the following macro:
```cpp
REGISTER_OP(mul, MulOp, MulOpMaker, mul_grad, MulOpGrad);
REGISTER_OP(mul, MulOp, MulOpMaker, MulOpGrad);
```
`mul` is the operator's type. `MulOp` and `MulOpMaker` are the operator class and the operator maker class respectively.
......
......@@ -127,8 +127,8 @@ class FillZeroOpMaker : public OpProtoAndCheckerMaker {
public:
FillZeroOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("x", "x");
AddOutput("out", "out");
AddInput("Src", "x");
AddOutput("Dst", "out");
AddComment("");
}
};
......@@ -138,7 +138,7 @@ class AddOpMaker : public OpProtoAndCheckerMaker {
AddOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "x").AsDuplicable();
AddOutput("Y", "y");
AddOutput("Out", "out");
AddComment("");
}
};
......@@ -148,16 +148,14 @@ class AddOpMaker : public OpProtoAndCheckerMaker {
namespace f = paddle::framework;
namespace ops = paddle::operators;
using EnforceNotMet = paddle::platform::EnforceNotMet;
REGISTER_OP(rowwise_add, f::NOP, f::RowWiseAddOpMaker, rowwise_add_grad,
f::NOP);
REGISTER_OP(mul, f::NOP, f::MulOpMaker, mul_grad, f::NOP);
REGISTER_OP(sigmoid, f::NOP, f::SigmoidOpMaker, sigmoid_grad, f::NOP);
REGISTER_OP(rowwise_add, f::NOP, f::RowWiseAddOpMaker, f::NOP);
REGISTER_OP(mul, f::NOP, f::MulOpMaker, f::NOP);
REGISTER_OP(sigmoid, f::NOP, f::SigmoidOpMaker, f::NOP);
REGISTER_OP_WITHOUT_GRADIENT(nograd, f::NOP, f::NoGradOpMaker);
REGISTER_OP_WITHOUT_GRADIENT(fill_zeros_like, f::NOP, f::FillZeroOpMaker);
REGISTER_OP(add, f::NOP, f::AddOpMaker, add_grad, f::NOP);
REGISTER_OP(add, f::NOP, f::AddOpMaker, f::NOP);
REGISTER_OP_WITHOUT_GRADIENT(fc, f::FcOp, f::FcOpMaker);
REGISTER_OP(many_output_op, f::NOP, f::ManyOutputOpMaker, many_output_op_grad,
f::NOP);
REGISTER_OP(many_output_op, f::NOP, f::ManyOutputOpMaker, f::NOP);
TEST(Backward, simple_op_grad) {
auto fwd = f::OpRegistry::CreateOp(
......
......@@ -54,8 +54,8 @@ TEST(GradOpBuilder, AddTwo) {
EXPECT_EQ(grad_add_op->Output(f::GradVarName("Y")), f::GradVarName("y"));
}
REGISTER_OP(mult_io, f::NOP, f::MutiInOutOpMaker, mult_io_grad, f::NOP);
REGISTER_OP(io_ignored, f::NOP, f::IOIgnoredOpMaker, io_ignored_grad, f::NOP);
REGISTER_OP(mult_io, f::NOP, f::MutiInOutOpMaker, f::NOP);
REGISTER_OP(io_ignored, f::NOP, f::IOIgnoredOpMaker, f::NOP);
TEST(GradOpBuilder, MutiInOut) {
std::shared_ptr<f::OperatorBase> test_op(f::OpRegistry::CreateOp(
......
......@@ -19,25 +19,24 @@
namespace paddle {
namespace framework {
LODTensor::LOD LODTensor::LOD::SliceLevels(size_t level_begin,
size_t level_end) const {
LOD SliceLevels(const LOD& in, size_t level_begin, size_t level_end) {
LOD new_lod;
new_lod.reserve(level_end - level_begin);
for (size_t i = level_begin; i < level_end; i++) {
new_lod.emplace_back(at(i));
new_lod.emplace_back(in.at(i));
}
return new_lod;
}
LODTensor::LOD LODTensor::LOD::SliceInLevel(size_t level, size_t elem_begin,
size_t elem_end) const {
LOD SliceInLevel(const LOD& in, size_t level, size_t elem_begin,
size_t elem_end) {
// slice the lod.
LOD new_lod;
new_lod.reserve(size() - level);
auto start = this->at(level)[elem_begin];
auto end = this->at(level)[elem_end];
new_lod.reserve(in.size() - level);
auto start = in.at(level)[elem_begin];
auto end = in.at(level)[elem_end];
for (auto it = this->begin() + level; it != this->end(); it++) {
for (auto it = in.begin() + level; it != in.end(); it++) {
auto it_begin = std::find(it->begin(), it->end(), start);
auto it_end = std::find(it_begin, it->end(), end);
PADDLE_ENFORCE(it_begin != it->end(), "error in parsing lod info");
......@@ -49,11 +48,11 @@ LODTensor::LOD LODTensor::LOD::SliceInLevel(size_t level, size_t elem_begin,
[start](int v) { return v - start; });
PADDLE_ENFORCE_EQ(new_lod.back().front(), 0, "error in slice LOD");
}
PADDLE_ENFORCE_LE(new_lod.size(), this->size());
PADDLE_ENFORCE_LE(new_lod.size(), in.size());
return new_lod;
}
bool operator==(const LODTensor::LOD& a, const LODTensor::LOD& b) {
bool operator==(const LOD& a, const LOD& b) {
if (a.size() != b.size()) {
return false;
}
......@@ -70,9 +69,27 @@ bool operator==(const LODTensor::LOD& a, const LODTensor::LOD& b) {
}
}
}
return true;
}
void LODTensor::SliceLevels(size_t level_begin, size_t level_end) {
auto new_lod = framework::SliceLevels(lod_, level_begin, level_end);
lod_ = new_lod;
}
void LODTensor::SliceInLevel(size_t level, size_t elem_begin, size_t elem_end) {
PADDLE_ENFORCE(level < NumLevels(), "level [%d] out of range [%d]", level,
NumLevels());
PADDLE_ENFORCE(elem_begin < NumElements(level),
"element begin [%d] out of range [%d]", elem_begin,
NumElements(level));
PADDLE_ENFORCE(elem_end < NumElements(level) + 1,
"element end [%d] out of range [%d]", elem_end,
NumElements(level));
auto new_lod = framework::SliceInLevel(lod_, level, elem_begin, elem_end);
lod_ = new_lod;
}
} // namespace framework
} // namespace paddle
......@@ -15,7 +15,7 @@
#pragma once
#include <memory>
#if !defined(PADDLE_ONLY_CPU)
#ifndef PADDLE_ONLY_CPU
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
#endif
......@@ -27,33 +27,39 @@
namespace paddle {
namespace framework {
#ifdef PADDLE_ONLY_CPU
template <typename T>
using Vector = std::vector<T>;
#else
template <typename T>
using Vector = thrust::host_vector<T>;
#endif
using LOD = std::vector<Vector<size_t>>;
LOD SliceLevels(const LOD& in, size_t level_begin, size_t level_end);
LOD SliceInLevel(const LOD& in, size_t level, size_t elem_begin,
size_t elem_end);
bool operator==(const LOD& a, const LOD& b);
/*
* LODTensor (Level of details Tensor)
* see https://en.wikipedia.org/wiki/Level_of_details for reference.
*/
class LODTensor : public Tensor {
class LODTensor {
public:
// Level save offsets of each unit.
#ifdef PADDLE_ONLY_CPU
template <typename T>
using Vector = std::vector<T>;
#else
template <typename T>
using Vector = thrust::host_vector<T>;
#endif
// LoD stores offsets of each level of units, the largest units level first,
// then the smaller units level. Each Level stores the offsets of units in
// Tesor.
class LOD : public std::vector<Vector<size_t>> {
public:
LOD SliceLevels(size_t level_begin, size_t level_end) const;
LOD SliceInLevel(size_t level, size_t elem_begin, size_t elem_end) const;
};
LODTensor() {}
explicit LODTensor(const LOD &lod) : lod_(lod) {}
LODTensor(const LOD& lod, Tensor* t) : lod_(lod), tensor_(t) {}
void set_lod(const LOD& lod) { lod_ = lod; }
virtual Tensor *Clone() const { return new LODTensor(lod_); }
void set_tensor(Tensor* tensor) { tensor_ = tensor; }
Tensor& tensor() { return *tensor_; }
LOD lod() { return lod_; }
/*
* Get a element from LOD.
......@@ -79,71 +85,23 @@ class LODTensor : public Tensor {
PADDLE_ENFORCE(level < NumLevels(), "level [%d] out of range [%d]", level,
NumLevels());
// the last offset is the end of last element
return lod_[level].size() - 1;
return (lod_)[level].size() - 1;
}
/*
* Slice of levels[level_begin:level_end], with tensor shared.
* Slice of levels[level_begin:level_end]
*/
template <typename T>
LODTensor SliceLevels(size_t level_begin, size_t level_end) const;
void SliceLevels(size_t level_begin, size_t level_end);
/*
* Slice of elements of a level, [elem_begin: elem_end], with tensor shared.
* Slice of elements of a level, [elem_begin: elem_end]
* @note: low performance in slice lod_.
*/
template <typename T>
LODTensor SliceInLevel(size_t level, size_t elem_begin,
size_t elem_end) const;
/*
* Copy other's lod_'s content, free to mutate.
*/
void CopyLOD(const LODTensor &other) { lod_ = other.lod_; }
/*
* Determine whether LODTensor has a valid LOD info.
*/
const LOD &lod() const { return lod_; }
LOD *mutable_lod() { return &lod_; }
virtual ~LODTensor() {}
void SliceInLevel(size_t level, size_t elem_begin, size_t elem_end);
private:
LOD lod_;
Tensor* tensor_; // not owned
};
bool operator==(const LODTensor::LOD &a, const LODTensor::LOD &b);
template <typename T>
LODTensor LODTensor::SliceLevels(size_t level_begin, size_t level_end) const {
auto new_lod = lod_.SliceLevels(level_begin, level_end);
// slice levels just need to update LOD info, each level will contains the
// whole tensor_, so no need to modify tensor_.
LODTensor new_tensor(new_lod);
new_tensor.ShareDataWith<T>(*this);
return new_tensor;
}
template <typename T>
LODTensor LODTensor::SliceInLevel(size_t level, size_t elem_begin,
size_t elem_end) const {
PADDLE_ENFORCE(level < NumLevels(), "level [%d] out of range [%d]", level,
NumLevels());
PADDLE_ENFORCE(elem_begin < NumElements(level),
"element begin [%d] out of range [%d]", elem_begin,
NumElements(level));
PADDLE_ENFORCE(elem_end < NumElements(level) + 1,
"element end [%d] out of range [%d]", elem_end,
NumElements(level));
auto new_lod = lod_.SliceInLevel(level, elem_begin, elem_end);
// slice elements just need to update LOD info, because offsets are not
// changed, so the original tensor_ can be reused.
LODTensor new_tensor(new_lod);
new_tensor.ShareDataWith<T>(*this);
return new_tensor;
}
} // namespace framework
} // namespace paddle
# Design Doc: LoD (Level-of-Detail) Tensor
PaddlePaddle's RNN doesn't require that all instances have the same length. To do so, we introduce an extension to Tensor, namely, LoD Tensor.
## Challenge of Variable-length Inputs
People usually represent a mini-batch by a Tensor. For example, a mini-batch of 32 images, each of size 32x32, is a 10x32x32 Tensor. So a transformation, T, of all images can be a matrix multiplication of the 32x32xO-dimensional tensor T and the 10x32x32 Tensor.
Another example is that each mini-batch contains 32 sentences, where each word is a D-dimensional one-hot vector. If all sentences have the same length L, we can represent this mini-batch by a 32xLxD tensor. However, in most cases, sentences have variable lengths, and we will need an index data structure to record these variable lengths.
## LoD as a Solution
### Mini-Batch of variable-length sentenses
Let's imagine a mini-batch of 3 variable lengths sentences, containing 3, 1, and 2 words respectively. We can represent it by a (3+1+2)xD tensor plus some index information:
```
3
3 1 2
||| | ||
```
Each `|` represents a D-dimensional word vectors. The number 3 on top indicate 3 sentences, and numbers 3, 1, and 2 on the second level represent the number of words in each sentence.
### Mini-Batch of variable-length videos
This approach generalizes to the case where elements are not words, but higher dimensional objects, like images. Suppose that a mini-batch contains videos of the same frame size 640x480. If a mini-batch contains 3 videos of 3, 1, and 2 frames respectively. The underlying tensor is of size (3+1+2)x640x480. The index information illustrates as:
```
3
3 1 2
口口口 口 口口
```
where each `口` represents an image.
### Mini-Batch of fixed-size images
Let's get back to a typical example, image classification, where each mini-batch has M fixed-sized images. The LoD Tensor representation is
```
M
1 1 1 1 1
口口口口 ... 口
```
The many 1's on the second level seem duplicated. For this particular case of 2 levels and the second level always have length 1, we can ignore the LoD index.
### Design and summarization
In summary, as long as that the essential elements (words or images) have the same size, we can represent mini-batches by a LoD Tensor:
- The underlying tensor has size LxD1xD2x..., where D1xD2... is the size of the essential elements, and
- the first dimension size L has an additon property -- a LoD index as a nested vector:
```c++
typedef std::vector<std::vector> > LoD;
```
- The LoD index can is not necessary when there are only two levels and all elements of the second level have length 1.
## Slicing of LoD Tensor
Consider that we have a network with three levels of RNN: the top level one handles articles, the second level one handles sentences, and the basic level one handles words. This network requires that mini-batches represented by 4 level LoD Tensor, for example,
```
3
3 1 2
3 2 4 1 2 3
||| || |||| | || |||
```
To allow each level of RNN to handle its input, we define **the slicing of a LoD Tensor is defined as getting the j-th sequence on level i, or the <i,j>-slice**
For example, the <2,1>-slice of above slice is
```
2
||
```
and the <1,2>-slice of above example is
```
2
2 3
|| |||
```
Let's go on slicing this slice. Its <1,1>-slice is
```
3
|||
```
### The General Slicing Algorithm
The algorithm, with over-simplified data structure, is defined as
```c++
typedef vector<vector<int> > LoD;
struct LoDTensor {
LoD lod_;
float* tensor_;
};
LoDTensor Slice(const LoDTensor& lodt, int level, int sequence) {
}
```
### Slicing the Top Level
Please be aware that an RNN operator only slices the top level of a LoD Tensor to get the step inputs.
```c++
LoDTensor Slice(const LoDTensor& lodt, int sequence) {
}
```
......@@ -24,13 +24,12 @@ namespace framework {
class LODTensorTester : public ::testing::Test {
public:
virtual void SetUp() override {
lod_tensor.reset(new LODTensor);
// tensor's batch_size: 30
// 3 levels
// 0 10 20
// 0 5 10 15 20
// 0 2 5 7 10 12 15 20
LODTensor::LOD lod;
LOD lod;
lod.push_back(std::vector<size_t>{0, 10, 20});
lod.push_back(std::vector<size_t>{0, 5, 10, 15, 20});
lod.push_back(std::vector<size_t>{0, 2, 5, 7, 10, 12, 15, 17, 20});
......@@ -41,75 +40,65 @@ class LODTensorTester : public ::testing::Test {
// malloc memory
tensor.mutable_data<float>(place);
lod_tensor.reset(new LODTensor(lod));
lod_tensor->Resize({20 /*batch size*/, 128 /*dim*/});
lod_tensor->ShareDataWith<float>(tensor);
// lod_tensor->ShareDataWith<Tensor>(tensor);
lod_tensor.set_lod(lod);
lod_tensor.set_tensor(&tensor);
}
protected:
std::unique_ptr<LODTensor> lod_tensor;
platform::CPUPlace place;
Tensor tensor;
LODTensor lod_tensor;
};
TEST_F(LODTensorTester, NumLevels) { ASSERT_EQ(lod_tensor->NumLevels(), 3UL); }
TEST_F(LODTensorTester, NumLevels) { ASSERT_EQ(lod_tensor.NumLevels(), 3UL); }
TEST_F(LODTensorTester, NumElements) {
ASSERT_EQ(lod_tensor->NumElements(0), 2UL);
ASSERT_EQ(lod_tensor->NumElements(1), 4UL);
ASSERT_EQ(lod_tensor->NumElements(2), 8UL);
ASSERT_EQ(lod_tensor.NumElements(0), 2UL);
ASSERT_EQ(lod_tensor.NumElements(1), 4UL);
ASSERT_EQ(lod_tensor.NumElements(2), 8UL);
}
TEST_F(LODTensorTester, SliceLevels) {
// slice 1 level
for (size_t level = 0; level < 3UL; ++level) {
auto new_lod_tensor = lod_tensor->SliceLevels<float>(level, level + 1);
LODTensor new_lod_tensor = lod_tensor;
new_lod_tensor.SliceLevels(level, level + 1);
ASSERT_EQ(new_lod_tensor.NumLevels(), 1UL);
ASSERT_EQ(new_lod_tensor.NumElements(0UL), lod_tensor->NumElements(level));
// ASSERT_EQ(new_lod_tensor, *lod_tensor);
ASSERT_EQ(new_lod_tensor.NumElements(0), lod_tensor.NumElements(level));
ASSERT_EQ(new_lod_tensor.tensor().data<float>(),
lod_tensor.tensor().data<float>());
}
// slice 2 level
for (size_t level = 0; level < 2UL; ++level) {
auto new_lod_tensor = lod_tensor->SliceLevels<float>(level, level + 2);
LODTensor new_lod_tensor = lod_tensor;
new_lod_tensor.SliceLevels(level, level + 2);
ASSERT_EQ(new_lod_tensor.NumLevels(), 2UL);
ASSERT_EQ(new_lod_tensor.NumElements(0), lod_tensor->NumElements(level));
ASSERT_EQ(new_lod_tensor.NumElements(1),
lod_tensor->NumElements(level + 1));
ASSERT_EQ(new_lod_tensor.data<float>(), lod_tensor->data<float>());
ASSERT_EQ(new_lod_tensor.NumElements(0), lod_tensor.NumElements(level));
ASSERT_EQ(new_lod_tensor.NumElements(1), lod_tensor.NumElements(level + 1));
ASSERT_EQ(new_lod_tensor.tensor().data<float>(),
lod_tensor.tensor().data<float>());
}
}
TEST_F(LODTensorTester, SliceInLevel) {
size_t level = 0;
auto new_lod_tensor = lod_tensor->SliceInLevel<float>(level, 0, 2);
LODTensor new_lod_tensor = lod_tensor;
new_lod_tensor.SliceInLevel(level, 0, 2);
EXPECT_EQ(new_lod_tensor.NumLevels(), 3UL);
EXPECT_EQ(new_lod_tensor.NumElements(0), 2UL);
EXPECT_EQ(new_lod_tensor.NumElements(1), 4UL);
EXPECT_EQ(new_lod_tensor.NumElements(2), 8UL);
ASSERT_EQ(new_lod_tensor.data<float>(), lod_tensor->data<float>());
ASSERT_EQ(new_lod_tensor.tensor().data<float>(),
lod_tensor.tensor().data<float>());
level = 1;
new_lod_tensor = lod_tensor->SliceInLevel<float>(level, 0, 2);
new_lod_tensor = lod_tensor;
new_lod_tensor.SliceInLevel(level, 0, 2);
ASSERT_EQ(new_lod_tensor.NumLevels(), 2UL);
ASSERT_EQ(new_lod_tensor.NumElements(0), 2UL);
ASSERT_EQ(new_lod_tensor.NumElements(1), 4UL);
ASSERT_EQ(new_lod_tensor.data<float>(), lod_tensor->data<float>());
}
TEST_F(LODTensorTester, ShareLOD) {
LODTensor new_lod_tensor;
new_lod_tensor.CopyLOD(*lod_tensor);
ASSERT_EQ(new_lod_tensor.lod(), lod_tensor->lod());
}
TEST_F(LODTensorTester, CopyLOD) {
LODTensor new_lod_tensor;
new_lod_tensor.CopyLOD(*lod_tensor);
bool equals = std::equal(lod_tensor->lod().begin(), lod_tensor->lod().end(),
new_lod_tensor.lod().begin());
ASSERT_TRUE(equals);
ASSERT_EQ(new_lod_tensor.tensor().data<float>(),
lod_tensor.tensor().data<float>());
}
} // namespace framework
......
......@@ -80,9 +80,19 @@ class OpInfoMap {
}
const OpInfo& Get(const std::string& type) const {
auto op_info_ptr = GetNullable(type);
PADDLE_ENFORCE_NOT_NULL(op_info_ptr, "Operator %s has not been registered",
type);
return *op_info_ptr;
}
const OpInfo* GetNullable(const std::string& type) const {
auto it = map_.find(type);
PADDLE_ENFORCE(it != map_.end(), "Operator %s are not found", type);
return it->second;
if (it == map_.end()) {
return nullptr;
} else {
return &it->second;
}
}
template <typename Callback>
......
......@@ -33,8 +33,7 @@ namespace framework {
class OpRegistry {
public:
template <typename OpType, typename ProtoMakerType, typename GradOpType>
static void RegisterOp(const std::string& op_type,
const std::string& grad_op_type) {
static void RegisterOp(const std::string& op_type) {
PADDLE_ENFORCE(!OpInfoMap::Instance().Has(op_type),
"'%s' is registered more than once.", op_type);
OpInfo op_info;
......@@ -43,9 +42,9 @@ class OpRegistry {
const VariableNameMap& outputs, const AttributeMap& attrs) {
return new OpType(type, inputs, outputs, attrs);
};
op_info.grad_op_type_ = grad_op_type;
if (std::type_index(typeid(ProtoMakerType)) !=
std::type_index(typeid(NOPMaker))) {
op_info.grad_op_type_ = op_type + "_grad";
op_info.proto_ = new OpProto;
op_info.checker_ = new OpAttrChecker;
auto maker = ProtoMakerType(op_info.proto_, op_info.checker_);
......@@ -55,15 +54,14 @@ class OpRegistry {
op_info.proto_->IsInitialized(),
"Fail to initialize %s's OpProto, because %s is not initialized",
op_type, op_info.proto_->InitializationErrorString());
// register gradient op
RegisterOp<GradOpType, NOPMaker, NOP>(op_info.grad_op_type_);
} else {
op_info.grad_op_type_ = "";
op_info.proto_ = nullptr;
op_info.checker_ = nullptr;
}
OpInfoMap::Instance().Insert(op_type, op_info);
// register gradient op
if (!grad_op_type.empty()) {
RegisterOp<GradOpType, NOPMaker, NOP>(grad_op_type, "");
}
}
static std::unique_ptr<OperatorBase> CreateOp(const std::string& type,
......@@ -92,10 +90,8 @@ class Registrar {
template <typename OpType, typename ProtoMakerType, typename GradOpType>
class OpRegistrar : public Registrar {
public:
explicit OpRegistrar(const char* op_type) { OpRegistrar(op_type, ""); }
OpRegistrar(const char* op_type, const char* grad_op_type) {
OpRegistry::RegisterOp<OpType, ProtoMakerType, GradOpType>(op_type,
grad_op_type);
explicit OpRegistrar(const char* op_type) {
OpRegistry::RegisterOp<OpType, ProtoMakerType, GradOpType>(op_type);
}
};
......@@ -121,8 +117,7 @@ class OpKernelRegistrar : public Registrar {
/**
* Macro to register Operator.
*/
#define REGISTER_OP(op_type, op_class, op_maker_class, grad_op_type, \
grad_op_class) \
#define REGISTER_OP(op_type, op_class, op_maker_class, grad_op_class) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_op__##op_type, "REGISTER_OP must be called in global namespace"); \
class _OpClass_##op_type##_ : public op_class { \
......@@ -137,14 +132,14 @@ class OpKernelRegistrar : public Registrar {
}; \
static ::paddle::framework::OpRegistrar< \
_OpClass_##op_type##_, op_maker_class, _OpGradClass_##op_type##_> \
__op_registrar_##op_type##__(#op_type, #grad_op_type); \
__op_registrar_##op_type##__(#op_type); \
int TouchOpRegistrar_##op_type() { \
__op_registrar_##op_type##__.Touch(); \
return 0; \
}
#define REGISTER_OP_WITHOUT_GRADIENT(op_type, op_class, op_maker_class) \
REGISTER_OP(op_type, op_class, op_maker_class, , ::paddle::framework::NOP)
REGISTER_OP(op_type, op_class, op_maker_class, ::paddle::framework::NOP)
/**
* Macro to register OperatorKernel.
......
......@@ -33,12 +33,12 @@ ExecutionContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const {
}
#endif
const std::string& OperatorBase::Input(const std::string& name) const {
std::string OperatorBase::Input(const std::string& name) const {
auto& ins = Inputs(name);
PADDLE_ENFORCE_EQ(ins.size(), 1UL,
PADDLE_ENFORCE_LE(ins.size(), 1UL,
"Op %s input %s should contain only one variable", type_,
name);
return ins[0];
return ins.empty() ? kEmptyVarName : ins[0];
}
const std::vector<std::string>& OperatorBase::Inputs(
......@@ -49,12 +49,12 @@ const std::vector<std::string>& OperatorBase::Inputs(
return it->second;
}
const std::string& OperatorBase::Output(const std::string& name) const {
std::string OperatorBase::Output(const std::string& name) const {
auto& outs = Outputs(name);
PADDLE_ENFORCE_EQ(outs.size(), 1UL,
PADDLE_ENFORCE_LE(outs.size(), 1UL,
"Op %s output %s should contain only one variable", type_,
name);
return outs[0];
return outs.empty() ? kEmptyVarName : outs[0];
}
const std::vector<std::string>& OperatorBase::Outputs(
......@@ -119,16 +119,8 @@ OperatorBase::OperatorBase(const std::string& type,
const VariableNameMap& outputs,
const AttributeMap& attrs)
: type_(type), inputs_(inputs), outputs_(outputs), attrs_(attrs) {
static std::atomic<size_t> gUniqId(0UL);
for (auto& output : outputs_) {
for (auto& output_name : output.second) {
if (output_name == kTempVarName) {
output_name += type_;
output_name += "@";
output_name += std::to_string(gUniqId.fetch_add(1));
}
}
}
GenerateTemporaryNames();
CheckAllInputOutputSet();
}
std::vector<std::string> OperatorBase::OutputVars(bool has_intermediate) const {
......@@ -156,6 +148,35 @@ std::vector<std::string> OperatorBase::OutputVars(bool has_intermediate) const {
return ret_val;
}
void OperatorBase::CheckAllInputOutputSet() const {
auto& info_map = OpInfoMap::Instance();
auto* op_info = info_map.GetNullable(Type());
if (op_info == nullptr || op_info->proto_ == nullptr) return;
for (auto& in : op_info->Proto().inputs()) {
PADDLE_ENFORCE(inputs_.find(in.name()) != inputs_.end(),
"Type %s's input %s is not set", Type(), in.name());
}
for (auto& out : op_info->Proto().outputs()) {
PADDLE_ENFORCE(outputs_.find(out.name()) != outputs_.end(),
"Type %s's output %s is not set", Type(), out.name());
}
}
void OperatorBase::GenerateTemporaryNames() {
static std::atomic<size_t> gUniqId(0UL);
for (auto& output : outputs_) {
for (auto& output_name : output.second) {
if (output_name == kTempVarName) {
output_name += type_;
output_name += "@";
output_name += std::to_string(gUniqId.fetch_add(1));
}
}
}
}
void OpProtoAndCheckerMaker::Validate() {
validated_ = true;
CheckNoDuplicatedInOutAttrs();
......
......@@ -95,12 +95,12 @@ class OperatorBase {
const VariableNameMap& Inputs() const { return inputs_; }
const VariableNameMap& Outputs() const { return outputs_; }
//! Get a input with argument's name described in `op_proto`
const std::string& Input(const std::string& name) const;
std::string Input(const std::string& name) const;
//! Get a input which has multiple variables.
const std::vector<std::string>& Inputs(const std::string& name) const;
//! Get a output with argument's name described in `op_proto`
const std::string& Output(const std::string& name) const;
std::string Output(const std::string& name) const;
//! Get an output which has multiple variables.
//! TODO add a vector_view to prevent memory copy.
const std::vector<std::string>& Outputs(const std::string& name) const;
......@@ -127,6 +127,10 @@ class OperatorBase {
// IG (Inputs Gradients)
VariableNameMap outputs_;
AttributeMap attrs_;
private:
void GenerateTemporaryNames();
void CheckAllInputOutputSet() const;
};
// Macro for define a clone method.
......@@ -229,6 +233,15 @@ class InferShapeContext {
InferShapeContext(const OperatorBase& op, const Scope& scope)
: op_(op), scope_(scope) {}
const OperatorBase& op() const { return op_; }
const Scope& scope() const { return scope_; }
template <typename T>
inline const T& GetAttr(const std::string& name) const {
return op_.GetAttr<T>(name);
}
size_t InputSize(const std::string& name) const {
return op_.Inputs(name).size();
}
......@@ -238,11 +251,13 @@ class InferShapeContext {
}
const Variable* InputVar(const std::string& name) const {
return scope_.FindVar(op_.Input(name));
auto ipt = op_.Input(name);
return ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
}
Variable* OutputVar(const std::string& name) const {
return scope_.FindVar(op_.Output(name));
auto opt = op_.Output(name);
return opt == kEmptyVarName ? nullptr : scope_.FindVar(opt);
}
const std::vector<const Variable*> MultiInputVar(
......@@ -250,9 +265,11 @@ class InferShapeContext {
auto names = op_.Inputs(name);
std::vector<const Variable*> res;
res.reserve(names.size());
std::transform(
names.begin(), names.end(), std::back_inserter(res),
[this](const std::string& name) { return scope_.FindVar(name); });
std::transform(names.begin(), names.end(), std::back_inserter(res),
[this](const std::string& name) {
return name == kEmptyVarName ? nullptr
: scope_.FindVar(name);
});
return res;
}
......@@ -260,24 +277,24 @@ class InferShapeContext {
auto names = op_.Outputs(name);
std::vector<const Variable*> res;
res.reserve(names.size());
std::transform(
names.begin(), names.end(), std::back_inserter(res),
[this](const std::string& name) { return scope_.FindVar(name); });
std::transform(names.begin(), names.end(), std::back_inserter(res),
[this](const std::string& name) {
return name == kEmptyVarName ? nullptr
: scope_.FindVar(name);
});
return res;
}
template <typename T>
const T* Input(const std::string& name) const {
auto* var = InputVar(name);
PADDLE_ENFORCE_NOT_NULL(var, "Input(%s) should not be nullptr", name);
return &var->Get<T>();
return var == nullptr ? nullptr : &var->Get<T>();
}
template <typename T>
T* Output(const std::string& name) const {
auto var = OutputVar(name);
PADDLE_ENFORCE_NOT_NULL(var, "Output(%s) should not be nullptr", name);
return var->GetMutable<T>();
return var == nullptr ? nullptr : var->GetMutable<T>();
}
template <typename T>
......@@ -288,10 +305,7 @@ class InferShapeContext {
std::transform(names.begin(), names.end(), std::back_inserter(res),
[&](const std::string& sub_name) {
auto var = scope_.FindVar(sub_name);
PADDLE_ENFORCE_NOT_NULL(
var, "MultiInput(%s:%s) should not be nullptr", name,
sub_name);
return &var->Get<T>();
return var == nullptr ? nullptr : &var->Get<T>();
});
return res;
}
......@@ -304,14 +318,12 @@ class InferShapeContext {
std::transform(names.begin(), names.end(), std::back_inserter(res),
[&](const std::string& sub_name) {
auto var = scope_.FindVar(sub_name);
PADDLE_ENFORCE_NOT_NULL(
var, "MultiOutput(%s:%s) should not be nullptr.", name,
sub_name);
return var->GetMutable<T>();
return var == nullptr ? nullptr : var->GetMutable<T>();
});
return res;
}
private:
const OperatorBase& op_;
const Scope& scope_;
};
......
......@@ -122,10 +122,10 @@ class CPUKernelTest : public OpKernel {
public:
void Compute(const ExecutionContext& ctx) const {
std::cout << "this is cpu kernel" << std::endl;
std::cout << ctx.op_.DebugString() << std::endl;
std::cout << ctx.op().DebugString() << std::endl;
cpu_kernel_run_num++;
ASSERT_EQ(ctx.op_.Input("x"), "IN1");
ASSERT_EQ(ctx.op_.Output("y"), "OUT1");
ASSERT_EQ(ctx.op().Input("x"), "IN1");
ASSERT_EQ(ctx.op().Output("y"), "OUT1");
}
};
......@@ -148,7 +148,7 @@ class OpKernelTestMultiInputsProtoAndCheckerMaker
class CPUKernalMultiInputsTest : public OpKernel {
public:
void Compute(const ExecutionContext& ctx) const {
auto xs = ctx.op_.Inputs("xs");
auto xs = ctx.op().Inputs("xs");
ASSERT_EQ(xs.size(), 3UL);
ASSERT_EQ(xs[0], "x0");
ASSERT_EQ(xs[1], "x1");
......@@ -172,10 +172,10 @@ class CPUKernalMultiInputsTest : public OpKernel {
auto outTensor0 = ctx.MultiOutput<Tensor>("ys");
ASSERT_EQ(outTensor0.size(), 2U);
auto k = ctx.op_.Input("k");
auto k = ctx.op().Input("k");
ASSERT_EQ(k, "k0");
auto ys = ctx.op_.Outputs("ys");
auto ys = ctx.op().Outputs("ys");
ASSERT_EQ(ys.size(), 2UL);
ASSERT_EQ(ys[0], "y0");
ASSERT_EQ(ys[1], "y1");
......
......@@ -48,7 +48,16 @@ public:
<< inputLayers_.size() << ") at " << getName();
}
s << format.substr(pos);
LOG(INFO) << s.str();
const std::string delimiter("\n");
std::string content = s.str();
std::string::size_type foundPos = 0;
std::string::size_type prevPos = 0;
while ((foundPos = content.find(delimiter, prevPos)) != std::string::npos) {
LOG(INFO) << content.substr(prevPos, foundPos - prevPos);
prevPos = foundPos + delimiter.size();
}
LOG(INFO) << content.substr(prevPos);
}
void backward(const UpdateCallback& callback) override {}
......
......@@ -57,7 +57,7 @@ class AddOpGrad : public framework::OperatorWithKernel {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(add_two, ops::AddOp, ops::AddOpMaker, add_two_grad, ops::AddOpGrad);
REGISTER_OP(add_two, ops::AddOp, ops::AddOpMaker, ops::AddOpGrad);
REGISTER_OP_CPU_KERNEL(add_two,
ops::AddKernel<paddle::platform::CPUPlace, float>);
......@@ -67,8 +67,7 @@ OnehotCrossEntropy Operator.
namespace ops = paddle::operators;
REGISTER_OP(onehot_cross_entropy, ops::OnehotCrossEntropyOp,
ops::OnehotCrossEntropyOpMaker, onehot_cross_entropy_grad,
ops::OnehotCrossEntropyGradientOp);
ops::OnehotCrossEntropyOpMaker, ops::OnehotCrossEntropyGradientOp);
REGISTER_OP_CPU_KERNEL(onehot_cross_entropy,
ops::OnehotCrossEntropyOpKernel<float>);
REGISTER_OP_CPU_KERNEL(onehot_cross_entropy_grad,
......
......@@ -63,8 +63,7 @@ Out = X[Index]
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(gather, ops::GatherOp, ops::GatherOpMaker, gather_grad,
ops::GatherGradOp);
REGISTER_OP(gather, ops::GatherOp, ops::GatherOpMaker, ops::GatherGradOp);
REGISTER_OP_CPU_KERNEL(gather,
ops::GatherOpKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
......
......@@ -19,13 +19,12 @@ template <typename T>
class CPUGaussianRandomKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
float mean = context.op_.GetAttr<float>("mean");
float std = context.op_.GetAttr<float>("std");
float mean = context.GetAttr<float>("mean");
float std = context.GetAttr<float>("std");
auto* tensor = context.Output<framework::Tensor>("Out");
T* data = tensor->mutable_data<T>(context.GetPlace());
unsigned int seed =
static_cast<unsigned int>(context.op_.GetAttr<int>("seed"));
unsigned int seed = static_cast<unsigned int>(context.GetAttr<int>("seed"));
std::minstd_rand engine;
if (seed == 0) {
seed = std::random_device()();
......
......@@ -42,14 +42,13 @@ class GPUGaussianRandomKernel : public framework::OpKernel {
void Compute(const framework::ExecutionContext& context) const override {
auto* tensor = context.Output<framework::Tensor>("Out");
T* data = tensor->mutable_data<T>(context.GetPlace());
unsigned int seed =
static_cast<unsigned int>(context.op_.GetAttr<int>("seed"));
unsigned int seed = static_cast<unsigned int>(context.GetAttr<int>("seed"));
if (seed == 0) {
std::random_device rd;
seed = rd();
}
T mean = static_cast<T>(context.op_.GetAttr<float>("mean"));
T std = static_cast<T>(context.op_.GetAttr<float>("std"));
T mean = static_cast<T>(context.GetAttr<float>("mean"));
T std = static_cast<T>(context.GetAttr<float>("std"));
thrust::counting_iterator<unsigned int> index_sequence_begin(0);
ssize_t N = framework::product(tensor->dims());
thrust::transform(index_sequence_begin, index_sequence_begin + N,
......
......@@ -66,7 +66,7 @@ class LookupTableOpGrad : public framework::OperatorWithKernel {
namespace ops = paddle::operators;
REGISTER_OP(lookup_table, ops::LookupTableOp, ops::LookupTableOpMaker,
lookup_table_grad, ops::LookupTableOpGrad);
ops::LookupTableOpGrad);
REGISTER_OP_CPU_KERNEL(lookup_table, ops::LookupTableKernel<float>);
REGISTER_OP_CPU_KERNEL(lookup_table_grad, ops::LookupTableGradKernel<float>);
......@@ -54,7 +54,7 @@ class MeanGradOp : public framework::OperatorWithKernel {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(mean, ops::MeanOp, ops::MeanOpMaker, mean_grad, ops::MeanGradOp);
REGISTER_OP(mean, ops::MeanOp, ops::MeanOpMaker, ops::MeanGradOp);
REGISTER_OP_CPU_KERNEL(mean,
ops::MeanKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(mean_grad,
......
......@@ -81,7 +81,6 @@ class MinusGradOp : public NetOp {
USE_OP(scale);
USE_OP_ITSELF(identity);
namespace ops = paddle::operators;
REGISTER_OP(minus, ops::MinusOp, ops::MinusOpMaker, minus_grad,
ops::MinusGradOp<float>);
REGISTER_OP(minus, ops::MinusOp, ops::MinusOpMaker, ops::MinusGradOp<float>);
REGISTER_OP_CPU_KERNEL(minus,
ops::MinusKernel<paddle::platform::CPUPlace, float>);
......@@ -29,10 +29,10 @@ class MulOp : public framework::OperatorWithKernel {
auto dim1 = ctx.Input<Tensor>("Y")->dims();
PADDLE_ENFORCE_EQ(dim0.size(), 2,
"input X(%s) should be a tensor with 2 dims, a matrix",
ctx.op_.Input("X"));
ctx.op().Input("X"));
PADDLE_ENFORCE_EQ(dim1.size(), 2,
"input Y(%s) should be a tensor with 2 dims, a matrix",
ctx.op_.Input("Y"));
ctx.op().Input("Y"));
PADDLE_ENFORCE_EQ(
dim0[1], dim1[0],
"First matrix's width must be equal with second matrix's height.");
......@@ -84,7 +84,7 @@ class MulOpGrad : public framework::OperatorWithKernel {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(mul, ops::MulOp, ops::MulOpMaker, mul_grad, ops::MulOpGrad);
REGISTER_OP(mul, ops::MulOp, ops::MulOpMaker, ops::MulOpGrad);
REGISTER_OP_CPU_KERNEL(mul, ops::MulKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(mul_grad,
ops::MulGradKernel<paddle::platform::CPUPlace, float>);
......@@ -74,7 +74,7 @@ class RowwiseAddGradOp : public framework::OperatorWithKernel {
namespace ops = paddle::operators;
REGISTER_OP(rowwise_add, ops::RowwiseAddOp, ops::RowwiseAddOpMaker,
rowwise_add_grad, ops::RowwiseAddGradOp);
ops::RowwiseAddGradOp);
REGISTER_OP_CPU_KERNEL(
rowwise_add, ops::RowwiseAddKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
......
......@@ -97,7 +97,7 @@ class IdentityOp : public NetOp {
namespace ops = paddle::operators;
REGISTER_OP(scale, ops::ScaleOp, ops::ScaleOpMaker<float>, scale_grad,
REGISTER_OP(scale, ops::ScaleOp, ops::ScaleOpMaker<float>,
ops::ScaleGradOp<float>);
REGISTER_OP_CPU_KERNEL(scale,
ops::ScaleKernel<paddle::platform::CPUPlace, float>);
......
......@@ -27,7 +27,7 @@ class ScaleKernel : public framework::OpKernel {
auto* in = context.Input<framework::Tensor>("X");
tensor->mutable_data<T>(in->place());
auto scale = static_cast<T>(context.op_.GetAttr<AttrType>("scale"));
auto scale = static_cast<T>(context.GetAttr<AttrType>("scale"));
auto eigen_out = framework::EigenVector<T>::Flatten(*tensor);
auto eigen_in = framework::EigenVector<T>::Flatten(*in);
......
......@@ -77,8 +77,7 @@ Out[Index] = Ref[Index] + Updates
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(scatter, ops::ScatterOp, ops::ScatterOpMaker, scatter_grad,
ops::ScatterGradOp);
REGISTER_OP(scatter, ops::ScatterOp, ops::ScatterOpMaker, ops::ScatterGradOp);
REGISTER_OP_CPU_KERNEL(scatter,
ops::ScatterOpKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
......
......@@ -31,7 +31,7 @@ class SGDOpKernel : public framework::OpKernel {
auto param = ctx.Input<Tensor>("param");
auto grad = ctx.Input<Tensor>("grad");
auto param_out = ctx.Output<Tensor>("param_out");
float lr = ctx.op_.GetAttr<float>("learning_rate");
float lr = ctx.GetAttr<float>("learning_rate");
param_out->mutable_data<T>(ctx.GetPlace());
......
......@@ -53,8 +53,7 @@ class SigmoidOpGrad : public framework::OperatorWithKernel {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(sigmoid, ops::SigmoidOp, ops::SigmoidOpMaker, sigmoid_grad,
ops::SigmoidOpGrad);
REGISTER_OP(sigmoid, ops::SigmoidOp, ops::SigmoidOpMaker, ops::SigmoidOpGrad);
REGISTER_OP_CPU_KERNEL(sigmoid,
ops::SigmoidKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
......
......@@ -62,8 +62,7 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel {
namespace ops = paddle::operators;
REGISTER_OP(softmax, ops::SoftmaxOp, ops::SoftmaxOpMaker, softmax_grad,
ops::SoftmaxOpGrad);
REGISTER_OP(softmax, ops::SoftmaxOp, ops::SoftmaxOpMaker, ops::SoftmaxOpGrad);
REGISTER_OP_CPU_KERNEL(softmax,
ops::SoftmaxKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
......
......@@ -26,16 +26,15 @@ class CPUUniformRandomKernel : public framework::OpKernel {
void Compute(const framework::ExecutionContext& context) const override {
auto* tensor = context.Output<framework::Tensor>("Out");
T* data = tensor->mutable_data<T>(context.GetPlace());
unsigned int seed =
static_cast<unsigned int>(context.op_.GetAttr<int>("seed"));
unsigned int seed = static_cast<unsigned int>(context.GetAttr<int>("seed"));
std::minstd_rand engine;
if (seed == 0) {
seed = std::random_device()();
}
engine.seed(seed);
std::uniform_real_distribution<T> dist(
static_cast<T>(context.op_.GetAttr<float>("min")),
static_cast<T>(context.op_.GetAttr<float>("max")));
static_cast<T>(context.GetAttr<float>("min")),
static_cast<T>(context.GetAttr<float>("max")));
ssize_t size = framework::product(tensor->dims());
for (ssize_t i = 0; i < size; ++i) {
data[i] = dist(engine);
......
......@@ -45,14 +45,13 @@ class GPUUniformRandomKernel : public framework::OpKernel {
void Compute(const framework::ExecutionContext& context) const override {
auto* tensor = context.Output<framework::Tensor>("Out");
T* data = tensor->mutable_data<T>(context.GetPlace());
unsigned int seed =
static_cast<unsigned int>(context.op_.GetAttr<int>("seed"));
unsigned int seed = static_cast<unsigned int>(context.GetAttr<int>("seed"));
if (seed == 0) {
std::random_device rd;
seed = rd();
}
T min = static_cast<T>(context.op_.GetAttr<float>("min"));
T max = static_cast<T>(context.op_.GetAttr<float>("max"));
T min = static_cast<T>(context.GetAttr<float>("min"));
T max = static_cast<T>(context.GetAttr<float>("max"));
thrust::counting_iterator<unsigned int> index_sequence_begin(0);
ssize_t N = framework::product(tensor->dims());
thrust::transform(index_sequence_begin, index_sequence_begin + N,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册