提交 1cd20140 编写于 作者: Y Yu Yang

Merge branch 'develop' of github.com:baidu/Paddle into feature/pybind_for_protobuf_desc

......@@ -22,5 +22,5 @@ def initHook(settings, height, width, color, num_class, **kwargs):
def process(settings, file_list):
for i in xrange(1024):
img = np.random.rand(1, settings.data_size).reshape(-1, 1).flatten()
lab = random.randint(0, settings.num_class)
lab = random.randint(0, settings.num_class - 1)
yield img.astype('float32'), int(lab)
set -e
unset OMP_NUM_THREADS MKL_NUM_THREADS
export OMP_DYNAMIC="FALSE"
export KMP_AFFINITY="granularity=fine,compact,0,0"
function train() {
topology=$1
bs=$2
use_mkldnn=$3
if [ $3 == "True" ]; then
use_mkldnn=$3
thread=1
log="logs/${topology}-mkldnn-${bs}.log"
elif [ $3 == "False" ]; then
use_mkldnn=$3
thread=`nproc`
log="logs/${topology}-${thread}mklml-${bs}.log"
else
echo "Wrong input $3, use True or False."
fi
args="batch_size=${bs}"
config="${topology}.py"
paddle train --job=time \
--config=$config \
--use_mkldnn=$use_mkldnn \
--use_gpu=False \
--trainer_count=$thread \
--log_period=10 \
--test_period=100 \
--config_args=$args \
2>&1 | tee ${log}
}
if [ ! -d "train.list" ]; then
echo " " > train.list
fi
if [ ! -d "logs" ]; then
mkdir logs
fi
#========= mkldnn =========#
# vgg
train vgg 64 True
train vgg 128 True
train vgg 256 True
#========== mklml ===========#
train vgg 64 False
train vgg 128 False
train vgg 256 False
#!/usr/bin/env python
from paddle.trainer_config_helpers import *
height = 224
width = 224
num_class = 1000
batch_size = get_config_arg('batch_size', int, 64)
layer_num = get_config_arg('layer_num', int, 19)
args = {'height': height, 'width': width, 'color': True, 'num_class': num_class}
define_py_data_sources2(
"train.list", None, module="provider", obj="process", args=args)
settings(
batch_size=batch_size,
learning_rate=0.01 / batch_size,
learning_method=MomentumOptimizer(0.9),
regularization=L2Regularization(0.0005 * batch_size))
img = data_layer(name='image', size=height * width * 3)
def vgg_network(vgg_num=3):
tmp = img_conv_group(
input=img,
num_channels=3,
conv_padding=1,
conv_num_filter=[64, 64],
conv_filter_size=3,
conv_act=ReluActivation(),
pool_size=2,
pool_stride=2,
pool_type=MaxPooling())
tmp = img_conv_group(
input=tmp,
conv_num_filter=[128, 128],
conv_padding=1,
conv_filter_size=3,
conv_act=ReluActivation(),
pool_stride=2,
pool_type=MaxPooling(),
pool_size=2)
channels = []
for i in range(vgg_num):
channels.append(256)
tmp = img_conv_group(
input=tmp,
conv_num_filter=channels,
conv_padding=1,
conv_filter_size=3,
conv_act=ReluActivation(),
pool_stride=2,
pool_type=MaxPooling(),
pool_size=2)
channels = []
for i in range(vgg_num):
channels.append(512)
tmp = img_conv_group(
input=tmp,
conv_num_filter=channels,
conv_padding=1,
conv_filter_size=3,
conv_act=ReluActivation(),
pool_stride=2,
pool_type=MaxPooling(),
pool_size=2)
tmp = img_conv_group(
input=tmp,
conv_num_filter=channels,
conv_padding=1,
conv_filter_size=3,
conv_act=ReluActivation(),
pool_stride=2,
pool_type=MaxPooling(),
pool_size=2)
tmp = fc_layer(
input=tmp,
size=4096,
act=ReluActivation(),
layer_attr=ExtraAttr(drop_rate=0.5))
tmp = fc_layer(
input=tmp,
size=4096,
act=ReluActivation(),
layer_attr=ExtraAttr(drop_rate=0.5))
return fc_layer(input=tmp, size=num_class, act=SoftmaxActivation())
if layer_num == 16:
vgg = vgg_network(3)
elif layer_num == 19:
vgg = vgg_network(4)
else:
print("Wrong layer number.")
lab = data_layer('label', num_class)
loss = cross_entropy(input=vgg, label=lab)
outputs(loss)
......@@ -253,7 +253,7 @@ function(nv_library TARGET_NAME)
foreach(source_file ${nv_library_SRCS})
string(REGEX REPLACE "\\.[^.]*$" "" source ${source_file})
if(EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${source}.h)
list(APPEND cc_library_HEADERS ${CMAKE_CURRENT_SOURCE_DIR}/${source}.h)
list(APPEND nv_library_HEADERS ${CMAKE_CURRENT_SOURCE_DIR}/${source}.h)
endif()
endforeach()
add_style_check_target(${TARGET_NAME} ${nv_library_SRCS} ${nv_library_HEADERS})
......
......@@ -97,6 +97,10 @@ function(link_paddle_exe TARGET_NAME)
target_link_libraries(${TARGET_NAME} log)
endif(ANDROID)
if(WITH_MKLDNN AND WITH_MKLML AND MKLDNN_IOMP_DIR)
target_link_libraries(${TARGET_NAME} "-L${MKLDNN_IOMP_DIR} -liomp5 -Wl,--as-needed")
endif()
add_dependencies(${TARGET_NAME} ${external_project_dependencies})
endfunction()
......
......@@ -390,4 +390,125 @@ PaddlePaddle保存的模型参数文件内容由16字节头信息和网络参数
* 如果发现最早的报错就是网络通信的问题,很有可能是非独占方式执行导致的端口冲突,可以联系OP,看当前MPI集群是否支持resource=full参数提交,如果支持增加此参数提交,并更换job 端口。
* 如果当前MPI集群并不支持任务独占模式,可以联系OP是否可以更换集群或升级当前集群。
\ No newline at end of file
* 如果当前MPI集群并不支持任务独占模式,可以联系OP是否可以更换集群或升级当前集群。
19. PaddlePaddle如何输出多个层
------------------------------
* 将需要输出的层作为 :code:`paddle.inference.Inference()` 接口的 :code:`output_layer` 参数输入,代码如下:
.. code-block:: python
inferer = paddle.inference.Inference(output_layer=[layer1, layer2], parameters=parameters)
* 指定要输出的字段进行输出。以输出 :code:`value` 字段为例,代码如下:
.. code-block:: python
out = inferer.infer(input=data_batch, flatten_result=False, field=["value"])
这里设置 :code:`flatten_result=False`,得到的输出结果是元素个数等于输出字段数的 :code:`list`,该 :code:`list` 的每个元素是由所有输出层相应字段结果组成的 :code:`list`,每个字段结果的类型是 :code:`numpy.array`。:code:`flatten_result` 的默认值为 :code:`True`,该情况下,PaddlePaddle会分别对每个字段将所有输出层的结果按行进行拼接,如果各输出层该字段 :code:`numpy.array` 结果的相应维数不匹配,程序将不能正常运行。
20. :code:`paddle.layer.memory` 的参数 :code:`name` 如何使用
-------------------------------------------------------------
* :code:`paddle.layer.memory` 用于获取特定layer上一时间步的输出,该layer是通过参数 :code:`name` 指定,即,:code:`paddle.layer.memory` 会关联参数 :code:`name` 取值相同的layer,并将该layer上一时间步的输出作为自身当前时间步的输出。
* PaddlePaddle的所有layer都有唯一的name,用户通过参数 :code:`name` 设定,当用户没有显式设定时,PaddlePaddle会自动设定。而 :code:`paddle.layer.memory` 不是真正的layer,其name由参数 :code:`memory_name` 设定,当用户没有显式设定时,PaddlePaddle会自动设定。:code:`paddle.layer.memory` 的参数 :code:`name` 用于指定其要关联的layer,需要用户显式设定。
21. dropout 使用
-----------------
* 在PaddlePaddle中使用dropout有两种方式
* 在相应layer的 :code:`layer_atter` 设置 :code:`drop_rate`,以 :code:`paddle.layer.fc` 为例,代码如下:
.. code-block:: python
fc = paddle.layer.fc(input=input, layer_attr=paddle.attr.ExtraLayerAttribute(drop_rate=0.5))
* 使用 :code:`paddle.layer.dropout`,以 :code:`paddle.layer.fc` 为例,代码如下:
.. code-block:: python
fc = paddle.layer.fc(input=input)
drop_fc = paddle.layer.dropout(input=fc, dropout_rate=0.5)
* :code:`paddle.layer.dropout` 实际上使用了 :code:`paddle.layer.add_to`,并在该layer里采用第一种方式设置 :code:`drop_rate` 来使用dropout的。这种方式对内存消耗较大。
* PaddlePaddle在激活函数里实现dropout,而不是在layer里实现。
* :code:`paddle.layer.lstmemory`、:code:`paddle.layer.grumemory`、:code:`paddle.layer.recurrent` 不是通过一般的方式来实现对输出的激活,所以不能采用第一种方式在这几个layer里设置 :code:`drop_rate` 来使用dropout。若要对这几个layer使用dropout,可采用第二种方式,即使用 :code:`paddle.layer.dropout`。
22. 如何设置学习率退火(learning rate annealing)
------------------------------------------------
在相应的优化算法里设置learning_rate_schedule及相关参数,以使用Adam算法为例,代码如下:
.. code-block:: python
optimizer = paddle.optimizer.Adam(
learning_rate=1e-3,
learning_rate_decay_a=0.5,
learning_rate_decay_b=0.75,
learning_rate_schedule="poly",)
PaddlePaddle目前支持8种learning_rate_schedule,这8种learning_rate_schedule及其对应学习率计算方式如下:
* "constant"
lr = learning_rate
* "poly"
lr = learning_rate * pow(1 + learning_rate_decay_a * num_samples_processed, -learning_rate_decay_b)
其中,num_samples_processed为已训练样本数,下同。
* "caffe_poly"
lr = learning_rate * pow(1.0 - num_samples_processed / learning_rate_decay_a, learning_rate_decay_b)
* "exp"
lr = learning_rate * pow(learning_rate_decay_a, num_samples_processed / learning_rate_decay_b)
* "discexp"
lr = learning_rate * pow(learning_rate_decay_a, floor(num_samples_processed / learning_rate_decay_b))
* "linear"
lr = max(learning_rate - learning_rate_decay_a * num_samples_processed, learning_rate_decay_b)
* "manual"
这是一种按已训练样本数分段取值的学习率退火方法。使用该learning_rate_schedule时,用户通过参数 :code:`learning_rate_args` 设置学习率衰减因子分段函数,当前的学习率为所设置 :code:`learning_rate` 与当前的衰减因子的乘积。以使用Adam算法为例,代码如下:
.. code-block:: python
optimizer = paddle.optimizer.Adam(
learning_rate=1e-3,
learning_rate_schedule="manual",
learning_rate_args="1000:1.0,2000:0.9,3000:0.8",)
在该示例中,当已训练样本数小于等于1000时,学习率为 :code:`1e-3 * 1.0`;当已训练样本数大于1000小于等于2000时,学习率为 :code:`1e-3 * 0.9`;当已训练样本数大于2000时,学习率为 :code:`1e-3 * 0.8`。
* "pass_manual"
这是一种按已训练pass数分段取值的学习率退火方法。使用该learning_rate_schedule时,用户通过参数 :code:`learning_rate_args` 设置学习率衰减因子分段函数,当前的学习率为所设置 :code:`learning_rate` 与当前的衰减因子的乘积。以使用Adam算法为例,代码如下:
.. code-block:: python
optimizer = paddle.optimizer.Adam(
learning_rate=1e-3,
learning_rate_schedule="manual",
learning_rate_args="1:1.0,2:0.9,3:0.8",)
在该示例中,当已训练pass数小于等于1时,学习率为 :code:`1e-3 * 1.0`;当已训练pass数大于1小于等于2时,学习率为 :code:`1e-3 * 0.9`;当已训练pass数大于2时,学习率为 :code:`1e-3 * 0.8`。
23. 出现 :code:`Duplicated layer name` 错误怎么办
--------------------------------------------------
出现该错误的原因一般是用户对不同layer的参数 :code:`name` 设置了相同的取值。遇到该错误时,先找出参数 :code:`name` 取值相同的layer,然后将这些layer的参数 :code:`name` 设置为不同的值。
# How to write a new operator
- [Background](#Background)
- [Implementing C++ Types](#Implementing_C++_Types)
- [Defining ProtoMaker](#Defining_ProtoMaker)
- [Defining Operator](#Defining_Operator)
- [Registering Operator](#Registering_Operator)
- [Compilation](#Compilation)
- [Python Binding](#Python_Binding)
- [Unit Tests](#Unit_Tests)
## Background
Here are the base types needed. For details, please refer to the design docs.
- `framework::OperatorBase`: Operator (Op)base class.
- `framework::OpKernel`: Base class for Op computation.
- `framework::OperatorWithKernel`: Inherited from OperatorBase, describing an operator with computation.
- `class OpProtoAndCheckerMaker`: Describes an Operator's input, output, attributes and description, mainly used to interface with Python API.
An operator can be differentiated by whether in has kernel methods. An operator with kernel inherits from `OperatorWithKernel` while the ones without inherit from `OperatorBase`. This tutorial focuses on implementing operators with kernels. In short, an operator includes the following information:
Information | Where is it defined
-------------- | :----------------------
OpProtoMake definition | `.cc`files, Backward Op does not need an OpProtoMake interface.
Op definition | `.cc` files
Kernel implementation | The kernel methods shared between CPU and GPU are defined in `.h` files. CPU-specific kernels live in `.cc` files, while GPU-specific kernels are implemented in `.cu`files.
Registering the Op | Ops are registered in `.cc` files; For Kernel registration, `.cc` files contain the CPU implementation, while `.cu` files contain the GPU implementation.
New Operator implementations are added to the list [paddle/operators](https://github.com/PaddlePaddle/Paddle/tree/develop/paddle/operators), with file names in the format `*_op.h` (if applicable), `*_op.cc`, `*_op.cu` (if applicable).** The system will use the naming scheme to automatically build operators and their corresponding Python extensions. **
Let's take matrix multiplication operator, [MulOp](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/operators/mul_op.cc), as an example to introduce the writing of an Operator with Kernel.
## Implementing C++ Types
### 1. Defining Class ProtoMaker
Matrix Multiplication can be written as $Out = X * Y$, meaning that the operation consists of two inputs and pne output.
First, define `ProtoMaker` to describe the Operator's input, output, and additional comments:
```cpp
class MulOpMaker : public framework::OpProtoAndCheckerMaker {
public:
MulOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(Tensor), 2D tensor of size (M x K)");
AddInput("Y", "(Tensor), 2D tensor of size (K x N)");
AddOutput("Out", "(Tensor), 2D tensor of size (M x N)");
AddComment(R"DOC(
Two Element Mul Operator.
The equation is: Out = X * Y
)DOC");
}
};
```
[`MulOpMaker`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/operators/mul_op.cc#L43)is inherited from`framework::OpProtoAndCheckerMaker`, consisting of 2 variables in the constructor:
- `framework::OpProto` stores Operator input and variable attribute, used for generating Python API interfaces.
- `framework::OpAttrChecker` is used to validate variable attributes.
The constructor utilizes `AddInput`, `AddOutput`, and `AddComment`, so that the corresponding information will be added to `OpProto`.
The code above adds two inputs `X` and `Y` to `MulOp`, an output `Out`, and their corresponding descriptions, in accordance to Paddle's [naming convention](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/operators/name_convention.md).
An additional example [`ScaleOp`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/operators/scale_op.cc#L37) is implemented as follows:
```cpp
template <typename AttrType>
class ScaleOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ScaleOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input tensor of scale operator.").NotInGradient();
AddOutput("Out", "The output tensor of scale operator.").NotInGradient();
AddComment(R"DOC(Scale operator
The equation is: Out = scale*X
)DOC");
AddAttr<AttrType>("scale", "scale of scale operator.").SetDefault(1.0);
}
};
```
There are two changes in this example:
- `AddInput("X","...").NotInGradient()` expresses that input `X` is not involved in `ScaleOp`'s corresponding computation. If an input to an operator is not participating in back-propagation, please explicitly set `.NotInGradient()`.
- `AddAttr<AttrType>("scale", "...").SetDefault(1.0);` adds `scale`constant as an attribute, and sets the default value to 1.0.
### 2. Defining Operator
The following code defines the interface for MulOp:
```cpp
class MulOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
auto dim0 = ctx.Input<Tensor>("X")->dims();
auto dim1 = ctx.Input<Tensor>("Y")->dims();
PADDLE_ENFORCE_EQ(dim0.size(), 2,
"input X(%s) should be a tensor with 2 dims, a matrix",
ctx.op_.Input("X"));
PADDLE_ENFORCE_EQ(dim1.size(), 2,
"input Y(%s) should be a tensor with 2 dims, a matrix",
ctx.op_.Input("Y"));
PADDLE_ENFORCE_EQ(
dim0[1], dim1[0],
"First matrix's width must be equal with second matrix's height.");
ctx.Output<Tensor>("Out")->Resize({dim0[0], dim1[1]});
}
};
```
[`MulOp`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/operators/mul_op.cc#L22) is inherited from `OperatorWithKernel`. Its `public` member
```cpp
using framework::OperatorWithKernel::OperatorWithKernel;
```
expresses an operator constructor using base class `OperatorWithKernel`, alternatively written as
```cpp
MulOp(const std::string &type, const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorWithKernel(type, inputs, outputs, attrs) {}
```
`InferShape` interface needs to be re-written.`InferShape` is a constant method and cannot modify Op's member variables, its constant member `const framework::InferShapeContext &ctx` can be used to extract input, output, and attributes. It functions to
- 1). validate and error out early: it checks input data dimensions and types.
- 2). configures the tensor shape in the output.
Usually `OpProtoMaker` and `Op`'s type definitions are written in `.cc` files, which also include the registration methods introduced later.
### 3. Defining OpKernel
`MulKernel` inherits `framework::OpKernel`, which includes the following templates:
- `typename Place` denotes device type. When different devices, namely the CPU and the GPU, share the same kernel, this template needs to be added. If they don't share kernels, this must not be added. An example of a non-sharing kernel is [`OnehotCrossEntropyOpKernel`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/operators/cross_entropy_op.h#L43).
- `typename T` denotes data type, such as `float` or `double`.
`MulKernel` types need to rewrite the interface for `Compute`.
- `Compute` takes one input variable `const framework::ExecutionContext& context`.
- Compared with `InferShapeContext`, `ExecutionContext` includes device types, and can similarly extract input, output, and attribute variables.
- `Compute` implements the computation logics of an `OpKernel`.
`MulKernel`'s implementation of `Compute` is as follows:
```cpp
template <typename Place, typename T>
class MulKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<Tensor>("X");
auto* Y = context.Input<Tensor>("Y");
auto* Z = context.Output<Tensor>("Out");
Z->mutable_data<T>(context.GetPlace());
auto* device_context =
const_cast<platform::DeviceContext*>(context.device_context_);
math::matmul<Place, T>(*X, false, *Y, false, 1, Z, 0, device_context);
}
};
```
Note that **different devices (CPU, GPU)share an Op definition; whether or not they share the same `OpKernel` depends on whether `Compute` calls functions that support both devices.**
`MulOp`'s CPU and GPU share the same `Kernel`. A non-sharing `OpKernel` example can be seen in [`OnehotCrossEntropyOpKernel`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/operators/cross_entropy_op.h#L43).
To ease the writing of `OpKernel` compute, and for reusing code cross-device, `Eigen unsupported Tensor` module is used to implement `Compute` interface. To learn about how the Eigen library is used in PaddlePaddle, please see [usage document](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/howto/dev/use_eigen_cn.md).
This concludes the forward implementation of an operator. Next its operation and kernel need to be registered in a `.cc` file.
The definition of its corresponding backward operator, if applicable, is similar to that of an forward operator. **Note that a backward operator does not include a `ProtoMaker`**.
### 4. Registering Operator
- In `.cc` files, register forward and backward operator classes and the CPU kernel.
```cpp
namespace ops = paddle::operators;
REGISTER_OP(mul, ops::MulOp, ops::MulOpMaker, mul_grad, ops::MulOpGrad);
REGISTER_OP_CPU_KERNEL(mul, ops::MulKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(mul_grad,
ops::MulGradKernel<paddle::platform::CPUPlace, float>);
```
In that code block,
- `REGISTER_OP` registers the `ops::MulOp` class, type named `mul`, its type `ProtoMaker` is `ops::MulOpMaker`, registering `ops::MulOpGrad` as `mul_grad`.
- `REGISTER_OP_WITHOUT_GRADIENT` registers an operator without gradient.
- `REGISTER_OP_CPU_KERNEL` registers `ops::MulKernel` class and specialized template types `paddle::platform::CPUPlace` and `float`, which also registers `ops::MulKernel`.
- Registering GPU Kernel in `.cu` files
- Note that if GPU Kernel is implemented using the `Eigen unsupported` module, then on top of `.cu`, a macro definition `#define EIGEN_USE_GPU` is needed, such as
```cpp
// if use Eigen unsupported module before include head files
#define EIGEN_USE_GPU
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(mul, ops::MulKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(mul_grad,
ops::MulGradKernel<paddle::platform::GPUPlace, float>);
```
### 5. Compilation
Run the following commands to compile.
```
make mul_op
```
## Python Binding
The system will automatically bind to Python and link it to a generated library.
## Unit Tests
Unit tests include comparing a forward operator's implementations on different devices, comparing a backward operator's implementation on different devices, and a scaling test for the backward operator. Here, we introduce the [unit tests for `MulOp`](https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/v2/framework/tests/test_mul_op.py).
# Cluster bootstrapping tool survey
## Abstract
In order to bring up a cluster from bare metal machine to a fully functional kubernetes cluster for Paddlepaddle to run, we need to utilize some tools. Here we are going to compare [Sextant](https://github.com/k8sp/sextant) and [Tectonic installer](https://github.com/coreos/tectonic-installer)
## Basic assumptions
Here are some basic assumptions before we move on to details
1. You are an administrator of a bare metal machine cluster, which means:
* you have full control to each of the machines.
* you have full control to the network which machines are connected to.
2. Machines can be booted from network with PEX or iPXE
3. You understand the [general procedure to bring up a cluster](#appendix-general-procedure-to-bring-up-a-cluster)
if your cluster is able to mark above items with checkmarks, then keep reading.
## Comparing Sextant and Tectonic installer
### Sextant
Sextant is an end2end solution to bring up a bare metal cluster to a fully functional k8s cluster, it integrates DHCP, name service, PEX, cloud-config-service, docker registry services altogether.
#### Pros
1. End2End: basically all admin need to do is to config the cluster.yaml and power on the cluster.
2. Offline cluster configuration: Sextant has 2 phases during working with it, config time and deploy time. when admin is configuring, it requires admin's machine has internet connectivity, which will download some images, etc. But in deploy time, it's completely OK to go offline since all dependencies are ready during config time.
3. docker registry integrated.
4. GPU machine took care of.
### Cons
1. k8s API server is not deployed with high availability in considering by default.
2. No grouping support.
3. No API interface, a one-off service.
### Tectonic installer
First of all, Tectonic is not free, it requires coreos.com account as a step of installation, and free user can only create less than 10 nodes.
Tectonic is a suite of software which wraps around k8s and providing more utility regarding dev ops, ie,
Tectonic installer as it's named, it installs Tectonic to a bare metal cluster which means it's not totally an equivalent of Sextant. At the "booting a cluster" part, it mostly utilizes [Matchbox](https://github.com/coreos/matchbox), which is a general cluster bootstrapper.
Matchbox's Approach is similar to Sexstant.
### Pros
1. supports grouping machines.
2. supports running provisioning service in rtk. (not a big deal though).
3. supports http/gRPC API interface.
4. supports multi-template.
### Cons
1. Not an e2e solution to bring up a cluster, need a lot of extra work and other software.
2. [Not fully supporting](https://github.com/coreos/matchbox/issues/550) centOS deployment yet.
## Conclusion
Sextant is a better solution overall for paddle cloud deploying to a bare metal cluster. It would be great if Sextant can also 1) deploy k8s api server with high availability by default; 2) not designed as a one-off service.
## Appendix: General procedure to bring up a cluster
It's physically impossible for a cluster admin to manually install OS and applications into cluster nodes one by one, here is what an admin would do in cloud industry:
1. setup a bootstrap machine with static IP in the cluster, which has following services:
* DHCP: assigns ip address for rest of the nodes.
* name service: to map node name to a IP
* PXE related services: the booting related info will be delivered to newly booted machines as their IP is assigned via DHCP service, PXE service will provide further booting and installing info and image with TFTP and http protocol.
* cluster config service: this is for providing cluster node with OS config via http
* optional docker registry: a built-in docker registry makes the whole cluster independent from connecting internet, and speeds up software distribution.
2. New node powers on, it will
* broadcast the request for an IP address
* DHCP server assigns the IP address, and deliver the PXE booting related info to the node.
* cluster node will request config files with booting info delivered with DHCP via the TFTP service, and in most of the cases, the config file will point to a http service for the booting image.
* Since PXE is configured with initrd, it will utilize the cloud config service and do further installations like coreOS or K8s installations.
* then restart the node.
For further understanding, following 2 links from Matchbox are some good readings:
* [Machine lifecycle](https://github.com/coreos/matchbox/blob/master/Documentation/machine-lifecycle.md)
* [PXE booting](https://github.com/coreos/matchbox/blob/master/Documentation/network-booting.md)
......@@ -73,14 +73,6 @@ Attribute GetAttrValue(const OpDesc::Attr& attr_desc) {
}
return val;
}
case framework::AttrType::INT_PAIRS: {
std::vector<std::pair<int, int>> val(attr_desc.int_pairs_size());
for (int i = 0; i < attr_desc.int_pairs_size(); ++i) {
val[i].first = attr_desc.int_pairs(i).first();
val[i].second = attr_desc.int_pairs(i).second();
}
return val;
}
case framework::AttrType::BLOCK: {
return GetProgramDesc().mutable_blocks(attr_desc.block_idx());
}
......
......@@ -29,8 +29,7 @@ namespace framework {
// The order should be as same as framework.proto
typedef boost::variant<boost::blank, int, float, std::string, std::vector<int>,
std::vector<float>, std::vector<std::string>,
std::vector<std::pair<int, int>>, bool,
std::vector<float>, std::vector<std::string>, bool,
std::vector<bool>, BlockDesc*>
Attribute;
......
......@@ -22,17 +22,11 @@ enum AttrType {
INTS = 3;
FLOATS = 4;
STRINGS = 5;
INT_PAIRS = 6;
BOOLEAN = 7;
BOOLEANS = 8;
BLOCK = 9;
BOOLEAN = 6;
BOOLEANS = 7;
BLOCK = 8;
}
message IntPair {
required int32 first = 1;
required int32 second = 2;
};
// OpDesc describes an instance of a C++ framework::OperatorBase
// derived class type.
message OpDesc {
......@@ -46,7 +40,6 @@ message OpDesc {
repeated int32 ints = 6;
repeated float floats = 7;
repeated string strings = 8;
repeated IntPair int_pairs = 9;
optional bool b = 10;
repeated bool bools = 11;
optional int32 block_idx = 12;
......
......@@ -100,6 +100,7 @@ public:
if (cnt_ == act.value->getElementCnt()) {
return;
}
VLOG(MKLDNN_BASE) << getName() << " reset mkldnn forward";
cnt_ = act.value->getElementCnt();
stream_.reset(new MKLDNNStream());
auto eng = CPUEngine::Instance().getEngine();
......@@ -110,7 +111,6 @@ public:
float alpha = getAlpha();
float beta = getBeta();
/// forward
pipelineFwd_.clear();
val_ = std::dynamic_pointer_cast<MKLDNNMatrix>(act.value);
if (val_ == nullptr) {
......@@ -152,6 +152,7 @@ public:
if (!needResetBwd_) {
return;
}
VLOG(MKLDNN_BASE) << getName() << " reset mkldnn backward";
needResetBwd_ = false;
mkldnn::algorithm algo = getAlgo(this->getName());
float alpha = getBwdAlpha();
......
......@@ -64,7 +64,7 @@ bool MKLDNNConvLayer::init(const LayerMap& layerMap,
// create biases
if (biasParameter_.get() != NULL) {
biases_ = std::unique_ptr<Weight>(new Weight(1, oc_, biasParameter_));
biases_ = std::unique_ptr<Weight>(new Weight(1, oc_, biasParameter_, 0));
}
return true;
}
......@@ -251,22 +251,31 @@ void MKLDNNConvLayer::resetInValue(
// create buffer and reorder if input value do not match
cpuInVal_ = nullptr;
cvtInVal_ = nullptr;
if (inputIsOnlyMKLDNN()) {
MKLDNNMatrixPtr dnnIn = std::dynamic_pointer_cast<MKLDNNMatrix>(inMat);
CHECK(dnnIn) << "Input should be MKLDNNMatrix";
if (dnnIn->getPrimitiveDesc() != in->getPrimitiveDesc()) {
CHECK_EQ(dnnIn->getFormat(), format::nc);
MKLDNNMatrixPtr dnnIn = std::dynamic_pointer_cast<MKLDNNMatrix>(inMat);
CHECK_EQ(inputIsOnlyMKLDNN(), dnnIn != nullptr);
if (dnnIn != nullptr && dnnIn->getPrimitiveDesc() == in->getPrimitiveDesc()) {
in = dnnIn;
return;
}
if (dnnIn) {
if (dnnIn->getFormat() == format::nc) {
CHECK(ih_ == 1 && iw_ == 1) << "when input is nc format";
// create a new one with nchw format and same data
memory::dims inDims = memory::dims{bs_, ic_, 1, 1};
dnnIn = MKLDNNMatrix::create(inMat, inDims, format::nchw, engine_);
CHECK(dnnIn->getPrimitiveDesc() == in->getPrimitiveDesc());
}
in = dnnIn;
if (dnnIn->getPrimitiveDesc() == in->getPrimitiveDesc()) {
in = dnnIn;
return;
}
cpuInVal_ = dnnIn;
in = MKLDNNMatrix::create(nullptr, pd->src_primitive_desc());
cvtInVal_ = MKLDNNMatrix::createReorder(cpuInVal_, in);
CHECK(cvtInVal_) << "should not be emptry";
} else {
const MatrixPtr& cpuIn = getInputValue(0, CPU_DEVICE);
memory::dims inDims = memory::dims{bs_, ic_, ih_, iw_};
cpuInVal_ = MKLDNNMatrix::create(cpuIn, inDims, format::nchw, engine_);
cpuInVal_ = MKLDNNMatrix::create(inMat, inDims, format::nchw, engine_);
if (cpuInVal_->getPrimitiveDesc() != in->getPrimitiveDesc()) {
// create new mkldnn matrix
in = MKLDNNMatrix::create(nullptr, pd->src_primitive_desc());
......@@ -535,7 +544,7 @@ void MKLDNNConvLayer::resetWgtValBwdData(
} else {
wgtValBwdData_ = wgtVal_;
}
VLOG(MKLDNN_FMTS) << "weight value format for backward data"
VLOG(MKLDNN_FMTS) << "weight value format for backward data: "
<< wgtValBwdData_->getFormat();
}
......
......@@ -49,7 +49,7 @@ bool MKLDNNFcLayer::init(const LayerMap& layerMap,
// create biases
if (biasParameter_.get() != NULL) {
biases_ = std::unique_ptr<Weight>(new Weight(1, oc_, biasParameter_));
biases_ = std::unique_ptr<Weight>(new Weight(1, oc_, biasParameter_, 0));
}
return true;
}
......@@ -161,9 +161,16 @@ void MKLDNNFcLayer::resetInValue(MKLDNNMatrixPtr& in) {
void MKLDNNFcLayer::resetWgtBiasValue(MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias) {
format wgtFmt = format::oihw;
if (inVal_->getFormat() == format::nChw8c) {
wgtFmt = format::oIhw8i;
} else if (inVal_->getFormat() == format::nChw16c) {
wgtFmt = format::oIhw16i;
}
wgt = MKLDNNMatrix::create(
weight_->getW(), {oc_, ic_, ih_, iw_}, format::oihw, engine_);
weight_->getW(), {oc_, ic_, ih_, iw_}, wgtFmt, engine_);
wgt->downSpatial();
VLOG(MKLDNN_FMTS) << "Weight value format: " << wgt->getFormat();
bias = (biases_ && biases_->getW())
? MKLDNNMatrix::create(biases_->getW(), {oc_}, format::x, engine_)
......
......@@ -115,6 +115,7 @@ public:
copySeqInfoToOutputs();
size_t elemenCnt = inputLayers_[0]->getOutput().value->getElementCnt();
if (inputElemenCnt_ != elemenCnt) {
VLOG(MKLDNN_BASE) << getName() << " reset mkldnn forward";
// reset when input total sizes changed, not only the batchsize
inputElemenCnt_ = elemenCnt;
reshape(bs_, ic_, ih_, iw_, oc_, oh_, ow_);
......@@ -142,6 +143,7 @@ public:
void backward(const UpdateCallback& callback) override {
if (needResetBwd_) {
VLOG(MKLDNN_BASE) << getName() << " reset mkldnn backward";
resetBwd(pipelineBwd_, inGrad_, wgtGrad_, biasGrad_, outGrad_);
needResetBwd_ = false;
}
......
......@@ -38,10 +38,10 @@ class CropKernel : public framework::OpKernel {
auto out_stride = framework::stride(out->dims());
auto offsets = context.Attr<std::vector<int>>("offsets");
PADDLE_ENFORCE_EQ(
x->dims().size(), offsets.size(),
x->dims().size(), static_cast<int64_t>(offsets.size()),
"Offsets size should be equal to dimension size of input tensor.");
int64_t offset = 0;
for (int i = 0; i < offsets.size(); ++i) {
for (size_t i = 0; i < offsets.size(); ++i) {
offset += (x_stride[i] * offsets[i]);
}
StridedMemcpy<T>(context.device_context(), x_data + offset, x_stride,
......@@ -57,7 +57,7 @@ void CropGradFunction(const framework::ExecutionContext& context) {
d_x->mutable_data<T>(context.GetPlace());
auto offsets = context.Attr<std::vector<int>>("offsets");
Eigen::array<std::pair<int, int>, D> paddings;
for (int i = 0; i < D; ++i) {
for (size_t i = 0; i < D; ++i) {
paddings[i].first = offsets[i];
paddings[i].second = d_x->dims()[i] - d_out->dims()[i] - offsets[i];
}
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/lstm_unit_op.h"
namespace paddle {
namespace operators {
class LstmUnitOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"),
"Input(X) of LSTM should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("C_prev"),
"Input(C_prev) of LSTM should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("C"),
"Output(C) of LSTM should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("H"),
"Output(H) of LSTM should not be null.");
auto *x = ctx.Input<framework::Tensor>("X");
auto *c_prev = ctx.Input<framework::Tensor>("C_prev");
PADDLE_ENFORCE_EQ(x->dims().size(), 2, "Input(X)'s rank must be 2.");
PADDLE_ENFORCE(x->dims()[0] == c_prev->dims()[0],
"Batch size of inputs and states must be equal");
PADDLE_ENFORCE(x->dims()[1] == c_prev->dims()[1] * 4,
"Dimension of FC should equal to prev state * 4");
int b_size = c_prev->dims()[0]; // batch size
int s_dim = c_prev->dims()[1]; // state dim
ctx.Output<framework::LoDTensor>("C")->Resize({b_size, s_dim});
ctx.Output<framework::LoDTensor>("H")->Resize({b_size, s_dim});
}
};
template <typename AttrType>
class LstmUnitOpMaker : public framework::OpProtoAndCheckerMaker {
public:
LstmUnitOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "FC input before the non-linear activation.");
AddInput(
"C_prev",
"The cell state tensor of last time-step in the Lstm Unit operator.");
AddOutput("C", "The cell tensor of Lstm Unit operator.");
AddOutput("H", "The hidden state tensor of Lstm Unit operator.");
AddComment(R"DOC(Lstm-Unit Operator
Equation:
i, f, o, j = split(X)
C = C_prev * sigm(f + forget_bias) + sigm(i) * tanh(j)
H = C * sigm(o)
)DOC");
AddAttr<AttrType>("forget_bias", "The forget bias of Lstm Unit.")
.SetDefault(0.0);
}
};
class LstmUnitGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("C")),
"Input(C@GRAD) should not be null");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("H")),
"Input(H@GRAD) should not be null");
ctx.Output<framework::LoDTensor>(framework::GradVarName("X"))
->Resize(ctx.Input<Tensor>("X")->dims());
ctx.Output<framework::LoDTensor>(framework::GradVarName("C_prev"))
->Resize(ctx.Input<Tensor>("C_prev")->dims());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(lstm_unit, ops::LstmUnitOp, ops::LstmUnitOpMaker<float>,
lstm_unit_grad, ops::LstmUnitGradOp);
REGISTER_OP_CPU_KERNEL(lstm_unit,
ops::LstmUnitKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
lstm_unit_grad, ops::LstmUnitGradKernel<paddle::platform::CPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/op_registry.h"
#include "paddle/operators/cross_entropy_op.h"
#include "paddle/platform/assert.h"
#include "paddle/platform/hostdevice.h"
namespace paddle {
namespace operators {
#define CUDA_1D_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
i += blockDim.x * gridDim.x)
template <typename Dtype>
__device__ Dtype cuda_sigmoid(const Dtype x) {
return Dtype(1) / (Dtype(1) + exp(-x));
}
template <typename Dtype>
__device__ Dtype cuda_tanh(const Dtype x) {
return Dtype(1 - exp(-2. * x)) / (Dtype(1) + exp(-2. * x));
}
template <typename T>
__global__ void LSTMUnitKernel(const int nthreads, const int dim,
const T* C_prev, const T* X, T* C, T* H,
const T forget_bias) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
const int n = index / dim;
const int d = index % dim;
const T* X_offset = X + 4 * dim * n;
const T i = cuda_sigmoid(X_offset[d]);
const T f = cuda_sigmoid(X_offset[1 * dim + d] + forget_bias);
const T o = cuda_sigmoid(X_offset[2 * dim + d]);
const T g = cuda_tanh(X_offset[3 * dim + d]);
const T c_prev = C_prev[index];
const T c = f * c_prev + i * g;
C[index] = c;
const T tanh_c = cuda_tanh(c);
H[index] = o * tanh_c;
}
}
template <typename T>
__global__ void LSTMUnitGradientKernel(const int nthreads, const int dim,
const T* C_prev, const T* X, const T* C,
const T* H, const T* C_diff,
const T* H_diff, T* C_prev_diff,
T* X_diff, const T forget_bias) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
const int n = index / dim;
const int d = index % dim;
const T* X_offset = X + 4 * dim * n;
T* c_prev_diff = C_prev_diff + index;
T* X_diff_offset = X_diff + 4 * dim * n;
T* i_diff = X_diff_offset + d;
T* f_diff = X_diff_offset + 1 * dim + d;
T* o_diff = X_diff_offset + 2 * dim + d;
T* g_diff = X_diff_offset + 3 * dim + d;
const T i = cuda_sigmoid(X_offset[d]);
const T f = cuda_sigmoid(X_offset[1 * dim + d] + forget_bias);
const T o = cuda_sigmoid(X_offset[2 * dim + d]);
const T g = cuda_tanh(X_offset[3 * dim + d]);
const T c_prev = C_prev[index];
const T c = C[index];
const T tanh_c = cuda_tanh(c);
const T c_term_diff =
C_diff[index] + H_diff[index] * o * (1 - tanh_c * tanh_c);
*c_prev_diff = c_term_diff * f;
*i_diff = c_term_diff * g * i * (1 - i);
*f_diff = c_term_diff * c_prev * f * (1 - f);
*o_diff = H_diff[index] * tanh_c * o * (1 - o);
*g_diff = c_term_diff * i * (1 - g * g);
}
}
template <typename T, typename AttrType = T>
class LstmUnitOpCUDAKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"It must use GPUPlace.");
auto* x_tensor = ctx.Input<framework::Tensor>("X");
auto* c_prev_tensor = ctx.Input<framework::Tensor>("C_prev");
auto* c_tensor = ctx.Output<framework::Tensor>("C");
auto* h_tensor = ctx.Output<framework::Tensor>("H");
auto forget_bias = static_cast<T>(ctx.Attr<AttrType>("forget_bias"));
int b_size = c_tensor->dims()[0];
int D = c_tensor->dims()[1];
const T* X = x_tensor->data<T>();
const T* C_prev = c_prev_tensor->data<T>();
T* C = c_tensor->mutable_data<T>(ctx.GetPlace());
T* H = h_tensor->mutable_data<T>(ctx.GetPlace());
int block = 512;
int n = b_size * D;
int grid = (n + block - 1) / block;
LSTMUnitKernel<T><<<grid, block>>>(n, D, C_prev, X, C, H, forget_bias);
}
};
template <typename T, typename AttrType = T>
class LstmUnitGradOpCUDAKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"It must use GPUPlace.");
auto x_tensor = ctx.Input<Tensor>("X");
auto c_prev_tensor = ctx.Input<Tensor>("C_prev");
auto c_tensor = ctx.Input<Tensor>("C");
auto h_tensor = ctx.Input<Tensor>("H");
auto hdiff_tensor = ctx.Input<Tensor>(framework::GradVarName("H"));
auto cdiff_tensor = ctx.Input<Tensor>(framework::GradVarName("C"));
auto xdiff_tensor = ctx.Output<Tensor>(framework::GradVarName("X"));
auto c_prev_diff_tensor =
ctx.Output<Tensor>(framework::GradVarName("C_prev"));
auto* X = x_tensor->data<T>();
auto* C_prev = c_prev_tensor->data<T>();
auto* C = c_tensor->data<T>();
auto* H = h_tensor->data<T>();
auto* H_diff = hdiff_tensor->data<T>();
auto* C_diff = cdiff_tensor->data<T>();
auto* C_prev_diff = c_prev_diff_tensor->mutable_data<T>(ctx.GetPlace());
auto* X_diff = xdiff_tensor->mutable_data<T>(ctx.GetPlace());
int N = c_tensor->dims()[0];
int D = c_tensor->dims()[1];
auto forget_bias = static_cast<T>(ctx.Attr<AttrType>("forget_bias"));
int block = 512;
int n = N * D;
int grid = (n + block - 1) / block;
LSTMUnitGradientKernel<T><<<grid, block>>>(n, D, C_prev, X, C, H, C_diff,
H_diff, C_prev_diff, X_diff,
forget_bias);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(lstm_unit, ops::LstmUnitOpCUDAKernel<float>);
REGISTER_OP_GPU_KERNEL(lstm_unit_grad, ops::LstmUnitGradOpCUDAKernel<float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "glog/logging.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
using framework::LoDTensor;
using framework::Tensor;
template <typename T>
inline T sigmoid(T x) {
return 1. / (1. + exp(-x));
}
template <typename T>
inline T tanh(T x) {
return 2. * sigmoid(2. * x) - 1.;
}
template <typename Place, typename T, typename AttrType = T>
class LstmUnitKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
"It must use CPUPlace.");
auto* x_tensor = ctx.Input<framework::Tensor>("X");
auto* c_prev_tensor = ctx.Input<framework::Tensor>("C_prev");
auto* c_tensor = ctx.Output<framework::Tensor>("C");
auto* h_tensor = ctx.Output<framework::Tensor>("H");
auto forget_bias = static_cast<T>(ctx.Attr<AttrType>("forget_bias"));
int b_size = c_tensor->dims()[0];
int D = c_tensor->dims()[1];
T* C = c_tensor->mutable_data<T>(ctx.GetPlace());
T* H = h_tensor->mutable_data<T>(ctx.GetPlace());
const T* X = x_tensor->data<T>();
const T* C_prev = c_prev_tensor->data<T>();
for (int n = 0; n < b_size; ++n) {
for (int d = 0; d < D; ++d) {
const T i = sigmoid(X[d]);
const T f = sigmoid(X[1 * D + d] + forget_bias);
const T o = sigmoid(X[2 * D + d]);
const T g = tanh(X[3 * D + d]);
const T c_prev = C_prev[d];
const T c = f * c_prev + i * g;
C[d] = c;
const T tanh_c = tanh(c);
H[d] = o * tanh_c;
}
C_prev += D;
X += 4 * D;
C += D;
H += D;
}
}
};
template <typename Place, typename T, typename AttrType = T>
class LstmUnitGradKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
"It must use CPUPlace.");
auto x_tensor = ctx.Input<Tensor>("X");
auto c_prev_tensor = ctx.Input<Tensor>("C_prev");
auto c_tensor = ctx.Input<Tensor>("C");
auto h_tensor = ctx.Input<Tensor>("H");
auto hdiff_tensor = ctx.Input<Tensor>(framework::GradVarName("H"));
auto cdiff_tensor = ctx.Input<Tensor>(framework::GradVarName("C"));
auto xdiff_tensor = ctx.Output<Tensor>(framework::GradVarName("X"));
auto c_prev_diff_tensor =
ctx.Output<Tensor>(framework::GradVarName("C_prev"));
auto* X = x_tensor->data<T>();
auto* C_prev = c_prev_tensor->data<T>();
auto* C = c_tensor->data<T>();
auto* H = h_tensor->data<T>();
auto* H_diff = hdiff_tensor->data<T>();
auto* C_diff = cdiff_tensor->data<T>();
auto* C_prev_diff = c_prev_diff_tensor->mutable_data<T>(ctx.GetPlace());
auto* X_diff = xdiff_tensor->mutable_data<T>(ctx.GetPlace());
int N = c_tensor->dims()[0];
int D = c_tensor->dims()[1];
auto forget_bias = static_cast<T>(ctx.Attr<AttrType>("forget_bias"));
for (int n = 0; n < N; ++n) {
for (int d = 0; d < D; ++d) {
T* c_prev_diff = C_prev_diff + d;
T* i_diff = X_diff + d;
T* f_diff = X_diff + 1 * D + d;
T* o_diff = X_diff + 2 * D + d;
T* g_diff = X_diff + 3 * D + d;
const T i = sigmoid(X[d]);
const T f = sigmoid(X[1 * D + d] + forget_bias);
const T o = sigmoid(X[2 * D + d]);
const T g = tanh(X[3 * D + d]);
const T c_prev = C_prev[d];
const T c = C[d];
const T tanh_c = tanh(c);
const T c_term_diff = C_diff[d] + H_diff[d] * o * (1 - tanh_c * tanh_c);
*c_prev_diff = c_term_diff * f;
*i_diff = c_term_diff * g * i * (1 - i);
*f_diff = c_term_diff * c_prev * f * (1 - f);
*o_diff = H_diff[d] * tanh_c * o * (1 - o);
*g_diff = c_term_diff * i * (1 - g * g);
}
C_prev += D;
X += 4 * D;
C += D;
H += D;
C_diff += D;
H_diff += D;
X_diff += 4 * D;
C_prev_diff += D;
}
}
};
} // namespace operators
} // namespace paddle
......@@ -48,6 +48,32 @@ void gemm<platform::CPUPlace, double>(const platform::DeviceContext& context,
beta, C, ldc);
}
template <>
void gemm<platform::CPUPlace, float>(const platform::DeviceContext& context,
const bool transA, const bool transB,
const int M, const int N, const int K,
const float alpha, const float* A,
const int lda, const float* B,
const int ldb, const float beta, float* C,
const int ldc) {
cblas_sgemm(CblasRowMajor, transA == false ? CblasNoTrans : CblasTrans,
transB == false ? CblasNoTrans : CblasTrans, M, N, K, alpha, A,
lda, B, ldb, beta, C, ldc);
}
template <>
void gemm<platform::CPUPlace, double>(const platform::DeviceContext& context,
const bool transA, const bool transB,
const int M, const int N, const int K,
const double alpha, const double* A,
const int lda, const double* B,
const int ldb, const double beta,
double* C, const int ldc) {
cblas_dgemm(CblasRowMajor, transA == false ? CblasNoTrans : CblasTrans,
transB == false ? CblasNoTrans : CblasTrans, M, N, K, alpha, A,
lda, B, ldb, beta, C, ldc);
}
template <>
void matmul<platform::CPUPlace, float>(
const platform::DeviceContext& context, const framework::Tensor& matrix_a,
......
......@@ -63,6 +63,42 @@ void gemm<platform::GPUPlace, double>(const platform::DeviceContext& context,
cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda, &beta, C, N));
}
template <>
void gemm<platform::GPUPlace, float>(const platform::DeviceContext& context,
const bool transA, const bool transB,
const int M, const int N, const int K,
const float alpha, const float* A,
const int lda, const float* B,
const int ldb, const float beta, float* C,
const int ldc) {
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
cublasOperation_t cuTransA = transA == false ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB = transB == false ? CUBLAS_OP_N : CUBLAS_OP_T;
PADDLE_ENFORCE(platform::dynload::cublasSgemm(
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.cublas_handle(),
cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda, &beta, C, ldc));
}
template <>
void gemm<platform::GPUPlace, double>(const platform::DeviceContext& context,
const bool transA, const bool transB,
const int M, const int N, const int K,
const double alpha, const double* A,
const int lda, const double* B,
const int ldb, const double beta,
double* C, const int ldc) {
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
cublasOperation_t cuTransA = transA == false ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB = transB == false ? CUBLAS_OP_N : CUBLAS_OP_T;
PADDLE_ENFORCE(platform::dynload::cublasDgemm(
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.cublas_handle(),
cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda, &beta, C, ldc));
}
template <>
void matmul<platform::GPUPlace, float>(
const platform::DeviceContext& context, const framework::Tensor& matrix_a,
......
......@@ -70,6 +70,13 @@ void gemm(const platform::DeviceContext& context, const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
const T alpha, const T* A, const T* B, const T beta, T* C);
// gemm wrapper with stride args for matrix uncontinuous in memory
template <typename Place, typename T>
void gemm(const platform::DeviceContext& context, const bool transA,
const bool transB, const int M, const int N, const int K,
const T alpha, const T* A, const int lda, const T* B, const int ldb,
const T beta, T* C, const int ldc);
// matrix multiply with continuous memory
template <typename Place, typename T>
void matmul(const platform::DeviceContext& context,
......
......@@ -72,4 +72,174 @@ TEST(math_function, trans_mul_notrans) {
EXPECT_EQ(out_ptr[8], 29);
delete gpu_place;
}
TEST(math_function, gemm_notrans_cublas) {
paddle::framework::Tensor input1;
paddle::framework::Tensor input2;
paddle::framework::Tensor input3;
paddle::framework::Tensor input1_gpu;
paddle::framework::Tensor input2_gpu;
paddle::framework::Tensor input3_gpu;
int m = 2;
int n = 3;
int k = 3;
auto* cpu_place = new paddle::platform::CPUPlace();
float* input1_ptr = input1.mutable_data<float>({2, 3}, *cpu_place);
float arr1[6] = {0, 1, 2, 3, 4, 5};
memcpy(input1_ptr, arr1, 6 * sizeof(float));
float* input2_ptr = input2.mutable_data<float>({3, 4}, *cpu_place);
float arr2[12] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
memcpy(input2_ptr, arr2, 12 * sizeof(float));
float* input3_ptr = input3.mutable_data<float>({2, 4}, *cpu_place);
float arr3[8] = {0, 1, 2, 3, 4, 5, 6, 7};
memcpy(input3_ptr, arr3, 8 * sizeof(float));
auto* gpu_place = new paddle::platform::GPUPlace(0);
paddle::platform::CUDADeviceContext context(*gpu_place);
input1_gpu.CopyFrom<float>(input1, *gpu_place);
input2_gpu.CopyFrom<float>(input2, *gpu_place);
input3_gpu.CopyFrom<float>(input3, *gpu_place);
float* a = input1_gpu.data<float>();
float* b = input2_gpu.data<float>();
float* c = input3_gpu.mutable_data<float>(*gpu_place);
paddle::operators::math::gemm<paddle::platform::GPUPlace, float>(
context, false, false, m, n, k, 1, a, 3, b + 1, 4, 1, c + 1, 4);
input3.CopyFrom<float>(input3_gpu, *cpu_place);
// numpy code:
// a = np.arange(6).reshape(2, 3)
// b = np.arange(12).reshape(3, 4)[:, 1:]
// c = np.arange(8).reshape(2, 4)[:, 1:]
// out = np.arange(8).reshape(2, 4)
// out[:, 1:] = np.dot(a, b) + c
EXPECT_EQ(input3_ptr[0], 0);
EXPECT_EQ(input3_ptr[1], 24);
EXPECT_EQ(input3_ptr[2], 28);
EXPECT_EQ(input3_ptr[3], 32);
EXPECT_EQ(input3_ptr[4], 4);
EXPECT_EQ(input3_ptr[5], 73);
EXPECT_EQ(input3_ptr[6], 86);
EXPECT_EQ(input3_ptr[7], 99);
delete gpu_place;
}
TEST(math_function, gemm_trans_cublas) {
paddle::framework::Tensor input1;
paddle::framework::Tensor input2;
paddle::framework::Tensor input3;
paddle::framework::Tensor input1_gpu;
paddle::framework::Tensor input2_gpu;
paddle::framework::Tensor input3_gpu;
int m = 2;
int n = 3;
int k = 3;
auto* cpu_place = new paddle::platform::CPUPlace();
float* input1_ptr = input1.mutable_data<float>({2, 3}, *cpu_place);
float arr1[6] = {0, 1, 2, 3, 4, 5};
memcpy(input1_ptr, arr1, 6 * sizeof(float));
float* input2_ptr = input2.mutable_data<float>({4, 3}, *cpu_place);
float arr2[12] = {0, 4, 8, 1, 5, 9, 2, 6, 10, 3, 7, 11};
memcpy(input2_ptr, arr2, 12 * sizeof(float));
float* input3_ptr = input3.mutable_data<float>({2, 4}, *cpu_place);
float arr3[8] = {0, 1, 2, 3, 4, 5, 6, 7};
memcpy(input3_ptr, arr3, 8 * sizeof(float));
auto* gpu_place = new paddle::platform::GPUPlace(0);
paddle::platform::CUDADeviceContext context(*gpu_place);
input1_gpu.CopyFrom<float>(input1, *gpu_place);
input2_gpu.CopyFrom<float>(input2, *gpu_place);
input3_gpu.CopyFrom<float>(input3, *gpu_place);
float* a = input1_gpu.data<float>();
float* b = input2_gpu.data<float>();
float* c = input3_gpu.mutable_data<float>(*gpu_place);
paddle::operators::math::gemm<paddle::platform::GPUPlace, float>(
context, false, true, m, n, k, 1, a, 3, b + 3, 3, 1, c + 1, 4);
input3.CopyFrom<float>(input3_gpu, *cpu_place);
EXPECT_EQ(input3_ptr[0], 0);
EXPECT_EQ(input3_ptr[1], 24);
EXPECT_EQ(input3_ptr[2], 28);
EXPECT_EQ(input3_ptr[3], 32);
EXPECT_EQ(input3_ptr[4], 4);
EXPECT_EQ(input3_ptr[5], 73);
EXPECT_EQ(input3_ptr[6], 86);
EXPECT_EQ(input3_ptr[7], 99);
delete gpu_place;
}
#endif
TEST(math_function, gemm_notrans_cblas) {
paddle::framework::Tensor input1;
paddle::framework::Tensor input2;
paddle::framework::Tensor input3;
int m = 2;
int n = 3;
int k = 3;
auto* cpu_place = new paddle::platform::CPUPlace();
float* input1_ptr = input1.mutable_data<float>({2, 3}, *cpu_place);
float arr1[6] = {0, 1, 2, 3, 4, 5};
memcpy(input1_ptr, arr1, 6 * sizeof(float));
float* input2_ptr = input2.mutable_data<float>({3, 4}, *cpu_place);
float arr2[12] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
memcpy(input2_ptr, arr2, 12 * sizeof(float));
float* input3_ptr = input3.mutable_data<float>({2, 4}, *cpu_place);
float arr3[8] = {0, 1, 2, 3, 4, 5, 6, 7};
memcpy(input3_ptr, arr3, 8 * sizeof(float));
paddle::platform::CPUDeviceContext context(*cpu_place);
paddle::operators::math::gemm<paddle::platform::CPUPlace, float>(
context, false, false, m, n, k, 1, input1_ptr, 3, input2_ptr + 1, 4, 1,
input3_ptr + 1, 4);
EXPECT_EQ(input3_ptr[0], 0);
EXPECT_EQ(input3_ptr[1], 24);
EXPECT_EQ(input3_ptr[2], 28);
EXPECT_EQ(input3_ptr[3], 32);
EXPECT_EQ(input3_ptr[4], 4);
EXPECT_EQ(input3_ptr[5], 73);
EXPECT_EQ(input3_ptr[6], 86);
EXPECT_EQ(input3_ptr[7], 99);
}
TEST(math_function, gemm_trans_clbas) {
paddle::framework::Tensor input1;
paddle::framework::Tensor input2;
paddle::framework::Tensor input3;
int m = 2;
int n = 3;
int k = 3;
auto* cpu_place = new paddle::platform::CPUPlace();
float* input1_ptr = input1.mutable_data<float>({2, 3}, *cpu_place);
float arr1[6] = {0, 1, 2, 3, 4, 5};
memcpy(input1_ptr, arr1, 6 * sizeof(float));
float* input2_ptr = input2.mutable_data<float>({4, 3}, *cpu_place);
float arr2[12] = {0, 4, 8, 1, 5, 9, 2, 6, 10, 3, 7, 11};
memcpy(input2_ptr, arr2, 12 * sizeof(float));
float* input3_ptr = input3.mutable_data<float>({2, 4}, *cpu_place);
float arr3[8] = {0, 1, 2, 3, 4, 5, 6, 7};
memcpy(input3_ptr, arr3, 8 * sizeof(float));
paddle::platform::CPUDeviceContext context(*cpu_place);
paddle::operators::math::gemm<paddle::platform::CPUPlace, float>(
context, false, true, m, n, k, 1, input1_ptr, 3, input2_ptr + 3, 3, 1,
input3_ptr + 1, 4);
EXPECT_EQ(input3_ptr[0], 0);
EXPECT_EQ(input3_ptr[1], 24);
EXPECT_EQ(input3_ptr[2], 28);
EXPECT_EQ(input3_ptr[3], 32);
EXPECT_EQ(input3_ptr[4], 4);
EXPECT_EQ(input3_ptr[5], 73);
EXPECT_EQ(input3_ptr[6], 86);
EXPECT_EQ(input3_ptr[7], 99);
}
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/multiplex_op.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
class MultiplexOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE(!ctx.MultiInputVar("X").empty(),
"Input(X) should not be null");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"),
"Output(Out) shouldn't be null.");
auto ins = ctx.MultiInput<Tensor>("X");
auto *out = ctx.Output<LoDTensor>("Out");
auto num_ins = ins.size();
PADDLE_ENFORCE(num_ins > 2,
"multiplex operator should have more than 2 inputs.");
PADDLE_ENFORCE_EQ(ins[0]->dims().size(), 1,
"The first input must be a index vector.");
auto in_dim = ins[1]->dims();
for (size_t i = 2; i < num_ins; i++) {
auto dim = ins[i]->dims();
PADDLE_ENFORCE(
in_dim == dim,
"All the input tensors except the first one must have the same size");
}
out->Resize(in_dim);
}
};
class MultiplexOpMaker : public framework::OpProtoAndCheckerMaker {
public:
MultiplexOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input tensors of multiplex operator.").AsDuplicable();
AddOutput("Out", "The output tensor of multiplex operator.");
AddComment(R"DOC(Multiplex operator
Multiplex multiple tensors according to the index provided by the first
input tensor.
ins[0]: the index tensor.
ins[1:N]: the candidate output tensors.
For each index i from 0 to batchSize - 1, the output is the i-th row of the
the (index[i] + 1)-th tensor.
For i-th row of the output tensor:
y[i][j] = x_{k}[i][j], j = 0,1, ... , (x_{1}.width - 1)
where y is the output tensor. `x_{k}` is the k-th input tensor
and `k = x{0}[i] + 1`.
)DOC");
}
};
class MultiplexGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE(!ctx.MultiInputVar("X").empty(),
"Input(X) should not be null");
PADDLE_ENFORCE(!ctx.MultiOutputVar(framework::GradVarName("X")).empty(),
"Output(X@Grad) should not be null");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
"Input(Out@GRAD) shouldn't be null.");
auto d_ins = ctx.MultiOutput<LoDTensor>(framework::GradVarName("X"));
auto ins = ctx.MultiInput<Tensor>("X");
// don't compute gradient for index (ins[0])
for (size_t i = 1; i < ins.size(); i++) {
if (d_ins[i]) {
d_ins[i]->Resize(ins[i]->dims());
}
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(multiplex, ops::MultiplexOp, ops::MultiplexOpMaker, multiplex_grad,
ops::MultiplexGradOp);
REGISTER_OP_CPU_KERNEL(
multiplex, ops::MultiplexCPUKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
multiplex_grad,
ops::MultiplexGradCPUKernel<paddle::platform::CPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/op_registry.h"
#include "paddle/operators/multiplex_op.h"
namespace paddle {
namespace operators {
template <typename Place, typename T>
class MultiplexGPUKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& ctx) const {
auto ins = ctx.MultiInput<framework::Tensor>("X");
auto* out = ctx.Output<framework::LoDTensor>("Out");
out->mutable_data<T>(ctx.GetPlace());
auto rows = ins[1]->dims()[0];
auto cols = ins[1]->dims()[1];
// copy index to cpu
framework::Tensor index_t_cpu;
index_t_cpu.CopyFrom<T>(*(ins[0]), platform::CPUPlace());
auto* index = index_t_cpu.data<T>();
auto stream = reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context())
.stream();
Place place = boost::get<Place>(ctx.GetPlace());
for (auto i = 0; i < rows; i++) {
int k = (int)index[i] + 1;
PADDLE_ENFORCE_LT(k, ins.size(),
"index exceeds the number of candidate tensors.");
memory::Copy(place, out->data<T>() + i * cols, place,
ins[k]->data<T>() + i * cols, cols * sizeof(T), stream);
}
}
};
template <typename Place, typename T>
class MultiplexGradGPUKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& ctx) const {
auto* d_out = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto ins = ctx.MultiInput<framework::Tensor>("X");
auto d_ins =
ctx.MultiOutput<framework::Tensor>(framework::GradVarName("X"));
for (size_t i = 1; i < d_ins.size(); i++) {
if (d_ins[i]) {
d_ins[i]->mutable_data<T>(ctx.GetPlace());
auto t = framework::EigenVector<T>::Flatten(*d_ins[i]);
t.device(ctx.GetEigenDevice<Place>()) = t.constant(static_cast<T>(0));
}
}
auto rows = ins[1]->dims()[0];
auto cols = ins[1]->dims()[1];
// copy index to cpu
framework::Tensor index_t_cpu;
index_t_cpu.CopyFrom<T>(*(ins[0]), platform::CPUPlace());
auto* index = index_t_cpu.data<T>();
auto stream = reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context())
.stream();
Place place = boost::get<Place>(ctx.GetPlace());
for (auto i = 0; i < rows; i++) {
int k = (int)index[i] + 1;
if (d_ins[k]) {
memory::Copy(place, d_ins[k]->data<T>() + i * cols, place,
d_out->data<T>() + i * cols, cols * sizeof(T), stream);
}
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(
multiplex, ops::MultiplexGPUKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(
multiplex_grad,
ops::MultiplexGradGPUKernel<paddle::platform::GPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/memory/memcpy.h"
namespace paddle {
namespace operators {
template <typename Place, typename T>
class MultiplexCPUKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& ctx) const {
auto ins = ctx.MultiInput<framework::Tensor>("X");
auto* out = ctx.Output<framework::LoDTensor>("Out");
out->mutable_data<T>(ctx.GetPlace());
auto rows = ins[1]->dims()[0];
auto cols = ins[1]->dims()[1];
auto* index = ins[0]->data<T>();
Place place = boost::get<Place>(ctx.GetPlace());
for (auto i = 0; i < rows; i++) {
int k = (int)index[i] + 1;
PADDLE_ENFORCE_LT(static_cast<size_t>(k), ins.size(),
"index exceeds the number of candidate tensors.");
memory::Copy(place, out->data<T>() + i * cols, place,
ins[k]->data<T>() + i * cols, cols * sizeof(T));
}
}
};
template <typename Place, typename T>
class MultiplexGradCPUKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& ctx) const {
auto* d_out = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto ins = ctx.MultiInput<framework::Tensor>("X");
auto d_ins =
ctx.MultiOutput<framework::Tensor>(framework::GradVarName("X"));
for (size_t i = 1; i < d_ins.size(); i++) {
if (d_ins[i]) {
d_ins[i]->mutable_data<T>(ctx.GetPlace());
auto t = framework::EigenVector<T>::Flatten(*d_ins[i]);
t.device(ctx.GetEigenDevice<Place>()) = t.constant(static_cast<T>(0));
}
}
auto rows = ins[1]->dims()[0];
auto cols = ins[1]->dims()[1];
auto* index = ins[0]->data<T>();
Place place = boost::get<Place>(ctx.GetPlace());
for (auto i = 0; i < rows; i++) {
int k = (int)index[i] + 1;
if (d_ins[k]) {
memory::Copy(place, d_ins[k]->data<T>() + i * cols, place,
d_out->data<T>() + i * cols, cols * sizeof(T));
}
}
}
};
} // namespace operators
} // namespace paddle
......@@ -12,22 +12,22 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/sequence_avg_pool_op.h"
#include "paddle/operators/sequence_pool_op.h"
namespace paddle {
namespace operators {
class SequenceAvgPoolOp : public framework::OperatorWithKernel {
class SequencePoolOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext& ctx) const override {
PADDLE_ENFORCE_NOT_NULL(
ctx.InputVar("X"), "Input(X) of SequenceAvgPoolOp should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"),
"Input(X) of SequencePoolOp should not be null.");
PADDLE_ENFORCE_NOT_NULL(
ctx.OutputVar("Out"),
"Output(Out) of SequenceAvgPoolOp should not be null.");
"Output(Out) of SequencePoolOp should not be null.");
auto* x = ctx.Input<framework::LoDTensor>("X");
auto dims = x->dims();
......@@ -42,21 +42,45 @@ class SequenceAvgPoolOp : public framework::OperatorWithKernel {
}
};
class SequenceAvgPoolOpMaker : public framework::OpProtoAndCheckerMaker {
class SequencePoolOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SequenceAvgPoolOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
SequencePoolOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of SequenceAvgPoolOp.");
AddOutput("Out", "The output of SequenceAvgPoolOp.");
AddInput("X",
"A float LoDTensor, the variable-length input of SequencePoolOp");
AddOutput(
"Out",
"A float LoDTensor, the variable-length output of SequencePoolOp.");
AddAttr<int>(
"strategy",
"(int, default AVERAGE) the pooling strategy of SequencePoolOp.")
.SetDefault(AVERAGE)
.InEnum({AVERAGE, SUM, SQRT, MAX, LAST, FIRST});
AddComment(R"DOC(
SequenceAvgPoolOp averages features of all time-steps of each instance.
More detailed comments will be added later.
SequencePoolOp pools features of all time-steps of each instance.
For a mini-batch of 3 variable lengths sentences, containing 2, 3, and 2 time-steps:
Assume X is a [7,M,N] float LoDTensor, and X->lod()[0] = [0, 2, 5, 7].
Besides, for the sake of simplicity, we assume M=1 and N=1,
and the value of X = [[1, 3], [2, 4, 6], [5, 1]].
Thus, Out is a [3,1,1] float LoDTensor, but Out->lod() is nullptr.
And for different strategy, the value of Out is as follows:
- AVERAGE: [2, 4, 3], where 2=(1+3)/2, 4=(2+4+6)/3, 3=(5+1)/2
- SUM: [4, 12, 6], where 4=1+3, 12=2+4+6, 6=5+1
- SQRT: [2.82, 6.93, 4.24], where 2.82=(1+3)/sqrt(2),
6.93=(2+4+6)/sqrt(3), 4.24=(5+1)/sqrt(2)
- MAX: [3, 6, 5], where 3=max(1,3), 6=max(2,4,6), 5=max(5,1)
- LAST: [3, 6, 1], where 3=last(1,3), 6=last(2,4,6), 1=last(5,1)
- FIRST: [1, 2, 5], where 1=first(1,3), 2=first(2,4,6), 5=first(5,1)
)DOC");
}
};
class SequenceAvgPoolGradOp : public framework::OperatorWithKernel {
class SequencePoolGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
......@@ -84,12 +108,10 @@ class SequenceAvgPoolGradOp : public framework::OperatorWithKernel {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(sequence_avg_pool, ops::SequenceAvgPoolOp,
ops::SequenceAvgPoolOpMaker, sequence_avg_pool_grad,
ops::SequenceAvgPoolGradOp);
REGISTER_OP(sequence_pool, ops::SequencePoolOp, ops::SequencePoolOpMaker,
sequence_pool_grad, ops::SequencePoolGradOp);
REGISTER_OP_CPU_KERNEL(
sequence_avg_pool,
ops::SequenceAvgPoolKernel<paddle::platform::CPUPlace, float>);
sequence_pool, ops::SequencePoolKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
sequence_avg_pool_grad,
ops::SequenceAvgPoolGradKernel<paddle::platform::CPUPlace, float>);
sequence_pool_grad,
ops::SequencePoolGradKernel<paddle::platform::CPUPlace, float>);
......@@ -14,12 +14,11 @@
#define EIGEN_USE_GPU
#include "paddle/operators/sequence_avg_pool_op.h"
#include "paddle/operators/sequence_pool_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(
sequence_avg_pool,
ops::SequenceAvgPoolKernel<paddle::platform::GPUPlace, float>);
sequence_pool, ops::SequencePoolKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(
sequence_avg_pool_grad,
ops::SequenceAvgPoolGradKernel<paddle::platform::GPUPlace, float>);
sequence_pool_grad,
ops::SequencePoolGradKernel<paddle::platform::GPUPlace, float>);
......@@ -28,54 +28,85 @@ template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
enum SeqPoolType {
AVERAGE = 0,
SUM = 1,
SQRT = 2, // square_root_n
MAX = 3,
LAST = 4,
FIRST = 5
};
template <typename Place, typename T>
class SequenceAvgPoolKernel : public framework::OpKernel {
class SequencePoolKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* in = context.Input<LoDTensor>("X");
auto* out = context.Output<LoDTensor>("Out");
int strategy = context.Attr<int>("strategy");
auto dims = in->dims();
auto lod = in->lod();
auto lod = in->lod()[0];
int64_t w = in->numel() / dims[0];
out->mutable_data<T>(context.GetPlace());
auto place = context.GetEigenDevice<Place>();
for (int i = 0; i < static_cast<int>(lod[0].size()) - 1; ++i) {
Tensor in_t = in->Slice<T>(static_cast<int>(lod[0][i]),
static_cast<int>(lod[0][i + 1]));
for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
Tensor in_t =
in->Slice<T>(static_cast<int>(lod[i]), static_cast<int>(lod[i + 1]));
Tensor out_t = out->Slice<T>(i, i + 1);
int64_t h = static_cast<int64_t>(lod[0][i + 1] - lod[0][i]);
int64_t h = static_cast<int64_t>(lod[i + 1] - lod[i]);
auto in_e = EigenMatrix<T>::From(in_t, framework::make_ddim({h, w}));
auto out_e = EigenVector<T>::Flatten(out_t);
out_e.device(place) = in_e.mean(Eigen::array<int, 1>({{0}}));
switch (strategy) {
case AVERAGE:
out_e.device(place) = in_e.mean(Eigen::array<int, 1>({{0}}));
break;
case SUM:
out_e.device(place) = in_e.sum(Eigen::array<int, 1>({{0}}));
break;
default:
PADDLE_THROW("unsupported pooling strategy");
}
}
}
};
template <typename Place, typename T>
class SequenceAvgPoolGradKernel : public framework::OpKernel {
class SequencePoolGradKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* in = context.Input<LoDTensor>("X");
auto* out_g = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto* in_g = context.Output<LoDTensor>(framework::GradVarName("X"));
int strategy = context.Attr<int>("strategy");
auto dims = in->dims();
auto lod = in->lod();
auto lod = in->lod()[0];
int64_t w = in->numel() / dims[0];
in_g->mutable_data<T>(context.GetPlace());
auto place = context.GetEigenDevice<Place>();
for (int i = 0; i < static_cast<int>(lod[0].size()) - 1; ++i) {
auto in_g_t = in_g->Slice<T>(static_cast<int>(lod[0][i]),
static_cast<int>(lod[0][i + 1]));
for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
auto in_g_t = in_g->Slice<T>(static_cast<int>(lod[i]),
static_cast<int>(lod[i + 1]));
auto out_g_t = out_g->Slice<T>(i, i + 1);
int64_t h = static_cast<int64_t>(lod[0][i + 1] - lod[0][i]);
int64_t h = static_cast<int64_t>(lod[i + 1] - lod[i]);
auto in_g_e = EigenMatrix<T>::From(in_g_t, {h, w});
auto out_g_e = EigenMatrix<T>::From(out_g_t, {1, w});
Eigen::DSizes<int, 2> bcast(h, 1);
in_g_e.device(place) = (out_g_e / static_cast<T>(h)).broadcast(bcast);
switch (strategy) {
case AVERAGE:
in_g_e.device(place) = (out_g_e / static_cast<T>(h)).broadcast(bcast);
break;
case SUM:
in_g_e.device(place) = (out_g_e).broadcast(bcast);
break;
default:
PADDLE_THROW("unsupported pooling strategy");
}
}
}
};
......
......@@ -34,13 +34,14 @@ class DeviceContext {
template <typename DeviceType>
DeviceType* get_eigen_device() const;
virtual void Wait() const {}
};
class CPUDeviceContext : public DeviceContext {
public:
CPUDeviceContext();
explicit CPUDeviceContext(CPUPlace place);
virtual ~CPUDeviceContext() {}
Eigen::DefaultDevice* eigen_device() const;
......@@ -59,7 +60,7 @@ class CUDADeviceContext : public DeviceContext {
virtual ~CUDADeviceContext();
/*! \brief Wait for all operations completion in the stream. */
void Wait() const;
void Wait() const override;
/*! \brief Return place in the device context. */
Place GetPlace() const override;
......
......@@ -36,7 +36,7 @@ int GetCurrentDeviceId();
//! Set the GPU device id for next execution.
void SetDeviceId(int device_id);
//Get the memory usage of current GPU device.
//! Get the memory usage of current GPU device.
void GpuMemoryUsage(size_t &available, size_t &total);
//! Get the maximum allocation size of current GPU device.
......
......@@ -229,7 +229,13 @@ All parameter, weight, gradient are variables in Paddle.
return Backward(forwardOp, no_grad_vars).release();
})
.def("infer_shape", &OperatorBase::InferShape)
.def("run", &OperatorBase::Run)
.def("run",
[](OperatorBase &self,
const Scope &scope,
const platform::DeviceContext &dev_ctx) {
self.Run(scope, dev_ctx);
dev_ctx.Wait();
})
.def("type",
[](const OperatorBase &op) -> std::string { return op.Type(); })
.def("outputs",
......
......@@ -37,6 +37,19 @@ add_test(NAME test_CompareTwoNets
--config_file_a=trainer/tests/sample_trainer_config_qb_rnn.conf --config_file_b=trainer/tests/sample_trainer_config_rnn.conf
WORKING_DIRECTORY ${PADDLE_SOURCE_DIR}/paddle/)
################ test_CompareMKLDNNandCPU ######################
if(WITH_MKLDNN)
add_unittest_without_exec(test_CompareMKLDNNandCPU
test_CompareTwoNets.cpp)
add_test(NAME test_CompareMKLDNNandCPU
COMMAND ${PADDLE_SOURCE_DIR}/paddle/.set_python_path.sh -d ${PADDLE_SOURCE_DIR}/python/
${CMAKE_CURRENT_BINARY_DIR}/test_CompareMKLDNNandCPU
--config_file_a=trainer/tests/sample_trainer_config_simple_net.conf --use_mkldnn_a=True
--config_file_b=trainer/tests/sample_trainer_config_simple_net.conf --use_mkldnn_b=False
--use_gpu=False
WORKING_DIRECTORY ${PADDLE_SOURCE_DIR}/paddle/)
endif()
############### test_CompareTwoOpts ###################
add_unittest_without_exec(test_CompareTwoOpts
test_CompareTwoOpts.cpp)
......
# Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.trainer_config_helpers import *
################################### Data Configuration ###################################
TrainData(ProtoData(files = "trainer/tests/mnist.list"))
################################### Algorithm Configuration ###################################
settings(batch_size = 1000,
learning_method = MomentumOptimizer(momentum=0.5, sparse=False))
################################### Network Configuration ###################################
data = data_layer(name ="input", size=784)
tmp = img_conv_layer(input=data,
num_channels=1,
filter_size=3,
num_filters=32,
padding=1,
shared_biases=True,
act=ReluActivation())
tmp = img_pool_layer(input=tmp,
pool_size=3,
stride=2,
padding=1,
pool_type=AvgPooling())
tmp = img_conv_layer(input=tmp,
filter_size=3,
num_filters=64,
padding=1,
shared_biases=True,
act=ReluActivation())
tmp = img_pool_layer(input=tmp,
pool_size=3,
stride=2,
padding=1,
pool_type=MaxPooling())
tmp = fc_layer(input=tmp, size=64,
bias_attr=True,
act=ReluActivation())
output = fc_layer(input=tmp, size=10,
bias_attr=True,
act=SoftmaxActivation())
lbl = data_layer(name ="label", size=10)
cost = classification_cost(input=output, label=lbl)
outputs(cost)
......@@ -26,12 +26,15 @@ DECLARE_int32(gpu_id);
DECLARE_bool(local);
DECLARE_bool(use_gpu);
DECLARE_bool(use_mkldnn);
DECLARE_string(config);
DECLARE_string(nics);
DEFINE_string(config_file_a, "", "config of one network to compare");
DEFINE_string(config_file_b, "", "config of another network to compare");
DEFINE_bool(use_mkldnn_a, false, "whether to use mkldnn to run config_file_a");
DEFINE_bool(use_mkldnn_b, false, "whether to use mkldnn to run config_file_b");
DEFINE_bool(need_high_accuracy,
false,
"whether need to run in double accuracy");
......@@ -128,6 +131,12 @@ void compareGradient(ComData& comDataA, ComData& comDataB) {
matA.getWidth());
}
if (FLAGS_use_mkldnn_a || FLAGS_use_mkldnn_b) {
// some format of mkldnn parameter is different with cpu
// test_MKLDNN will check the parameters
return;
}
vector<ParameterPtr>& parametersA = comDataA.parameters;
vector<ParameterPtr>& parametersB = comDataB.parameters;
......@@ -167,10 +176,12 @@ void compareGradient(ComData& comDataA, ComData& comDataB) {
TEST(Trainer, create) {
ComData dataA;
FLAGS_use_mkldnn = FLAGS_use_mkldnn_a;
calcGradient(dataA, FLAGS_config_file_a);
LOG(INFO) << "\n\nforwardBackward of Network A is finished\n\n";
ComData dataB;
FLAGS_use_mkldnn = FLAGS_use_mkldnn_b;
calcGradient(dataB, FLAGS_config_file_b);
LOG(INFO) << "\n\nforwardBackward of the Network B is finished\n\n";
......
......@@ -921,7 +921,7 @@ def data_layer(name, size, depth=None, height=None, width=None,
data = data_layer(name="input", size=1000)
:param name: The name of this layer. It is optional.
:param name: The name of this layer.
:type name: basestring
:param size: Size of this data layer.
:type size: int
......@@ -3668,6 +3668,7 @@ def gru_step_naive_layer(input,
:param param_attr:
:param layer_attr:
:return:
:rtype: LayerOutput
"""
if input.size % 3 != 0:
raise ValueError("GruStep input size must be divided by 3")
......
import unittest
import numpy as np
from op_test import OpTest
def sigmoid_np(x):
return 1. / (1. + np.exp(-x))
def tanh_np(x):
return 2 * sigmoid_np(2. * x) - 1.
class LstmUnitTest(OpTest):
def setUp(self):
self.op_type = "lstm_unit"
x_np = np.random.normal(size=(5, 16)).astype("float32")
c_np = np.random.normal(size=(5, 4)).astype("float32")
i_np, f_np, o_np, j_np = np.split(x_np, 4, axis=1)
forget_bias_np = 0.
self.attrs = {'forget_bias': 0.}
new_c = c_np * sigmoid_np(f_np + forget_bias_np) + sigmoid_np(
i_np) * tanh_np(j_np)
new_h = tanh_np(new_c) * sigmoid_np(o_np)
self.inputs = {'X': x_np, 'C_prev': c_np}
self.outputs = {'C': new_c, 'H': new_h}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X', 'C_prev'], ['C', 'H'], max_relative_error=0.01)
if __name__ == "__main__":
unittest.main()
import unittest
import numpy as np
from op_test import OpTest
class TestMultiplexOp(OpTest):
def setUp(self):
self.op_type = "multiplex"
rows = 3
index = np.array([3, 1, 0])
ins1 = np.random.random((rows, 10)).astype("float32")
ins2 = np.random.random((rows, 10)).astype("float32")
ins3 = np.random.random((rows, 10)).astype("float32")
ins4 = np.random.random((rows, 10)).astype("float32")
self.inputs = {
'X': [('index', index), ('x1', ins1), ('x2', ins2), ('x3', ins3),
('x4', ins4)]
}
# multiplex output
output = np.zeros_like(ins1)
for i in range(0, rows):
k = index[i] + 1
output[i] = self.inputs['X'][k][1][i]
self.outputs = {'Out': output}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['x1', 'x2', 'x3', 'x4'], 'Out')
def test_check_grad_ignore_x1(self):
self.check_grad(['x2', 'x3', 'x4'], 'Out', no_grad_set=set('x1'))
def test_check_grad_ignore_x1_x2(self):
self.check_grad(['x3', 'x4'], 'Out', no_grad_set=set(['x1', 'x2']))
def test_check_grad_ignore_x3(self):
self.check_grad(['x1', 'x2', 'x4'], 'Out', no_grad_set=set('x3'))
if __name__ == '__main__':
unittest.main()
......@@ -7,6 +7,14 @@ class PReluTest(OpTest):
def setUp(self):
self.op_type = "prelu"
x_np = np.random.normal(size=(10, 10)).astype("float32")
for pos, val in np.ndenumerate(x_np):
# Since zero point in prelu is not differentiable, avoid randomize
# zero.
while abs(val) < 1e-3:
x_np[pos] = np.random.normal()
val = x_np[pos]
x_np_sign = np.sign(x_np)
x_np = x_np_sign * np.maximum(x_np, .005)
alpha_np = np.array([.1])
......
......@@ -3,20 +3,37 @@ import numpy as np
from op_test import OpTest
class TestSeqAvgPool1D(OpTest):
def setUp(self):
self.op_type = 'sequence_avg_pool'
class SeqPoolType(OpTest):
AVERAGE = 0
SUM = 1
SQRT = 2
MAX = 3
LAST = 4
FIRST = 5
class TestSeqAvgPool(OpTest):
def set_data(self):
self.op_type = 'sequence_pool'
# one level, batch size is 4
x = np.random.uniform(0.1, 1, [11, 23]).astype('float32')
lod = [[0, 4, 5, 8, 11]]
self.inputs = {'X': (x, lod)}
out = np.zeros((4, 23)).astype('float32')
self.outputs = {'Out': out}
def compute(self):
self.attrs = {'strategy': SeqPoolType.AVERAGE}
x, lod = self.inputs['X']
out = self.outputs['Out']
for i in range(4):
sub_x = x[lod[0][i]:lod[0][i + 1], :]
out[i] = sub_x.mean(axis=0)
self.inputs = {'X': (x, lod)}
self.outputs = {'Out': out}
def setUp(self):
self.set_data()
self.compute()
def test_check_output(self):
self.check_output()
......@@ -25,26 +42,44 @@ class TestSeqAvgPool1D(OpTest):
self.check_grad(["X"], "Out")
class TestSeqAvgPool2D(OpTest):
def setUp(self):
self.op_type = 'sequence_avg_pool'
class TestSeqAvgPool2D(TestSeqAvgPool):
def set_data(self):
self.op_type = 'sequence_pool'
# one level, batch size is 4
x = np.random.uniform(0.1, 1, [13, 3, 17]).astype('float32')
lod = [[0, 4, 5, 8, 13]]
self.inputs = {'X': (x, lod)}
out = np.zeros((4, 3, 17)).astype('float32')
self.outputs = {'Out': out}
def compute(self):
self.attrs = {'strategy': SeqPoolType.AVERAGE}
x, lod = self.inputs['X']
out = self.outputs['Out']
for i in range(4):
sub_x = np.reshape(x[lod[0][i]:lod[0][i + 1], :], (-1, 3 * 17))
out[i] = np.reshape(sub_x.mean(axis=0), (3, 17))
self.inputs = {'X': (x, lod)}
self.outputs = {'Out': out}
def test_check_output(self):
self.check_output()
class TestSeqSumPool(TestSeqAvgPool):
def compute(self):
self.attrs = {'strategy': SeqPoolType.SUM}
x, lod = self.inputs['X']
out = self.outputs['Out']
for i in range(4):
sub_x = x[lod[0][i]:lod[0][i + 1], :]
out[i] = sub_x.sum(axis=0)
def test_check_grad(self):
self.check_grad(["X"], "Out")
class TestSeqSumPool2D(TestSeqAvgPool2D):
def compute(self):
self.attrs = {'strategy': SeqPoolType.SUM}
x, lod = self.inputs['X']
out = self.outputs['Out']
for i in range(4):
sub_x = np.reshape(x[lod[0][i]:lod[0][i + 1], :], (-1, 3 * 17))
out[i] = np.reshape(sub_x.sum(axis=0), (3, 17))
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册