提交 3d490d03 编写于 作者: W wanghaoshuang

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into pixel_softmax_layer

...@@ -434,9 +434,9 @@ lambda_cost ...@@ -434,9 +434,9 @@ lambda_cost
.. autoclass:: paddle.v2.layer.lambda_cost .. autoclass:: paddle.v2.layer.lambda_cost
:noindex: :noindex:
mse_cost square_error_cost
-------- --------
.. autoclass:: paddle.v2.layer.mse_cost .. autoclass:: paddle.v2.layer.square_error_cost
:noindex: :noindex:
rank_cost rank_cost
......
# Design Doc: Computations as Graphs
A primary goal of the refactorization of PaddlePaddle is a more flexible representation of deep learning computation, in particular, a graph of operators and variables, instead of sequences of layers as before.
This document explains that the construction of a graph as three steps:
- construct the forward part
- construct the backward part
- construct the optimization part
Let us take the problem of image classification as a simple example. The application program that trains the model looks like:
```python
x = layer.data("images")
l = layer.data("label")
y = layer.fc(x)
cost = layer.mse(y, l)
optimize(cost)
train(cost, reader=mnist.train())
```
### Forward Part
The first four lines of above program build the forward part of the graph.
![](images/graph_construction_example_forward_only.png)
In particular, the first line `x = layer.data("images")` creates variable x and a Feed operator that copies a column from the minibatch to x. `y = layer.fc(x)` creates not only the FC operator and output variable y, but also two parameters, W and b.
In this example, all operators are created as `OpDesc` protobuf messages, and all variables are `VarDesc`. These protobuf messages are saved in a `BlockDesc` protobuf message.
### Backward Part
The fifth line `optimize(cost)` calls two functions, `ConstructBackwardGraph` and `ConstructOptimizationGraph`.
`ConstructBackwardGraph` traverses the forward graph in the `BlockDesc` protobuf message and builds the backward part.
![](images/graph_construction_example_forward_backward.png)
According to the chain rule of gradient computation, `ConstructBackwardGraph` would
1. create a gradient operator G for each operator F,
1. make all inputs, outputs, and outputs' gradient of F as inputs of G,
1. create gradients for all inputs of F, except for those who don't have gradients, like x and l, and
1. make all these gradients as outputs of G.
### Optimization Part
For each parameter, like W and b created by `layer.fc`, marked as double circles in above graphs, `ConstructOptimizationGraph` creates an optimization operator to apply its gradient. Here results in the complete graph:
![](images/graph_construction_example_all.png)
cat ./graph_construction_example.dot | \
sed 's/color=red/color=red, style=invis/g' | \
sed 's/color=green/color=green, style=invis/g' | \
dot -Tpng > graph_construction_example_forward_only.png
cat ./graph_construction_example.dot | \
sed 's/color=green/color=green, style=invis/g' | \
dot -Tpng > graph_construction_example_forward_backward.png
cat ./graph_construction_example.dot | \
dot -Tpng > graph_construction_example_all.png
digraph ImageClassificationGraph {
///////// The forward part /////////
FeedX [label="Feed", color=blue, shape=box];
FeedY [label="Feed", color=blue, shape=box];
FC [label="FC", color=blue, shape=box];
MSE [label="MSE", color=blue, shape=box];
x [label="x", color=blue, shape=oval];
l [label="l", color=blue, shape=oval];
y [label="y", color=blue, shape=oval];
W [label="W", color=blue, shape=doublecircle];
b [label="b", color=blue, shape=doublecircle];
cost [label="cost", color=blue, shape=oval];
FeedX -> x -> FC -> y -> MSE -> cost [color=blue];
FeedY -> l [color=blue];
W -> FC [color=blue];
b -> FC [color=blue];
l -> MSE [color=blue];
////////// The backward part /////////
MSE_Grad [label="MSE_grad", color=red, shape=box];
FC_Grad [label="FC_grad", color=red, shape=box];
d_cost [label="d cost", color=red, shape=oval];
d_y [label="d y", color=red, shape=oval];
d_b [label="d b", color=red, shape=oval];
d_W [label="d W", color=red, shape=oval];
cost -> MSE_Grad [color=red];
d_cost -> MSE_Grad [color=red];
x -> MSE_Grad [color=red];
l -> MSE_Grad [color=red];
y -> MSE_Grad -> d_y [color=red];
x -> FC_Grad [color=red];
y -> FC_Grad [color=red];
d_y -> FC_Grad [color=red];
W -> FC_Grad -> d_W [color=red];
b -> FC_Grad -> d_b [color=red];
////////// The optimizaiton part //////////
OPT_W [label="SGD", color=green, shape=box];
OPT_b [label="SGD", color=green, shape=box];
W -> OPT_W [color=green];
b -> OPT_b [color=green];
d_W -> OPT_W -> W [color=green];
d_b -> OPT_b -> b [color=green];
////////// Groupings //////////
subgraph clusterMSE {
style=invis;
MSE;
MSE_Grad;
}
subgraph clusterFC {
style=invis;
FC;
FC_Grad;
}
}
...@@ -55,7 +55,7 @@ PaddlePaddle是源于百度的一个深度学习平台。这份简短的介绍 ...@@ -55,7 +55,7 @@ PaddlePaddle是源于百度的一个深度学习平台。这份简短的介绍
# 线性计算网络层: ȳ = wx + b # 线性计算网络层: ȳ = wx + b
ȳ = fc_layer(input=x, param_attr=ParamAttr(name='w'), size=1, act=LinearActivation(), bias_attr=ParamAttr(name='b')) ȳ = fc_layer(input=x, param_attr=ParamAttr(name='w'), size=1, act=LinearActivation(), bias_attr=ParamAttr(name='b'))
# 计算误差函数,即 ȳ 和真实 y 之间的距离 # 计算误差函数,即 ȳ 和真实 y 之间的距离
cost = mse_cost(input= ȳ, label=y) cost = square_error_cost(input= ȳ, label=y)
outputs(cost) outputs(cost)
...@@ -69,7 +69,7 @@ PaddlePaddle是源于百度的一个深度学习平台。这份简短的介绍 ...@@ -69,7 +69,7 @@ PaddlePaddle是源于百度的一个深度学习平台。这份简短的介绍
- **数据层**:数据层 `data_layer` 是神经网络的入口,它读入数据并将它们传输到接下来的网络层。这里数据层有两个,分别对应于变量 `x` 和 `y`。 - **数据层**:数据层 `data_layer` 是神经网络的入口,它读入数据并将它们传输到接下来的网络层。这里数据层有两个,分别对应于变量 `x` 和 `y`。
- **全连接层**:全连接层 `fc_layer` 是基础的计算单元,这里利用它建模变量之间的线性关系。计算单元是神经网络的核心,PaddlePaddle支持大量的计算单元和任意深度的网络连接,从而可以拟合任意的函数来学习复杂的数据关系。 - **全连接层**:全连接层 `fc_layer` 是基础的计算单元,这里利用它建模变量之间的线性关系。计算单元是神经网络的核心,PaddlePaddle支持大量的计算单元和任意深度的网络连接,从而可以拟合任意的函数来学习复杂的数据关系。
- **回归误差代价层**:回归误差代价层 `mse_cost` 是众多误差代价函数层的一种,它们在训练过程作为网络的出口,用来计算模型的误差,是模型参数优化的目标函数。 - **回归误差代价层**:回归误差代价层 `square_error_cost` 是众多误差代价函数层的一种,它们在训练过程作为网络的出口,用来计算模型的误差,是模型参数优化的目标函数。
定义了网络结构并保存为 `trainer_config.py` 之后,运行以下训练命令: 定义了网络结构并保存为 `trainer_config.py` 之后,运行以下训练命令:
......
...@@ -49,7 +49,7 @@ To recover this relationship between ``X`` and ``Y``, we use a neural network wi ...@@ -49,7 +49,7 @@ To recover this relationship between ``X`` and ``Y``, we use a neural network wi
x = data_layer(name='x', size=1) x = data_layer(name='x', size=1)
y = data_layer(name='y', size=1) y = data_layer(name='y', size=1)
y_predict = fc_layer(input=x, param_attr=ParamAttr(name='w'), size=1, act=LinearActivation(), bias_attr=ParamAttr(name='b')) y_predict = fc_layer(input=x, param_attr=ParamAttr(name='w'), size=1, act=LinearActivation(), bias_attr=ParamAttr(name='b'))
cost = mse_cost(input=y_predict, label=y) cost = square_error_cost(input=y_predict, label=y)
outputs(cost) outputs(cost)
Some of the most fundamental usages of PaddlePaddle are demonstrated: Some of the most fundamental usages of PaddlePaddle are demonstrated:
......
...@@ -8,7 +8,7 @@ paddle.init(use_gpu=False) ...@@ -8,7 +8,7 @@ paddle.init(use_gpu=False)
x = paddle.layer.data(name='x', type=paddle.data_type.dense_vector(2)) x = paddle.layer.data(name='x', type=paddle.data_type.dense_vector(2))
y_predict = paddle.layer.fc(input=x, size=1, act=paddle.activation.Linear()) y_predict = paddle.layer.fc(input=x, size=1, act=paddle.activation.Linear())
y = paddle.layer.data(name='y', type=paddle.data_type.dense_vector(1)) y = paddle.layer.data(name='y', type=paddle.data_type.dense_vector(1))
cost = paddle.layer.mse_cost(input=y_predict, label=y) cost = paddle.layer.square_error_cost(input=y_predict, label=y)
# create parameters # create parameters
parameters = paddle.parameters.create(cost) parameters = paddle.parameters.create(cost)
......
...@@ -81,9 +81,9 @@ PaddlePaddle支持不同类型的输入数据,主要包括四种类型,和 ...@@ -81,9 +81,9 @@ PaddlePaddle支持不同类型的输入数据,主要包括四种类型,和
.. code-block:: bash .. code-block:: bash
y_predict = paddle.layer.fc(input=x, size=1, act=paddle.activation.Linear()) y_predict = paddle.layer.fc(input=x, size=1, act=paddle.activation.Linear())
cost = paddle.layer.mse_cost(input=y_predict, label=y) cost = paddle.layer.square_error_cost(input=y_predict, label=y)
其中,x与y为之前描述的输入层;而y_predict是接收x作为输入,接上一个全连接层;cost接收y_predict与y作为输入,接上方误差层。 其中,x与y为之前描述的输入层;而y_predict是接收x作为输入,接上一个全连接层;cost接收y_predict与y作为输入,接上方误差层。
最后一层cost中记录了神经网络的所有拓扑结构,通过组合不同的layer,我们即可完成神经网络的搭建。 最后一层cost中记录了神经网络的所有拓扑结构,通过组合不同的layer,我们即可完成神经网络的搭建。
...@@ -147,4 +147,4 @@ PaddlePaddle支持不同类型的输入数据,主要包括四种类型,和 ...@@ -147,4 +147,4 @@ PaddlePaddle支持不同类型的输入数据,主要包括四种类型,和
.. literalinclude:: src/train.py .. literalinclude:: src/train.py
:linenos: :linenos:
有关线性回归的实际应用,可以参考PaddlePaddle book的 `第一章节 <http://book.paddlepaddle.org/index.html>`_。 有关线性回归的实际应用,可以参考PaddlePaddle book的 `第一章节 <http://book.paddlepaddle.org/index.html>`_。
\ No newline at end of file
...@@ -178,13 +178,13 @@ class MulKernel : public framework::OpKernel { ...@@ -178,13 +178,13 @@ class MulKernel : public framework::OpKernel {
```c++ ```c++
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(mul, ops::MulOp, ops::MulOpMaker, ops::MulOpGrad); REGISTER_OP(mul, ops::MulOp, ops::MulOpMaker, mul_grad, ops::MulOpGrad);
REGISTER_OP_CPU_KERNEL(mul, ops::MulKernel<paddle::platform::CPUPlace, float>); REGISTER_OP_CPU_KERNEL(mul, ops::MulKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(mul_grad, REGISTER_OP_CPU_KERNEL(mul_grad,
ops::MulGradKernel<paddle::platform::CPUPlace, float>); ops::MulGradKernel<paddle::platform::CPUPlace, float>);
``` ```
- `REGISTER_OP` : 注册`ops::MulOp`类,类型名为`mul`,该类的`ProtoMaker``ops::MulOpMaker`并且注册`ops::MulOpGrad`为其反向Op。 - `REGISTER_OP` : 注册`ops::MulOp`类,类型名为`mul`,该类的`ProtoMaker``ops::MulOpMaker`注册`ops::MulOpGrad`,类型名为`mul_grad`
- `REGISTER_OP_WITHOUT_GRADIENT` : 用于注册没有反向的Op。 - `REGISTER_OP_WITHOUT_GRADIENT` : 用于注册没有反向的Op。
- `REGISTER_OP_CPU_KERNEL` :注册`ops::MulKernel`类,并特化模板参数为`paddle::platform::CPUPlace``float`类型,同理,注册`ops::MulKernel`类。 - `REGISTER_OP_CPU_KERNEL` :注册`ops::MulKernel`类,并特化模板参数为`paddle::platform::CPUPlace``float`类型,同理,注册`ops::MulKernel`类。
...@@ -227,6 +227,12 @@ make mul_op ...@@ -227,6 +227,12 @@ make mul_op
USE_CPU_ONLY_OP(gather); USE_CPU_ONLY_OP(gather);
``` ```
如果OP不带Kernel,则使用`USE_NO_KENREL_OP`:
```
USE_NO_KENREL_OP(recurrent);
```
使用`USE_OP`告知编译器需要链接该Op的目标文件,具体解释参考[代码注释](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/op_registry.h#L81)。 使用`USE_OP`告知编译器需要链接该Op的目标文件,具体解释参考[代码注释](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/op_registry.h#L81)。
......
...@@ -213,7 +213,7 @@ I1116 09:10:17.123440 50 Util.cpp:130] Calling runInitFunctions ...@@ -213,7 +213,7 @@ I1116 09:10:17.123440 50 Util.cpp:130] Calling runInitFunctions
I1116 09:10:17.123764 50 Util.cpp:143] Call runInitFunctions done. I1116 09:10:17.123764 50 Util.cpp:143] Call runInitFunctions done.
[WARNING 2016-11-16 09:10:17,227 default_decorators.py:40] please use keyword arguments in paddle config. [WARNING 2016-11-16 09:10:17,227 default_decorators.py:40] please use keyword arguments in paddle config.
[INFO 2016-11-16 09:10:17,239 networks.py:1282] The input order is [movie_id, title, genres, user_id, gender, age, occupation, rating] [INFO 2016-11-16 09:10:17,239 networks.py:1282] The input order is [movie_id, title, genres, user_id, gender, age, occupation, rating]
[INFO 2016-11-16 09:10:17,239 networks.py:1289] The output order is [__mse_cost_0__] [INFO 2016-11-16 09:10:17,239 networks.py:1289] The output order is [__square_error_cost_0__]
I1116 09:10:17.392917 50 Trainer.cpp:170] trainer mode: Normal I1116 09:10:17.392917 50 Trainer.cpp:170] trainer mode: Normal
I1116 09:10:17.613910 50 PyDataProvider2.cpp:257] loading dataprovider dataprovider::process I1116 09:10:17.613910 50 PyDataProvider2.cpp:257] loading dataprovider dataprovider::process
I1116 09:10:17.680917 50 PyDataProvider2.cpp:257] loading dataprovider dataprovider::process I1116 09:10:17.680917 50 PyDataProvider2.cpp:257] loading dataprovider dataprovider::process
......
...@@ -182,7 +182,7 @@ static std::unique_ptr<OperatorBase> BackwardRecursive( ...@@ -182,7 +182,7 @@ static std::unique_ptr<OperatorBase> BackwardRecursive(
}); });
// process recurrent gradient op as a special operator. // process recurrent gradient op as a special operator.
if (forwardOp.Type() == "recurrent_op") { if (forwardOp.Type() == "recurrent") {
// NOTE clean up cycle call somewhere (RNN's stepnet constains itself), or // NOTE clean up cycle call somewhere (RNN's stepnet constains itself), or
// this will result in infinite loop. // this will result in infinite loop.
const auto& rnnop = const auto& rnnop =
......
...@@ -18,7 +18,7 @@ A backward network is built up with several backward operators. Backward operato ...@@ -18,7 +18,7 @@ A backward network is built up with several backward operators. Backward operato
For example, we have got a `mul_op`, and we can register it's information and corresponding backward operator by the following macro: For example, we have got a `mul_op`, and we can register it's information and corresponding backward operator by the following macro:
```cpp ```cpp
REGISTER_OP(mul, MulOp, MulOpMaker, MulOpGrad); REGISTER_OP(mul, MulOp, MulOpMaker, mul_grad, MulOpGrad);
``` ```
`mul` is the operator's type. `MulOp` and `MulOpMaker` are the operator class and the operator maker class respectively. `mul` is the operator's type. `MulOp` and `MulOpMaker` are the operator class and the operator maker class respectively.
......
...@@ -148,14 +148,16 @@ class AddOpMaker : public OpProtoAndCheckerMaker { ...@@ -148,14 +148,16 @@ class AddOpMaker : public OpProtoAndCheckerMaker {
namespace f = paddle::framework; namespace f = paddle::framework;
namespace ops = paddle::operators; namespace ops = paddle::operators;
using EnforceNotMet = paddle::platform::EnforceNotMet; using EnforceNotMet = paddle::platform::EnforceNotMet;
REGISTER_OP(rowwise_add, f::NOP, f::RowWiseAddOpMaker, f::NOP); REGISTER_OP(rowwise_add, f::NOP, f::RowWiseAddOpMaker, rowwise_add_grad,
REGISTER_OP(mul, f::NOP, f::MulOpMaker, f::NOP); f::NOP);
REGISTER_OP(sigmoid, f::NOP, f::SigmoidOpMaker, f::NOP); REGISTER_OP(mul, f::NOP, f::MulOpMaker, mul_grad, f::NOP);
REGISTER_OP(sigmoid, f::NOP, f::SigmoidOpMaker, sigmoid_grad, f::NOP);
REGISTER_OP_WITHOUT_GRADIENT(nograd, f::NOP, f::NoGradOpMaker); REGISTER_OP_WITHOUT_GRADIENT(nograd, f::NOP, f::NoGradOpMaker);
REGISTER_OP_WITHOUT_GRADIENT(fill_zeros_like, f::NOP, f::FillZeroOpMaker); REGISTER_OP_WITHOUT_GRADIENT(fill_zeros_like, f::NOP, f::FillZeroOpMaker);
REGISTER_OP(add, f::NOP, f::AddOpMaker, f::NOP); REGISTER_OP(add, f::NOP, f::AddOpMaker, add_grad, f::NOP);
REGISTER_OP_WITHOUT_GRADIENT(fc, f::FcOp, f::FcOpMaker); REGISTER_OP_WITHOUT_GRADIENT(fc, f::FcOp, f::FcOpMaker);
REGISTER_OP(many_output_op, f::NOP, f::ManyOutputOpMaker, f::NOP); REGISTER_OP(many_output_op, f::NOP, f::ManyOutputOpMaker, many_output_op_grad,
f::NOP);
TEST(Backward, simple_op_grad) { TEST(Backward, simple_op_grad) {
auto fwd = f::OpRegistry::CreateOp( auto fwd = f::OpRegistry::CreateOp(
......
...@@ -54,8 +54,8 @@ TEST(GradOpBuilder, AddTwo) { ...@@ -54,8 +54,8 @@ TEST(GradOpBuilder, AddTwo) {
EXPECT_EQ(grad_add_op->Output(f::GradVarName("Y")), f::GradVarName("y")); EXPECT_EQ(grad_add_op->Output(f::GradVarName("Y")), f::GradVarName("y"));
} }
REGISTER_OP(mult_io, f::NOP, f::MutiInOutOpMaker, f::NOP); REGISTER_OP(mult_io, f::NOP, f::MutiInOutOpMaker, mult_io_grad, f::NOP);
REGISTER_OP(io_ignored, f::NOP, f::IOIgnoredOpMaker, f::NOP); REGISTER_OP(io_ignored, f::NOP, f::IOIgnoredOpMaker, io_ignored_grad, f::NOP);
TEST(GradOpBuilder, MutiInOut) { TEST(GradOpBuilder, MutiInOut) {
std::shared_ptr<f::OperatorBase> test_op(f::OpRegistry::CreateOp( std::shared_ptr<f::OperatorBase> test_op(f::OpRegistry::CreateOp(
......
...@@ -33,7 +33,8 @@ namespace framework { ...@@ -33,7 +33,8 @@ namespace framework {
class OpRegistry { class OpRegistry {
public: public:
template <typename OpType, typename ProtoMakerType, typename GradOpType> template <typename OpType, typename ProtoMakerType, typename GradOpType>
static void RegisterOp(const std::string& op_type) { static void RegisterOp(const std::string& op_type,
const std::string& grad_op_type) {
PADDLE_ENFORCE(!OpInfoMap::Instance().Has(op_type), PADDLE_ENFORCE(!OpInfoMap::Instance().Has(op_type),
"'%s' is registered more than once.", op_type); "'%s' is registered more than once.", op_type);
OpInfo op_info; OpInfo op_info;
...@@ -42,9 +43,9 @@ class OpRegistry { ...@@ -42,9 +43,9 @@ class OpRegistry {
const VariableNameMap& outputs, const AttributeMap& attrs) { const VariableNameMap& outputs, const AttributeMap& attrs) {
return new OpType(type, inputs, outputs, attrs); return new OpType(type, inputs, outputs, attrs);
}; };
op_info.grad_op_type_ = grad_op_type;
if (std::type_index(typeid(ProtoMakerType)) != if (std::type_index(typeid(ProtoMakerType)) !=
std::type_index(typeid(NOPMaker))) { std::type_index(typeid(NOPMaker))) {
op_info.grad_op_type_ = op_type + "_grad";
op_info.proto_ = new OpProto; op_info.proto_ = new OpProto;
op_info.checker_ = new OpAttrChecker; op_info.checker_ = new OpAttrChecker;
auto maker = ProtoMakerType(op_info.proto_, op_info.checker_); auto maker = ProtoMakerType(op_info.proto_, op_info.checker_);
...@@ -54,14 +55,15 @@ class OpRegistry { ...@@ -54,14 +55,15 @@ class OpRegistry {
op_info.proto_->IsInitialized(), op_info.proto_->IsInitialized(),
"Fail to initialize %s's OpProto, because %s is not initialized", "Fail to initialize %s's OpProto, because %s is not initialized",
op_type, op_info.proto_->InitializationErrorString()); op_type, op_info.proto_->InitializationErrorString());
// register gradient op
RegisterOp<GradOpType, NOPMaker, NOP>(op_info.grad_op_type_);
} else { } else {
op_info.grad_op_type_ = "";
op_info.proto_ = nullptr; op_info.proto_ = nullptr;
op_info.checker_ = nullptr; op_info.checker_ = nullptr;
} }
OpInfoMap::Instance().Insert(op_type, op_info); OpInfoMap::Instance().Insert(op_type, op_info);
// register gradient op
if (!grad_op_type.empty()) {
RegisterOp<GradOpType, NOPMaker, NOP>(grad_op_type, "");
}
} }
static std::unique_ptr<OperatorBase> CreateOp(const std::string& type, static std::unique_ptr<OperatorBase> CreateOp(const std::string& type,
...@@ -90,8 +92,10 @@ class Registrar { ...@@ -90,8 +92,10 @@ class Registrar {
template <typename OpType, typename ProtoMakerType, typename GradOpType> template <typename OpType, typename ProtoMakerType, typename GradOpType>
class OpRegistrar : public Registrar { class OpRegistrar : public Registrar {
public: public:
explicit OpRegistrar(const char* op_type) { explicit OpRegistrar(const char* op_type) { OpRegistrar(op_type, ""); }
OpRegistry::RegisterOp<OpType, ProtoMakerType, GradOpType>(op_type); OpRegistrar(const char* op_type, const char* grad_op_type) {
OpRegistry::RegisterOp<OpType, ProtoMakerType, GradOpType>(op_type,
grad_op_type);
} }
}; };
...@@ -117,7 +121,8 @@ class OpKernelRegistrar : public Registrar { ...@@ -117,7 +121,8 @@ class OpKernelRegistrar : public Registrar {
/** /**
* Macro to register Operator. * Macro to register Operator.
*/ */
#define REGISTER_OP(op_type, op_class, op_maker_class, grad_op_class) \ #define REGISTER_OP(op_type, op_class, op_maker_class, grad_op_type, \
grad_op_class) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \ STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_op__##op_type, "REGISTER_OP must be called in global namespace"); \ __reg_op__##op_type, "REGISTER_OP must be called in global namespace"); \
class _OpClass_##op_type##_ : public op_class { \ class _OpClass_##op_type##_ : public op_class { \
...@@ -132,14 +137,14 @@ class OpKernelRegistrar : public Registrar { ...@@ -132,14 +137,14 @@ class OpKernelRegistrar : public Registrar {
}; \ }; \
static ::paddle::framework::OpRegistrar< \ static ::paddle::framework::OpRegistrar< \
_OpClass_##op_type##_, op_maker_class, _OpGradClass_##op_type##_> \ _OpClass_##op_type##_, op_maker_class, _OpGradClass_##op_type##_> \
__op_registrar_##op_type##__(#op_type); \ __op_registrar_##op_type##__(#op_type, #grad_op_type); \
int TouchOpRegistrar_##op_type() { \ int TouchOpRegistrar_##op_type() { \
__op_registrar_##op_type##__.Touch(); \ __op_registrar_##op_type##__.Touch(); \
return 0; \ return 0; \
} }
#define REGISTER_OP_WITHOUT_GRADIENT(op_type, op_class, op_maker_class) \ #define REGISTER_OP_WITHOUT_GRADIENT(op_type, op_class, op_maker_class) \
REGISTER_OP(op_type, op_class, op_maker_class, ::paddle::framework::NOP) REGISTER_OP(op_type, op_class, op_maker_class, , ::paddle::framework::NOP)
/** /**
* Macro to register OperatorKernel. * Macro to register OperatorKernel.
...@@ -194,6 +199,8 @@ class OpKernelRegistrar : public Registrar { ...@@ -194,6 +199,8 @@ class OpKernelRegistrar : public Registrar {
USE_OP_DEVICE_KERNEL(op_type, GPU) USE_OP_DEVICE_KERNEL(op_type, GPU)
#endif #endif
#define USE_NO_KERNEL_OP(op_type) USE_OP_ITSELF(op_type);
#define USE_CPU_ONLY_OP(op_type) \ #define USE_CPU_ONLY_OP(op_type) \
USE_OP_ITSELF(op_type); \ USE_OP_ITSELF(op_type); \
USE_OP_DEVICE_KERNEL(op_type, CPU); USE_OP_DEVICE_KERNEL(op_type, CPU);
......
...@@ -14,18 +14,20 @@ limitations under the License. */ ...@@ -14,18 +14,20 @@ limitations under the License. */
#include "Evaluator.h" #include "Evaluator.h"
#include "paddle/gserver/gradientmachines/NeuralNetwork.h" #include "paddle/gserver/gradientmachines/NeuralNetwork.h"
#include "paddle/utils/StringUtil.h"
namespace paddle { namespace paddle {
/** /**
* calculate sequence-to-sequence edit distance * calculate sequence-to-sequence edit distance
*/ */
class CTCErrorEvaluator : public NotGetableEvaluator { class CTCErrorEvaluator : public Evaluator {
private: private:
MatrixPtr outActivations_; MatrixPtr outActivations_;
int numTimes_, numClasses_, numSequences_, blank_; int numTimes_, numClasses_, numSequences_, blank_;
real deletions_, insertions_, substitutions_; real deletions_, insertions_, substitutions_;
int seqClassficationError_; int seqClassficationError_;
mutable std::unordered_map<std::string, real> evalResults_;
std::vector<int> path2String(const std::vector<int>& path) { std::vector<int> path2String(const std::vector<int>& path) {
std::vector<int> str; std::vector<int> str;
...@@ -183,6 +185,18 @@ private: ...@@ -183,6 +185,18 @@ private:
return stringAlignment(gtStr, recogStr); return stringAlignment(gtStr, recogStr);
} }
void storeLocalValues() const {
evalResults_["error"] = numSequences_ ? totalScore_ / numSequences_ : 0;
evalResults_["deletion_error"] =
numSequences_ ? deletions_ / numSequences_ : 0;
evalResults_["insertion_error"] =
numSequences_ ? insertions_ / numSequences_ : 0;
evalResults_["substitution_error"] =
numSequences_ ? substitutions_ / numSequences_ : 0;
evalResults_["sequence_error"] =
(real)seqClassficationError_ / numSequences_;
}
public: public:
CTCErrorEvaluator() CTCErrorEvaluator()
: numTimes_(0), : numTimes_(0),
...@@ -245,16 +259,12 @@ public: ...@@ -245,16 +259,12 @@ public:
} }
virtual void printStats(std::ostream& os) const { virtual void printStats(std::ostream& os) const {
os << config_.name() << "=" storeLocalValues();
<< (numSequences_ ? totalScore_ / numSequences_ : 0); os << config_.name() << " error = " << evalResults_["error"];
os << " deletions error" os << " deletions error = " << evalResults_["deletion_error"];
<< "=" << (numSequences_ ? deletions_ / numSequences_ : 0); os << " insertions error = " << evalResults_["insertion_error"];
os << " insertions error" os << " substitution error = " << evalResults_["substitution_error"];
<< "=" << (numSequences_ ? insertions_ / numSequences_ : 0); os << " sequence error = " << evalResults_["sequence_error"];
os << " substitutions error"
<< "=" << (numSequences_ ? substitutions_ / numSequences_ : 0);
os << " sequences error"
<< "=" << (real)seqClassficationError_ / numSequences_;
} }
virtual void distributeEval(ParameterClient2* client) { virtual void distributeEval(ParameterClient2* client) {
...@@ -272,6 +282,37 @@ public: ...@@ -272,6 +282,37 @@ public:
seqClassficationError_ = (int)buf[4]; seqClassficationError_ = (int)buf[4];
numSequences_ = (int)buf[5]; numSequences_ = (int)buf[5];
} }
void getNames(std::vector<std::string>* names) {
storeLocalValues();
names->reserve(names->size() + evalResults_.size());
for (auto it = evalResults_.begin(); it != evalResults_.end(); ++it) {
names->push_back(config_.name() + "." + it->first);
}
}
real getValue(const std::string& name, Error* err) const {
storeLocalValues();
std::vector<std::string> buffers;
paddle::str::split(name, '.', &buffers);
auto it = evalResults_.find(buffers[buffers.size() - 1]);
if (it == evalResults_.end()) {
*err = Error("Evaluator does not have the key %s", name.c_str());
return 0.0f;
}
return it->second;
}
std::string getType(const std::string& name, Error* err) const {
this->getValue(name, err);
if (!err->isOK()) {
return "";
}
return "ctc_edit_distance";
}
}; };
REGISTER_EVALUATOR(ctc_edit_distance, CTCErrorEvaluator); REGISTER_EVALUATOR(ctc_edit_distance, CTCErrorEvaluator);
......
...@@ -268,7 +268,13 @@ public: ...@@ -268,7 +268,13 @@ public:
} }
// get type of evaluator // get type of evaluator
std::string getTypeImpl() const { return "chunk"; } std::string getType(const std::string& name, Error* err) const {
this->getValue(name, err);
if (!err->isOK()) {
return "";
}
return "chunk";
}
private: private:
void storeLocalValues() const { void storeLocalValues() const {
......
...@@ -211,6 +211,7 @@ public: ...@@ -211,6 +211,7 @@ public:
*err = Error("Not implemented"); *err = Error("Not implemented");
return .0f; return .0f;
} }
std::string getType(const std::string& name, Error* err) const { std::string getType(const std::string& name, Error* err) const {
*err = Error("Not implemented"); *err = Error("Not implemented");
return ""; return "";
...@@ -331,6 +332,7 @@ private: ...@@ -331,6 +332,7 @@ private:
protected: protected:
std::string getTypeImpl() const; std::string getTypeImpl() const;
}; };
/** /**
* @brief precision, recall and f1 score Evaluator * @brief precision, recall and f1 score Evaluator
* \f[ * \f[
...@@ -358,6 +360,12 @@ public: ...@@ -358,6 +360,12 @@ public:
virtual void distributeEval(ParameterClient2* client); virtual void distributeEval(ParameterClient2* client);
void getNames(std::vector<std::string>* names);
real getValue(const std::string& name, Error* err) const;
std::string getType(const std::string& name, Error* err) const;
struct StatsInfo { struct StatsInfo {
/// numbers of true positives /// numbers of true positives
double TP; double TP;
...@@ -428,11 +436,6 @@ private: ...@@ -428,11 +436,6 @@ private:
mutable std::unordered_map<std::string, real> values_; mutable std::unordered_map<std::string, real> values_;
void storeLocalValues() const; void storeLocalValues() const;
// Evaluator interface
public:
void getNames(std::vector<std::string>* names);
real getValue(const std::string& name, Error* err) const;
std::string getType(const std::string& name, Error* err) const;
}; };
/* /*
......
...@@ -42,10 +42,10 @@ bool Conv3DLayer::init(const LayerMap &layerMap, ...@@ -42,10 +42,10 @@ bool Conv3DLayer::init(const LayerMap &layerMap,
if (sharedBiases_) { if (sharedBiases_) {
CHECK_EQ((size_t)numFilters_, biasParameter_->getSize()); CHECK_EQ((size_t)numFilters_, biasParameter_->getSize());
biases_ = biases_ =
std::unique_ptr<Weight>(new Weight(1, numFilters_, biasParameter_)); std::unique_ptr<Weight>(new Weight(numFilters_, 1, biasParameter_));
} else { } else {
biases_ = biases_ =
std::unique_ptr<Weight>(new Weight(1, getSize(), biasParameter_)); std::unique_ptr<Weight>(new Weight(getSize(), 1, biasParameter_));
} }
} }
return true; return true;
...@@ -224,20 +224,31 @@ void Conv3DLayer::bpropData(int i) { ...@@ -224,20 +224,31 @@ void Conv3DLayer::bpropData(int i) {
} }
void Conv3DLayer::bpropBiases() { void Conv3DLayer::bpropBiases() {
MatrixPtr biases = Matrix::create(biases_->getWGrad()->getData(),
1,
biases_->getWGrad()->getElementCnt(),
false,
useGpu_);
MatrixPtr outGradMat = getOutputGrad(); MatrixPtr outGradMat = getOutputGrad();
if (this->sharedBiases_) { if (this->sharedBiases_) {
biases_->getWGrad()->collectSharedBias(*outGradMat, 1.0f); biases->collectSharedBias(*outGradMat, 1.0f);
} else { } else {
biases_->getWGrad()->collectBias(*outGradMat, 1.0f); biases->collectBias(*outGradMat, 1.0f);
} }
} }
void Conv3DLayer::addBias() { void Conv3DLayer::addBias() {
MatrixPtr outMat = getOutputValue(); MatrixPtr outMat = getOutputValue();
MatrixPtr bias = Matrix::create(biases_->getW()->getData(),
1,
biases_->getW()->getElementCnt(),
false,
useGpu_);
if (this->sharedBiases_) { if (this->sharedBiases_) {
outMat->addSharedBias(*(biases_->getW()), 1.0f); outMat->addSharedBias(*(bias), 1.0f);
} else { } else {
outMat->addBias(*(biases_->getW()), 1.0f); outMat->addBias(*(bias), 1.0f);
} }
} }
......
...@@ -42,10 +42,10 @@ bool DeConv3DLayer::init(const LayerMap &layerMap, ...@@ -42,10 +42,10 @@ bool DeConv3DLayer::init(const LayerMap &layerMap,
if (sharedBiases_) { if (sharedBiases_) {
CHECK_EQ((size_t)numFilters_, biasParameter_->getSize()); CHECK_EQ((size_t)numFilters_, biasParameter_->getSize());
biases_ = biases_ =
std::unique_ptr<Weight>(new Weight(1, numFilters_, biasParameter_)); std::unique_ptr<Weight>(new Weight(numFilters_, 1, biasParameter_));
} else { } else {
biases_ = biases_ =
std::unique_ptr<Weight>(new Weight(1, getSize(), biasParameter_)); std::unique_ptr<Weight>(new Weight(getSize(), 1, biasParameter_));
} }
} }
return true; return true;
...@@ -191,21 +191,31 @@ void DeConv3DLayer::bpropWeights(int i) {} ...@@ -191,21 +191,31 @@ void DeConv3DLayer::bpropWeights(int i) {}
void DeConv3DLayer::bpropData(int i) {} void DeConv3DLayer::bpropData(int i) {}
void DeConv3DLayer::bpropBiases() { void DeConv3DLayer::bpropBiases() {
MatrixPtr biases = Matrix::create(biases_->getWGrad()->getData(),
1,
biases_->getWGrad()->getElementCnt(),
false,
useGpu_);
const MatrixPtr &outGradMat = getOutputGrad(); const MatrixPtr &outGradMat = getOutputGrad();
if (this->sharedBiases_) { if (this->sharedBiases_) {
biases_->getWGrad()->collectSharedBias(*outGradMat, 1.0f); biases->collectSharedBias(*outGradMat, 1.0f);
} else { } else {
biases_->getWGrad()->collectBias(*outGradMat, 1.0f); biases->collectBias(*outGradMat, 1.0f);
} }
} }
void DeConv3DLayer::addBias() { void DeConv3DLayer::addBias() {
MatrixPtr outMat = getOutputValue(); MatrixPtr outMat = getOutputValue();
MatrixPtr bias = Matrix::create(biases_->getW()->getData(),
1,
biases_->getW()->getElementCnt(),
false,
useGpu_);
if (this->sharedBiases_) { if (this->sharedBiases_) {
outMat->addSharedBias(*(biases_->getW()), 1.0f); outMat->addSharedBias(*(bias), 1.0f);
} else { } else {
outMat->addBias(*(biases_->getW()), 1.0f); outMat->addBias(*(bias), 1.0f);
} }
} }
......
...@@ -57,7 +57,7 @@ class AddOpGrad : public framework::OperatorWithKernel { ...@@ -57,7 +57,7 @@ class AddOpGrad : public framework::OperatorWithKernel {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(add_two, ops::AddOp, ops::AddOpMaker, ops::AddOpGrad); REGISTER_OP(add_two, ops::AddOp, ops::AddOpMaker, add_two_grad, ops::AddOpGrad);
REGISTER_OP_CPU_KERNEL(add_two, REGISTER_OP_CPU_KERNEL(add_two,
ops::AddKernel<paddle::platform::CPUPlace, float>); ops::AddKernel<paddle::platform::CPUPlace, float>);
...@@ -67,7 +67,8 @@ OnehotCrossEntropy Operator. ...@@ -67,7 +67,8 @@ OnehotCrossEntropy Operator.
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(onehot_cross_entropy, ops::OnehotCrossEntropyOp, REGISTER_OP(onehot_cross_entropy, ops::OnehotCrossEntropyOp,
ops::OnehotCrossEntropyOpMaker, ops::OnehotCrossEntropyGradientOp); ops::OnehotCrossEntropyOpMaker, onehot_cross_entropy_grad,
ops::OnehotCrossEntropyGradientOp);
REGISTER_OP_CPU_KERNEL(onehot_cross_entropy, REGISTER_OP_CPU_KERNEL(onehot_cross_entropy,
ops::OnehotCrossEntropyOpKernel<float>); ops::OnehotCrossEntropyOpKernel<float>);
REGISTER_OP_CPU_KERNEL(onehot_cross_entropy_grad, REGISTER_OP_CPU_KERNEL(onehot_cross_entropy_grad,
......
...@@ -63,7 +63,8 @@ Out = X[Index] ...@@ -63,7 +63,8 @@ Out = X[Index]
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(gather, ops::GatherOp, ops::GatherOpMaker, ops::GatherGradOp); REGISTER_OP(gather, ops::GatherOp, ops::GatherOpMaker, gather_grad,
ops::GatherGradOp);
REGISTER_OP_CPU_KERNEL(gather, REGISTER_OP_CPU_KERNEL(gather,
ops::GatherOpKernel<paddle::platform::CPUPlace, float>); ops::GatherOpKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
......
...@@ -66,7 +66,7 @@ class LookupTableOpGrad : public framework::OperatorWithKernel { ...@@ -66,7 +66,7 @@ class LookupTableOpGrad : public framework::OperatorWithKernel {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(lookup_table, ops::LookupTableOp, ops::LookupTableOpMaker, REGISTER_OP(lookup_table, ops::LookupTableOp, ops::LookupTableOpMaker,
ops::LookupTableOpGrad); lookup_table_grad, ops::LookupTableOpGrad);
REGISTER_OP_CPU_KERNEL(lookup_table, ops::LookupTableKernel<float>); REGISTER_OP_CPU_KERNEL(lookup_table, ops::LookupTableKernel<float>);
REGISTER_OP_CPU_KERNEL(lookup_table_grad, ops::LookupTableGradKernel<float>); REGISTER_OP_CPU_KERNEL(lookup_table_grad, ops::LookupTableGradKernel<float>);
...@@ -54,7 +54,7 @@ class MeanGradOp : public framework::OperatorWithKernel { ...@@ -54,7 +54,7 @@ class MeanGradOp : public framework::OperatorWithKernel {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(mean, ops::MeanOp, ops::MeanOpMaker, ops::MeanGradOp); REGISTER_OP(mean, ops::MeanOp, ops::MeanOpMaker, mean_grad, ops::MeanGradOp);
REGISTER_OP_CPU_KERNEL(mean, REGISTER_OP_CPU_KERNEL(mean,
ops::MeanKernel<paddle::platform::CPUPlace, float>); ops::MeanKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(mean_grad, REGISTER_OP_CPU_KERNEL(mean_grad,
......
...@@ -79,8 +79,9 @@ class MinusGradOp : public NetOp { ...@@ -79,8 +79,9 @@ class MinusGradOp : public NetOp {
} // namespace paddle } // namespace paddle
USE_OP(scale); USE_OP(scale);
USE_OP_ITSELF(identity); USE_NO_KERNEL_OP(identity);
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(minus, ops::MinusOp, ops::MinusOpMaker, ops::MinusGradOp<float>); REGISTER_OP(minus, ops::MinusOp, ops::MinusOpMaker, minus_grad,
ops::MinusGradOp<float>);
REGISTER_OP_CPU_KERNEL(minus, REGISTER_OP_CPU_KERNEL(minus,
ops::MinusKernel<paddle::platform::CPUPlace, float>); ops::MinusKernel<paddle::platform::CPUPlace, float>);
...@@ -84,7 +84,7 @@ class MulOpGrad : public framework::OperatorWithKernel { ...@@ -84,7 +84,7 @@ class MulOpGrad : public framework::OperatorWithKernel {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(mul, ops::MulOp, ops::MulOpMaker, ops::MulOpGrad); REGISTER_OP(mul, ops::MulOp, ops::MulOpMaker, mul_grad, ops::MulOpGrad);
REGISTER_OP_CPU_KERNEL(mul, ops::MulKernel<paddle::platform::CPUPlace, float>); REGISTER_OP_CPU_KERNEL(mul, ops::MulKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(mul_grad, REGISTER_OP_CPU_KERNEL(mul_grad,
ops::MulGradKernel<paddle::platform::CPUPlace, float>); ops::MulGradKernel<paddle::platform::CPUPlace, float>);
...@@ -235,5 +235,5 @@ RecurrentGradientOp::RecurrentGradientOp( ...@@ -235,5 +235,5 @@ RecurrentGradientOp::RecurrentGradientOp(
} // namespace paddle } // namespace paddle
REGISTER_OP_WITHOUT_GRADIENT( REGISTER_OP_WITHOUT_GRADIENT(
recurrent_op, paddle::operators::RecurrentOp, recurrent, paddle::operators::RecurrentOp,
paddle::operators::RecurrentAlgorithmProtoAndCheckerMaker); paddle::operators::RecurrentAlgorithmProtoAndCheckerMaker);
...@@ -74,7 +74,7 @@ class RowwiseAddGradOp : public framework::OperatorWithKernel { ...@@ -74,7 +74,7 @@ class RowwiseAddGradOp : public framework::OperatorWithKernel {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(rowwise_add, ops::RowwiseAddOp, ops::RowwiseAddOpMaker, REGISTER_OP(rowwise_add, ops::RowwiseAddOp, ops::RowwiseAddOpMaker,
ops::RowwiseAddGradOp); rowwise_add_grad, ops::RowwiseAddGradOp);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
rowwise_add, ops::RowwiseAddKernel<paddle::platform::CPUPlace, float>); rowwise_add, ops::RowwiseAddKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
......
...@@ -97,7 +97,7 @@ class IdentityOp : public NetOp { ...@@ -97,7 +97,7 @@ class IdentityOp : public NetOp {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(scale, ops::ScaleOp, ops::ScaleOpMaker<float>, REGISTER_OP(scale, ops::ScaleOp, ops::ScaleOpMaker<float>, scale_grad,
ops::ScaleGradOp<float>); ops::ScaleGradOp<float>);
REGISTER_OP_CPU_KERNEL(scale, REGISTER_OP_CPU_KERNEL(scale,
ops::ScaleKernel<paddle::platform::CPUPlace, float>); ops::ScaleKernel<paddle::platform::CPUPlace, float>);
......
...@@ -77,7 +77,8 @@ Out[Index] = Ref[Index] + Updates ...@@ -77,7 +77,8 @@ Out[Index] = Ref[Index] + Updates
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(scatter, ops::ScatterOp, ops::ScatterOpMaker, ops::ScatterGradOp); REGISTER_OP(scatter, ops::ScatterOp, ops::ScatterOpMaker, scatter_grad,
ops::ScatterGradOp);
REGISTER_OP_CPU_KERNEL(scatter, REGISTER_OP_CPU_KERNEL(scatter,
ops::ScatterOpKernel<paddle::platform::CPUPlace, float>); ops::ScatterOpKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
......
...@@ -53,7 +53,8 @@ class SigmoidOpGrad : public framework::OperatorWithKernel { ...@@ -53,7 +53,8 @@ class SigmoidOpGrad : public framework::OperatorWithKernel {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(sigmoid, ops::SigmoidOp, ops::SigmoidOpMaker, ops::SigmoidOpGrad); REGISTER_OP(sigmoid, ops::SigmoidOp, ops::SigmoidOpMaker, sigmoid_grad,
ops::SigmoidOpGrad);
REGISTER_OP_CPU_KERNEL(sigmoid, REGISTER_OP_CPU_KERNEL(sigmoid,
ops::SigmoidKernel<paddle::platform::CPUPlace, float>); ops::SigmoidKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
......
...@@ -62,7 +62,8 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel { ...@@ -62,7 +62,8 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(softmax, ops::SoftmaxOp, ops::SoftmaxOpMaker, ops::SoftmaxOpGrad); REGISTER_OP(softmax, ops::SoftmaxOp, ops::SoftmaxOpMaker, softmax_grad,
ops::SoftmaxOpGrad);
REGISTER_OP_CPU_KERNEL(softmax, REGISTER_OP_CPU_KERNEL(softmax,
ops::SoftmaxKernel<paddle::platform::CPUPlace, float>); ops::SoftmaxKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
......
...@@ -22,3 +22,5 @@ ENDIF() ...@@ -22,3 +22,5 @@ ENDIF()
cc_library(device_context SRCS device_context.cc DEPS memory buddy_allocator cc_library(device_context SRCS device_context.cc DEPS memory buddy_allocator
system_allocator memory_block meta_data meta_cache place eigen3 ${GPU_CTX_DEPS}) system_allocator memory_block meta_data meta_cache place eigen3 ${GPU_CTX_DEPS})
nv_test(device_context_test SRCS device_context_test.cc DEPS device_context gpu_info) nv_test(device_context_test SRCS device_context_test.cc DEPS device_context gpu_info)
nv_test(cudnn_helper_test SRCS cudnn_helper_test.cc DEPS dynload_cuda)
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/platform/dynload/cudnn.h"
#include "paddle/platform/enforce.h"
#include "paddle/platform/macros.h"
namespace paddle {
namespace platform {
enum class DataLayout {
kNHWC,
kNCHW,
kNCHW_VECT_C,
};
enum class PoolingMode {
kMaximum,
kAverage,
};
template <typename T>
class CudnnDataType;
template <>
class CudnnDataType<float> {
public:
static const cudnnDataType_t type = CUDNN_DATA_FLOAT;
};
template <>
class CudnnDataType<double> {
public:
static const cudnnDataType_t type = CUDNN_DATA_DOUBLE;
};
inline cudnnTensorFormat_t GetCudnnTensorFormat(const DataLayout& order) {
switch (order) {
case DataLayout::kNHWC:
return CUDNN_TENSOR_NHWC;
case DataLayout::kNCHW:
return CUDNN_TENSOR_NCHW;
default:
PADDLE_THROW("Unknown cudnn equivalent for order");
}
return CUDNN_TENSOR_NCHW;
}
class ScopedTensorDescriptor {
public:
ScopedTensorDescriptor() {
PADDLE_ENFORCE(dynload::cudnnCreateTensorDescriptor(&desc_));
}
~ScopedTensorDescriptor() {
PADDLE_ENFORCE(dynload::cudnnDestroyTensorDescriptor(desc_));
}
inline cudnnTensorDescriptor_t descriptor(const cudnnTensorFormat_t format,
const cudnnDataType_t type,
const std::vector<int>& dims) {
// the format is not used now, but it maybe useful feature
std::vector<int> strides(dims.size());
strides[dims.size() - 1] = 1;
for (int i = dims.size() - 2; i >= 0; i--) {
strides[i] = dims[i + 1] * strides[i + 1];
}
PADDLE_ENFORCE(dynload::cudnnSetTensorNdDescriptor(
desc_, type, dims.size(), dims.data(), strides.data()));
return desc_;
}
template <typename T>
inline cudnnTensorDescriptor_t descriptor(const DataLayout& order,
const std::vector<int>& dims) {
return descriptor(GetCudnnTensorFormat(order), CudnnDataType<T>::type,
dims);
}
private:
cudnnTensorDescriptor_t desc_;
DISABLE_COPY_AND_ASSIGN(ScopedTensorDescriptor);
};
class ScopedFilterDescriptor {
public:
ScopedFilterDescriptor() {
PADDLE_ENFORCE(dynload::cudnnCreateFilterDescriptor(&desc_));
}
~ScopedFilterDescriptor() {
PADDLE_ENFORCE(dynload::cudnnDestroyFilterDescriptor(desc_));
}
inline cudnnFilterDescriptor_t descriptor(const cudnnTensorFormat_t format,
const cudnnDataType_t type,
const std::vector<int>& kernel) {
// filter layout: output input spatial_dim_y spatial_dim_x
PADDLE_ENFORCE(dynload::cudnnSetFilterNdDescriptor(
desc_, type, format, kernel.size(), kernel.data()));
return desc_;
}
template <typename T>
inline cudnnFilterDescriptor_t descriptor(const DataLayout& order,
const std::vector<int>& kernel) {
return descriptor(GetCudnnTensorFormat(order), CudnnDataType<T>::type,
kernel);
}
private:
cudnnFilterDescriptor_t desc_;
DISABLE_COPY_AND_ASSIGN(ScopedFilterDescriptor);
};
class ScopedConvolutionDescriptor {
public:
ScopedConvolutionDescriptor() {
PADDLE_ENFORCE(dynload::cudnnCreateConvolutionDescriptor(&desc_));
}
~ScopedConvolutionDescriptor() {
PADDLE_ENFORCE(dynload::cudnnDestroyConvolutionDescriptor(desc_));
}
inline cudnnConvolutionDescriptor_t descriptor(
cudnnDataType_t type, const std::vector<int>& pads,
const std::vector<int>& strides, const std::vector<int>& dilations) {
PADDLE_ENFORCE_EQ(pads.size(), strides.size());
PADDLE_ENFORCE_EQ(pads.size(), dilations.size());
#if CUDNN_VERSION < 6000
// cudnn v5 does not support dilation conv, the argument is called upscale
// instead of dilations and it is must be one.
for (size_t i = 0; i < dilations.size(); ++i) {
PADDLE_ENFORCE_EQ(
dilations[i], 1,
"Dilations conv is not supported in this cuDNN version");
}
#endif
PADDLE_ENFORCE(dynload::cudnnSetConvolutionNdDescriptor(
desc_, pads.size(), pads.data(), strides.data(), dilations.data(),
CUDNN_CROSS_CORRELATION, type));
return desc_;
}
template <typename T>
inline cudnnConvolutionDescriptor_t descriptor(
const std::vector<int>& pads, const std::vector<int>& strides,
const std::vector<int>& dilations) {
return descriptor(CudnnDataType<T>::type, pads, strides, dilations);
}
private:
cudnnConvolutionDescriptor_t desc_;
DISABLE_COPY_AND_ASSIGN(ScopedConvolutionDescriptor);
};
class ScopedPoolingDescriptor {
public:
ScopedPoolingDescriptor() {
PADDLE_ENFORCE(dynload::cudnnCreatePoolingDescriptor(&desc_));
}
~ScopedPoolingDescriptor() {
PADDLE_ENFORCE(dynload::cudnnDestroyPoolingDescriptor(desc_));
}
inline cudnnPoolingDescriptor_t descriptor(const PoolingMode& mode,
const std::vector<int>& kernel,
const std::vector<int>& pads,
const std::vector<int>& strides) {
PADDLE_ENFORCE_EQ(kernel.size(), pads.size());
PADDLE_ENFORCE_EQ(kernel.size(), strides.size());
PADDLE_ENFORCE(dynload::cudnnSetPoolingNdDescriptor(
desc_, (mode == PoolingMode::kMaximum
? CUDNN_POOLING_MAX
: CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING),
CUDNN_PROPAGATE_NAN, // Always propagate nans.
kernel.size(), kernel.data(), pads.data(), strides.data()));
return desc_;
}
private:
cudnnPoolingDescriptor_t desc_;
DISABLE_COPY_AND_ASSIGN(ScopedPoolingDescriptor);
};
} // namespace platform
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/platform/cudnn_helper.h"
#include <gtest/gtest.h>
TEST(CudnnHelper, ScopedTensorDescriptor) {
using paddle::platform::ScopedTensorDescriptor;
using paddle::platform::DataLayout;
ScopedTensorDescriptor tensor_desc;
std::vector<int> shape = {2, 4, 6, 6};
auto desc = tensor_desc.descriptor<float>(DataLayout::kNCHW, shape);
cudnnDataType_t type;
int nd;
std::vector<int> dims(4);
std::vector<int> strides(4);
paddle::platform::dynload::cudnnGetTensorNdDescriptor(
desc, 4, &type, &nd, dims.data(), strides.data());
EXPECT_EQ(nd, 4);
for (size_t i = 0; i < dims.size(); ++i) {
EXPECT_EQ(dims[i], shape[i]);
}
EXPECT_EQ(strides[3], 1);
EXPECT_EQ(strides[2], 6);
EXPECT_EQ(strides[1], 36);
EXPECT_EQ(strides[0], 144);
}
TEST(CudnnHelper, ScopedFilterDescriptor) {
using paddle::platform::ScopedFilterDescriptor;
using paddle::platform::DataLayout;
ScopedFilterDescriptor filter_desc;
std::vector<int> shape = {2, 3, 3};
auto desc = filter_desc.descriptor<float>(DataLayout::kNCHW, shape);
cudnnDataType_t type;
int nd;
cudnnTensorFormat_t format;
std::vector<int> kernel(3);
paddle::platform::dynload::cudnnGetFilterNdDescriptor(desc, 3, &type, &format,
&nd, kernel.data());
EXPECT_EQ(GetCudnnTensorFormat(DataLayout::kNCHW), format);
EXPECT_EQ(nd, 3);
for (size_t i = 0; i < shape.size(); ++i) {
EXPECT_EQ(kernel[i], shape[i]);
}
}
TEST(CudnnHelper, ScopedConvolutionDescriptor) {
using paddle::platform::ScopedConvolutionDescriptor;
ScopedConvolutionDescriptor conv_desc;
std::vector<int> src_pads = {2, 2, 2};
std::vector<int> src_strides = {1, 1, 1};
std::vector<int> src_dilations = {1, 1, 1};
auto desc = conv_desc.descriptor<float>(src_pads, src_strides, src_dilations);
cudnnDataType_t type;
cudnnConvolutionMode_t mode;
int nd;
std::vector<int> pads(3);
std::vector<int> strides(3);
std::vector<int> dilations(3);
paddle::platform::dynload::cudnnGetConvolutionNdDescriptor(
desc, 3, &nd, pads.data(), strides.data(), dilations.data(), &mode,
&type);
EXPECT_EQ(nd, 3);
for (size_t i = 0; i < src_pads.size(); ++i) {
EXPECT_EQ(pads[i], src_pads[i]);
EXPECT_EQ(strides[i], src_strides[i]);
EXPECT_EQ(dilations[i], src_dilations[i]);
}
EXPECT_EQ(mode, CUDNN_CROSS_CORRELATION);
}
TEST(CudnnHelper, ScopedPoolingDescriptor) {
using paddle::platform::ScopedPoolingDescriptor;
using paddle::platform::PoolingMode;
ScopedPoolingDescriptor pool_desc;
std::vector<int> src_kernel = {2, 2, 5};
std::vector<int> src_pads = {1, 1, 2};
std::vector<int> src_strides = {2, 2, 3};
auto desc = pool_desc.descriptor(PoolingMode::kMaximum, src_kernel, src_pads,
src_strides);
cudnnPoolingMode_t mode;
cudnnNanPropagation_t nan_t = CUDNN_PROPAGATE_NAN;
int nd;
std::vector<int> kernel(3);
std::vector<int> pads(3);
std::vector<int> strides(3);
paddle::platform::dynload::cudnnGetPoolingNdDescriptor(
desc, 3, &mode, &nan_t, &nd, kernel.data(), pads.data(), strides.data());
EXPECT_EQ(nd, 3);
for (size_t i = 0; i < src_pads.size(); ++i) {
EXPECT_EQ(kernel[i], src_kernel[i]);
EXPECT_EQ(pads[i], src_pads[i]);
EXPECT_EQ(strides[i], src_strides[i]);
}
EXPECT_EQ(mode, CUDNN_POOLING_MAX);
}
cc_library(dynamic_loader SRCS dynamic_loader.cc DEPS glog gflags) cc_library(dynamic_loader SRCS dynamic_loader.cc DEPS glog gflags)
nv_library(dynload_cuda SRCS cublas.cc cudnn.cc curand.cc) nv_library(dynload_cuda SRCS cublas.cc cudnn.cc curand.cc DEPS dynamic_loader)
...@@ -62,19 +62,27 @@ extern void* cudnn_dso_handle; ...@@ -62,19 +62,27 @@ extern void* cudnn_dso_handle;
#define CUDNN_DNN_ROUTINE_EACH(__macro) \ #define CUDNN_DNN_ROUTINE_EACH(__macro) \
__macro(cudnnSetTensor4dDescriptor); \ __macro(cudnnSetTensor4dDescriptor); \
__macro(cudnnSetTensor4dDescriptorEx); \ __macro(cudnnSetTensor4dDescriptorEx); \
__macro(cudnnSetTensorNdDescriptor); \
__macro(cudnnGetTensorNdDescriptor); \
__macro(cudnnGetConvolutionNdForwardOutputDim); \ __macro(cudnnGetConvolutionNdForwardOutputDim); \
__macro(cudnnGetConvolutionForwardAlgorithm); \ __macro(cudnnGetConvolutionForwardAlgorithm); \
__macro(cudnnCreateTensorDescriptor); \ __macro(cudnnCreateTensorDescriptor); \
__macro(cudnnDestroyTensorDescriptor); \ __macro(cudnnDestroyTensorDescriptor); \
__macro(cudnnCreateFilterDescriptor); \ __macro(cudnnCreateFilterDescriptor); \
__macro(cudnnSetFilter4dDescriptor); \ __macro(cudnnSetFilter4dDescriptor); \
__macro(cudnnSetFilterNdDescriptor); \
__macro(cudnnGetFilterNdDescriptor); \
__macro(cudnnSetPooling2dDescriptor); \ __macro(cudnnSetPooling2dDescriptor); \
__macro(cudnnSetPoolingNdDescriptor); \
__macro(cudnnGetPoolingNdDescriptor); \
__macro(cudnnDestroyFilterDescriptor); \ __macro(cudnnDestroyFilterDescriptor); \
__macro(cudnnCreateConvolutionDescriptor); \ __macro(cudnnCreateConvolutionDescriptor); \
__macro(cudnnCreatePoolingDescriptor); \ __macro(cudnnCreatePoolingDescriptor); \
__macro(cudnnDestroyPoolingDescriptor); \ __macro(cudnnDestroyPoolingDescriptor); \
__macro(cudnnSetConvolution2dDescriptor); \ __macro(cudnnSetConvolution2dDescriptor); \
__macro(cudnnDestroyConvolutionDescriptor); \ __macro(cudnnDestroyConvolutionDescriptor); \
__macro(cudnnSetConvolutionNdDescriptor); \
__macro(cudnnGetConvolutionNdDescriptor); \
__macro(cudnnCreate); \ __macro(cudnnCreate); \
__macro(cudnnDestroy); \ __macro(cudnnDestroy); \
__macro(cudnnSetStream); \ __macro(cudnnSetStream); \
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
// Disable the copy and assignment operator for a class.
#ifndef DISABLE_COPY_AND_ASSIGN
#define DISABLE_COPY_AND_ASSIGN(classname) \
private: \
classname(const classname&) = delete; \
classname& operator=(const classname&) = delete
#endif
...@@ -39,12 +39,12 @@ USE_OP(sigmoid); ...@@ -39,12 +39,12 @@ USE_OP(sigmoid);
USE_OP(softmax); USE_OP(softmax);
USE_OP(rowwise_add); USE_OP(rowwise_add);
USE_OP(fill_zeros_like); USE_OP(fill_zeros_like);
USE_OP_ITSELF(recurrent_op); USE_NO_KERNEL_OP(recurrent);
USE_OP(gaussian_random); USE_OP(gaussian_random);
USE_OP(uniform_random); USE_OP(uniform_random);
USE_OP(lookup_table); USE_OP(lookup_table);
USE_OP(scale); USE_OP(scale);
USE_OP_ITSELF(identity); USE_NO_KERNEL_OP(identity);
USE_OP(minus); USE_OP(minus);
USE_CPU_ONLY_OP(gather); USE_CPU_ONLY_OP(gather);
USE_CPU_ONLY_OP(scatter); USE_CPU_ONLY_OP(scatter);
......
...@@ -53,7 +53,7 @@ __all__ = [ ...@@ -53,7 +53,7 @@ __all__ = [
'cos_sim', 'cos_sim',
'hsigmoid', 'hsigmoid',
'conv_projection', 'conv_projection',
'mse_cost', 'square_error_cost',
'regression_cost', 'regression_cost',
'classification_cost', 'classification_cost',
'LayerOutput', 'LayerOutput',
...@@ -4240,13 +4240,18 @@ def __cost_input__(input, label, weight=None): ...@@ -4240,13 +4240,18 @@ def __cost_input__(input, label, weight=None):
@wrap_name_default() @wrap_name_default()
@layer_support() @layer_support()
def mse_cost(input, label, weight=None, name=None, coeff=1.0, layer_attr=None): def square_error_cost(input,
label,
weight=None,
name=None,
coeff=1.0,
layer_attr=None):
""" """
mean squared error cost: sum of square error cost:
.. math:: .. math::
\\frac{1}{N}\sum_{i=1}^N(t_i-y_i)^2 cost = \\sum_{i=1}^N(t_i-y_i)^2
:param name: layer name. :param name: layer name.
:type name: basestring :type name: basestring
...@@ -4275,7 +4280,7 @@ def mse_cost(input, label, weight=None, name=None, coeff=1.0, layer_attr=None): ...@@ -4275,7 +4280,7 @@ def mse_cost(input, label, weight=None, name=None, coeff=1.0, layer_attr=None):
return LayerOutput(name, LayerType.COST, parents=parents, size=1) return LayerOutput(name, LayerType.COST, parents=parents, size=1)
regression_cost = mse_cost regression_cost = square_error_cost
@wrap_name_default("cost") @wrap_name_default("cost")
...@@ -5800,9 +5805,9 @@ def huber_regression_cost(input, ...@@ -5800,9 +5805,9 @@ def huber_regression_cost(input,
coeff=1.0, coeff=1.0,
layer_attr=None): layer_attr=None):
""" """
In statistics, the Huber loss is a loss function used in robust regression, In statistics, the Huber loss is a loss function used in robust regression,
that is less sensitive to outliers in data than the squared error loss. that is less sensitive to outliers in data than the squared error loss.
Given a prediction f(x), a label y and :math:`\delta`, the loss function Given a prediction f(x), a label y and :math:`\delta`, the loss function
is defined as: is defined as:
.. math: .. math:
...@@ -5850,13 +5855,13 @@ def huber_classification_cost(input, ...@@ -5850,13 +5855,13 @@ def huber_classification_cost(input,
coeff=1.0, coeff=1.0,
layer_attr=None): layer_attr=None):
""" """
For classification purposes, a variant of the Huber loss called modified Huber For classification purposes, a variant of the Huber loss called modified Huber
is sometimes used. Given a prediction f(x) (a real-valued classifier score) and is sometimes used. Given a prediction f(x) (a real-valued classifier score) and
a true binary class label :math:`y\in \left \{-1, 1 \right \}`, the modified Huber a true binary class label :math:`y\in \left \{-1, 1 \right \}`, the modified Huber
loss is defined as: loss is defined as:
.. math: .. math:
loss = \max \left ( 0, 1-yf(x) \right )^2, yf(x)\geq 1 loss = \max \left ( 0, 1-yf(x) \right )^2, yf(x)\geq 1
loss = -4yf(x), \text{otherwise} loss = -4yf(x), \text{otherwise}
The example usage is: The example usage is:
......
...@@ -45,7 +45,7 @@ layers { ...@@ -45,7 +45,7 @@ layers {
coeff: 1.0 coeff: 1.0
} }
layers { layers {
name: "__mse_cost_0__" name: "__square_error_cost_0__"
type: "square_error" type: "square_error"
size: 1 size: 1
active_type: "" active_type: ""
...@@ -130,7 +130,7 @@ input_layer_names: "label" ...@@ -130,7 +130,7 @@ input_layer_names: "label"
input_layer_names: "weight" input_layer_names: "weight"
input_layer_names: "multi_class_label" input_layer_names: "multi_class_label"
output_layer_names: "__cost_0__" output_layer_names: "__cost_0__"
output_layer_names: "__mse_cost_0__" output_layer_names: "__square_error_cost_0__"
output_layer_names: "__nce_layer_0__" output_layer_names: "__nce_layer_0__"
evaluators { evaluators {
name: "classification_error_evaluator" name: "classification_error_evaluator"
...@@ -146,7 +146,7 @@ sub_models { ...@@ -146,7 +146,7 @@ sub_models {
layer_names: "weight" layer_names: "weight"
layer_names: "__fc_layer_0__" layer_names: "__fc_layer_0__"
layer_names: "__cost_0__" layer_names: "__cost_0__"
layer_names: "__mse_cost_0__" layer_names: "__square_error_cost_0__"
layer_names: "multi_class_label" layer_names: "multi_class_label"
layer_names: "__nce_layer_0__" layer_names: "__nce_layer_0__"
input_layer_names: "input" input_layer_names: "input"
...@@ -154,7 +154,7 @@ sub_models { ...@@ -154,7 +154,7 @@ sub_models {
input_layer_names: "weight" input_layer_names: "weight"
input_layer_names: "multi_class_label" input_layer_names: "multi_class_label"
output_layer_names: "__cost_0__" output_layer_names: "__cost_0__"
output_layer_names: "__mse_cost_0__" output_layer_names: "__square_error_cost_0__"
output_layer_names: "__nce_layer_0__" output_layer_names: "__nce_layer_0__"
evaluator_names: "classification_error_evaluator" evaluator_names: "classification_error_evaluator"
is_recurrent_layer_group: false is_recurrent_layer_group: false
......
...@@ -10,7 +10,7 @@ fc = fc_layer(input=data, size=10, act=SoftmaxActivation()) ...@@ -10,7 +10,7 @@ fc = fc_layer(input=data, size=10, act=SoftmaxActivation())
outputs( outputs(
classification_cost( classification_cost(
input=fc, label=lbl, weight=wt), input=fc, label=lbl, weight=wt),
mse_cost( square_error_cost(
input=fc, label=lbl, weight=wt), input=fc, label=lbl, weight=wt),
nce_layer( nce_layer(
input=fc, input=fc,
......
...@@ -179,7 +179,7 @@ class OperatorFactory(object): ...@@ -179,7 +179,7 @@ class OperatorFactory(object):
class __RecurrentOp__(object): class __RecurrentOp__(object):
__proto__ = None __proto__ = None
type = 'recurrent_op' type = 'recurrent'
def __init__(self): def __init__(self):
# cache recurrent_op's proto # cache recurrent_op's proto
......
...@@ -134,8 +134,9 @@ class CostLayerTest(unittest.TestCase): ...@@ -134,8 +134,9 @@ class CostLayerTest(unittest.TestCase):
cost3 = layer.cross_entropy_cost(input=inference, label=label) cost3 = layer.cross_entropy_cost(input=inference, label=label)
cost4 = layer.cross_entropy_with_selfnorm_cost( cost4 = layer.cross_entropy_with_selfnorm_cost(
input=inference, label=label) input=inference, label=label)
cost5 = layer.mse_cost(input=inference, label=label) cost5 = layer.square_error_cost(input=inference, label=label)
cost6 = layer.mse_cost(input=inference, label=label, weight=weight) cost6 = layer.square_error_cost(
input=inference, label=label, weight=weight)
cost7 = layer.multi_binary_label_cross_entropy_cost( cost7 = layer.multi_binary_label_cross_entropy_cost(
input=inference, label=label) input=inference, label=label)
cost8 = layer.rank_cost(left=score, right=score, label=score) cost8 = layer.rank_cost(left=score, right=score, label=score)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册