diff --git a/doc/api/v2/config/layer.rst b/doc/api/v2/config/layer.rst index 1329b77bb44f52c66a703740715b890c47234e72..c94627a72806fa2eca77c79da24f7f3ca18f0259 100644 --- a/doc/api/v2/config/layer.rst +++ b/doc/api/v2/config/layer.rst @@ -434,9 +434,9 @@ lambda_cost .. autoclass:: paddle.v2.layer.lambda_cost :noindex: -mse_cost +square_error_cost -------- -.. autoclass:: paddle.v2.layer.mse_cost +.. autoclass:: paddle.v2.layer.square_error_cost :noindex: rank_cost diff --git a/doc/design/graph.md b/doc/design/graph.md new file mode 100644 index 0000000000000000000000000000000000000000..87f696f90f164a639ad5182823ddfb14aab7e065 --- /dev/null +++ b/doc/design/graph.md @@ -0,0 +1,51 @@ +# Design Doc: Computations as Graphs + +A primary goal of the refactorization of PaddlePaddle is a more flexible representation of deep learning computation, in particular, a graph of operators and variables, instead of sequences of layers as before. + +This document explains that the construction of a graph as three steps: + +- construct the forward part +- construct the backward part +- construct the optimization part + +Let us take the problem of image classification as a simple example. The application program that trains the model looks like: + +```python +x = layer.data("images") +l = layer.data("label") +y = layer.fc(x) +cost = layer.mse(y, l) +optimize(cost) +train(cost, reader=mnist.train()) +``` + +### Forward Part + +The first four lines of above program build the forward part of the graph. + +![](images/graph_construction_example_forward_only.png) + +In particular, the first line `x = layer.data("images")` creates variable x and a Feed operator that copies a column from the minibatch to x. `y = layer.fc(x)` creates not only the FC operator and output variable y, but also two parameters, W and b. + +In this example, all operators are created as `OpDesc` protobuf messages, and all variables are `VarDesc`. These protobuf messages are saved in a `BlockDesc` protobuf message. + +### Backward Part + +The fifth line `optimize(cost)` calls two functions, `ConstructBackwardGraph` and `ConstructOptimizationGraph`. + +`ConstructBackwardGraph` traverses the forward graph in the `BlockDesc` protobuf message and builds the backward part. + +![](images/graph_construction_example_forward_backward.png) + +According to the chain rule of gradient computation, `ConstructBackwardGraph` would + +1. create a gradient operator G for each operator F, +1. make all inputs, outputs, and outputs' gradient of F as inputs of G, +1. create gradients for all inputs of F, except for those who don't have gradients, like x and l, and +1. make all these gradients as outputs of G. + +### Optimization Part + +For each parameter, like W and b created by `layer.fc`, marked as double circles in above graphs, `ConstructOptimizationGraph` creates an optimization operator to apply its gradient. Here results in the complete graph: + +![](images/graph_construction_example_all.png) diff --git a/doc/design/if_else_op.md b/doc/design/if_else_op.md new file mode 100644 index 0000000000000000000000000000000000000000..7370c2a24fa644a64e738f202bac9b9209642e08 --- /dev/null +++ b/doc/design/if_else_op.md @@ -0,0 +1,59 @@ +IfOp should have only one branch. An IfOp operator takes a `cond` variable whose value must be a vector of N boolean elements. Its return value has M (M<=N) instances, each corresponds to a true element in `cond`. + +```python +import paddle as pd + +x = var() +y = var() +cond = var() + +b = pd.create_ifop(inputs=[x], output_num=1) +with b.true_block(): + x = b.inputs(0) + z = operator.add(x, y) + b.set_output(0, operator.softmax(z)) + +out = b(cond) +``` + +If we want the output still has N instances, we can use IfElseOp with a default value, whose minibatch size must be N: + +```python +import paddle as pd + +x = var() +y = var() +cond = var() +default_value = var() +b = pd.create_ifelseop(inputs=[x], output_num=1) +with b.true_block(): + x = b.inputs(0) + z = operator.add(x, y) + b.set_output(0, operator.softmax(z)) + +with b.false_block(): + x = b.inputs(0) + z = layer.fc(x) + b.set_output(0, operator.softmax(z)) + +out = b(cond) +``` + +If only true_block is set in an IfElseOp, we can have a default value for false as: +```python +import paddle as pd + +x = var() +y = var() +cond = var() +default_value = var() +b = pd.create_ifelseop(inputs=[x], output_num=1, default_value) + +with b.true_block(): + x = b.inputs(0) + z = operator.add(x, y) + b.set_output(0, operator.softmax(z)) + +out = b(cond) +``` +where default_value is a list of vars for `cond` == False. diff --git a/doc/design/images/graph_construction_example.bash b/doc/design/images/graph_construction_example.bash new file mode 100755 index 0000000000000000000000000000000000000000..35e6997abd17588e17a82d448918fc1b3bd7220e --- /dev/null +++ b/doc/design/images/graph_construction_example.bash @@ -0,0 +1,11 @@ +cat ./graph_construction_example.dot | \ + sed 's/color=red/color=red, style=invis/g' | \ + sed 's/color=green/color=green, style=invis/g' | \ + dot -Tpng > graph_construction_example_forward_only.png + +cat ./graph_construction_example.dot | \ + sed 's/color=green/color=green, style=invis/g' | \ + dot -Tpng > graph_construction_example_forward_backward.png + +cat ./graph_construction_example.dot | \ + dot -Tpng > graph_construction_example_all.png diff --git a/doc/design/images/graph_construction_example.dot b/doc/design/images/graph_construction_example.dot new file mode 100644 index 0000000000000000000000000000000000000000..bedb6de0111a8ccab4030d034d65cf72705fc25a --- /dev/null +++ b/doc/design/images/graph_construction_example.dot @@ -0,0 +1,65 @@ +digraph ImageClassificationGraph { + ///////// The forward part ///////// + FeedX [label="Feed", color=blue, shape=box]; + FeedY [label="Feed", color=blue, shape=box]; + FC [label="FC", color=blue, shape=box]; + MSE [label="MSE", color=blue, shape=box]; + + x [label="x", color=blue, shape=oval]; + l [label="l", color=blue, shape=oval]; + y [label="y", color=blue, shape=oval]; + W [label="W", color=blue, shape=doublecircle]; + b [label="b", color=blue, shape=doublecircle]; + cost [label="cost", color=blue, shape=oval]; + + FeedX -> x -> FC -> y -> MSE -> cost [color=blue]; + FeedY -> l [color=blue]; + W -> FC [color=blue]; + b -> FC [color=blue]; + l -> MSE [color=blue]; + + ////////// The backward part ///////// + MSE_Grad [label="MSE_grad", color=red, shape=box]; + FC_Grad [label="FC_grad", color=red, shape=box]; + + d_cost [label="d cost", color=red, shape=oval]; + d_y [label="d y", color=red, shape=oval]; + d_b [label="d b", color=red, shape=oval]; + d_W [label="d W", color=red, shape=oval]; + + cost -> MSE_Grad [color=red]; + d_cost -> MSE_Grad [color=red]; + x -> MSE_Grad [color=red]; + l -> MSE_Grad [color=red]; + y -> MSE_Grad -> d_y [color=red]; + + x -> FC_Grad [color=red]; + y -> FC_Grad [color=red]; + d_y -> FC_Grad [color=red]; + W -> FC_Grad -> d_W [color=red]; + b -> FC_Grad -> d_b [color=red]; + + ////////// The optimizaiton part ////////// + + OPT_W [label="SGD", color=green, shape=box]; + OPT_b [label="SGD", color=green, shape=box]; + + W -> OPT_W [color=green]; + b -> OPT_b [color=green]; + d_W -> OPT_W -> W [color=green]; + d_b -> OPT_b -> b [color=green]; + + ////////// Groupings ////////// + + subgraph clusterMSE { + style=invis; + MSE; + MSE_Grad; + } + + subgraph clusterFC { + style=invis; + FC; + FC_Grad; + } +} diff --git a/doc/design/images/graph_construction_example_all.png b/doc/design/images/graph_construction_example_all.png new file mode 100644 index 0000000000000000000000000000000000000000..18d8330b60e12720bb993c8cf588d64ff8db1ea9 Binary files /dev/null and b/doc/design/images/graph_construction_example_all.png differ diff --git a/doc/design/images/graph_construction_example_forward_backward.png b/doc/design/images/graph_construction_example_forward_backward.png new file mode 100644 index 0000000000000000000000000000000000000000..61c3a02a04bc8891ab5b921a889829bcce386df8 Binary files /dev/null and b/doc/design/images/graph_construction_example_forward_backward.png differ diff --git a/doc/design/images/graph_construction_example_forward_only.png b/doc/design/images/graph_construction_example_forward_only.png new file mode 100644 index 0000000000000000000000000000000000000000..14805df11fc09f64d6bc17f5e969f1400d615148 Binary files /dev/null and b/doc/design/images/graph_construction_example_forward_only.png differ diff --git a/doc/getstarted/basic_usage/index_cn.rst b/doc/getstarted/basic_usage/index_cn.rst index 428f58830e0b10c024f31238b7404c6df193eecd..b473944fc7fb89d3e0a0b330933f2226734bb5bd 100644 --- a/doc/getstarted/basic_usage/index_cn.rst +++ b/doc/getstarted/basic_usage/index_cn.rst @@ -55,7 +55,7 @@ PaddlePaddle是源于百度的一个深度学习平台。这份简短的介绍 # 线性计算网络层: ȳ = wx + b ȳ = fc_layer(input=x, param_attr=ParamAttr(name='w'), size=1, act=LinearActivation(), bias_attr=ParamAttr(name='b')) # 计算误差函数,即 ȳ 和真实 y 之间的距离 - cost = mse_cost(input= ȳ, label=y) + cost = square_error_cost(input= ȳ, label=y) outputs(cost) @@ -69,7 +69,7 @@ PaddlePaddle是源于百度的一个深度学习平台。这份简短的介绍 - **数据层**:数据层 `data_layer` 是神经网络的入口,它读入数据并将它们传输到接下来的网络层。这里数据层有两个,分别对应于变量 `x` 和 `y`。 - **全连接层**:全连接层 `fc_layer` 是基础的计算单元,这里利用它建模变量之间的线性关系。计算单元是神经网络的核心,PaddlePaddle支持大量的计算单元和任意深度的网络连接,从而可以拟合任意的函数来学习复杂的数据关系。 - - **回归误差代价层**:回归误差代价层 `mse_cost` 是众多误差代价函数层的一种,它们在训练过程作为网络的出口,用来计算模型的误差,是模型参数优化的目标函数。 + - **回归误差代价层**:回归误差代价层 `square_error_cost` 是众多误差代价函数层的一种,它们在训练过程作为网络的出口,用来计算模型的误差,是模型参数优化的目标函数。 定义了网络结构并保存为 `trainer_config.py` 之后,运行以下训练命令: diff --git a/doc/getstarted/basic_usage/index_en.rst b/doc/getstarted/basic_usage/index_en.rst index 6775da20c2f51000f305b095d40abd27b8fa6c0e..2cc438ebbe0f97345d25354b93b4ebbd43502415 100644 --- a/doc/getstarted/basic_usage/index_en.rst +++ b/doc/getstarted/basic_usage/index_en.rst @@ -49,7 +49,7 @@ To recover this relationship between ``X`` and ``Y``, we use a neural network wi x = data_layer(name='x', size=1) y = data_layer(name='y', size=1) y_predict = fc_layer(input=x, param_attr=ParamAttr(name='w'), size=1, act=LinearActivation(), bias_attr=ParamAttr(name='b')) - cost = mse_cost(input=y_predict, label=y) + cost = square_error_cost(input=y_predict, label=y) outputs(cost) Some of the most fundamental usages of PaddlePaddle are demonstrated: diff --git a/doc/getstarted/concepts/src/train.py b/doc/getstarted/concepts/src/train.py index 7e604f23de38543a00f305d508af0791193f78ba..8aceb23406a476f08639cc6223cdf730b728a705 100644 --- a/doc/getstarted/concepts/src/train.py +++ b/doc/getstarted/concepts/src/train.py @@ -8,7 +8,7 @@ paddle.init(use_gpu=False) x = paddle.layer.data(name='x', type=paddle.data_type.dense_vector(2)) y_predict = paddle.layer.fc(input=x, size=1, act=paddle.activation.Linear()) y = paddle.layer.data(name='y', type=paddle.data_type.dense_vector(1)) -cost = paddle.layer.mse_cost(input=y_predict, label=y) +cost = paddle.layer.square_error_cost(input=y_predict, label=y) # create parameters parameters = paddle.parameters.create(cost) diff --git a/doc/getstarted/concepts/use_concepts_cn.rst b/doc/getstarted/concepts/use_concepts_cn.rst index f15b11bd780402a3ec1755900e8c648f5d2a7bc5..c243083794bb3c4659242de99b3b2715af9d7c24 100644 --- a/doc/getstarted/concepts/use_concepts_cn.rst +++ b/doc/getstarted/concepts/use_concepts_cn.rst @@ -81,9 +81,9 @@ PaddlePaddle支持不同类型的输入数据,主要包括四种类型,和 .. code-block:: bash y_predict = paddle.layer.fc(input=x, size=1, act=paddle.activation.Linear()) - cost = paddle.layer.mse_cost(input=y_predict, label=y) + cost = paddle.layer.square_error_cost(input=y_predict, label=y) -其中,x与y为之前描述的输入层;而y_predict是接收x作为输入,接上一个全连接层;cost接收y_predict与y作为输入,接上均方误差层。 +其中,x与y为之前描述的输入层;而y_predict是接收x作为输入,接上一个全连接层;cost接收y_predict与y作为输入,接上平方误差层。 最后一层cost中记录了神经网络的所有拓扑结构,通过组合不同的layer,我们即可完成神经网络的搭建。 @@ -147,4 +147,4 @@ PaddlePaddle支持不同类型的输入数据,主要包括四种类型,和 .. literalinclude:: src/train.py :linenos: -有关线性回归的实际应用,可以参考PaddlePaddle book的 `第一章节 `_。 \ No newline at end of file +有关线性回归的实际应用,可以参考PaddlePaddle book的 `第一章节 `_。 diff --git a/doc/howto/dev/new_op_cn.md b/doc/howto/dev/new_op_cn.md index ec79b7f42b2d70df8fcb25faca5bc3a4759e177c..7f8da2da5a0d42ff065265c5d173d0e6167dc08a 100644 --- a/doc/howto/dev/new_op_cn.md +++ b/doc/howto/dev/new_op_cn.md @@ -178,13 +178,13 @@ class MulKernel : public framework::OpKernel { ```c++ namespace ops = paddle::operators; -REGISTER_OP(mul, ops::MulOp, ops::MulOpMaker, ops::MulOpGrad); +REGISTER_OP(mul, ops::MulOp, ops::MulOpMaker, mul_grad, ops::MulOpGrad); REGISTER_OP_CPU_KERNEL(mul, ops::MulKernel); REGISTER_OP_CPU_KERNEL(mul_grad, ops::MulGradKernel); ``` - - `REGISTER_OP` : 注册`ops::MulOp`类,类型名为`mul`,该类的`ProtoMaker`为`ops::MulOpMaker`,并且注册`ops::MulOpGrad`为其反向Op。 + - `REGISTER_OP` : 注册`ops::MulOp`类,类型名为`mul`,该类的`ProtoMaker`为`ops::MulOpMaker`,注册`ops::MulOpGrad`,类型名为`mul_grad`, - `REGISTER_OP_WITHOUT_GRADIENT` : 用于注册没有反向的Op。 - `REGISTER_OP_CPU_KERNEL` :注册`ops::MulKernel`类,并特化模板参数为`paddle::platform::CPUPlace`和`float`类型,同理,注册`ops::MulKernel`类。 diff --git a/doc/howto/usage/k8s/k8s_distributed_cn.md b/doc/howto/usage/k8s/k8s_distributed_cn.md index 3121b3f59df650c0a22d0bd305a6f793b202d30e..a9bebf09558b06993119803458977abedbbfbdd0 100644 --- a/doc/howto/usage/k8s/k8s_distributed_cn.md +++ b/doc/howto/usage/k8s/k8s_distributed_cn.md @@ -213,7 +213,7 @@ I1116 09:10:17.123440 50 Util.cpp:130] Calling runInitFunctions I1116 09:10:17.123764 50 Util.cpp:143] Call runInitFunctions done. [WARNING 2016-11-16 09:10:17,227 default_decorators.py:40] please use keyword arguments in paddle config. [INFO 2016-11-16 09:10:17,239 networks.py:1282] The input order is [movie_id, title, genres, user_id, gender, age, occupation, rating] -[INFO 2016-11-16 09:10:17,239 networks.py:1289] The output order is [__mse_cost_0__] +[INFO 2016-11-16 09:10:17,239 networks.py:1289] The output order is [__square_error_cost_0__] I1116 09:10:17.392917 50 Trainer.cpp:170] trainer mode: Normal I1116 09:10:17.613910 50 PyDataProvider2.cpp:257] loading dataprovider dataprovider::process I1116 09:10:17.680917 50 PyDataProvider2.cpp:257] loading dataprovider dataprovider::process diff --git a/paddle/framework/backward.md b/paddle/framework/backward.md index 9500c92a265d60a696e1e2c422d0f2bd1621ef71..8aa6728a95bc464ab8884986f0cec6c817d3303b 100644 --- a/paddle/framework/backward.md +++ b/paddle/framework/backward.md @@ -18,7 +18,7 @@ A backward network is built up with several backward operators. Backward operato For example, we have got a `mul_op`, and we can register it's information and corresponding backward operator by the following macro: ```cpp -REGISTER_OP(mul, MulOp, MulOpMaker, MulOpGrad); +REGISTER_OP(mul, MulOp, MulOpMaker, mul_grad, MulOpGrad); ``` `mul` is the operator's type. `MulOp` and `MulOpMaker` are the operator class and the operator maker class respectively. diff --git a/paddle/framework/backward_test.cc b/paddle/framework/backward_test.cc index f4d99f55b5ec00fcca61d81e7bf71f7791248fe9..ad8003420dc14538d0dae9a1cb19d6459b154576 100644 --- a/paddle/framework/backward_test.cc +++ b/paddle/framework/backward_test.cc @@ -127,8 +127,8 @@ class FillZeroOpMaker : public OpProtoAndCheckerMaker { public: FillZeroOpMaker(OpProto *proto, OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("x", "x"); - AddOutput("out", "out"); + AddInput("Src", "x"); + AddOutput("Dst", "out"); AddComment(""); } }; @@ -138,7 +138,7 @@ class AddOpMaker : public OpProtoAndCheckerMaker { AddOpMaker(OpProto *proto, OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "x").AsDuplicable(); - AddOutput("Y", "y"); + AddOutput("Out", "out"); AddComment(""); } }; @@ -148,14 +148,16 @@ class AddOpMaker : public OpProtoAndCheckerMaker { namespace f = paddle::framework; namespace ops = paddle::operators; using EnforceNotMet = paddle::platform::EnforceNotMet; -REGISTER_OP(rowwise_add, f::NOP, f::RowWiseAddOpMaker, f::NOP); -REGISTER_OP(mul, f::NOP, f::MulOpMaker, f::NOP); -REGISTER_OP(sigmoid, f::NOP, f::SigmoidOpMaker, f::NOP); +REGISTER_OP(rowwise_add, f::NOP, f::RowWiseAddOpMaker, rowwise_add_grad, + f::NOP); +REGISTER_OP(mul, f::NOP, f::MulOpMaker, mul_grad, f::NOP); +REGISTER_OP(sigmoid, f::NOP, f::SigmoidOpMaker, sigmoid_grad, f::NOP); REGISTER_OP_WITHOUT_GRADIENT(nograd, f::NOP, f::NoGradOpMaker); REGISTER_OP_WITHOUT_GRADIENT(fill_zeros_like, f::NOP, f::FillZeroOpMaker); -REGISTER_OP(add, f::NOP, f::AddOpMaker, f::NOP); +REGISTER_OP(add, f::NOP, f::AddOpMaker, add_grad, f::NOP); REGISTER_OP_WITHOUT_GRADIENT(fc, f::FcOp, f::FcOpMaker); -REGISTER_OP(many_output_op, f::NOP, f::ManyOutputOpMaker, f::NOP); +REGISTER_OP(many_output_op, f::NOP, f::ManyOutputOpMaker, many_output_op_grad, + f::NOP); TEST(Backward, simple_op_grad) { auto fwd = f::OpRegistry::CreateOp( diff --git a/paddle/framework/grad_op_builder_test.cc b/paddle/framework/grad_op_builder_test.cc index 8a817a3e13ca64d6f8df566891a1059995e041ae..902c2655e9182d74a48ad13e17a39a3304d5fa57 100644 --- a/paddle/framework/grad_op_builder_test.cc +++ b/paddle/framework/grad_op_builder_test.cc @@ -54,8 +54,8 @@ TEST(GradOpBuilder, AddTwo) { EXPECT_EQ(grad_add_op->Output(f::GradVarName("Y")), f::GradVarName("y")); } -REGISTER_OP(mult_io, f::NOP, f::MutiInOutOpMaker, f::NOP); -REGISTER_OP(io_ignored, f::NOP, f::IOIgnoredOpMaker, f::NOP); +REGISTER_OP(mult_io, f::NOP, f::MutiInOutOpMaker, mult_io_grad, f::NOP); +REGISTER_OP(io_ignored, f::NOP, f::IOIgnoredOpMaker, io_ignored_grad, f::NOP); TEST(GradOpBuilder, MutiInOut) { std::shared_ptr test_op(f::OpRegistry::CreateOp( diff --git a/paddle/framework/op_info.h b/paddle/framework/op_info.h index 94245c6c44aca962b0db890947a9dc5550ac0799..b98d8f23a14cf6fbe787953ad16b5c9ab99222ad 100644 --- a/paddle/framework/op_info.h +++ b/paddle/framework/op_info.h @@ -80,9 +80,19 @@ class OpInfoMap { } const OpInfo& Get(const std::string& type) const { + auto op_info_ptr = GetNullable(type); + PADDLE_ENFORCE_NOT_NULL(op_info_ptr, "Operator %s has not been registered", + type); + return *op_info_ptr; + } + + const OpInfo* GetNullable(const std::string& type) const { auto it = map_.find(type); - PADDLE_ENFORCE(it != map_.end(), "Operator %s are not found", type); - return it->second; + if (it == map_.end()) { + return nullptr; + } else { + return &it->second; + } } template diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index 64c7f23ab6b79bad9533f566ca39db3cfd5ac5c5..2d09cde41e3f5086279f9441e0fdc52549bed5ab 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -33,7 +33,8 @@ namespace framework { class OpRegistry { public: template - static void RegisterOp(const std::string& op_type) { + static void RegisterOp(const std::string& op_type, + const std::string& grad_op_type) { PADDLE_ENFORCE(!OpInfoMap::Instance().Has(op_type), "'%s' is registered more than once.", op_type); OpInfo op_info; @@ -42,9 +43,9 @@ class OpRegistry { const VariableNameMap& outputs, const AttributeMap& attrs) { return new OpType(type, inputs, outputs, attrs); }; + op_info.grad_op_type_ = grad_op_type; if (std::type_index(typeid(ProtoMakerType)) != std::type_index(typeid(NOPMaker))) { - op_info.grad_op_type_ = op_type + "_grad"; op_info.proto_ = new OpProto; op_info.checker_ = new OpAttrChecker; auto maker = ProtoMakerType(op_info.proto_, op_info.checker_); @@ -54,14 +55,15 @@ class OpRegistry { op_info.proto_->IsInitialized(), "Fail to initialize %s's OpProto, because %s is not initialized", op_type, op_info.proto_->InitializationErrorString()); - // register gradient op - RegisterOp(op_info.grad_op_type_); } else { - op_info.grad_op_type_ = ""; op_info.proto_ = nullptr; op_info.checker_ = nullptr; } OpInfoMap::Instance().Insert(op_type, op_info); + // register gradient op + if (!grad_op_type.empty()) { + RegisterOp(grad_op_type, ""); + } } static std::unique_ptr CreateOp(const std::string& type, @@ -90,8 +92,10 @@ class Registrar { template class OpRegistrar : public Registrar { public: - explicit OpRegistrar(const char* op_type) { - OpRegistry::RegisterOp(op_type); + explicit OpRegistrar(const char* op_type) { OpRegistrar(op_type, ""); } + OpRegistrar(const char* op_type, const char* grad_op_type) { + OpRegistry::RegisterOp(op_type, + grad_op_type); } }; @@ -117,7 +121,8 @@ class OpKernelRegistrar : public Registrar { /** * Macro to register Operator. */ -#define REGISTER_OP(op_type, op_class, op_maker_class, grad_op_class) \ +#define REGISTER_OP(op_type, op_class, op_maker_class, grad_op_type, \ + grad_op_class) \ STATIC_ASSERT_GLOBAL_NAMESPACE( \ __reg_op__##op_type, "REGISTER_OP must be called in global namespace"); \ class _OpClass_##op_type##_ : public op_class { \ @@ -132,14 +137,14 @@ class OpKernelRegistrar : public Registrar { }; \ static ::paddle::framework::OpRegistrar< \ _OpClass_##op_type##_, op_maker_class, _OpGradClass_##op_type##_> \ - __op_registrar_##op_type##__(#op_type); \ + __op_registrar_##op_type##__(#op_type, #grad_op_type); \ int TouchOpRegistrar_##op_type() { \ __op_registrar_##op_type##__.Touch(); \ return 0; \ } #define REGISTER_OP_WITHOUT_GRADIENT(op_type, op_class, op_maker_class) \ - REGISTER_OP(op_type, op_class, op_maker_class, ::paddle::framework::NOP) + REGISTER_OP(op_type, op_class, op_maker_class, , ::paddle::framework::NOP) /** * Macro to register OperatorKernel. diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index 7abbde610f1e9c530393b9a9cabe40b826712212..790cfc4746b1d34da413fa3c29a266f962c6dde6 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -33,12 +33,12 @@ ExecutionContext::GetEigenDevice() const { } #endif -const std::string& OperatorBase::Input(const std::string& name) const { +std::string OperatorBase::Input(const std::string& name) const { auto& ins = Inputs(name); - PADDLE_ENFORCE_EQ(ins.size(), 1UL, + PADDLE_ENFORCE_LE(ins.size(), 1UL, "Op %s input %s should contain only one variable", type_, name); - return ins[0]; + return ins.empty() ? kEmptyVarName : ins[0]; } const std::vector& OperatorBase::Inputs( @@ -49,12 +49,12 @@ const std::vector& OperatorBase::Inputs( return it->second; } -const std::string& OperatorBase::Output(const std::string& name) const { +std::string OperatorBase::Output(const std::string& name) const { auto& outs = Outputs(name); - PADDLE_ENFORCE_EQ(outs.size(), 1UL, + PADDLE_ENFORCE_LE(outs.size(), 1UL, "Op %s output %s should contain only one variable", type_, name); - return outs[0]; + return outs.empty() ? kEmptyVarName : outs[0]; } const std::vector& OperatorBase::Outputs( @@ -119,16 +119,8 @@ OperatorBase::OperatorBase(const std::string& type, const VariableNameMap& outputs, const AttributeMap& attrs) : type_(type), inputs_(inputs), outputs_(outputs), attrs_(attrs) { - static std::atomic gUniqId(0UL); - for (auto& output : outputs_) { - for (auto& output_name : output.second) { - if (output_name == kTempVarName) { - output_name += type_; - output_name += "@"; - output_name += std::to_string(gUniqId.fetch_add(1)); - } - } - } + GenerateTemporaryNames(); + CheckAllInputOutputSet(); } std::vector OperatorBase::OutputVars(bool has_intermediate) const { @@ -156,6 +148,35 @@ std::vector OperatorBase::OutputVars(bool has_intermediate) const { return ret_val; } +void OperatorBase::CheckAllInputOutputSet() const { + auto& info_map = OpInfoMap::Instance(); + auto* op_info = info_map.GetNullable(Type()); + if (op_info == nullptr || op_info->proto_ == nullptr) return; + + for (auto& in : op_info->Proto().inputs()) { + PADDLE_ENFORCE(inputs_.find(in.name()) != inputs_.end(), + "Type %s's input %s is not set", Type(), in.name()); + } + + for (auto& out : op_info->Proto().outputs()) { + PADDLE_ENFORCE(outputs_.find(out.name()) != outputs_.end(), + "Type %s's output %s is not set", Type(), out.name()); + } +} + +void OperatorBase::GenerateTemporaryNames() { + static std::atomic gUniqId(0UL); + for (auto& output : outputs_) { + for (auto& output_name : output.second) { + if (output_name == kTempVarName) { + output_name += type_; + output_name += "@"; + output_name += std::to_string(gUniqId.fetch_add(1)); + } + } + } +} + void OpProtoAndCheckerMaker::Validate() { validated_ = true; CheckNoDuplicatedInOutAttrs(); diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 8397570d26f06f0238e9c5afc85d721df7679257..da92220b04e313e4743cc77241755b685d0791ad 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -95,12 +95,12 @@ class OperatorBase { const VariableNameMap& Inputs() const { return inputs_; } const VariableNameMap& Outputs() const { return outputs_; } //! Get a input with argument's name described in `op_proto` - const std::string& Input(const std::string& name) const; + std::string Input(const std::string& name) const; //! Get a input which has multiple variables. const std::vector& Inputs(const std::string& name) const; //! Get a output with argument's name described in `op_proto` - const std::string& Output(const std::string& name) const; + std::string Output(const std::string& name) const; //! Get an output which has multiple variables. //! TODO add a vector_view to prevent memory copy. const std::vector& Outputs(const std::string& name) const; @@ -127,6 +127,10 @@ class OperatorBase { // IG (Inputs Gradients) VariableNameMap outputs_; AttributeMap attrs_; + + private: + void GenerateTemporaryNames(); + void CheckAllInputOutputSet() const; }; // Macro for define a clone method. @@ -229,6 +233,15 @@ class InferShapeContext { InferShapeContext(const OperatorBase& op, const Scope& scope) : op_(op), scope_(scope) {} + const OperatorBase& op() const { return op_; } + + const Scope& scope() const { return scope_; } + + template + inline const T& GetAttr(const std::string& name) const { + return op_.GetAttr(name); + } + size_t InputSize(const std::string& name) const { return op_.Inputs(name).size(); } @@ -238,11 +251,13 @@ class InferShapeContext { } const Variable* InputVar(const std::string& name) const { - return scope_.FindVar(op_.Input(name)); + auto ipt = op_.Input(name); + return ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt); } Variable* OutputVar(const std::string& name) const { - return scope_.FindVar(op_.Output(name)); + auto opt = op_.Output(name); + return opt == kEmptyVarName ? nullptr : scope_.FindVar(opt); } const std::vector MultiInputVar( @@ -250,9 +265,11 @@ class InferShapeContext { auto names = op_.Inputs(name); std::vector res; res.reserve(names.size()); - std::transform( - names.begin(), names.end(), std::back_inserter(res), - [this](const std::string& name) { return scope_.FindVar(name); }); + std::transform(names.begin(), names.end(), std::back_inserter(res), + [this](const std::string& name) { + return name == kEmptyVarName ? nullptr + : scope_.FindVar(name); + }); return res; } @@ -260,24 +277,24 @@ class InferShapeContext { auto names = op_.Outputs(name); std::vector res; res.reserve(names.size()); - std::transform( - names.begin(), names.end(), std::back_inserter(res), - [this](const std::string& name) { return scope_.FindVar(name); }); + std::transform(names.begin(), names.end(), std::back_inserter(res), + [this](const std::string& name) { + return name == kEmptyVarName ? nullptr + : scope_.FindVar(name); + }); return res; } template const T* Input(const std::string& name) const { auto* var = InputVar(name); - PADDLE_ENFORCE_NOT_NULL(var, "Input(%s) should not be nullptr", name); - return &var->Get(); + return var == nullptr ? nullptr : &var->Get(); } template T* Output(const std::string& name) const { auto var = OutputVar(name); - PADDLE_ENFORCE_NOT_NULL(var, "Output(%s) should not be nullptr", name); - return var->GetMutable(); + return var == nullptr ? nullptr : var->GetMutable(); } template @@ -288,10 +305,7 @@ class InferShapeContext { std::transform(names.begin(), names.end(), std::back_inserter(res), [&](const std::string& sub_name) { auto var = scope_.FindVar(sub_name); - PADDLE_ENFORCE_NOT_NULL( - var, "MultiInput(%s:%s) should not be nullptr", name, - sub_name); - return &var->Get(); + return var == nullptr ? nullptr : &var->Get(); }); return res; } @@ -304,14 +318,12 @@ class InferShapeContext { std::transform(names.begin(), names.end(), std::back_inserter(res), [&](const std::string& sub_name) { auto var = scope_.FindVar(sub_name); - PADDLE_ENFORCE_NOT_NULL( - var, "MultiOutput(%s:%s) should not be nullptr.", name, - sub_name); - return var->GetMutable(); + return var == nullptr ? nullptr : var->GetMutable(); }); return res; } + private: const OperatorBase& op_; const Scope& scope_; }; diff --git a/paddle/framework/operator_test.cc b/paddle/framework/operator_test.cc index 1d7efb7b9403f7c1c6bdbb27a0258f79ae032f43..f7c9e6b196a9d63c91d83fb6d985472a4e8976c4 100644 --- a/paddle/framework/operator_test.cc +++ b/paddle/framework/operator_test.cc @@ -122,10 +122,10 @@ class CPUKernelTest : public OpKernel { public: void Compute(const ExecutionContext& ctx) const { std::cout << "this is cpu kernel" << std::endl; - std::cout << ctx.op_.DebugString() << std::endl; + std::cout << ctx.op().DebugString() << std::endl; cpu_kernel_run_num++; - ASSERT_EQ(ctx.op_.Input("x"), "IN1"); - ASSERT_EQ(ctx.op_.Output("y"), "OUT1"); + ASSERT_EQ(ctx.op().Input("x"), "IN1"); + ASSERT_EQ(ctx.op().Output("y"), "OUT1"); } }; @@ -148,7 +148,7 @@ class OpKernelTestMultiInputsProtoAndCheckerMaker class CPUKernalMultiInputsTest : public OpKernel { public: void Compute(const ExecutionContext& ctx) const { - auto xs = ctx.op_.Inputs("xs"); + auto xs = ctx.op().Inputs("xs"); ASSERT_EQ(xs.size(), 3UL); ASSERT_EQ(xs[0], "x0"); ASSERT_EQ(xs[1], "x1"); @@ -172,10 +172,10 @@ class CPUKernalMultiInputsTest : public OpKernel { auto outTensor0 = ctx.MultiOutput("ys"); ASSERT_EQ(outTensor0.size(), 2U); - auto k = ctx.op_.Input("k"); + auto k = ctx.op().Input("k"); ASSERT_EQ(k, "k0"); - auto ys = ctx.op_.Outputs("ys"); + auto ys = ctx.op().Outputs("ys"); ASSERT_EQ(ys.size(), 2UL); ASSERT_EQ(ys[0], "y0"); ASSERT_EQ(ys[1], "y1"); diff --git a/paddle/gserver/evaluators/CTCErrorEvaluator.cpp b/paddle/gserver/evaluators/CTCErrorEvaluator.cpp index 132119015f967c6e8d055792de8afe8450df5ec6..92087fa32b1e48b50fbf447ec6f3c43e2a510220 100644 --- a/paddle/gserver/evaluators/CTCErrorEvaluator.cpp +++ b/paddle/gserver/evaluators/CTCErrorEvaluator.cpp @@ -14,18 +14,20 @@ limitations under the License. */ #include "Evaluator.h" #include "paddle/gserver/gradientmachines/NeuralNetwork.h" +#include "paddle/utils/StringUtil.h" namespace paddle { /** * calculate sequence-to-sequence edit distance */ -class CTCErrorEvaluator : public NotGetableEvaluator { +class CTCErrorEvaluator : public Evaluator { private: MatrixPtr outActivations_; int numTimes_, numClasses_, numSequences_, blank_; real deletions_, insertions_, substitutions_; int seqClassficationError_; + mutable std::unordered_map evalResults_; std::vector path2String(const std::vector& path) { std::vector str; @@ -183,6 +185,18 @@ private: return stringAlignment(gtStr, recogStr); } + void storeLocalValues() const { + evalResults_["error"] = numSequences_ ? totalScore_ / numSequences_ : 0; + evalResults_["deletion_error"] = + numSequences_ ? deletions_ / numSequences_ : 0; + evalResults_["insertion_error"] = + numSequences_ ? insertions_ / numSequences_ : 0; + evalResults_["substitution_error"] = + numSequences_ ? substitutions_ / numSequences_ : 0; + evalResults_["sequence_error"] = + (real)seqClassficationError_ / numSequences_; + } + public: CTCErrorEvaluator() : numTimes_(0), @@ -245,16 +259,12 @@ public: } virtual void printStats(std::ostream& os) const { - os << config_.name() << "=" - << (numSequences_ ? totalScore_ / numSequences_ : 0); - os << " deletions error" - << "=" << (numSequences_ ? deletions_ / numSequences_ : 0); - os << " insertions error" - << "=" << (numSequences_ ? insertions_ / numSequences_ : 0); - os << " substitutions error" - << "=" << (numSequences_ ? substitutions_ / numSequences_ : 0); - os << " sequences error" - << "=" << (real)seqClassficationError_ / numSequences_; + storeLocalValues(); + os << config_.name() << " error = " << evalResults_["error"]; + os << " deletions error = " << evalResults_["deletion_error"]; + os << " insertions error = " << evalResults_["insertion_error"]; + os << " substitution error = " << evalResults_["substitution_error"]; + os << " sequence error = " << evalResults_["sequence_error"]; } virtual void distributeEval(ParameterClient2* client) { @@ -272,6 +282,37 @@ public: seqClassficationError_ = (int)buf[4]; numSequences_ = (int)buf[5]; } + + void getNames(std::vector* names) { + storeLocalValues(); + names->reserve(names->size() + evalResults_.size()); + for (auto it = evalResults_.begin(); it != evalResults_.end(); ++it) { + names->push_back(config_.name() + "." + it->first); + } + } + + real getValue(const std::string& name, Error* err) const { + storeLocalValues(); + + std::vector buffers; + paddle::str::split(name, '.', &buffers); + auto it = evalResults_.find(buffers[buffers.size() - 1]); + + if (it == evalResults_.end()) { + *err = Error("Evaluator does not have the key %s", name.c_str()); + return 0.0f; + } + + return it->second; + } + + std::string getType(const std::string& name, Error* err) const { + this->getValue(name, err); + if (!err->isOK()) { + return ""; + } + return "ctc_edit_distance"; + } }; REGISTER_EVALUATOR(ctc_edit_distance, CTCErrorEvaluator); diff --git a/paddle/gserver/evaluators/ChunkEvaluator.cpp b/paddle/gserver/evaluators/ChunkEvaluator.cpp index 1658282f3a5f79b128ce8685e92fd5cf9db2e41a..a2ab15eedee4aaa7b47504d50e25300359f18173 100644 --- a/paddle/gserver/evaluators/ChunkEvaluator.cpp +++ b/paddle/gserver/evaluators/ChunkEvaluator.cpp @@ -268,7 +268,13 @@ public: } // get type of evaluator - std::string getTypeImpl() const { return "chunk"; } + std::string getType(const std::string& name, Error* err) const { + this->getValue(name, err); + if (!err->isOK()) { + return ""; + } + return "chunk"; + } private: void storeLocalValues() const { diff --git a/paddle/gserver/evaluators/Evaluator.h b/paddle/gserver/evaluators/Evaluator.h index b114500e2b7c1e460a02c78b99b5f1a8fb63b8c3..90203553e0a5fe8cc8183274f374da178bae30d0 100644 --- a/paddle/gserver/evaluators/Evaluator.h +++ b/paddle/gserver/evaluators/Evaluator.h @@ -211,6 +211,7 @@ public: *err = Error("Not implemented"); return .0f; } + std::string getType(const std::string& name, Error* err) const { *err = Error("Not implemented"); return ""; @@ -331,6 +332,7 @@ private: protected: std::string getTypeImpl() const; }; + /** * @brief precision, recall and f1 score Evaluator * \f[ @@ -358,6 +360,12 @@ public: virtual void distributeEval(ParameterClient2* client); + void getNames(std::vector* names); + + real getValue(const std::string& name, Error* err) const; + + std::string getType(const std::string& name, Error* err) const; + struct StatsInfo { /// numbers of true positives double TP; @@ -428,11 +436,6 @@ private: mutable std::unordered_map values_; void storeLocalValues() const; - // Evaluator interface -public: - void getNames(std::vector* names); - real getValue(const std::string& name, Error* err) const; - std::string getType(const std::string& name, Error* err) const; }; /* diff --git a/paddle/operators/add_op.cc b/paddle/operators/add_op.cc index 6384d8c8ce13dae8b58ed1069d496dd8e93eaa8a..8ab748ed71e9a5dc0ee0259a78a2b886870bec5b 100644 --- a/paddle/operators/add_op.cc +++ b/paddle/operators/add_op.cc @@ -57,7 +57,7 @@ class AddOpGrad : public framework::OperatorWithKernel { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP(add_two, ops::AddOp, ops::AddOpMaker, ops::AddOpGrad); +REGISTER_OP(add_two, ops::AddOp, ops::AddOpMaker, add_two_grad, ops::AddOpGrad); REGISTER_OP_CPU_KERNEL(add_two, ops::AddKernel); diff --git a/paddle/operators/cross_entropy_op.cc b/paddle/operators/cross_entropy_op.cc index ac76326262c88e2014cf64f7fb73b5a7338ab3e9..ab1e1c101a10e09a81f7785d2f1514822e3bdf15 100644 --- a/paddle/operators/cross_entropy_op.cc +++ b/paddle/operators/cross_entropy_op.cc @@ -67,7 +67,8 @@ OnehotCrossEntropy Operator. namespace ops = paddle::operators; REGISTER_OP(onehot_cross_entropy, ops::OnehotCrossEntropyOp, - ops::OnehotCrossEntropyOpMaker, ops::OnehotCrossEntropyGradientOp); + ops::OnehotCrossEntropyOpMaker, onehot_cross_entropy_grad, + ops::OnehotCrossEntropyGradientOp); REGISTER_OP_CPU_KERNEL(onehot_cross_entropy, ops::OnehotCrossEntropyOpKernel); REGISTER_OP_CPU_KERNEL(onehot_cross_entropy_grad, diff --git a/paddle/operators/gather_op.cc b/paddle/operators/gather_op.cc index 07fa704824174f939e459093b245036771d9cd4f..123bed296c462c30bddd3bfbd530098fdbfe4856 100644 --- a/paddle/operators/gather_op.cc +++ b/paddle/operators/gather_op.cc @@ -63,7 +63,8 @@ Out = X[Index] } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP(gather, ops::GatherOp, ops::GatherOpMaker, ops::GatherGradOp); +REGISTER_OP(gather, ops::GatherOp, ops::GatherOpMaker, gather_grad, + ops::GatherGradOp); REGISTER_OP_CPU_KERNEL(gather, ops::GatherOpKernel); REGISTER_OP_CPU_KERNEL( diff --git a/paddle/operators/gaussian_random_op.cc b/paddle/operators/gaussian_random_op.cc index a85363ad81d2a23e7267026c067f74f8c94c4786..056447901d418355c01d499f7d92d0b59a39edfa 100644 --- a/paddle/operators/gaussian_random_op.cc +++ b/paddle/operators/gaussian_random_op.cc @@ -19,13 +19,12 @@ template class CPUGaussianRandomKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - float mean = context.op_.GetAttr("mean"); - float std = context.op_.GetAttr("std"); + float mean = context.GetAttr("mean"); + float std = context.GetAttr("std"); auto* tensor = context.Output("Out"); T* data = tensor->mutable_data(context.GetPlace()); - unsigned int seed = - static_cast(context.op_.GetAttr("seed")); + unsigned int seed = static_cast(context.GetAttr("seed")); std::minstd_rand engine; if (seed == 0) { seed = std::random_device()(); diff --git a/paddle/operators/gaussian_random_op.cu b/paddle/operators/gaussian_random_op.cu index 018a4bfcb26b9008c054000c91edf01e371fd82b..833a82bbf293a0892531283dc681ca2edd72f6a1 100644 --- a/paddle/operators/gaussian_random_op.cu +++ b/paddle/operators/gaussian_random_op.cu @@ -42,14 +42,13 @@ class GPUGaussianRandomKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& context) const override { auto* tensor = context.Output("Out"); T* data = tensor->mutable_data(context.GetPlace()); - unsigned int seed = - static_cast(context.op_.GetAttr("seed")); + unsigned int seed = static_cast(context.GetAttr("seed")); if (seed == 0) { std::random_device rd; seed = rd(); } - T mean = static_cast(context.op_.GetAttr("mean")); - T std = static_cast(context.op_.GetAttr("std")); + T mean = static_cast(context.GetAttr("mean")); + T std = static_cast(context.GetAttr("std")); thrust::counting_iterator index_sequence_begin(0); ssize_t N = framework::product(tensor->dims()); thrust::transform(index_sequence_begin, index_sequence_begin + N, diff --git a/paddle/operators/lookup_table_op.cc b/paddle/operators/lookup_table_op.cc index c3108ba8ec7ad85bd3485c135bf03e514bc66cd1..94d40890a765413e88a35a6ad995ca97ac84dcda 100644 --- a/paddle/operators/lookup_table_op.cc +++ b/paddle/operators/lookup_table_op.cc @@ -66,7 +66,7 @@ class LookupTableOpGrad : public framework::OperatorWithKernel { namespace ops = paddle::operators; REGISTER_OP(lookup_table, ops::LookupTableOp, ops::LookupTableOpMaker, - ops::LookupTableOpGrad); + lookup_table_grad, ops::LookupTableOpGrad); REGISTER_OP_CPU_KERNEL(lookup_table, ops::LookupTableKernel); REGISTER_OP_CPU_KERNEL(lookup_table_grad, ops::LookupTableGradKernel); diff --git a/paddle/operators/mean_op.cc b/paddle/operators/mean_op.cc index e66e0abb25f9b933025a6d098ed9dd9eb18a47a5..d3d0e55a674587fb04f43f24d0790de4358f035a 100644 --- a/paddle/operators/mean_op.cc +++ b/paddle/operators/mean_op.cc @@ -54,7 +54,7 @@ class MeanGradOp : public framework::OperatorWithKernel { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP(mean, ops::MeanOp, ops::MeanOpMaker, ops::MeanGradOp); +REGISTER_OP(mean, ops::MeanOp, ops::MeanOpMaker, mean_grad, ops::MeanGradOp); REGISTER_OP_CPU_KERNEL(mean, ops::MeanKernel); REGISTER_OP_CPU_KERNEL(mean_grad, diff --git a/paddle/operators/minus_op.cc b/paddle/operators/minus_op.cc index b4afebcd97a8efff70aaaa85bc2ec5455ddd05c5..1eee9644babbdfac68821ca774845ad8ebbd5aee 100644 --- a/paddle/operators/minus_op.cc +++ b/paddle/operators/minus_op.cc @@ -81,6 +81,7 @@ class MinusGradOp : public NetOp { USE_OP(scale); USE_OP_ITSELF(identity); namespace ops = paddle::operators; -REGISTER_OP(minus, ops::MinusOp, ops::MinusOpMaker, ops::MinusGradOp); +REGISTER_OP(minus, ops::MinusOp, ops::MinusOpMaker, minus_grad, + ops::MinusGradOp); REGISTER_OP_CPU_KERNEL(minus, ops::MinusKernel); diff --git a/paddle/operators/mul_op.cc b/paddle/operators/mul_op.cc index f668008a10f74406b36e16318f756a82893e06af..d301a8619f966c1269c76e6d6085d4cf187a3fd2 100644 --- a/paddle/operators/mul_op.cc +++ b/paddle/operators/mul_op.cc @@ -33,11 +33,11 @@ class MulOp : public framework::OperatorWithKernel { PADDLE_ENFORCE(x_dim.size() > x_num_row_dims, "The rank of input tensor X(%s) should be larger than " "`mul_op`'s `X_num_raw_dims`.", - ctx.op_.Input("X")); + ctx.op().Input("X")); PADDLE_ENFORCE(y_dim.size() > y_num_row_dims, "The rank of input tensor Y(%s) should be larger than " "`mul_op`'s `Y_num_raw_dims`.", - ctx.op_.Input("Y")); + ctx.op().Input("Y")); PADDLE_ENFORCE_EQ( product(x_dim, x_dim.size() - x_num_row_dims, x_dim.size()), product(y_dim, 0, y_dim.size() - y_num_row_dims), @@ -113,7 +113,7 @@ class MulOpGrad : public framework::OperatorWithKernel { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP(mul, ops::MulOp, ops::MulOpMaker, ops::MulOpGrad); +REGISTER_OP(mul, ops::MulOp, ops::MulOpMaker, mul_grad, ops::MulOpGrad); REGISTER_OP_CPU_KERNEL(mul, ops::MulKernel); REGISTER_OP_CPU_KERNEL(mul_grad, ops::MulGradKernel); diff --git a/paddle/operators/rowwise_add_op.cc b/paddle/operators/rowwise_add_op.cc index 63de91254f4b75587cb2fb29aeb8ff7358ba8e76..6825dce332adc0dc11dda187d1bd367875b8603e 100644 --- a/paddle/operators/rowwise_add_op.cc +++ b/paddle/operators/rowwise_add_op.cc @@ -74,7 +74,7 @@ class RowwiseAddGradOp : public framework::OperatorWithKernel { namespace ops = paddle::operators; REGISTER_OP(rowwise_add, ops::RowwiseAddOp, ops::RowwiseAddOpMaker, - ops::RowwiseAddGradOp); + rowwise_add_grad, ops::RowwiseAddGradOp); REGISTER_OP_CPU_KERNEL( rowwise_add, ops::RowwiseAddKernel); REGISTER_OP_CPU_KERNEL( diff --git a/paddle/operators/scale_op.cc b/paddle/operators/scale_op.cc index 4e039688d4d74f2a101fc91c747bd1e6ebec7ad2..8e96a74c94ab7ff4d8c3266695e5157aff67905b 100644 --- a/paddle/operators/scale_op.cc +++ b/paddle/operators/scale_op.cc @@ -97,7 +97,7 @@ class IdentityOp : public NetOp { namespace ops = paddle::operators; -REGISTER_OP(scale, ops::ScaleOp, ops::ScaleOpMaker, +REGISTER_OP(scale, ops::ScaleOp, ops::ScaleOpMaker, scale_grad, ops::ScaleGradOp); REGISTER_OP_CPU_KERNEL(scale, ops::ScaleKernel); diff --git a/paddle/operators/scale_op.h b/paddle/operators/scale_op.h index aea64f1b0428ffe79ba8d90cf79dbfd2b5ef36f4..65fb77eefad812fa52ac053b791ba1b8f480375f 100644 --- a/paddle/operators/scale_op.h +++ b/paddle/operators/scale_op.h @@ -27,7 +27,7 @@ class ScaleKernel : public framework::OpKernel { auto* in = context.Input("X"); tensor->mutable_data(in->place()); - auto scale = static_cast(context.op_.GetAttr("scale")); + auto scale = static_cast(context.GetAttr("scale")); auto eigen_out = framework::EigenVector::Flatten(*tensor); auto eigen_in = framework::EigenVector::Flatten(*in); diff --git a/paddle/operators/scatter_op.cc b/paddle/operators/scatter_op.cc index 35c185ad80f93d1005c1616dcffd2e61bcd54222..f901edefa22dc9a252e87116df756d04767a7162 100644 --- a/paddle/operators/scatter_op.cc +++ b/paddle/operators/scatter_op.cc @@ -77,7 +77,8 @@ Out[Index] = Ref[Index] + Updates } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP(scatter, ops::ScatterOp, ops::ScatterOpMaker, ops::ScatterGradOp); +REGISTER_OP(scatter, ops::ScatterOp, ops::ScatterOpMaker, scatter_grad, + ops::ScatterGradOp); REGISTER_OP_CPU_KERNEL(scatter, ops::ScatterOpKernel); REGISTER_OP_CPU_KERNEL( diff --git a/paddle/operators/sgd_op.h b/paddle/operators/sgd_op.h index a0b5000ffbf54364e15f87870913926a071fa972..8422b622ee54ba76fb98b7dacfa9618031c1c88c 100644 --- a/paddle/operators/sgd_op.h +++ b/paddle/operators/sgd_op.h @@ -31,7 +31,7 @@ class SGDOpKernel : public framework::OpKernel { auto param = ctx.Input("param"); auto grad = ctx.Input("grad"); auto param_out = ctx.Output("param_out"); - float lr = ctx.op_.GetAttr("learning_rate"); + float lr = ctx.GetAttr("learning_rate"); param_out->mutable_data(ctx.GetPlace()); diff --git a/paddle/operators/sigmoid_op.cc b/paddle/operators/sigmoid_op.cc index f35b7023845bac52887d81a8f5c496cb5e7193aa..761c6de8d4d2150b30b97b58da95da3d5f33db63 100644 --- a/paddle/operators/sigmoid_op.cc +++ b/paddle/operators/sigmoid_op.cc @@ -53,7 +53,8 @@ class SigmoidOpGrad : public framework::OperatorWithKernel { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP(sigmoid, ops::SigmoidOp, ops::SigmoidOpMaker, ops::SigmoidOpGrad); +REGISTER_OP(sigmoid, ops::SigmoidOp, ops::SigmoidOpMaker, sigmoid_grad, + ops::SigmoidOpGrad); REGISTER_OP_CPU_KERNEL(sigmoid, ops::SigmoidKernel); REGISTER_OP_CPU_KERNEL( diff --git a/paddle/operators/softmax_op.cc b/paddle/operators/softmax_op.cc index 471bb288fb20f113aefb2a9e13eb805b161b0631..40c51a64c49bc064f55975ef6ced1d54070f1291 100644 --- a/paddle/operators/softmax_op.cc +++ b/paddle/operators/softmax_op.cc @@ -62,7 +62,8 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel { namespace ops = paddle::operators; -REGISTER_OP(softmax, ops::SoftmaxOp, ops::SoftmaxOpMaker, ops::SoftmaxOpGrad); +REGISTER_OP(softmax, ops::SoftmaxOp, ops::SoftmaxOpMaker, softmax_grad, + ops::SoftmaxOpGrad); REGISTER_OP_CPU_KERNEL(softmax, ops::SoftmaxKernel); REGISTER_OP_CPU_KERNEL( diff --git a/paddle/operators/uniform_random_op.cc b/paddle/operators/uniform_random_op.cc index 29491137e6d8b4bfa2d0d07d48ffed1212a6131f..2d943c4508490d25a8747330c92c24c384bd0232 100644 --- a/paddle/operators/uniform_random_op.cc +++ b/paddle/operators/uniform_random_op.cc @@ -26,16 +26,15 @@ class CPUUniformRandomKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& context) const override { auto* tensor = context.Output("Out"); T* data = tensor->mutable_data(context.GetPlace()); - unsigned int seed = - static_cast(context.op_.GetAttr("seed")); + unsigned int seed = static_cast(context.GetAttr("seed")); std::minstd_rand engine; if (seed == 0) { seed = std::random_device()(); } engine.seed(seed); std::uniform_real_distribution dist( - static_cast(context.op_.GetAttr("min")), - static_cast(context.op_.GetAttr("max"))); + static_cast(context.GetAttr("min")), + static_cast(context.GetAttr("max"))); ssize_t size = framework::product(tensor->dims()); for (ssize_t i = 0; i < size; ++i) { data[i] = dist(engine); diff --git a/paddle/operators/uniform_random_op.cu b/paddle/operators/uniform_random_op.cu index 1d6709934cbbcf50265eabef87c857654f783ed8..df993c07794b0b2408e4edc8a45fae9a17aef01c 100644 --- a/paddle/operators/uniform_random_op.cu +++ b/paddle/operators/uniform_random_op.cu @@ -45,14 +45,13 @@ class GPUUniformRandomKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& context) const override { auto* tensor = context.Output("Out"); T* data = tensor->mutable_data(context.GetPlace()); - unsigned int seed = - static_cast(context.op_.GetAttr("seed")); + unsigned int seed = static_cast(context.GetAttr("seed")); if (seed == 0) { std::random_device rd; seed = rd(); } - T min = static_cast(context.op_.GetAttr("min")); - T max = static_cast(context.op_.GetAttr("max")); + T min = static_cast(context.GetAttr("min")); + T max = static_cast(context.GetAttr("max")); thrust::counting_iterator index_sequence_begin(0); ssize_t N = framework::product(tensor->dims()); thrust::transform(index_sequence_begin, index_sequence_begin + N, diff --git a/paddle/platform/CMakeLists.txt b/paddle/platform/CMakeLists.txt index 120eb1e4af9cef43e76e27d4ad66acfbbd597a36..17bdac8749e31565b119b2cb84aed199fac0f441 100644 --- a/paddle/platform/CMakeLists.txt +++ b/paddle/platform/CMakeLists.txt @@ -22,3 +22,5 @@ ENDIF() cc_library(device_context SRCS device_context.cc DEPS memory buddy_allocator system_allocator memory_block meta_data meta_cache place eigen3 ${GPU_CTX_DEPS}) nv_test(device_context_test SRCS device_context_test.cc DEPS device_context gpu_info) + +nv_test(cudnn_helper_test SRCS cudnn_helper_test.cc DEPS dynload_cuda) diff --git a/paddle/platform/cudnn_helper.h b/paddle/platform/cudnn_helper.h new file mode 100644 index 0000000000000000000000000000000000000000..24ddf3441caa6e5f45a7b96af26a23ed324dc1b6 --- /dev/null +++ b/paddle/platform/cudnn_helper.h @@ -0,0 +1,200 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/platform/dynload/cudnn.h" +#include "paddle/platform/enforce.h" +#include "paddle/platform/macros.h" + +namespace paddle { +namespace platform { + +enum class DataLayout { + kNHWC, + kNCHW, + kNCHW_VECT_C, +}; + +enum class PoolingMode { + kMaximum, + kAverage, +}; + +template +class CudnnDataType; + +template <> +class CudnnDataType { + public: + static const cudnnDataType_t type = CUDNN_DATA_FLOAT; +}; + +template <> +class CudnnDataType { + public: + static const cudnnDataType_t type = CUDNN_DATA_DOUBLE; +}; + +inline cudnnTensorFormat_t GetCudnnTensorFormat(const DataLayout& order) { + switch (order) { + case DataLayout::kNHWC: + return CUDNN_TENSOR_NHWC; + case DataLayout::kNCHW: + return CUDNN_TENSOR_NCHW; + default: + PADDLE_THROW("Unknown cudnn equivalent for order"); + } + return CUDNN_TENSOR_NCHW; +} + +class ScopedTensorDescriptor { + public: + ScopedTensorDescriptor() { + PADDLE_ENFORCE(dynload::cudnnCreateTensorDescriptor(&desc_)); + } + ~ScopedTensorDescriptor() { + PADDLE_ENFORCE(dynload::cudnnDestroyTensorDescriptor(desc_)); + } + + inline cudnnTensorDescriptor_t descriptor(const cudnnTensorFormat_t format, + const cudnnDataType_t type, + const std::vector& dims) { + // the format is not used now, but it maybe useful feature + std::vector strides(dims.size()); + strides[dims.size() - 1] = 1; + for (int i = dims.size() - 2; i >= 0; i--) { + strides[i] = dims[i + 1] * strides[i + 1]; + } + PADDLE_ENFORCE(dynload::cudnnSetTensorNdDescriptor( + desc_, type, dims.size(), dims.data(), strides.data())); + return desc_; + } + + template + inline cudnnTensorDescriptor_t descriptor(const DataLayout& order, + const std::vector& dims) { + return descriptor(GetCudnnTensorFormat(order), CudnnDataType::type, + dims); + } + + private: + cudnnTensorDescriptor_t desc_; + DISABLE_COPY_AND_ASSIGN(ScopedTensorDescriptor); +}; + +class ScopedFilterDescriptor { + public: + ScopedFilterDescriptor() { + PADDLE_ENFORCE(dynload::cudnnCreateFilterDescriptor(&desc_)); + } + ~ScopedFilterDescriptor() { + PADDLE_ENFORCE(dynload::cudnnDestroyFilterDescriptor(desc_)); + } + + inline cudnnFilterDescriptor_t descriptor(const cudnnTensorFormat_t format, + const cudnnDataType_t type, + const std::vector& kernel) { + // filter layout: output input spatial_dim_y spatial_dim_x + PADDLE_ENFORCE(dynload::cudnnSetFilterNdDescriptor( + desc_, type, format, kernel.size(), kernel.data())); + return desc_; + } + + template + inline cudnnFilterDescriptor_t descriptor(const DataLayout& order, + const std::vector& kernel) { + return descriptor(GetCudnnTensorFormat(order), CudnnDataType::type, + kernel); + } + + private: + cudnnFilterDescriptor_t desc_; + DISABLE_COPY_AND_ASSIGN(ScopedFilterDescriptor); +}; + +class ScopedConvolutionDescriptor { + public: + ScopedConvolutionDescriptor() { + PADDLE_ENFORCE(dynload::cudnnCreateConvolutionDescriptor(&desc_)); + } + ~ScopedConvolutionDescriptor() { + PADDLE_ENFORCE(dynload::cudnnDestroyConvolutionDescriptor(desc_)); + } + + inline cudnnConvolutionDescriptor_t descriptor( + cudnnDataType_t type, const std::vector& pads, + const std::vector& strides, const std::vector& dilations) { + PADDLE_ENFORCE_EQ(pads.size(), strides.size()); + PADDLE_ENFORCE_EQ(pads.size(), dilations.size()); + +#if CUDNN_VERSION < 6000 + // cudnn v5 does not support dilation conv, the argument is called upscale + // instead of dilations and it is must be one. + for (size_t i = 0; i < dilations.size(); ++i) { + PADDLE_ENFORCE_EQ( + dilations[i], 1, + "Dilations conv is not supported in this cuDNN version"); + } +#endif + + PADDLE_ENFORCE(dynload::cudnnSetConvolutionNdDescriptor( + desc_, pads.size(), pads.data(), strides.data(), dilations.data(), + CUDNN_CROSS_CORRELATION, type)); + return desc_; + } + + template + inline cudnnConvolutionDescriptor_t descriptor( + const std::vector& pads, const std::vector& strides, + const std::vector& dilations) { + return descriptor(CudnnDataType::type, pads, strides, dilations); + } + + private: + cudnnConvolutionDescriptor_t desc_; + DISABLE_COPY_AND_ASSIGN(ScopedConvolutionDescriptor); +}; + +class ScopedPoolingDescriptor { + public: + ScopedPoolingDescriptor() { + PADDLE_ENFORCE(dynload::cudnnCreatePoolingDescriptor(&desc_)); + } + ~ScopedPoolingDescriptor() { + PADDLE_ENFORCE(dynload::cudnnDestroyPoolingDescriptor(desc_)); + } + + inline cudnnPoolingDescriptor_t descriptor(const PoolingMode& mode, + const std::vector& kernel, + const std::vector& pads, + const std::vector& strides) { + PADDLE_ENFORCE_EQ(kernel.size(), pads.size()); + PADDLE_ENFORCE_EQ(kernel.size(), strides.size()); + PADDLE_ENFORCE(dynload::cudnnSetPoolingNdDescriptor( + desc_, (mode == PoolingMode::kMaximum + ? CUDNN_POOLING_MAX + : CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING), + CUDNN_PROPAGATE_NAN, // Always propagate nans. + kernel.size(), kernel.data(), pads.data(), strides.data())); + return desc_; + } + + private: + cudnnPoolingDescriptor_t desc_; + DISABLE_COPY_AND_ASSIGN(ScopedPoolingDescriptor); +}; + +} // namespace platform +} // namespace paddle diff --git a/paddle/platform/cudnn_helper_test.cc b/paddle/platform/cudnn_helper_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..6bd85ae1ca8b47b203e0321e9d9224d5cfd3a586 --- /dev/null +++ b/paddle/platform/cudnn_helper_test.cc @@ -0,0 +1,120 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/platform/cudnn_helper.h" +#include + +TEST(CudnnHelper, ScopedTensorDescriptor) { + using paddle::platform::ScopedTensorDescriptor; + using paddle::platform::DataLayout; + + ScopedTensorDescriptor tensor_desc; + std::vector shape = {2, 4, 6, 6}; + auto desc = tensor_desc.descriptor(DataLayout::kNCHW, shape); + + cudnnDataType_t type; + int nd; + std::vector dims(4); + std::vector strides(4); + paddle::platform::dynload::cudnnGetTensorNdDescriptor( + desc, 4, &type, &nd, dims.data(), strides.data()); + + EXPECT_EQ(nd, 4); + for (size_t i = 0; i < dims.size(); ++i) { + EXPECT_EQ(dims[i], shape[i]); + } + EXPECT_EQ(strides[3], 1); + EXPECT_EQ(strides[2], 6); + EXPECT_EQ(strides[1], 36); + EXPECT_EQ(strides[0], 144); +} + +TEST(CudnnHelper, ScopedFilterDescriptor) { + using paddle::platform::ScopedFilterDescriptor; + using paddle::platform::DataLayout; + + ScopedFilterDescriptor filter_desc; + std::vector shape = {2, 3, 3}; + auto desc = filter_desc.descriptor(DataLayout::kNCHW, shape); + + cudnnDataType_t type; + int nd; + cudnnTensorFormat_t format; + std::vector kernel(3); + paddle::platform::dynload::cudnnGetFilterNdDescriptor(desc, 3, &type, &format, + &nd, kernel.data()); + + EXPECT_EQ(GetCudnnTensorFormat(DataLayout::kNCHW), format); + EXPECT_EQ(nd, 3); + for (size_t i = 0; i < shape.size(); ++i) { + EXPECT_EQ(kernel[i], shape[i]); + } +} + +TEST(CudnnHelper, ScopedConvolutionDescriptor) { + using paddle::platform::ScopedConvolutionDescriptor; + + ScopedConvolutionDescriptor conv_desc; + std::vector src_pads = {2, 2, 2}; + std::vector src_strides = {1, 1, 1}; + std::vector src_dilations = {1, 1, 1}; + auto desc = conv_desc.descriptor(src_pads, src_strides, src_dilations); + + cudnnDataType_t type; + cudnnConvolutionMode_t mode; + int nd; + std::vector pads(3); + std::vector strides(3); + std::vector dilations(3); + paddle::platform::dynload::cudnnGetConvolutionNdDescriptor( + desc, 3, &nd, pads.data(), strides.data(), dilations.data(), &mode, + &type); + + EXPECT_EQ(nd, 3); + for (size_t i = 0; i < src_pads.size(); ++i) { + EXPECT_EQ(pads[i], src_pads[i]); + EXPECT_EQ(strides[i], src_strides[i]); + EXPECT_EQ(dilations[i], src_dilations[i]); + } + EXPECT_EQ(mode, CUDNN_CROSS_CORRELATION); +} + +TEST(CudnnHelper, ScopedPoolingDescriptor) { + using paddle::platform::ScopedPoolingDescriptor; + using paddle::platform::PoolingMode; + + ScopedPoolingDescriptor pool_desc; + std::vector src_kernel = {2, 2, 5}; + std::vector src_pads = {1, 1, 2}; + std::vector src_strides = {2, 2, 3}; + auto desc = pool_desc.descriptor(PoolingMode::kMaximum, src_kernel, src_pads, + src_strides); + + cudnnPoolingMode_t mode; + cudnnNanPropagation_t nan_t = CUDNN_PROPAGATE_NAN; + int nd; + std::vector kernel(3); + std::vector pads(3); + std::vector strides(3); + paddle::platform::dynload::cudnnGetPoolingNdDescriptor( + desc, 3, &mode, &nan_t, &nd, kernel.data(), pads.data(), strides.data()); + + EXPECT_EQ(nd, 3); + for (size_t i = 0; i < src_pads.size(); ++i) { + EXPECT_EQ(kernel[i], src_kernel[i]); + EXPECT_EQ(pads[i], src_pads[i]); + EXPECT_EQ(strides[i], src_strides[i]); + } + EXPECT_EQ(mode, CUDNN_POOLING_MAX); +} diff --git a/paddle/platform/dynload/CMakeLists.txt b/paddle/platform/dynload/CMakeLists.txt index d205ead84598e04eea523be32139959a02e0dd83..ceb66f84b6b01892cbaf61c79a47ae60d2589164 100644 --- a/paddle/platform/dynload/CMakeLists.txt +++ b/paddle/platform/dynload/CMakeLists.txt @@ -1,2 +1,2 @@ cc_library(dynamic_loader SRCS dynamic_loader.cc DEPS glog gflags) -nv_library(dynload_cuda SRCS cublas.cc cudnn.cc curand.cc) +nv_library(dynload_cuda SRCS cublas.cc cudnn.cc curand.cc DEPS dynamic_loader) diff --git a/paddle/platform/dynload/cudnn.h b/paddle/platform/dynload/cudnn.h index ef0dd85b083dc2335dd5c70d3dc5f59eda25daeb..0120625b7c14448f1b8deb88c24a3ee06eaf4f01 100644 --- a/paddle/platform/dynload/cudnn.h +++ b/paddle/platform/dynload/cudnn.h @@ -62,19 +62,27 @@ extern void* cudnn_dso_handle; #define CUDNN_DNN_ROUTINE_EACH(__macro) \ __macro(cudnnSetTensor4dDescriptor); \ __macro(cudnnSetTensor4dDescriptorEx); \ + __macro(cudnnSetTensorNdDescriptor); \ + __macro(cudnnGetTensorNdDescriptor); \ __macro(cudnnGetConvolutionNdForwardOutputDim); \ __macro(cudnnGetConvolutionForwardAlgorithm); \ __macro(cudnnCreateTensorDescriptor); \ __macro(cudnnDestroyTensorDescriptor); \ __macro(cudnnCreateFilterDescriptor); \ __macro(cudnnSetFilter4dDescriptor); \ + __macro(cudnnSetFilterNdDescriptor); \ + __macro(cudnnGetFilterNdDescriptor); \ __macro(cudnnSetPooling2dDescriptor); \ + __macro(cudnnSetPoolingNdDescriptor); \ + __macro(cudnnGetPoolingNdDescriptor); \ __macro(cudnnDestroyFilterDescriptor); \ __macro(cudnnCreateConvolutionDescriptor); \ __macro(cudnnCreatePoolingDescriptor); \ __macro(cudnnDestroyPoolingDescriptor); \ __macro(cudnnSetConvolution2dDescriptor); \ __macro(cudnnDestroyConvolutionDescriptor); \ + __macro(cudnnSetConvolutionNdDescriptor); \ + __macro(cudnnGetConvolutionNdDescriptor); \ __macro(cudnnCreate); \ __macro(cudnnDestroy); \ __macro(cudnnSetStream); \ diff --git a/paddle/platform/macros.h b/paddle/platform/macros.h new file mode 100644 index 0000000000000000000000000000000000000000..4a04a38c0c6f905639004dea2f4416ecc57c8620 --- /dev/null +++ b/paddle/platform/macros.h @@ -0,0 +1,23 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +// Disable the copy and assignment operator for a class. +#ifndef DISABLE_COPY_AND_ASSIGN +#define DISABLE_COPY_AND_ASSIGN(classname) \ + private: \ + classname(const classname&) = delete; \ + classname& operator=(const classname&) = delete +#endif diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py index 2bd274fad2ab7eed0902ffe944c6e0670f963233..47ac601e678013aceb62005d6f25595f49673d2c 100644 --- a/python/paddle/trainer_config_helpers/layers.py +++ b/python/paddle/trainer_config_helpers/layers.py @@ -53,7 +53,7 @@ __all__ = [ 'cos_sim', 'hsigmoid', 'conv_projection', - 'mse_cost', + 'square_error_cost', 'regression_cost', 'classification_cost', 'LayerOutput', @@ -4238,13 +4238,18 @@ def __cost_input__(input, label, weight=None): @wrap_name_default() @layer_support() -def mse_cost(input, label, weight=None, name=None, coeff=1.0, layer_attr=None): +def square_error_cost(input, + label, + weight=None, + name=None, + coeff=1.0, + layer_attr=None): """ - mean squared error cost: + sum of square error cost: .. math:: - \\frac{1}{N}\sum_{i=1}^N(t_i-y_i)^2 + cost = \\sum_{i=1}^N(t_i-y_i)^2 :param name: layer name. :type name: basestring @@ -4273,7 +4278,7 @@ def mse_cost(input, label, weight=None, name=None, coeff=1.0, layer_attr=None): return LayerOutput(name, LayerType.COST, parents=parents, size=1) -regression_cost = mse_cost +regression_cost = square_error_cost @wrap_name_default("cost") @@ -5798,9 +5803,9 @@ def huber_regression_cost(input, coeff=1.0, layer_attr=None): """ - In statistics, the Huber loss is a loss function used in robust regression, - that is less sensitive to outliers in data than the squared error loss. - Given a prediction f(x), a label y and :math:`\delta`, the loss function + In statistics, the Huber loss is a loss function used in robust regression, + that is less sensitive to outliers in data than the squared error loss. + Given a prediction f(x), a label y and :math:`\delta`, the loss function is defined as: .. math: @@ -5848,13 +5853,13 @@ def huber_classification_cost(input, coeff=1.0, layer_attr=None): """ - For classification purposes, a variant of the Huber loss called modified Huber - is sometimes used. Given a prediction f(x) (a real-valued classifier score) and - a true binary class label :math:`y\in \left \{-1, 1 \right \}`, the modified Huber + For classification purposes, a variant of the Huber loss called modified Huber + is sometimes used. Given a prediction f(x) (a real-valued classifier score) and + a true binary class label :math:`y\in \left \{-1, 1 \right \}`, the modified Huber loss is defined as: .. math: - loss = \max \left ( 0, 1-yf(x) \right )^2, yf(x)\geq 1 + loss = \max \left ( 0, 1-yf(x) \right )^2, yf(x)\geq 1 loss = -4yf(x), \text{otherwise} The example usage is: diff --git a/python/paddle/trainer_config_helpers/tests/configs/protostr/test_cost_layers_with_weight.protostr b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_cost_layers_with_weight.protostr index 96fb1d4ebde08b1bca2ffd09e8db0895842cbfd3..cec8a73db66f6091ec971527b3a42aa9e08154eb 100644 --- a/python/paddle/trainer_config_helpers/tests/configs/protostr/test_cost_layers_with_weight.protostr +++ b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_cost_layers_with_weight.protostr @@ -45,7 +45,7 @@ layers { coeff: 1.0 } layers { - name: "__mse_cost_0__" + name: "__square_error_cost_0__" type: "square_error" size: 1 active_type: "" @@ -130,7 +130,7 @@ input_layer_names: "label" input_layer_names: "weight" input_layer_names: "multi_class_label" output_layer_names: "__cost_0__" -output_layer_names: "__mse_cost_0__" +output_layer_names: "__square_error_cost_0__" output_layer_names: "__nce_layer_0__" evaluators { name: "classification_error_evaluator" @@ -146,7 +146,7 @@ sub_models { layer_names: "weight" layer_names: "__fc_layer_0__" layer_names: "__cost_0__" - layer_names: "__mse_cost_0__" + layer_names: "__square_error_cost_0__" layer_names: "multi_class_label" layer_names: "__nce_layer_0__" input_layer_names: "input" @@ -154,7 +154,7 @@ sub_models { input_layer_names: "weight" input_layer_names: "multi_class_label" output_layer_names: "__cost_0__" - output_layer_names: "__mse_cost_0__" + output_layer_names: "__square_error_cost_0__" output_layer_names: "__nce_layer_0__" evaluator_names: "classification_error_evaluator" is_recurrent_layer_group: false diff --git a/python/paddle/trainer_config_helpers/tests/configs/test_cost_layers_with_weight.py b/python/paddle/trainer_config_helpers/tests/configs/test_cost_layers_with_weight.py index c369062930e2b067ceab0dc3b25ba6c1eabe2450..caa6aaa9430ffaee7ade93ee04ec90103bf8cf43 100644 --- a/python/paddle/trainer_config_helpers/tests/configs/test_cost_layers_with_weight.py +++ b/python/paddle/trainer_config_helpers/tests/configs/test_cost_layers_with_weight.py @@ -10,7 +10,7 @@ fc = fc_layer(input=data, size=10, act=SoftmaxActivation()) outputs( classification_cost( input=fc, label=lbl, weight=wt), - mse_cost( + square_error_cost( input=fc, label=lbl, weight=wt), nce_layer( input=fc, diff --git a/python/paddle/v2/tests/test_layer.py b/python/paddle/v2/tests/test_layer.py index 783a0ca85dc61b9f00ac8126e03788884dfb44cb..de932ad715bea8db158393c3c192ef67502e2fa3 100644 --- a/python/paddle/v2/tests/test_layer.py +++ b/python/paddle/v2/tests/test_layer.py @@ -134,8 +134,9 @@ class CostLayerTest(unittest.TestCase): cost3 = layer.cross_entropy_cost(input=inference, label=label) cost4 = layer.cross_entropy_with_selfnorm_cost( input=inference, label=label) - cost5 = layer.mse_cost(input=inference, label=label) - cost6 = layer.mse_cost(input=inference, label=label, weight=weight) + cost5 = layer.square_error_cost(input=inference, label=label) + cost6 = layer.square_error_cost( + input=inference, label=label, weight=weight) cost7 = layer.multi_binary_label_cross_entropy_cost( input=inference, label=label) cost8 = layer.rank_cost(left=score, right=score, label=score)