diff --git a/doc/howto/cross_compiling/cross_compiling_for_ios_cn.md b/doc/howto/cross_compiling/cross_compiling_for_ios_cn.md new file mode 100644 index 0000000000000000000000000000000000000000..32c490d9aa4202e17aa1784a45a317c5307b98ea --- /dev/null +++ b/doc/howto/cross_compiling/cross_compiling_for_ios_cn.md @@ -0,0 +1,99 @@ +# 构建iOS平台上的PaddlePaddle库 +交叉编译iOS平台上适用的PaddlePaddle库,需要在MacOS系统上进行。本文的将介绍在MacOS上,从源码交叉编译iOS平台上适用的PaddlePaddle库。 + +## 准备交叉编译环境 +Apple官方为iOS开发提供了完整的交叉编译工具和集成开发环境,用户从App Store下载安装Xcode即可。也可自行前往官网下载,[Xcode](https://developer.apple.com/cn/xcode/)。安装完成之后,可在命令行执行`xcodebuild -version`,判断是否安装成功。 + +```bash +$ xcodebuild -version +Xcode 9.0 +Build version 9A235 +``` + +## 配置交叉编译参数 + +PaddlePaddle为交叉编译提供了工具链配置文档[cmake/cross_compiling/ios.cmake](https://github.com/PaddlePaddle/Paddle/blob/develop/cmake/cross_compiling/ios.cmake),以提供一些默认的编译器和编译参数配置。 + +交叉编译iOS版本的PaddlePaddle库时,有一些必须配置的参数: + +- `CMAKE_SYSTEM_NAME`,CMake编译的目标平台,必须设置为`iOS`。在设置`CMAKE_SYSTEM_NAME=iOS`后,PaddlePaddle的CMake系统会自动编译所有的第三方依赖库,并且强制设置一些PaddlePaddle参数的值(`WITH_C_API=ON`、`WITH_GPU=OFF`、`WITH_AVX=OFF`、`WITH_PYTHON=OFF`、`WITH_RDMA=OFF`)。 +- `WITH_C_API`,是否编译C-API预测库,必须设置为ON。在iOS平台上只支持使用C-API来预测。 +- `WITH_SWIG_PY`,必须设置为ON。在iOS平台上不支持通过swig调用来训练或者预测。 + +iOS平台可选配置参数: + +- `IOS_PLATFORM`,可设置为`OS/SIMULATOR`,默认值为`OS`。 + - `OS`,构建目标为`arm`架构的iPhone或者iPad等物理设备。 + - `SIMULATOR`,构建目标为`x86`架构的模拟器平台。 +- `IOS_ARCH`,目标架构。针对不同的`IOS_PLATFORM`,可设置的目标架构如下表所示: + + | IOS_PLATFORM | IOS_ARCH | + |--------------|----------------------| + | OS | armv7, armv7s, arm64 (默认) | + | SIMULATOR | i386, x86_64 (默认) | + +- `IOS_DEPLOYMENT_TARGET`,最小的iOS部署版本,默认值为`7.0`。 +- `IOS_ENABLE_BITCODE`,是否使能[Bitcode](https://developer.apple.com/library/content/documentation/IDEs/Conceptual/AppDistributionGuide/AppThinning/AppThinning.html#//apple_ref/doc/uid/TP40012582-CH35-SW3),可设置`ON/OFF`,默认值为`ON`。 +- `IOS_USE_VECLIB_FOR_BLAS`,是否使用[vecLib](https://developer.apple.com/documentation/accelerate/veclib)框架进行BLAS矩阵计算,可设置`ON/OFF`,默认值为`OFF`。 +- `IOS_DEVELOPMENT_ROOT`,`Developer`目录,可显式指定为`/path/to/platform/Developer`。若未显式指定,PaddlePaddle将会根据`IOS_PLATFORM`自动选择`Xcode`对应`platform`的`Developer`目录。 +- `IOS_SDK_ROOT`,所使用`SDK`的根目录,可显式指定为`/path/to/platform/Developer/SDKs/SDK`。若未显式指定,PaddlePaddle将会自动选择`IOS_DEVELOPMENT_ROOT`目录下最新的`SDK`版本。 + +其他配置参数: + +- `USE_EIGEN_FOR_BLAS`,是否使用Eigen库进行矩阵计算,在`IOS_USE_VECLIB_FOR_BLAS=OFF`时有效。可设置`ON/OFF`,默认值为`OFF`。 +- `HOST_C/CXX_COMPILER`,宿主机的C/C++编译器。默认值为环境变量`CC/CXX`的值;若环境变量`CC/CXX`未设置,则使用`cc/c++`编译器。 + +常用的cmake配置如下: + +```bash +cmake -DCMAKE_SYSTEM_NAME=iOS \ + -DIOS_PLATFORM=OS \ + -DIOS_ARCH="arm64" \ + -DIOS_ENABLE_BITCODE=ON \ + -DIOS_USE_VECLIB_FOR_BLAS=ON \ + -DCMAKE_INSTALL_PREFIX=your/path/to/install \ + -DWITH_C_API=ON \ + -DWITH_TESTING=OFF \ + -DWITH_SWIG_PY=OFF \ + .. +``` + +```bash +cmake -DCMAKE_SYSTEM_NAME=iOS \ + -DIOS_PLATFORM=SIMULATOR \ + -DIOS_ARCH="x86_64" \ + -DIOS_USE_VECLIB_FOR_BLAS=ON \ + -DCMAKE_INSTALL_PREFIX=your/path/to/install \ + -DWITH_C_API=ON \ + -DWITH_TESTING=OFF \ + -DWITH_SWIG_PY=OFF \ + .. +``` + +用户还可根据自己的需求设置其他编译参数。比如希望最小化生成库的大小,可以设置`CMAKE_BUILD_TYPE`为`MinSizeRel`;若希望得到最快的执行速度,则可设置`CMAKE_BUILD_TYPE`为`Release`。亦可以通过手动设置`CMAKE_C/CXX_FLAGS`来影响PaddlePaddle的编译过程。 + +**性能TIPS**,为了达到最快的计算速度,在CMake参数配置上,有以下建议: + +- 设置`CMAKE_BUILD_TYPE`为`Release` +- 设置`IOS_USE_VECLIB_FOR_BLAS=ON`,调用`vecLib`框架提供的BLAS函数进行矩阵计算。 + +## 编译和安装 + +CMake配置完成后,执行以下命令,PaddlePaddle将自动下载和编译所有第三方依赖库、编译和安装PaddlePaddle预测库。 + +``` +$ make +$ make install +``` + +注意:如果你曾在源码目录下编译过其他平台的PaddlePaddle库,请先使用`rm -rf`命令删除`third_party`目录和`build`目录,以确保所有的第三方依赖库和PaddlePaddle代码都是针对新的CMake配置重新编译的。 + +执行完安装命令后,`your/path/to/install`目录中会包含以下内容: + +- `include`目录,其中包含所有C-API的头文件 +- `lib`目录,其中包含PaddlePaddle的C-API静态库 +- `third_party`目录,其中包含所依赖的所有第三方库 + +注意,不同架构的PaddlePaddle库建议安装到不同的目录下,然后使用`lipo`工具将多个静态库合并成一个支持多个架构的fat库。 + +自此,PaddlePaddle库已经安装完成,用户可将合成的fat库用于深度学习相关的iOS App中,调用方法见C-API文档。 diff --git a/doc/howto/cross_compiling/cross_compiling_for_raspberry_cn.md b/doc/howto/cross_compiling/cross_compiling_for_raspberry_cn.md index 026c0c6f3b2a2ca322d063f38e1736a010e1197e..6e983645faaed1f67edaeeb82ddbef9cef6bb85f 100644 --- a/doc/howto/cross_compiling/cross_compiling_for_raspberry_cn.md +++ b/doc/howto/cross_compiling/cross_compiling_for_raspberry_cn.md @@ -59,4 +59,4 @@ make install 注意:如果你曾经在源码目录下编译过其他平台的PaddlePaddle库,请先使用`rm -rf`命令删除`third_party`目录和`build`目录,以确保所有的第三方依赖库和PaddlePaddle代码都是针对新的CMake配置重新编译的。 -执行完安装命令后,,`your/path/to/install`目录中会包含`include`和`lib`目录,其中`include`中包含C-API的头文件,`lib`中包含一个Raspberry Pi版本的库。 +执行完安装命令后,`your/path/to/install`目录中会包含`include`和`lib`目录,其中`include`中包含C-API的头文件,`lib`中包含一个Raspberry Pi版本的库。 diff --git a/doc/howto/cross_compiling/cross_compiling_for_raspberry_en.md b/doc/howto/cross_compiling/cross_compiling_for_raspberry_en.md index 09ac4733ec98c598dfd62f22beaf838320dc7531..3c1a5950ff9553bb725d5a96e3fdf2e5e9f6f95c 100644 --- a/doc/howto/cross_compiling/cross_compiling_for_raspberry_en.md +++ b/doc/howto/cross_compiling/cross_compiling_for_raspberry_en.md @@ -44,7 +44,7 @@ cmake -DCMAKE_SYSTEM_NAME=RPi \ .. ``` -To build the inference library, please set the argument WITH_API to ON: `WITH_C_API=ON`. +To build the inference library, please set the argument WITH\_C\_API to ON: `WITH_C_API=ON`. You can add more arguments. For example, to minimize the size of the generated inference library, you may use `CMAKE_BUILD_TYPE=MinSizeRel`. For performance optimization, you may use `CMAKE_BUILD_TYPE=Release`. diff --git a/paddle/framework/lod_tensor_test.cu b/paddle/framework/lod_tensor_test.cu index c79c4d0c721f9e568c937cb9e524e925fcdc83d0..5b90fbfca7f6bec4f2c862d0ff18dfd7cf39e181 100644 --- a/paddle/framework/lod_tensor_test.cu +++ b/paddle/framework/lod_tensor_test.cu @@ -36,8 +36,8 @@ TEST(LoDTensor, LoDInGPU) { lod_tensor.mutable_data(place); lod_tensor.set_lod(src_lod); - CHECK_EQ(lod_tensor.lod_element(0, 2).first, 4UL); - CHECK_EQ(lod_tensor.lod_element(0, 4).first, 8UL); + EXPECT_EQ(lod_tensor.lod_element(0, 2).first, 4UL); + EXPECT_EQ(lod_tensor.lod_element(0, 4).first, 8UL); auto lod = lod_tensor.lod(); @@ -45,6 +45,6 @@ TEST(LoDTensor, LoDInGPU) { cudaDeviceSynchronize(); for (size_t i = 0; i < src_lod[0].size(); ++i) { - CHECK_EQ(lod[0].data()[i], src_lod[0].data()[i] * 2); + EXPECT_EQ(lod[0].data()[i], src_lod[0].data()[i] * 2); } -} \ No newline at end of file +} diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index aa46829fdde82b58a649108bf708901299cd8153..3be26fdc4fb6ebdd0ec427a2248b0f97d9edff01 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -37,32 +37,32 @@ ExecutionContext::GetEigenDevice() const { std::string OperatorBase::Input(const std::string& name) const { auto& ins = Inputs(name); PADDLE_ENFORCE_LE(ins.size(), 1UL, - "Op %s input %s should contain only one variable", type_, - name); + "Operator %s's input %s should contain only one variable.", + type_, name); return ins.empty() ? kEmptyVarName : ins[0]; } const std::vector& OperatorBase::Inputs( const std::string& name) const { auto it = inputs_.find(name); - PADDLE_ENFORCE(it != inputs_.end(), "Op %s do not have input %s", type_, - name); + PADDLE_ENFORCE(it != inputs_.end(), "Operator %s does not have the input %s.", + type_, name); return it->second; } std::string OperatorBase::Output(const std::string& name) const { auto& outs = Outputs(name); PADDLE_ENFORCE_LE(outs.size(), 1UL, - "Op %s output %s should contain only one variable", type_, - name); + "Operator %s's output %s should contain only one variable.", + type_, name); return outs.empty() ? kEmptyVarName : outs[0]; } const std::vector& OperatorBase::Outputs( const std::string& name) const { auto it = outputs_.find(name); - PADDLE_ENFORCE(it != outputs_.end(), "Op %s does not have output called %s", - type_, name); + PADDLE_ENFORCE(it != outputs_.end(), + "Operator %s does not have an output called %s.", type_, name); return it->second; } diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 93885fa3028e072bc0bd021ea9287087678f3621..b8a7040ed024fc7b19980beef3d8b367dfdd7f50 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -427,7 +427,8 @@ class OperatorWithKernel : public OperatorBase { int tmp = static_cast(ToDataType(t->type())); VLOG(3) << "Input " << ipt_name << " with data_type " << tmp; PADDLE_ENFORCE(tmp == data_type || data_type == -1, - "DataType of Paddle Op %s must be same.", Type()); + "DataType of Paddle Op %s must be the same.", + Type()); data_type = tmp; } } diff --git a/paddle/framework/tensor.h b/paddle/framework/tensor.h index 7b9a5b75e1087a1cc3b6c6c7a6e4dc185c32dd42..9eab67561a42b3fb4e22d8475ad5eeb146a72f1c 100644 --- a/paddle/framework/tensor.h +++ b/paddle/framework/tensor.h @@ -118,10 +118,12 @@ class Tensor { const platform::DeviceContext& ctx); /** - * @brief Return the slice of the tensor. + * @brief Return a sub-tensor of the given tensor. * - * @param[in] begin_idx The begin index of the slice. - * @param[in] end_idx The end index of the slice. + * @param[in] begin_idx The index of the start row(inclusive) to slice. + * The index number begins from 0. + * @param[in] end_idx The index of the end row(exclusive) to slice. + * The index number begins from 0. */ inline Tensor Slice(const int& begin_idx, const int& end_idx) const; diff --git a/paddle/framework/tensor_impl.h b/paddle/framework/tensor_impl.h index 29ac683f48fcde4dd3b5ad7f04b5d1d7434706ba..bcccdd5881775e199297dce7e70aaf6aae62d95a 100644 --- a/paddle/framework/tensor_impl.h +++ b/paddle/framework/tensor_impl.h @@ -112,9 +112,10 @@ inline void* Tensor::mutable_data(platform::Place place, std::type_index type) { if (holder_ != nullptr) { holder_->set_type(type); } - PADDLE_ENFORCE_GT(numel(), 0, - "Tensor's numel must be larger than zero to call " - "Tensor::mutable_data. Call Tensor::set_dim first."); + PADDLE_ENFORCE_GT( + numel(), 0, + "When calling this method, the Tensor's numel must be larger than zero. " + "Please check Tensor::Resize has been called first."); int64_t size = numel() * SizeOfType(type); /* some versions of boost::variant don't have operator!= */ if (holder_ == nullptr || !(holder_->place() == place) || @@ -229,10 +230,12 @@ inline void Tensor::CopyFromVector(const std::vector& src, inline Tensor Tensor::Slice(const int& begin_idx, const int& end_idx) const { check_memory_size(); - PADDLE_ENFORCE_GE(begin_idx, 0, "Slice begin index is less than zero."); - PADDLE_ENFORCE_LE(end_idx, dims_[0], "Slice end index is out of bound."); - PADDLE_ENFORCE_LT(begin_idx, end_idx, - "Begin index must be less than end index."); + PADDLE_ENFORCE_GE(begin_idx, 0, + "The start row index must be greater than 0."); + PADDLE_ENFORCE_LE(end_idx, dims_[0], "The end row index is out of bound."); + PADDLE_ENFORCE_LT( + begin_idx, end_idx, + "The start row index must be lesser than the end row index."); if (dims_[0] == 1) { return *this; diff --git a/paddle/gserver/layers/CRFLayer.cpp b/paddle/gserver/layers/CRFLayer.cpp index 0b544420097e9150f8489731b6379dea633e992c..867303b4fa0d490297ab152fc2ad266e92e29baf 100644 --- a/paddle/gserver/layers/CRFLayer.cpp +++ b/paddle/gserver/layers/CRFLayer.cpp @@ -101,8 +101,10 @@ void CRFLayer::backward(const UpdateCallback& callback) { : real(1.0f); instanceWeight *= coeff_; - MatrixPtr grad = output.grad->subRowMatrix(starts[i], starts[i + 1]); - grad->add(*crfs_[i].getXGrad(), real(1.0f), instanceWeight); + if (output.grad) { + MatrixPtr grad = output.grad->subRowMatrix(starts[i], starts[i + 1]); + grad->add(*crfs_[i].getXGrad(), real(1.0f), instanceWeight); + } if (needWGrad) { weight_->getWGrad()->add( *crfs_[i].getWGrad(), real(1.0f), instanceWeight); diff --git a/paddle/gserver/layers/LinearChainCRF.cpp b/paddle/gserver/layers/LinearChainCRF.cpp index dc3dc156792bdf32c3b948a292597d0e9eca5d8b..abaa1802b763a49f748214dbd4dec1d2bac53b59 100644 --- a/paddle/gserver/layers/LinearChainCRF.cpp +++ b/paddle/gserver/layers/LinearChainCRF.cpp @@ -102,7 +102,6 @@ real LinearChainCRF::forward(real* x, int* s, int length) { } void LinearChainCRF::backward(real* x, int* s, int length, bool needWGrad) { - MatrixPtr matX = Matrix::create(x, length, numClasses_); Matrix::resizeOrCreate(matGrad_, length, numClasses_); Matrix::resizeOrCreate(beta_, length, numClasses_); real* b = b_->getData(); diff --git a/paddle/gserver/layers/SequenceReshapeLayer.cpp b/paddle/gserver/layers/SequenceReshapeLayer.cpp index 433592953b220eda4db4634124a57a2074cef4c0..822974407283c9ee6d0efee71bc945bc418b1942 100644 --- a/paddle/gserver/layers/SequenceReshapeLayer.cpp +++ b/paddle/gserver/layers/SequenceReshapeLayer.cpp @@ -70,11 +70,23 @@ void SequenceReshapeLayer::forward(PassType passType) { size_t outDim = getSize(); size_t numSequences = input.getNumSequences(); - auto startPositions = input.sequenceStartPositions->getVector(false); - const int* starts = startPositions->getData(); - CHECK_EQ(starts[numSequences], input.getBatchSize()); - CHECK_EQ(numSequences, startPositions->getSize() - 1); + // by default, we assume each instance as a sequence + IVectorPtr seqStarts; + IVector::resizeOrCreate(seqStarts, input.getBatchSize() + 1, false); + int* startsData = seqStarts->getData(); + for (int i = 0; i < input.getBatchSize() + 1; i++) { + startsData[i] = i; + } + const int* starts = startsData; + + // if there is sequence, then use start positions + if (input.sequenceStartPositions) { + auto startPositions = input.sequenceStartPositions->getVector(false); + starts = startPositions->getData(); + CHECK_EQ(starts[numSequences], input.getBatchSize()); + CHECK_EQ(numSequences, startPositions->getSize() - 1); + } for (size_t seqID = 0; seqID < numSequences; seqID++) { size_t inNumIns = starts[seqID + 1] - starts[seqID]; diff --git a/paddle/memory/detail/system_allocator.cc b/paddle/memory/detail/system_allocator.cc index 33166d9ce23a4a345fc00a65adf63281b13643c3..6b4e46f56a0c9c9836c5b353ec9c554454ab0491 100644 --- a/paddle/memory/detail/system_allocator.cc +++ b/paddle/memory/detail/system_allocator.cc @@ -41,7 +41,16 @@ void* CPUAllocator::Alloc(size_t& index, size_t size) { index = 0; // unlock memory - void* p = malloc(size); + void* p; + +#ifdef PADDLE_USE_MKLDNN + // refer to https://github.com/01org/mkl-dnn/blob/master/include/mkldnn.hpp + // memory alignment + PADDLE_ENFORCE_EQ(posix_memalign(&p, 4096ul, size), 0); +#else + PADDLE_ENFORCE_EQ(posix_memalign(&p, 32ul, size), 0); +#endif + PADDLE_ENFORCE(p, "Fail to allocate CPU memory: size = %d .", size); if (p != nullptr) { if (FLAGS_use_pinned_memory) { diff --git a/paddle/operators/cross_entropy_op.cc b/paddle/operators/cross_entropy_op.cc index d94b96200c2a5cd112b17e45aa6cd4a63bdd04d0..39df19da677a7dee7d0989d491f8d5511f73a9c7 100644 --- a/paddle/operators/cross_entropy_op.cc +++ b/paddle/operators/cross_entropy_op.cc @@ -28,8 +28,9 @@ class CrossEntropyOp : public framework::OperatorWithKernel { auto x_dims = ctx->GetInputDim("X"); auto label_dims = ctx->GetInputDim("Label"); - PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank should be 2."); - PADDLE_ENFORCE_EQ(label_dims.size(), 2, "Input(Label)'s rank should be 2."); + PADDLE_ENFORCE_EQ(x_dims.size(), 2UL, "Input(X)'s rank should be 2."); + PADDLE_ENFORCE_EQ(label_dims.size(), 2UL, + "Input(Label)'s rank should be 2."); PADDLE_ENFORCE_EQ(x_dims[0], label_dims[0], "The 1st dimension of Input(X) and Input(Label) should " "be equal."); @@ -38,8 +39,8 @@ class CrossEntropyOp : public framework::OperatorWithKernel { "If Attr(soft_label) == true, the 2nd dimension of " "Input(X) and Input(Label) should be equal."); } else { - PADDLE_ENFORCE_EQ(label_dims[1], 1, - "If Attr(soft_label) == false, the 2nd dimension of " + PADDLE_ENFORCE_EQ(label_dims[1], 1UL, + "If Attr(softLabel) == false, the 2nd dimension of " "Input(Label) should be 1."); } @@ -48,7 +49,8 @@ class CrossEntropyOp : public framework::OperatorWithKernel { } protected: - // CrossEntropy's data type just determined by "X" + // Explicitly set that data type of the output of the cross_entropy operator + // is determined by its input "X". framework::DataType IndicateDataType( const framework::ExecutionContext& ctx) const override { return framework::ToDataType(ctx.Input("X")->type()); diff --git a/paddle/operators/linear_chain_crf_op.cc b/paddle/operators/linear_chain_crf_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..605dbba5af1bb8b0d718833be6af45fdaeac70ac --- /dev/null +++ b/paddle/operators/linear_chain_crf_op.cc @@ -0,0 +1,261 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/linear_chain_crf_op.h" + +namespace paddle { +namespace operators { + +class LinearChainCRFOpMaker : public framework::OpProtoAndCheckerMaker { + public: + LinearChainCRFOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput( + "Emission", + "(LoDTensor, default: LoDTensor). " + "The unscaled emission weight matrix for the linear chain CRF. " + "This input is a LoDTensor with shape [N x D] where N is the size of " + "the mini-batch and D is the total tag number."); + AddInput( + "Transition", + "(Tensor, default: Tensor). A Tensor with shape [(D + 2) x D]. " + "The learnable parameter for the linear_chain_crf operator. " + "See more details in the operator's comments."); + AddInput( + "Label", + "(LoDTensor, default: LoDTensor). The ground truth which is a 2-D " + "LoDTensor with shape [N x 1], where N is the total element number in " + "a mini-batch."); + AddOutput( + "Alpha", + "Tensor, default: Tensor. The forward vectors for the entire " + "batch. A two dimensional tensor with shape [N x D], " + "denoted as \f$\alpha\f$. \f$\alpha$\f is a memo table used to " + "calculate the normalization factor in CRF. \f$\alpha[k, v]$\f stores " + "the unnormalized probabilites of all possible unfinished sequences of " + "tags that end at position \f$k$\f with tag \f$v$\f. For each \f$k$\f, " + "\f$\alpha[k, v]$\f is a vector of length \f$D$\f with a component for " + "each tag value \f$v$\f. This vector is called a forward vecotr and " + "will also be used in backward computations.") + .AsIntermediate(); + AddOutput("EmissionExps", + "The exponentials of Input(Emission). This is an intermediate " + "computational result in forward computation, and will be reused " + "in backward computation.") + .AsIntermediate(); + AddOutput("TransitionExps", + "The exponentials of Input(Transition). This is an intermediate " + "computational result in forward computation, and will be reused " + "in backward computation.") + .AsIntermediate(); + AddOutput( + "LogLikelihood", + "(Tensor, default: Tensor). The logarithm of the conditional " + "likelihood of each training sample in a mini-batch. This is a 2-D " + "tensor with shape [S x 1], where S is the sequence number in a " + "mini-batch. Note: S is equal to the sequence number in a mini-batch. " + "The output is no longer a LoDTensor."); + AddComment(R"DOC( +Conditional Random Field defines an undirected probabilistic graph with nodes +denoting random variables and edges denoting dependencies between these +variables. CRF learns the conditional probability \f$P(Y|X)\f$, where +\f$X = (x_1, x_2, ... , x_n)\f$ are structured inputs and +\f$Y = (y_1, y_2, ... , y_n)\f$ are labels for the inputs. + +Linear chain CRF is a special case of CRF that is useful for sequence labeling +task. Sequence labeling tasks do not assume a lot of conditional +independences among inputs. The only constraint they impose is that the input +and output must be linear sequences. Thus, the graph of such a CRF is a simple +chain or a line, which results in the linear chain CRF. + +This operator implements the Forward-Backward algorithm for the linear chain +CRF. Please see http://www.cs.columbia.edu/~mcollins/fb.pdf and +http://cseweb.ucsd.edu/~elkan/250Bwinter2012/loglinearCRFs.pdf for reference. + +Equation: + +- Denote Input(Emission) to this operator as \f$x\f$ here. +- The first D values of Input(Transition) to this operator are for starting +weights, denoted as \f$a\f$ here. +- The next D values of Input(Transition) of this operator are for ending +weights, denoted as \f$b\f$ here. +- The remaning values of Input(Transition) are for transition weights, +denoted as \f$w\f$ here. +- Denote Input(Label) as \f$s\f$ here. + +The probability of a sequence \f$s\f$ of length \f$L\f$ is defined as: +\f$P(s) = (1/Z) exp(a_{s_1} + b_{s_L} + + \sum_{l=1}^L x_{s_l} + + \sum_{l=2}^L w_{s_{l-1},s_l})\f$ +where \f$Z\f$ is a normalization value so that the sum of \f$P(s)\f$ over +all possible sequences is \f$1\f$, and \f$x\f$ is the emission feature weight +to the linear chain CRF. + +Finaly, the linear chain CRF operator outputs the logarithm of the conditional +likelihood of each training sample in a mini-batch. + +NOTE: +1. The feature function for a CRF is made up of the emission features and the +transition features. The emission feature weights are NOT computed in +this operator. They MUST be computed first before this operator is called. + +2. Because this operator performs global normalization over all possible +sequences internally, it expects UNSCALED emission feature weights. +Please do not call this op with the emission feature being output of any +nonlinear activation. + +3. The 2nd dimension of Input(Emission) MUST be equal to the tag number. + +)DOC"); + } +}; + +class LinearChainCRFOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Emission"), + "Input(Emission) should be not null."); + PADDLE_ENFORCE(ctx->HasInput("Transition"), + "Input(Transition) should be not null."); + PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null."); + + PADDLE_ENFORCE(ctx->HasOutput("Alpha"), + "Output(Alpha) should be not null."); + PADDLE_ENFORCE(ctx->HasOutput("EmissionExps"), + "Output(EmissionExps) should be not null."); + PADDLE_ENFORCE(ctx->HasOutput("TransitionExps"), + "Output(TransitionExps) should be not null."); + PADDLE_ENFORCE(ctx->HasOutput("LogLikelihood"), + "Output(LogLikelihood) should be not null."); + + auto emission_dims = ctx->GetInputDim("Emission"); + PADDLE_ENFORCE_EQ(emission_dims.size(), 2UL, + "The Input(Emission) should be a 2-D tensor."); + PADDLE_ENFORCE(emission_dims[0], "An empty mini-batch is not allowed."); + + auto transition_dims = ctx->GetInputDim("Transition"); + PADDLE_ENFORCE_EQ(transition_dims.size(), 2UL, + "The Input(Transition) should be a 2-D tensor."); + PADDLE_ENFORCE_EQ( + transition_dims[0] - 2, transition_dims[1], + "An invalid dimension for the Input(Transition), which should " + "be a 2-D tensor with shape [(D + 2) x D]."); + PADDLE_ENFORCE_EQ( + emission_dims[1], transition_dims[1], + "The 2nd dimension of the Input(Emission) and the Input(Transition) " + "should be equal to the tag number."); + + auto label_dims = ctx->GetInputDim("Label"); + PADDLE_ENFORCE(label_dims.size() == 2UL && label_dims[1] == 1UL, + "The Input(Label) should be a 2-D tensor with the 2nd " + "dimensions fixed to 1."); + PADDLE_ENFORCE_EQ( + emission_dims[0], label_dims[0], + "The height of Input(Emission) and the height of Input(Label) " + "should be the same."); + + ctx->SetOutputDim("Alpha", emission_dims); + ctx->SetOutputDim("EmissionExps", emission_dims); + ctx->SetOutputDim("TransitionExps", transition_dims); + // TODO(caoying) This is tricky. The 1st dimension of Output(LogLikelihood) + // is the sequence number in a mini-batch. The dimension set here should be + // resized to its correct size in the function Compute. Fix this once we can + // get LoD information in the InferShape interface. + ctx->SetOutputDim("LogLikelihood", {emission_dims[0], 1}); + } + + protected: + // Explicitly set that the data type of output of the linear_chain_crf + // operator is determined by its input "Emission". + framework::DataType IndicateDataType( + const framework::ExecutionContext& ctx) const override { + return framework::ToDataType(ctx.Input("Emission")->type()); + } +}; + +class LinearChainCRFGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("EmissionExps"), + "Input(EmissionExps) should be not null."); + PADDLE_ENFORCE(ctx->HasInput("TransitionExps"), + "Input(TransitionExps) should be not null."); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("LogLikelihood")), + "Input(LogLikelihood@GRAD) shoudl be not null."); + + auto emission_exps_dims = ctx->GetInputDim("EmissionExps"); + PADDLE_ENFORCE_EQ(emission_exps_dims.size(), 2UL, + "The Input(EmissionExps) should be a 2-D tensor."); + PADDLE_ENFORCE(emission_exps_dims[0], + "An empty mini-batch is not allowed."); + + auto transition_exps_dims = ctx->GetInputDim("TransitionExps"); + PADDLE_ENFORCE_EQ(transition_exps_dims.size(), 2UL, + "The Input(TransitionExps) should be a 2-D tensor."); + PADDLE_ENFORCE_EQ( + transition_exps_dims[0] - 2, transition_exps_dims[1], + "An invalid dimension for the Input(TransitionExps), which should " + "be a 2-D tensor with shape [(D + 2) x D]."); + PADDLE_ENFORCE_EQ( + emission_exps_dims[1], transition_exps_dims[1], + "The 2nd dimension of the Input(EmissionExps) and the " + "Input(TransitionExps) should be equal to the tag number."); + + auto label_dims = ctx->GetInputDim("Label"); + PADDLE_ENFORCE(label_dims.size() == 2UL && label_dims[1] == 1UL, + "The Input(Label) should be a 2-D tensor with the 2nd " + "dimensions fixed to 1."); + PADDLE_ENFORCE_EQ( + emission_exps_dims[0], label_dims[0], + "The height of Input(EmissionExps) and the height of Input(Label) " + "should be the same."); + + if (ctx->HasOutput(framework::GradVarName("Emission"))) { + ctx->SetOutputDim(framework::GradVarName("Emission"), emission_exps_dims); + } + if (ctx->HasOutput(framework::GradVarName("Transition"))) { + ctx->SetOutputDim(framework::GradVarName("Transition"), + transition_exps_dims); + } + } + + protected: + // Explicitly set that the data type of output of the linear_chain_crf_grad + // operator is determined by its input: gradients of LogLikelihood. + framework::DataType IndicateDataType( + const framework::ExecutionContext& ctx) const override { + return framework::ToDataType( + ctx.Input(framework::GradVarName("LogLikelihood"))->type()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(linear_chain_crf, ops::LinearChainCRFOp, ops::LinearChainCRFOpMaker, + linear_chain_crf_grad, ops::LinearChainCRFGradOp); +REGISTER_OP_CPU_KERNEL( + linear_chain_crf, + ops::LinearChainCRFOpKernel, + ops::LinearChainCRFOpKernel); +REGISTER_OP_CPU_KERNEL( + linear_chain_crf_grad, + ops::LinearChainCRFGradOpKernel, + ops::LinearChainCRFGradOpKernel); diff --git a/paddle/operators/linear_chain_crf_op.cu b/paddle/operators/linear_chain_crf_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..6fc8995f4c2ce05f89ffb58129695113f89159fa --- /dev/null +++ b/paddle/operators/linear_chain_crf_op.cu @@ -0,0 +1,26 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/linear_chain_crf_op.h" + +namespace ops = paddle::operators; + +REGISTER_OP_GPU_KERNEL( + linear_chain_crf, + ops::LinearChainCRFOpKernel, + ops::LinearChainCRFOpKernel); +REGISTER_OP_GPU_KERNEL( + linear_chain_crf_grad, + ops::LinearChainCRFGradOpKernel, + ops::LinearChainCRFGradOpKernel); diff --git a/paddle/operators/linear_chain_crf_op.h b/paddle/operators/linear_chain_crf_op.h new file mode 100644 index 0000000000000000000000000000000000000000..56fb0c9102bee6e2fefd1180ef20237891573f70 --- /dev/null +++ b/paddle/operators/linear_chain_crf_op.h @@ -0,0 +1,543 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" +#include "paddle/operators/math/math_function.h" + +namespace paddle { +namespace operators { + +template +static inline T NormalizeL1(T* x, size_t len) { + T sum = 0.; + for (size_t i = 0; i < len; ++i) sum += x[i]; + // (This comment is from the old LinearChainCRFLayer.) + // Right now, we just bet that sum won't be zero. If this really happens, we + // will figure out what should be done then. + PADDLE_ENFORCE(sum, + "The unnormalized probabilities of all possible unfinished " + "sequences must be greater than 0."); + T s = 1. / sum; + for (size_t i = 0; i < len; ++i) x[i] *= s; + return sum; +} + +template +struct ScalarMul { + explicit ScalarMul(const T& scalar) : scalar(scalar) {} + T operator()(const T& val) const { return val * scalar; } + + T scalar; +}; + +using framework::LoDTensor; +using framework::LoD; +using framework::Tensor; +template +using EigenMatrix = framework::EigenMatrix; + +template +class LinearChainCRFOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + // TODO(caoying) The checks related to LoD information should be + // moved into InferShape once after the InferShape is refactored. + PADDLE_ENFORCE_EQ(ctx.Input("Emission")->NumLevels(), 1UL, + "The Input(Emission) should be a sequence."); + PADDLE_ENFORCE_EQ(ctx.Input("Label")->NumLevels(), 1UL, + "The Input(Label) should be a sequence."); + auto in_lod = ctx.Input("Label")->lod(); + PADDLE_ENFORCE(in_lod.size(), "Input(Label) must be a sequence."); + const size_t level = 0; + const size_t seq_num = in_lod[level].size() - 1; + + // These local variables hold the inputs and outputs, garanteeing them on + // CPU memory, to provide a consistent reference. + // TODO(caoying) Fix this by moving all these local variables into the + // class's data members once we can profile the whole training process. + LoDTensor* emission_weights = nullptr; + LoDTensor emission_weight_tensor; + Tensor* transition_weights = nullptr; + Tensor transition_weight_tensor; + LoDTensor* label = nullptr; + LoDTensor label_tensor; + + Tensor* emission_exps = nullptr; + Tensor emission_exps_tensor; + Tensor* transition_exps = nullptr; + Tensor transition_exps_tensor; + Tensor* alpha = nullptr; + Tensor alpha_tensor; + Tensor* ll = nullptr; + Tensor ll_tensor; + + if (platform::is_gpu_place(ctx.GetPlace())) { + emission_weights = &emission_weight_tensor; + transition_weights = &transition_weight_tensor; + label = &label_tensor; + + CopyInputsToCpuMemory( + ctx.device_context(), *ctx.Input("Emission"), + *ctx.Input("Transition"), *ctx.Input("Label"), + emission_weights, transition_weights, label); + + emission_exps = &emission_exps_tensor; + emission_exps->Resize(emission_weights->dims()); + + transition_exps = &transition_exps_tensor; + transition_exps->Resize(transition_weights->dims()); + + alpha = &alpha_tensor; + alpha->Resize(ctx.Output("Alpha")->dims()); + + ll = &ll_tensor; + } else { + emission_weights = + const_cast(ctx.Input("Emission")); + transition_weights = const_cast(ctx.Input("Transition")); + label = const_cast(ctx.Input("Label")); + + emission_exps = ctx.Output("EmissionExps"); + transition_exps = ctx.Output("TransitionExps"); + alpha = ctx.Output("Alpha"); + ll = ctx.Output("LogLikelihood"); + } + + // Because the computation codes only runs on CPU, here the memory for all + // the outputs is FIXED to be allocated on the CPU memory. + emission_exps->mutable_data(platform::CPUPlace()); + transition_exps->mutable_data(platform::CPUPlace()); + alpha->mutable_data(platform::CPUPlace()); + + // Resize the output tensor to its correct dimension. + ll->Resize({static_cast(seq_num), 1}); + ll->mutable_data(platform::CPUPlace()); + + // Now, all the inputs and outputs should be on the CPU memory. + auto emission_dims = emission_weights->dims(); + const size_t batch_size = emission_dims[0]; + const size_t tag_num = emission_dims[1]; + + Tensor emission_row_max; + emission_row_max.mutable_data( + framework::make_ddim({static_cast(batch_size), 1}), + platform::CPUPlace()); + + auto place = ctx.GetEigenDevice(); + auto x = EigenMatrix::From(*emission_weights); + auto x_row_max = EigenMatrix::From(emission_row_max); + x_row_max.device(place) = + x.maximum(Eigen::DSizes(1)) + .reshape(Eigen::DSizes(int(batch_size), 1)); + + auto x_exps = EigenMatrix::From(*emission_exps); + x_exps.device(place) = + (x - x_row_max.broadcast(Eigen::DSizes(1, tag_num))).exp(); + + auto w = EigenMatrix::From(*transition_weights); + auto w_exps = EigenMatrix::From(*transition_exps); + w_exps.device(place) = w.exp(); + + T* log_likelihood = ll->data(); + for (size_t i = 0; i < seq_num; ++i) { + int start_pos = static_cast(in_lod[level][i]); + int end_pos = static_cast(in_lod[level][i + 1]); + if (end_pos == start_pos) { + // If an empty input sequence is given, pad 0 for its cost. + log_likelihood[i] = 0.; + continue; + } + + const Tensor one_seq = emission_weights->Slice(start_pos, end_pos); + Tensor one_seq_row_max = emission_row_max.Slice(start_pos, end_pos); + Tensor one_seq_exps = emission_exps->Slice(start_pos, end_pos); + const Tensor one_seq_label = label->Slice(start_pos, end_pos); + Tensor one_seq_alpha = alpha->Slice(start_pos, end_pos); + + log_likelihood[i] = ForwardOneSequence( + one_seq, one_seq_row_max, one_seq_exps, *transition_weights, + *transition_exps, one_seq_label, &one_seq_alpha); + } + + if (platform::is_gpu_place(ctx.GetPlace())) { + CopyOutputsToGpuMemory( + ctx.device_context(), *emission_exps, *transition_exps, *alpha, *ll, + ctx.Output("EmissionExps"), + ctx.Output("TransitionExps"), ctx.Output("Alpha"), + ctx.Output("LogLikelihood")); + } + }; + + private: + void CopyInputsToCpuMemory(const platform::DeviceContext& ctx, + const LoDTensor& emission_weights_src, + const Tensor& transition_weights_src, + const LoDTensor& label_src, + LoDTensor* emission_weights_dst, + Tensor* transition_weights_dst, + LoDTensor* label_dst) const { + // Copy the inputs from GPU memory to CPU memory if this operators runs on + // GPU device. + auto copyLoDTensor = [](const platform::DeviceContext& ctx, + const LoDTensor& src, LoDTensor* dst) { + dst->mutable_data(src.dims(), platform::CPUPlace()); + dst->CopyFrom(src, platform::CPUPlace(), ctx); + }; + + copyLoDTensor(ctx, emission_weights_src, emission_weights_dst); + copyLoDTensor(ctx, label_src, label_dst); + + transition_weights_dst->mutable_data(transition_weights_src.dims(), + platform::CPUPlace()); + transition_weights_dst->CopyFrom(transition_weights_src, + platform::CPUPlace(), ctx); + } + + void CopyOutputsToGpuMemory(const platform::DeviceContext& ctx, + const Tensor& emission_exps_src, + const Tensor& transition_exps_src, + const Tensor& alpha_src, const Tensor& ll_src, + Tensor* emission_exps_dst, + Tensor* transition_exps_dst, Tensor* alpha_dst, + Tensor* ll_dst) const { + // Copy the forward results from CPU memory to GPU memory if this + // operators runs on GPU device. + auto copyTensor = [](const platform::DeviceContext& ctx, const Tensor& src, + Tensor* dst) { + dst->mutable_data(platform::GPUPlace()); + dst->CopyFrom(src, platform::GPUPlace(), ctx); + }; + copyTensor(ctx, emission_exps_src, emission_exps_dst); + copyTensor(ctx, transition_exps_src, transition_exps_dst); + copyTensor(ctx, alpha_src, alpha_dst); + copyTensor(ctx, ll_src, ll_dst); + } + + T ForwardOneSequence(const Tensor& emission, const Tensor& emission_row_max, + const Tensor& emission_exps, const Tensor& trans_weights, + const Tensor& trans_weight_exps, const Tensor& label, + Tensor* alpha) const { + const T* x = emission.data(); + const T* x_row_max = emission_row_max.data(); + const T* x_exps = emission_exps.data(); + const T* w = trans_weights.data(); + const T* w_exps = trans_weight_exps.data(); + T* alpha_value = alpha->data(); + + auto x_dims = emission.dims(); + const size_t seq_length = x_dims[0]; + const size_t tag_num = x_dims[1]; + // The 1st row of w are transition weights for start mask. + // The 2nd row of w are transition weights for end mask. + // Transition weights between other tags begin from the 3rd row of w. + const size_t state_trans_base_idx = 2; + + for (size_t i = 0; i < tag_num; ++i) { + alpha_value[i] = w_exps[i] * x_exps[i]; + } + T ll = -x_row_max[0] - std::log(NormalizeL1(alpha_value, tag_num)); + + for (size_t k = 1; k < seq_length; ++k) { + for (size_t i = 0; i < tag_num; ++i) { + T sum = 0.; + for (size_t j = 0; j < tag_num; ++j) { + sum += alpha_value[(k - 1) * tag_num + j] * // (*) + w_exps[(j + state_trans_base_idx) * tag_num + i]; + } + alpha_value[k * tag_num + i] = x_exps[k * tag_num + i] * sum; + } + // NormalizeL1 is to avoid underflow or overflow at (*). + ll -= x_row_max[k] + + std::log(NormalizeL1(alpha_value + k * tag_num, tag_num)); + } + T sum = 0.; + for (size_t i = 0; i < tag_num; ++i) { + sum += alpha_value[(seq_length - 1) * tag_num + i] * w_exps[tag_num + i]; + } + ll -= std::log(sum); + // Now ll is equal to -log(Z). + + const int* lbl = label.data(); + PADDLE_ENFORCE_LT( + *std::max_element(lbl, lbl + seq_length), tag_num, + "An invalid tag label that execesses the largest tag number."); + + // Calculate the nominator part, which depends on the label sequence. + ll += w[lbl[0]] /*start transition*/ + x[lbl[0]] + + w[tag_num + lbl[seq_length - 1]] /*end transition*/; + for (size_t k = 1; k < seq_length; ++k) { + ll += x[k * tag_num + lbl[k]] + + w[(lbl[k - 1] + state_trans_base_idx) * tag_num + lbl[k]]; + } + return -ll; + } +}; + +template +class LinearChainCRFGradOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + const size_t level = 0; // currently, only support sequence. + auto lod = ctx.Input("Label")->lod(); + PADDLE_ENFORCE(lod.size(), "Input(Label) must be a sequence."); + + // These local variables hold the inputs and outputs, garanteeing them on + // CPU memory, to provide a consistent reference. + // TODO(caoying) Fix this by moving all these local variables into the + // class's data members once we can profile the training process, or + // implementing a real GPU kernel for CRF. + Tensor* label = nullptr; + Tensor label_tensor; + Tensor* emission_exps = nullptr; + Tensor emission_exps_tensor; + Tensor* transition_exps = nullptr; + Tensor transition_exps_tensor; + Tensor* alpha = nullptr; + Tensor alpha_tensor; + Tensor ll_grad_tensor; + T* ll_grad = nullptr; + + Tensor* emission_grad = nullptr; + Tensor emission_grad_tensor; + Tensor* transition_grad = nullptr; + Tensor transition_grad_tensor; + + if (platform::is_gpu_place(ctx.GetPlace())) { + label = &label_tensor; + emission_exps = &emission_exps_tensor; + transition_exps = &transition_exps_tensor; + alpha = &alpha_tensor; + CopyInputsToCpuMemory( + ctx.device_context(), *ctx.Input("Label"), + *ctx.Input("EmissionExps"), + *ctx.Input("TransitionExps"), *ctx.Input("Alpha"), + *ctx.Input(framework::GradVarName("LogLikelihood")), label, + emission_exps, transition_exps, alpha, &ll_grad_tensor); + ll_grad = ll_grad_tensor.data(); + + if (ctx.Output(framework::GradVarName("Emission"))) { + emission_grad = &emission_grad_tensor; + emission_grad->Resize(emission_exps->dims()); + } + + if (ctx.Output(framework::GradVarName("Transition"))) { + transition_grad = &transition_grad_tensor; + transition_grad->Resize(transition_exps->dims()); + } + } else { + label = const_cast(ctx.Input("Label")); + emission_exps = const_cast(ctx.Input("EmissionExps")); + transition_exps = + const_cast(ctx.Input("TransitionExps")); + alpha = const_cast(ctx.Input("Alpha")); + ll_grad = const_cast( + ctx.Input(framework::GradVarName("LogLikelihood"))) + ->data(); + + emission_grad = ctx.Output(framework::GradVarName("Emission")); + transition_grad = + ctx.Output(framework::GradVarName("Transition")); + } + + // TODO(caoying) Fix this constraint. When the Input(Emission) is from the + // data reader operator, it can have no gradients. + PADDLE_ENFORCE(emission_grad, "Output(Emission@Grad) should not be null."); + emission_grad->mutable_data(platform::CPUPlace()); + if (transition_grad) { + transition_grad->mutable_data(platform::CPUPlace()); + math::SetConstant()(ctx.device_context(), + transition_grad, 0.); + } + // Now, all the inputs and outputs should be on the CPU memory. + + auto emission_dims = emission_exps->dims(); + // Beta is the memo table used in dynamic programming to calculate the + // backwark vectors. For a backward vector i (the i-th row of beta), it + // captures the unnormalized probabilities of partial sequences starting + // at position i. + Tensor beta; + beta.mutable_data(emission_dims, platform::CPUPlace()); + + for (size_t i = 0; i < lod[level].size() - 1; ++i) { + int start_pos = static_cast(lod[level][i]); + int end_pos = static_cast(lod[level][i + 1]); + if (end_pos == start_pos) continue; + + const Tensor one_seq_emission_exps = + emission_exps->Slice(start_pos, end_pos); + const Tensor one_seq_label = label->Slice(start_pos, end_pos); + const Tensor one_seq_alpha = alpha->Slice(start_pos, end_pos); + Tensor one_seq_beta = beta.Slice(start_pos, end_pos); + Tensor one_seq_emission_grad = emission_grad->Slice(start_pos, end_pos); + + BackwardOneSequence(ctx.device_context(), ll_grad[i], + one_seq_emission_exps, *transition_exps, + one_seq_alpha, one_seq_label, &one_seq_beta, + transition_grad, &one_seq_emission_grad); + } + + if (platform::is_gpu_place(ctx.GetPlace())) { + CopyOutputsToGpuMemory( + ctx.device_context(), emission_grad, transition_grad, + ctx.Output(framework::GradVarName("Emission")), + ctx.Output(framework::GradVarName("Transition"))); + } + }; + + private: + void CopyInputsToCpuMemory(const platform::DeviceContext& ctx, + const LoDTensor& label_src, + const Tensor& emission_exps_src, + const Tensor& transition_exps_src, + const Tensor& alpha_src, const Tensor& ll_grad_src, + Tensor* label_dst, Tensor* emission_exps_dst, + Tensor* transition_exps_dst, Tensor* alpha_dst, + Tensor* ll_grad_dst) const { + // Copy the inputs from GPU memory to CPU memory when this operators runs on + // GPU device. + label_dst->mutable_data(label_src.dims(), platform::CPUPlace()); + label_dst->CopyFrom(label_src, platform::CPUPlace(), ctx); + + auto copyTensor = [](const platform::DeviceContext& ctx, const Tensor& src, + Tensor* dst) { + dst->mutable_data(src.dims(), platform::CPUPlace()); + dst->CopyFrom(src, platform::CPUPlace(), ctx); + }; + copyTensor(ctx, emission_exps_src, emission_exps_dst); + copyTensor(ctx, transition_exps_src, transition_exps_dst); + copyTensor(ctx, alpha_src, alpha_dst); + copyTensor(ctx, ll_grad_src, ll_grad_dst); + } + + void CopyOutputsToGpuMemory(const platform::DeviceContext& ctx, + const Tensor* emission_grad_src, + const Tensor* transition_grad_src, + Tensor* emission_grad_dst, + Tensor* transition_grad_dst) const { + // Copy the backward results from CPU memory to GPU + // memory if this operators runs on GPU device. + auto copyTensor = [](const platform::DeviceContext& ctx, const Tensor* src, + Tensor* dst) { + if (src && dst) { + dst->mutable_data(platform::GPUPlace()); + dst->CopyFrom(*src, platform::GPUPlace(), ctx); + } + }; + copyTensor(ctx, emission_grad_src, emission_grad_dst); + copyTensor(ctx, transition_grad_src, transition_grad_dst); + } + + void BackwardOneSequence(const platform::DeviceContext& ctx, const T ll_grad, + const Tensor& emission_exps, + const Tensor& transition_exps, const Tensor& alpha, + const Tensor& label, Tensor* beta, + Tensor* transition_grad, + Tensor* emission_grad) const { + const T* w_exps = transition_exps.data(); + const T* x_exps = emission_exps.data(); + const int* label_value = label.data(); + T* beta_value = beta->data(); + + auto x_dims = emission_exps.dims(); + const size_t seq_length = x_dims[0]; + const size_t tag_num = x_dims[1]; + const size_t state_trans_base_idx = 2; + + // Calculate the backward vectors: beta. + // First, calculate the initialition state. + for (size_t i = 0; i < tag_num; ++i) { + beta_value[(seq_length - 1) * tag_num + i] = w_exps[tag_num + i]; + } + NormalizeL1(beta_value + (seq_length - 1) * tag_num, tag_num); + for (int k = static_cast(seq_length) - 2; k >= 0; --k) { + for (size_t i = 0; i < tag_num; ++i) { + T sum = 0.; + for (size_t j = 0; j < tag_num; ++j) { + sum += w_exps[(i + state_trans_base_idx) * tag_num + j] * // (**) + x_exps[(k + 1) * tag_num + j] * + beta_value[(k + 1) * tag_num + j]; + } + beta_value[k * tag_num + i] = sum; + } + // NormalizeL1 is to avoid underflow or overflow at (**). + NormalizeL1(beta_value + k * tag_num, tag_num); + } + + auto x_grad_mat = EigenMatrix::From(*emission_grad); + auto alpha_mat = EigenMatrix::From(alpha); + auto beta_mat = EigenMatrix::From(*beta); + + auto* place = ctx.GetEigenDevice(); + auto prob = alpha_mat * beta_mat; + auto row_sum = prob.sum(Eigen::DSizes(1)) + .reshape(Eigen::DSizes(seq_length, 1)) + .broadcast(Eigen::DSizes(1, tag_num)); + x_grad_mat.device(*place) = + (prob / row_sum).unaryExpr(ScalarMul(ll_grad)); + + for (size_t k = 0; k < seq_length; ++k) { + x_grad_mat(k, label_value[k]) -= static_cast(ll_grad); + } + + if (transition_grad) { + T* trans_grad = transition_grad->data(); + for (size_t k = 0; k < tag_num; ++k) { + // Do not multiply by the output gradient here, because x_grad_mat has + // alrealy done this. + trans_grad[k] += x_grad_mat(/*from start state*/ 0, k); + trans_grad[tag_num + k] += + x_grad_mat(/*to end state*/ seq_length - 1, k); + } + + auto x_exps_mat = EigenMatrix::From(emission_exps); + + // TODO(caoying): Fix this to avoid using this local variable if we can + // profile the training process. + Tensor tmp; + tmp.mutable_data(beta->dims(), platform::CPUPlace()); + auto tmp_mat = EigenMatrix::From(tmp); + auto prob = beta_mat * x_exps_mat; + auto row_sum = prob.sum(Eigen::DSizes(1)) + .reshape(Eigen::DSizes(seq_length, 1)) + .broadcast(Eigen::DSizes(1, tag_num)); + tmp_mat.device(*place) = prob / row_sum; + + for (size_t k = 1; k < seq_length; ++k) { + T sum = 0.; + for (size_t i = 0; i < tag_num; ++i) { + for (size_t j = 0; j < tag_num; ++j) { + sum += w_exps[(i + state_trans_base_idx) * tag_num + j] * // (**) + alpha_mat(k - 1, i) * tmp_mat(k, j); + } + } + sum = 1. / sum; + for (size_t i = 0; i < tag_num; ++i) { + for (size_t j = 0; j < tag_num; ++j) { + trans_grad[(i + state_trans_base_idx) * tag_num + j] += + sum * w_exps[(i + state_trans_base_idx) * tag_num + j] * + alpha_mat(k - 1, i) * tmp_mat(k, j) * ll_grad; + } + } + trans_grad[(label_value[k - 1] + state_trans_base_idx) * tag_num + + label_value[k]] -= static_cast(ll_grad); + } + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/lstm_op.cc b/paddle/operators/lstm_op.cc index 0a089b7c2dc1e05224525bc4fe5399ec39036d01..94342d940704d850a2a45c281a3d88de5a132753 100644 --- a/paddle/operators/lstm_op.cc +++ b/paddle/operators/lstm_op.cc @@ -21,7 +21,6 @@ class LSTMOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - protected: void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("Input"), "Input(Input) of LSTM should not be null."); @@ -29,9 +28,13 @@ class LSTMOp : public framework::OperatorWithKernel { "Output(Hidden) of LSTM should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Cell"), "Output(Cell) of LSTM should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("BatchGate"), + "Output(BatchGate) of LSTM should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("BatchCellPreAct"), + "Output(BatchGate) of LSTM should not be null."); - auto x_dims = ctx->GetInputDim("Input"); - PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2."); + auto in_dims = ctx->GetInputDim("Input"); + PADDLE_ENFORCE_EQ(in_dims.size(), 2, "Input(X)'s rank must be 2."); if (ctx->HasInput("H0")) { PADDLE_ENFORCE(ctx->HasInput("C0"), @@ -44,7 +47,7 @@ class LSTMOp : public framework::OperatorWithKernel { "should be the same."); } - int frame_size = x_dims[1] / 4; + int frame_size = in_dims[1] / 4; auto w_dims = ctx->GetInputDim("Weight"); PADDLE_ENFORCE_EQ(w_dims.size(), 2, "The rank of Input(Weight) should be 2."); @@ -71,12 +74,21 @@ class LSTMOp : public framework::OperatorWithKernel { "4 * %d if disable peepholes connection", frame_size); } - ctx->SetOutputDim("Hidden", {x_dims[0], frame_size}); - ctx->SetOutputDim("Cell", {x_dims[0], frame_size}); - ctx->SetOutputDim("BatchGate", x_dims); + framework::DDim out_dims({in_dims[0], frame_size}); + ctx->SetOutputDim("Hidden", out_dims); + ctx->SetOutputDim("Cell", out_dims); + ctx->SetOutputDim("BatchGate", in_dims); + ctx->SetOutputDim("BatchCellPreAct", out_dims); ctx->ShareLoD("Input", "Hidden"); ctx->ShareLoD("Input", "Cell"); } + + protected: + framework::DataType IndicateDataType( + const framework::ExecutionContext& ctx) const override { + return framework::ToDataType( + ctx.Input("Input")->type()); + } }; class LSTMOpMaker : public framework::OpProtoAndCheckerMaker { @@ -86,16 +98,18 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("Input", "(LoDTensor) the first input is a LodTensor, which support " "variable-time length input sequence. The underlying tensor in " - "this LoDTensor is a matrix with shape (T X 4D), where, T is the " + "this LoDTensor is a matrix with shape (T X 4D), where T is the " "total time steps in this mini-batch, D is the hidden size."); AddInput("H0", "(Tensor, optional) the initial hidden state is an optional " "input. This is a tensor with shape (N x D), where N is the " - "batch size, D is the hidden size."); + "batch size, D is the hidden size.") + .AsDispensable(); AddInput("C0", "(Tensor, optional) the initial cell state is an optional " "input. This is a tensor with shape (N x D), where N is the " - "batch size. `H0` and `C0` can be NULL but only at the same time"); + "batch size. `H0` and `C0` can be NULL but only at the same time") + .AsDispensable(); AddInput("Weight", "(Tensor) the learnable hidden-hidden weights." " - The shape is (D x 4D), where D is the hidden size. " @@ -109,22 +123,27 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker { " - Bias = {b_c, b_i, b_f, b_o}." "2. `usePeepholes = True` " " - The shape is (1 x 7D). " - " - Bias = {b_c, b_i, b_f, b_o, W_ic, W_fc, W_oc}."); + " - Bias = {b_c, b_i, b_f, b_o, W_ic, W_fc, W_oc}.") + .AsDispensable(); + AddOutput("Hidden", + "(LoDTensor) the hidden state of LSTM operator. " + "The shape is (T x D), and lod is the same with the `Input`."); + AddOutput("Cell", + "(LoDTensor) the cell state of LSTM operator. " + "The shape is (T x D), and lod is the same with the `Input`."); AddOutput("BatchGate", "(LoDTensor) This LoDTensor contains input gate, forget gate " "and output gate after the nonlinear computation. This " "LoDTensor has the same shape with the reorganized input, which " - "was also be called batch input. The LoD size is 2. The first " + "is also be called batch input. The LoD size is 2. The first " "LoD is the batch offsets and the second LoD contains the " "indexes, which denote the position of reorganized sequence " "in the raw input.") .AsIntermediate(); - AddOutput("Hidden", - "(LoDTensor) the hidden state lod tensor of LSTM operator. " - "The shape and lod is the same with the `Input`."); - AddOutput("Cell", - "(LoDTensor) the cell state lod tensor of LSTM operator. " - "The shape and lod is the same with the `Input`."); + AddOutput("BatchCellPreAct", + "(LoDTensor) This LoDTensor is got in the forward and used " + "in the backward.") + .AsIntermediate(); AddAttr("usePeepholes", "(bool, defalut: True) " "whether to enable diagonal/peephole connections.") @@ -202,15 +221,37 @@ class LSTMGradOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - protected: void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Hidden")), - "Input(Hidden@GRAD) should not be null"); - PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Cell")), - "Input(Cell@GRAD) should not be null"); - ctx->SetOutputDim(framework::GradVarName("Weight"), - ctx->GetInputDim("Weight")); - ctx->SetOutputDim(framework::GradVarName("Bias"), ctx->GetInputDim("Bias")); + PADDLE_ENFORCE(ctx->HasInput("Input"), + "Input(Input) of LSTM should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Hidden"), + "Input(Hidden) of LSTM should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Cell"), + "Input(Cell) of LSTM should not be null."); + + PADDLE_ENFORCE(ctx->HasInput("BatchGate"), + "Input(BatchGate) of LSTM should not be null."); + PADDLE_ENFORCE(ctx->HasInput("BatchCellPreAct"), + "Input(BatchGate) of LSTM should not be null."); + + auto in_g_name = framework::GradVarName("Input"); + if (ctx->HasOutput(in_g_name)) + ctx->SetOutputDim(in_g_name, ctx->GetInputDim("Input")); + + auto w_g_name = framework::GradVarName("Weight"); + if (ctx->HasOutput(w_g_name)) + ctx->SetOutputDim(w_g_name, ctx->GetInputDim("Weight")); + + auto b_g_name = framework::GradVarName("Bias"); + if (ctx->HasOutput(b_g_name)) + ctx->SetOutputDim(b_g_name, ctx->GetInputDim("Bias")); + } + + protected: + framework::DataType IndicateDataType( + const framework::ExecutionContext& ctx) const override { + return framework::ToDataType( + ctx.Input("Input")->type()); } }; diff --git a/paddle/operators/lstm_op.h b/paddle/operators/lstm_op.h index 0af5694c48fcb4437e3acd422606de013bb2e145..af088b80b4283cf221a1dff74546d73d977fada3 100644 --- a/paddle/operators/lstm_op.h +++ b/paddle/operators/lstm_op.h @@ -21,8 +21,9 @@ limitations under the License. */ namespace paddle { namespace operators { -using framework::LoDTensor; -using framework::Tensor; +using LoDTensor = framework::LoDTensor; +using Tensor = framework::Tensor; + template using EigenMatrix = framework::EigenMatrix; @@ -31,15 +32,15 @@ template class LSTMKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto* input = ctx.Input("Input"); - auto* weight = ctx.Input("Weight"); - auto* bias = ctx.Input("Bias"); + auto* input = ctx.Input("Input"); + auto* weight = ctx.Input("Weight"); + auto* bias = ctx.Input("Bias"); - auto* batch_gate = ctx.Output("BatchGate"); + auto* batch_gate = ctx.Output("BatchGate"); batch_gate->mutable_data(ctx.GetPlace()); - auto* hidden_out = ctx.Output("Hidden"); + auto* hidden_out = ctx.Output("Hidden"); hidden_out->mutable_data(ctx.GetPlace()); - auto* cell_out = ctx.Output("Cell"); + auto* cell_out = ctx.Output("Cell"); cell_out->mutable_data(ctx.GetPlace()); // Now the function ShareLoD in InferShape is not implemented. @@ -49,7 +50,8 @@ class LSTMKernel : public framework::OpKernel { bool is_reverse = ctx.Attr("isReverse"); math::LoDTensor2BatchFunctor to_batch; - to_batch(ctx.device_context(), *input, *batch_gate, is_reverse); + auto& device_ctx = ctx.device_context(); + to_batch(device_ctx, *input, *batch_gate, true, is_reverse); auto in_dims = input->dims(); int frame_size = static_cast(in_dims[1] / 4); @@ -69,17 +71,26 @@ class LSTMKernel : public framework::OpKernel { } math::LstmMetaValue lstm_value; - T* bias_data = const_cast(bias->data()); - // the code style in LstmMetaValue will be updated later. - lstm_value.checkIg = bias_data + 4 * frame_size; - lstm_value.checkFg = lstm_value.checkIg + frame_size; - lstm_value.checkOg = lstm_value.checkFg + frame_size; + if (bias) { + T* bias_data = const_cast(bias->data()); + // the code style in LstmMetaValue will be updated later. + + lstm_value.checkIg = bias_data + 4 * frame_size; + lstm_value.checkFg = lstm_value.checkIg + frame_size; + lstm_value.checkOg = lstm_value.checkFg + frame_size; + } else { + lstm_value.checkIg = nullptr; + lstm_value.checkFg = nullptr; + lstm_value.checkOg = nullptr; + } lstm_value.prevStateValue = nullptr; - framework::LoDTensor batch_out, batch_cell, batch_cell_pre_act; - batch_out.mutable_data(dims, ctx.GetPlace()); + // Use the local variable as here. + LoDTensor batch_hidden, batch_cell; + auto* batch_cell_pre_act = ctx.Output("BatchCellPreAct"); + batch_hidden.mutable_data(dims, ctx.GetPlace()); batch_cell.mutable_data(dims, ctx.GetPlace()); - batch_cell_pre_act.mutable_data(dims, ctx.GetPlace()); + batch_cell_pre_act->mutable_data(dims, ctx.GetPlace()); auto batch_starts = batch_gate->lod()[0]; size_t num_batch = batch_starts.size() - 1; @@ -92,18 +103,18 @@ class LSTMKernel : public framework::OpKernel { int bend = static_cast(batch_starts[n + 1]); Tensor gate_t = batch_gate->Slice(bstart, bend); - Tensor out_t = batch_out.Slice(bstart, bend); + Tensor out_t = batch_hidden.Slice(bstart, bend); Tensor cell_t = batch_cell.Slice(bstart, bend); - Tensor cell_pre_act_t = batch_cell_pre_act.Slice(bstart, bend); + Tensor cell_pre_act_t = batch_cell_pre_act->Slice(bstart, bend); int cur_batch_size = bend - bstart; if (n != 0) { int pre_h_start = static_cast(batch_starts[n - 1]); int pre_h_end = pre_h_start + cur_batch_size; - auto pre_hidden_t = batch_out.Slice(pre_h_start, pre_h_end); - math::matmul(ctx.device_context(), pre_hidden_t, false, - *weight, false, static_cast(1.0), &gate_t, + auto pre_hidden_t = batch_hidden.Slice(pre_h_start, pre_h_end); + math::matmul(device_ctx, pre_hidden_t, false, *weight, false, + static_cast(1.0), &gate_t, static_cast(1.0)); } // else if : FIXME support the initial hidden and cell @@ -112,27 +123,186 @@ class LSTMKernel : public framework::OpKernel { lstm_value.outputValue = out_t.data(); lstm_value.stateValue = cell_t.data(); lstm_value.stateActiveValue = cell_pre_act_t.data(); - math::LstmUnitFunctor::compute(ctx.device_context(), lstm_value, + math::LstmUnitFunctor::compute(device_ctx, lstm_value, frame_size, cur_batch_size, gate_act, cell_act, cand_act); lstm_value.prevStateValue = lstm_value.stateValue; } math::Batch2LoDTensorFunctor to_seq; - batch_out.set_lod(batch_gate->lod()); + batch_hidden.set_lod(batch_gate->lod()); // restore the output hidden in LoDTensor from the batch hidden - to_seq(ctx.device_context(), batch_out, *hidden_out); + to_seq(device_ctx, batch_hidden, *hidden_out); batch_cell.set_lod(batch_gate->lod()); // restore the output cell state in LoDTensor from the batch cell - to_seq(ctx.device_context(), batch_cell, *cell_out); + to_seq(device_ctx, batch_cell, *cell_out); } }; template class LSTMGradKernel : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext& ctx) const override {} + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input = ctx.Input("Input"); + auto* weight = ctx.Input("Weight"); + auto* bias = ctx.Input("Bias"); + + auto* hidden_out = ctx.Input("Hidden"); + auto* cell_out = ctx.Input("Cell"); + + auto* batch_gate = ctx.Input("BatchGate"); + auto* batch_cell_pre_act = ctx.Input("BatchCellPreAct"); + + auto* hidden_g = ctx.Input(framework::GradVarName("Hidden")); + + auto* in_g = ctx.Output(framework::GradVarName("Input")); + auto* weight_g = ctx.Output(framework::GradVarName("Weight")); + auto* bias_g = ctx.Output(framework::GradVarName("Bias")); + + auto& device_ctx = ctx.device_context(); + math::SetConstant zero; + if (weight_g) { + weight_g->mutable_data(ctx.GetPlace()); + zero(device_ctx, weight_g, static_cast(0.0)); + } + + auto in_dims = input->dims(); + auto out_dims = hidden_g->dims(); + int frame_size = static_cast(in_dims[1] / 4); + PADDLE_ENFORCE_EQ(frame_size, out_dims[1]); + + math::LstmMetaValue lstm_value; + if (bias) { + T* bias_data = const_cast(bias->data()); + lstm_value.checkIg = bias_data + 4 * frame_size; + lstm_value.checkFg = lstm_value.checkIg + frame_size; + lstm_value.checkOg = lstm_value.checkFg + frame_size; + } else { + lstm_value.checkIg = nullptr; + lstm_value.checkFg = nullptr; + lstm_value.checkOg = nullptr; + } + + math::LstmMetaGrad lstm_grad; + if (bias && bias_g) { + T* bias_g_data = const_cast(bias_g->mutable_data(ctx.GetPlace())); + zero(device_ctx, bias_g, static_cast(0.0)); + lstm_grad.checkIgGrad = bias_g_data + 4 * frame_size; + lstm_grad.checkFgGrad = lstm_grad.checkIgGrad + frame_size; + lstm_grad.checkOgGrad = lstm_grad.checkFgGrad + frame_size; + } else { + lstm_grad.checkIgGrad = nullptr; + lstm_grad.checkFgGrad = nullptr; + lstm_grad.checkOgGrad = nullptr; + } + + math::LoDTensor2BatchFunctor to_batch; + + // use the local variable as here. + LoDTensor batch_hidden; + batch_hidden.mutable_data(out_dims, ctx.GetPlace()); + batch_hidden.set_lod(batch_gate->lod()); + to_batch(device_ctx, *hidden_out, batch_hidden, false); + + LoDTensor batch_hidden_g; + batch_hidden_g.mutable_data(out_dims, ctx.GetPlace()); + batch_hidden_g.set_lod(batch_gate->lod()); + to_batch(device_ctx, *hidden_g, batch_hidden_g, false); + + LoDTensor batch_cell; + batch_cell.mutable_data(out_dims, ctx.GetPlace()); + batch_cell.set_lod(batch_gate->lod()); + to_batch(device_ctx, *cell_out, batch_cell, false); + + LoDTensor batch_cell_g; + batch_cell_g.mutable_data(out_dims, ctx.GetPlace()); + batch_cell_g.set_lod(batch_gate->lod()); + // TODO(qingqing) support the case output cell has gradient. + // to_batch(device_ctx, *cell_g, batch_cell_g, false); + zero(device_ctx, &batch_cell_g, static_cast(0.0)); + + LoDTensor batch_gate_g; + batch_gate_g.mutable_data(batch_gate->dims(), ctx.GetPlace()); + batch_gate_g.set_lod(batch_gate->lod()); + + auto gate_act = ctx.Attr("gateActivation"); + auto cell_act = ctx.Attr("cellActivation"); + auto cand_act = ctx.Attr("candidateActivation"); + + auto batch_starts = batch_gate->lod()[0]; + size_t num_batch = batch_starts.size() - 1; + for (int n = static_cast(num_batch) - 1; n >= 0; n--) { + int bstart = static_cast(batch_starts[n]); + int bend = static_cast(batch_starts[n + 1]); + + Tensor gate = batch_gate->Slice(bstart, bend); + Tensor cell = batch_cell.Slice(bstart, bend); + Tensor cell_pre_act = batch_cell_pre_act->Slice(bstart, bend); + lstm_value.gateValue = gate.data(); + lstm_value.stateValue = cell.data(); + lstm_value.stateActiveValue = cell_pre_act.data(); + + Tensor out_g = batch_hidden_g.Slice(bstart, bend); + Tensor gate_g = batch_gate_g.Slice(bstart, bend); + Tensor cell_g = batch_cell_g.Slice(bstart, bend); + lstm_grad.stateGrad = cell_g.data(); + lstm_grad.gateGrad = gate_g.data(); + lstm_grad.outputGrad = out_g.data(); + + if (n) { + int bstart_pre = static_cast(batch_starts[n - 1]); + Tensor cell_pre = batch_cell.Slice(bstart_pre, bstart); + Tensor cell_pre_g = batch_cell_g.Slice(bstart_pre, bstart); + lstm_value.prevStateValue = cell_pre.data(); + lstm_grad.prevStateGrad = cell_pre_g.data(); + } else { + lstm_value.prevStateValue = nullptr; + lstm_grad.prevStateGrad = nullptr; + } + + int cur_batch_size = bend - bstart; + math::LstmUnitGradFunctor::compute( + device_ctx, lstm_value, lstm_grad, frame_size, cur_batch_size, + gate_act, cell_act, cand_act); + + if (n != 0) { + int pre_h_start = static_cast(batch_starts[n - 1]); + int pre_h_end = pre_h_start + cur_batch_size; + auto pre_hidden_g = batch_hidden_g.Slice(pre_h_start, pre_h_end); + math::matmul(device_ctx, gate_g, false, *weight, true, + static_cast(1.0), &pre_hidden_g, + static_cast(1.0)); + if (weight_g) { + /* backward weight */ + auto pre_hidden = batch_hidden.Slice(pre_h_start, pre_h_end); + math::matmul(device_ctx, pre_hidden, true, gate_g, false, + static_cast(1.0), weight_g, + static_cast(1.0)); + } + } + } + + math::Batch2LoDTensorFunctor to_seq; + if (in_g) { + /* backward data */ + in_g->mutable_data(ctx.GetPlace()); + to_seq(device_ctx, batch_gate_g, *in_g); + } + if (bias && bias_g) { + /* backward bias */ + int m = static_cast(batch_gate_g.dims()[0]); + int n = static_cast(batch_gate_g.dims()[1]); + + Tensor ones; + ones.mutable_data({m}, ctx.GetPlace()); + math::SetConstant set; + set(device_ctx, &ones, static_cast(1.0)); + + math::gemv(device_ctx, true, m, n, 1., batch_gate_g.data(), + ones.data(), 0., bias_g->data()); + } + } }; } // namespace operators diff --git a/paddle/operators/math/detail/lstm_cpu_kernel.h b/paddle/operators/math/detail/lstm_cpu_kernel.h index 74d51d7bc9b91f4c8088384d77183131f57aafab..d0ed55ea168bc3e701c421c51d662c646e475351 100644 --- a/paddle/operators/math/detail/lstm_cpu_kernel.h +++ b/paddle/operators/math/detail/lstm_cpu_kernel.h @@ -26,10 +26,7 @@ namespace detail { template void naive_lstm_forward_one_sequence(Op op, LstmMetaValue value, - int frameSize, - activation_mode_t active_node, - activation_mode_t active_gate, - activation_mode_t active_state) { + int frameSize) { T rValueIn; T rValueIg; T rValueFg; @@ -60,10 +57,8 @@ void naive_lstm_forward_one_sequence(Op op, LstmMetaValue value, rPrevState = value.prevStateValue[i]; } - hppl::cpu::ForwardAct act; op(rValueIn, rValueIg, rValueFg, rValueOg, rPrevState, rState, rStateAtv, - rOut, rCheckI, rCheckF, rCheckO, act(active_node), act(active_gate), - act(active_state)); + rOut, rCheckI, rCheckF, rCheckO); valueIn[i] = rValueIn; valueIg[i] = rValueIg; @@ -77,10 +72,7 @@ void naive_lstm_forward_one_sequence(Op op, LstmMetaValue value, template void naive_lstm_backward_one_sequence(Op op, LstmMetaValue value, - LstmMetaGrad grad, int frameSize, - activation_mode_t active_node, - activation_mode_t active_gate, - activation_mode_t active_state) { + LstmMetaGrad grad, int frameSize) { T rValueIn; T rValueIg; T rValueFg; @@ -127,11 +119,10 @@ void naive_lstm_backward_one_sequence(Op op, LstmMetaValue value, rPrevState = value.prevStateValue[i]; } - hppl::cpu::BackwardAct act; op(rValueIn, rValueIg, rValueFg, rValueOg, rGradIn, rGradIg, rGradFg, rGradOg, rPrevState, rPrevStateGrad, rState, rStateGrad, rStateAtv, rOutputGrad, rCheckI, rCheckF, rCheckO, rCheckIGrad, rCheckFGrad, - rCheckOGrad, act(active_node), act(active_gate), act(active_state)); + rCheckOGrad); gradIn[i] = rGradIn; gradIg[i] = rGradIg; @@ -283,8 +274,7 @@ void cpu_lstm_forward(Op op, LstmMetaValue value, int frameSize, avx_lstm_forward_one_sequence(op, value, frameSize, active_node, active_gate, active_state); } else { - naive_lstm_forward_one_sequence(op, value, frameSize, active_node, - active_gate, active_state); + naive_lstm_forward_one_sequence(op, value, frameSize); } } @@ -297,8 +287,7 @@ void cpu_lstm_backward(Op op, LstmMetaValue value, LstmMetaGrad grad, avx_lstm_backward_one_sequence(op, value, grad, frameSize, active_node, active_gate, active_state); } else { - naive_lstm_backward_one_sequence(op, value, grad, frameSize, active_node, - active_gate, active_state); + naive_lstm_backward_one_sequence(op, value, grad, frameSize); } } diff --git a/paddle/operators/math/detail/lstm_gpu_kernel.h b/paddle/operators/math/detail/lstm_gpu_kernel.h index 9573eaefb6a9d678ef70f2e2bffdc6a3011b21ea..c06f164f84a92d31f89901e2656bdb8e69c533b7 100644 --- a/paddle/operators/math/detail/lstm_gpu_kernel.h +++ b/paddle/operators/math/detail/lstm_gpu_kernel.h @@ -32,9 +32,7 @@ namespace detail { */ template __global__ void KeLstmForward(Op op, LstmMetaValue value, int frameSize, - int batchSize, activation_mode_t active_node, - activation_mode_t active_gate, - activation_mode_t active_state) { + int batchSize) { const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x; if (frameIdx >= frameSize) return; @@ -70,10 +68,8 @@ __global__ void KeLstmForward(Op op, LstmMetaValue value, int frameSize, rPrevState = value.prevStateValue[frameIdx]; } - hppl::gpu::ForwardAct act; op(rValueIn, rValueIg, rValueFg, rValueOg, rPrevState, rState, rStateAtv, - rOut, rCheckI, rCheckF, rCheckO, act(active_node), act(active_gate), - act(active_state)); + rOut, rCheckI, rCheckF, rCheckO); value.gateValue[frameIdx] = rValueIn; value.gateValue[frameIdx + frameSize] = rValueIg; @@ -92,9 +88,7 @@ __global__ void KeLstmForward(Op op, LstmMetaValue value, int frameSize, template __global__ void KeLstmBackward(Op op, LstmMetaValue value, LstmMetaGrad grad, int frameSize, - int batchSize, activation_mode_t active_node, - activation_mode_t active_gate, - activation_mode_t active_state) { + int batchSize) { const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x; if (frameIdx >= frameSize) return; @@ -145,11 +139,9 @@ __global__ void KeLstmBackward(Op op, LstmMetaValue value, rPrevState = value.prevStateValue[frameIdx]; } - hppl::gpu::BackwardAct act; op(rValueIn, rValueIg, rValueFg, rValueOg, rGradIn, rGradIg, rGradFg, rGradOg, rPrevState, rPrevStateGrad, rState, rStateGrad, rStateAtv, rOutputGrad, - rCheckI, rCheckF, rCheckO, rCheckIGrad, rCheckFGrad, rCheckOGrad, - act(active_node), act(active_gate), act(active_state)); + rCheckI, rCheckF, rCheckO, rCheckIGrad, rCheckFGrad, rCheckOGrad); grad.gateGrad[frameIdx] = rGradIn; grad.gateGrad[frameIdx + frameSize] = rGradIg; @@ -205,13 +197,11 @@ void gpu_lstm_forward(const platform::DeviceContext& context, Op op, if (batchSize == 1) { KeLstmForward<<>>( - op, value, frameSize, batchSize, active_node, active_gate, - active_state); + op, value, frameSize, batchSize); } else { KeLstmForward<<>>( - op, value, frameSize, batchSize, active_node, active_gate, - active_state); + op, value, frameSize, batchSize); } } @@ -240,13 +230,11 @@ void gpu_lstm_backward(const platform::DeviceContext& context, Op op, if (batchSize == 1) { KeLstmBackward<<>>( - op, value, grad, frameSize, batchSize, active_node, active_gate, - active_state); + op, value, grad, frameSize, batchSize); } else { KeLstmBackward<<>>( - op, value, grad, frameSize, batchSize, active_node, active_gate, - active_state); + op, value, grad, frameSize, batchSize); } } diff --git a/paddle/operators/math/detail/lstm_kernel.h b/paddle/operators/math/detail/lstm_kernel.h index 6f3ead2397d5131b4468d0ad288513cedb289594..461039a4d51a2b9b8a55d3101bdf4c511907597e 100644 --- a/paddle/operators/math/detail/lstm_kernel.h +++ b/paddle/operators/math/detail/lstm_kernel.h @@ -24,15 +24,29 @@ namespace detail { namespace forward { +template +DEVICE inline T sigmoid(const T a) { + const T min = SIGMOID_THRESHOLD_MIN; + const T max = SIGMOID_THRESHOLD_MAX; + T tmp = (a < min) ? min : ((a > max) ? max : a); + return static_cast(1.0) / (static_cast(1.0) + exp(-tmp)); +} + +template +DEVICE inline T tanh(const T a) { + T tmp = -2.0 * a; + tmp = (tmp > EXP_MAX_INPUT) ? EXP_MAX_INPUT : tmp; + return (2.0 / (1.0 + exp(tmp))) - 1.0; +} + template class lstm { public: HOSTDEVICE void operator()(T &valueIn, T &valueIg, T &valueFg, T &valueOg, T &prevState, T &state, T &stateAtv, T &output, - T &checkI, T &checkF, T &checkO, - typename hppl::ForwardActType::type actInput, - typename hppl::ForwardActType::type actGate, - typename hppl::ForwardActType::type actState) { + T &checkI, T &checkF, T &checkO) { +#if 0 + // TODO(qingqing) support to activation speficed by users valueIn = actInput(valueIn); valueIg = actGate(valueIg + prevState * checkI); valueFg = actGate(valueFg + prevState * checkF); @@ -40,6 +54,15 @@ class lstm { valueOg = actGate(valueOg + state * checkO); stateAtv = actState(state); output = valueOg * stateAtv; +#else + valueIn = tanh(valueIn); + valueIg = sigmoid(valueIg + prevState * checkI); + valueFg = sigmoid(valueFg + prevState * checkF); + state = valueIn * valueIg + prevState * valueFg; + valueOg = sigmoid(valueOg + state * checkO); + stateAtv = tanh(state); + output = valueOg * stateAtv; +#endif } #ifndef __NVCC__ #ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default @@ -72,6 +95,16 @@ class lstm { namespace backward { +template +DEVICE inline T sigmoid(const T a, const T b) { + return a * b * (1.0 - b); +} + +template +DEVICE inline T tanh(const T a, const T b) { + return a * (1.0 - b * b); +} + template class lstm { public: @@ -80,10 +113,9 @@ class lstm { T &prevState, T &prevStateGrad, T &state, T &stateGrad, T &stateAtv, T &outputGrad, T &checkI, T &checkF, T &checkO, T &checkIGrad, - T &checkFGrad, T &checkOGrad, - typename hppl::BackwardActType::type actInput, - typename hppl::BackwardActType::type actGate, - typename hppl::BackwardActType::type actState) { + T &checkFGrad, T &checkOGrad) { +#if 0 + // TODO(qingqing) support to activation speficed by users gradOg = actGate(outputGrad * stateAtv, valueOg); stateGrad += actState(outputGrad * valueOg, stateAtv) + gradOg * checkO; gradIn = actInput(stateGrad * valueIg, valueIn); @@ -93,6 +125,17 @@ class lstm { checkIGrad = gradIg * prevState; checkFGrad = gradFg * prevState; checkOGrad = gradOg * state; +#else + gradOg = sigmoid(outputGrad * stateAtv, valueOg); + stateGrad += tanh(outputGrad * valueOg, stateAtv) + gradOg * checkO; + gradIn = tanh(stateGrad * valueIg, valueIn); + gradIg = sigmoid(stateGrad * valueIn, valueIg); + gradFg = sigmoid(stateGrad * prevState, valueFg); + prevStateGrad = gradIg * checkI + gradFg * checkF + stateGrad * valueFg; + checkIGrad = gradIg * prevState; + checkFGrad = gradFg * prevState; + checkOGrad = gradOg * state; +#endif } #ifndef __NVCC__ #ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default diff --git a/paddle/operators/math/math_function.cc b/paddle/operators/math/math_function.cc index aad1357598c629a4edfe0ad9b23f0241093a2522..2a9c09a0f16b71473e21765ab9253eb7b8bcf28c 100644 --- a/paddle/operators/math/math_function.cc +++ b/paddle/operators/math/math_function.cc @@ -211,6 +211,26 @@ void batched_gemm( } #endif +template <> +void gemv(const platform::DeviceContext& context, + const bool trans_a, const int M, + const int N, const float alpha, + const float* A, const float* B, + const float beta, float* C) { + CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans; + cblas_sgemv(CblasRowMajor, transA, M, N, alpha, A, N, B, 1, beta, C, 1); +} + +template <> +void gemv(const platform::DeviceContext& context, + const bool trans_a, const int M, + const int N, const double alpha, + const double* A, const double* B, + const double beta, double* C) { + CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans; + cblas_dgemv(CblasRowMajor, transA, M, N, alpha, A, N, B, 1, beta, C, 1); +} + template struct SetConstant; } // namespace math diff --git a/paddle/operators/math/math_function.cu b/paddle/operators/math/math_function.cu index 5583683c6e12b88ba81015aef9161913de261ef2..e6fd8bf235b8539702ca2c5b39e305cb1becf5cb 100644 --- a/paddle/operators/math/math_function.cu +++ b/paddle/operators/math/math_function.cu @@ -203,6 +203,33 @@ void batched_gemm( &beta, C, ldc, strideC, batchCount)); } +template <> +void gemv(const platform::DeviceContext& context, + const bool trans_a, const int M, + const int N, const float alpha, + const float* A, const float* B, + const float beta, float* C) { + cublasOperation_t cuTransA = (trans_a == false) ? CUBLAS_OP_T : CUBLAS_OP_N; + + PADDLE_ENFORCE(platform::dynload::cublasSgemv( + reinterpret_cast(context) + .cublas_handle(), + cuTransA, N, M, &alpha, A, N, B, 1, &beta, C, 1)); +} + +template <> +void gemv(const platform::DeviceContext& context, + const bool trans_a, const int M, + const int N, const double alpha, + const double* A, const double* B, + const double beta, double* C) { + cublasOperation_t cuTransA = (trans_a == false) ? CUBLAS_OP_T : CUBLAS_OP_N; + PADDLE_ENFORCE(platform::dynload::cublasDgemv( + reinterpret_cast(context) + .cublas_handle(), + cuTransA, N, M, &alpha, A, N, B, 1, &beta, C, 1)); +} + template struct SetConstant; } // namespace math diff --git a/paddle/operators/math/math_function.h b/paddle/operators/math/math_function.h index 9777ebfd156709a370be2cb4ba0077ac7c6735fb..3bb5aa0332c7e2a63d20b91893c03ccd468dd863 100644 --- a/paddle/operators/math/math_function.h +++ b/paddle/operators/math/math_function.h @@ -93,6 +93,11 @@ void batched_gemm(const platform::DeviceContext& context, const T* A, const T* B, const T beta, T* C, const int batchCount, const int strideA, const int strideB); +template +void gemv(const platform::DeviceContext& context, const bool trans_a, + const int M, const int N, const T alpha, const T* A, const T* B, + const T beta, T* C); + template struct SetConstant { void operator()(const platform::DeviceContext& context, diff --git a/paddle/operators/math/math_function_test.cc b/paddle/operators/math/math_function_test.cc index 3b9f92e7ae5f34dd0fb1ba8fb0c67ff5ae1628c4..7d84ad9aadb2892db0d0ee9cab428dc5036614e9 100644 --- a/paddle/operators/math/math_function_test.cc +++ b/paddle/operators/math/math_function_test.cc @@ -89,3 +89,53 @@ TEST(math_function, zero) { EXPECT_EQ(t[2], 1); EXPECT_EQ(t[3], 1); } + +template +void GemvTest(int m, int n, bool trans) { + paddle::framework::Tensor mat_a; + paddle::framework::Tensor vec_b; + paddle::framework::Tensor vec_c; + auto* cpu_place = new paddle::platform::CPUPlace(); + int b_num = trans ? m : n; + int c_num = trans ? n : m; + + T* data_a = mat_a.mutable_data({m, n}, *cpu_place); + T* data_b = vec_b.mutable_data({b_num}, *cpu_place); + T* data_c = vec_c.mutable_data({c_num}, *cpu_place); + for (int i = 0; i < mat_a.numel(); ++i) { + data_a[i] = static_cast(i); + } + for (int i = 0; i < vec_b.numel(); ++i) { + data_b[i] = static_cast(i); + } + + paddle::platform::CPUDeviceContext context(*cpu_place); + paddle::operators::math::gemv( + context, trans, static_cast(m), static_cast(n), 1., data_a, + data_b, 0., data_c); + + if (!trans) { + for (int i = 0; i < m; ++i) { + T sum = 0.0; + for (int j = 0; j < n; ++j) { + sum += data_a[i * n + j] * data_b[j]; + } + ASSERT_FLOAT_EQ(data_c[i], sum); + } + } else { + for (int i = 0; i < n; ++i) { + T sum = 0.0; + for (int j = 0; j < m; ++j) { + sum += data_a[j * n + i] * data_b[j]; + } + ASSERT_FLOAT_EQ(data_c[i], sum); + } + } +} + +TEST(math_function, gemv) { + GemvTest(3, 13, false); + GemvTest(4, 5, false); + GemvTest(12, 7, true); + GemvTest(7, 9, true); +} diff --git a/paddle/operators/math/math_function_test.cu b/paddle/operators/math/math_function_test.cu index 8b22c71552a65044cbd02441fb35c1eafe0173dc..780d17ffc6539c5f4d67ebab5476d6f646840b41 100644 --- a/paddle/operators/math/math_function_test.cu +++ b/paddle/operators/math/math_function_test.cu @@ -177,3 +177,65 @@ TEST(math_function, gemm_trans_cublas) { EXPECT_EQ(input3_ptr[7], 99); delete gpu_place; } + +template +void GemvTest(int m, int n, bool trans) { + paddle::framework::Tensor mat_a; + paddle::framework::Tensor vec_b; + paddle::framework::Tensor vec_c; + auto* cpu_place = new paddle::platform::CPUPlace(); + + T* data_a = mat_a.mutable_data({m, n}, *cpu_place); + T* data_b = vec_b.mutable_data({trans ? m : n}, *cpu_place); + T* data_c = vec_c.mutable_data({trans ? n : m}, *cpu_place); + + auto* gpu_place = new paddle::platform::GPUPlace(0); + paddle::framework::Tensor g_mat_a; + paddle::framework::Tensor g_vec_b; + paddle::framework::Tensor g_vec_c; + T* g_data_a = g_mat_a.mutable_data(mat_a.dims(), *gpu_place); + T* g_data_b = g_vec_b.mutable_data(vec_b.dims(), *gpu_place); + T* g_data_c = g_vec_c.mutable_data(vec_c.dims(), *gpu_place); + + for (int i = 0; i < mat_a.numel(); ++i) { + data_a[i] = static_cast(i); + } + for (int i = 0; i < vec_b.numel(); ++i) { + data_b[i] = static_cast(i); + } + + paddle::platform::CUDADeviceContext context(*gpu_place); + g_mat_a.CopyFrom(mat_a, *gpu_place, context); + g_vec_b.CopyFrom(vec_b, *gpu_place, context); + + paddle::operators::math::gemv( + context, trans, static_cast(m), static_cast(n), 1., g_data_a, + g_data_b, 0., g_data_c); + + vec_c.CopyFrom(g_vec_c, paddle::platform::CPUPlace(), context); + + if (!trans) { + for (int i = 0; i < m; ++i) { + T sum = 0.0; + for (int j = 0; j < n; ++j) { + sum += data_a[i * n + j] * data_b[j]; + } + ASSERT_FLOAT_EQ(data_c[i], sum); + } + } else { + for (int i = 0; i < n; ++i) { + T sum = 0.0; + for (int j = 0; j < m; ++j) { + sum += data_a[j * n + i] * data_b[j]; + } + ASSERT_FLOAT_EQ(data_c[i], sum); + } + } +} + +TEST(math_function, gemv) { + GemvTest(3, 13, false); + GemvTest(3, 13, false); + GemvTest(3, 13, true); + GemvTest(3, 13, true); +} diff --git a/paddle/operators/math/sequence2batch.h b/paddle/operators/math/sequence2batch.h index 03cd018e46e90c9bbe689c9686377e0e998ee513..b1ba35a6d4a891e9152ac2088bc76e3969be6405 100644 --- a/paddle/operators/math/sequence2batch.h +++ b/paddle/operators/math/sequence2batch.h @@ -53,7 +53,18 @@ class LoDTensor2BatchFunctor { public: void operator()(const platform::DeviceContext& context, const framework::LoDTensor& lod_tensor, - framework::LoDTensor& batch, bool is_reverse) const { + framework::LoDTensor& batch, bool is_cal_batch_lod, + bool is_reverse = false) const { + if (!is_cal_batch_lod) { + auto lods = batch.lod(); + PADDLE_ENFORCE_EQ(lods.size(), 2UL); + PADDLE_ENFORCE_EQ(lods[1].size(), + static_cast(lod_tensor.dims()[0])); + CopyMatrixRowsFunctor to_batch; + to_batch(context, lod_tensor, lods[1].data(), batch, true); + return; + } + auto lods = lod_tensor.lod(); PADDLE_ENFORCE_EQ(lods.size(), 1UL, "Only support one level sequence now."); auto lod = lods[0]; @@ -101,10 +112,10 @@ class LoDTensor2BatchFunctor { size_t* batch_starts = batch_lods[0].data(); size_t* seq2batch_idx = batch_lods[1].data(); batch_starts[0] = 0; - for (size_t n = 0; n < num_batch; n++) { + for (int n = 0; n < num_batch; n++) { auto batch_id = static_cast(batch_starts[n]); for (size_t i = 0; i < seq_info.size(); ++i) { - size_t seq_len = seq_info[i].length; + int seq_len = seq_info[i].length; int start = seq_info[i].start; if (n < seq_len) { seq2batch_idx[batch_id] = @@ -132,11 +143,8 @@ class Batch2LoDTensorFunctor { auto in_lod = batch.lod(); PADDLE_ENFORCE_EQ(in_lod.size(), 2UL, "The LoD size of input `batch` should be 2."); - auto out_lod = lod_tensor.lod()[0]; - auto num = out_lod[out_lod.size() - 1]; - PADDLE_ENFORCE_EQ(num, lod_tensor.dims()[0]); - PADDLE_ENFORCE_EQ(num, in_lod[1].size()); - PADDLE_ENFORCE_EQ(num, batch.dims()[0]); + PADDLE_ENFORCE_EQ(in_lod[1].size(), + static_cast(lod_tensor.dims()[0])); CopyMatrixRowsFunctor to_seq; size_t* index = in_lod[1].data(); to_seq(context, batch, index, lod_tensor, false); diff --git a/paddle/operators/nccl_op_test.cu b/paddle/operators/nccl_op_test.cu index 80c50a28a9e5d560fc693c518b9e62091ddc5724..e5927d56ae7cfbd09e941c993041af46ecd8d70d 100644 --- a/paddle/operators/nccl_op_test.cu +++ b/paddle/operators/nccl_op_test.cu @@ -185,7 +185,7 @@ TEST_F(NCCLTester, ncclAllReduceOp) { recv_tensor.numel() * sizeof(float), static_cast(dev_ctxs[i])->stream()); - for (size_t j = 0; j < f::product(kDims); ++j) { + for (int64_t j = 0; j < f::product(kDims); ++j) { ASSERT_NEAR(ct[j], result, 1e-5); } } @@ -234,7 +234,7 @@ TEST_F(NCCLTester, ncclReduceOp) { recv_tensor.numel() * sizeof(float), static_cast(dev_ctxs[kRoot])->stream()); - for (int j = 0; j < f::product(kDims); ++j) { + for (int64_t j = 0; j < f::product(kDims); ++j) { ASSERT_NEAR(ct[j], result, 1e-5); } } @@ -282,7 +282,7 @@ TEST_F(NCCLTester, ncclBcastOp) { recv_tensor.numel() * sizeof(float), static_cast(dev_ctxs[idx])->stream()); - for (size_t j = 0; j < f::product(kDims); ++j) { + for (int64_t j = 0; j < f::product(kDims); ++j) { ASSERT_NEAR(ct[j], result, 1e-5); } } diff --git a/paddle/operators/reshape_op.cc b/paddle/operators/reshape_op.cc index eda8226480a66ae1a631391e9335db04604039c5..9213cc7a85822e4c78ef72aec2bf86d2edac023a 100644 --- a/paddle/operators/reshape_op.cc +++ b/paddle/operators/reshape_op.cc @@ -36,7 +36,7 @@ class ReshapeOp : public framework::OperatorWithKernel { PADDLE_ENFORCE(shape.size() > 0, "Attr(shape) shouldn't be empty."); auto x_dims = ctx->GetInputDim("X"); // TODO(qiao) change batch_size - for (int i = 1; i < shape.size(); ++i) { + for (size_t i = 1; i < shape.size(); ++i) { PADDLE_ENFORCE(shape[i] > 0, "Each dimension of shape " "must be positiv except the first."); diff --git a/paddle/operators/save_load_op_test.cc b/paddle/operators/save_load_op_test.cc index fe2b15ec09c6d29ad5f78e5c36f534c6a88497e6..a57466a48d4d6016fe2618d19fdca4c4f667124a 100644 --- a/paddle/operators/save_load_op_test.cc +++ b/paddle/operators/save_load_op_test.cc @@ -34,7 +34,7 @@ TEST(SaveLoadOp, CPU) { tensor->set_lod(expect_lod); int* expect = tensor->mutable_data(place); - for (size_t i = 0; i < paddle::framework::product(tensor->dims()); ++i) { + for (int64_t i = 0; i < tensor->numel(); ++i) { expect[i] = static_cast(i); } paddle::framework::AttributeMap attrs; @@ -50,7 +50,7 @@ TEST(SaveLoadOp, CPU) { "load", {}, {{"Out", {"out_var"}}}, attrs); load_op->Run(scope, ctx); int* actual = target->data(); - for (size_t i = 0; i < paddle::framework::product(tensor->dims()); ++i) { + for (int64_t i = 0; i < tensor->numel(); ++i) { EXPECT_EQ(expect[i], actual[i]); } auto& actual_lod = target->lod(); @@ -60,4 +60,4 @@ TEST(SaveLoadOp, CPU) { EXPECT_EQ(expect_lod[i][j], actual_lod[i][j]); } } -} \ No newline at end of file +} diff --git a/paddle/operators/sequence_pool_op.cc b/paddle/operators/sequence_pool_op.cc index 6d600c27271c660f0cf933e8bd05455df61740ec..29d19df10898634dd433abc1263fefe169de6f08 100644 --- a/paddle/operators/sequence_pool_op.cc +++ b/paddle/operators/sequence_pool_op.cc @@ -39,15 +39,14 @@ class SequencePoolOpMaker : public framework::OpProtoAndCheckerMaker { AddOutput("Out", "(Tensor), output of SequencePoolOp, which does not contain LoD " "infomation."); - AddAttr( - "strategy", - "(int, default AVERAGE) the pooling strategy of SequencePoolOp.") - .SetDefault(AVERAGE) - .InEnum({AVERAGE, SUM, SQRT, MAX, LAST, FIRST}); + AddAttr( + "pooltype", + "(int, default AVERAGE) the pooling pooltype of SequencePoolOp.") + .SetDefault("AVERAGE"); AddComment(R"DOC( SequencePoolOp pools features of all time-steps of each instance. - It supports six pooling strategy: + It supports six pooling pooltype: - AVERAGE: Out[i] = average_{for each instance in i-th sequence}{X[i]} - SUM: Out[i] = sum_{for each instance in i-th sequence}{X[i]} - SQRT: Out[i] = sum_{for each instance in i-th sequence}{X[i]} @@ -63,7 +62,7 @@ class SequencePoolOpMaker : public framework::OpProtoAndCheckerMaker { and the value of X = [[1, 3], [2, 4, 6], [5, 1]]. Thus, Out is a [3,1,1] Tensor without LoD infomation. - And for different strategy, the value of Out is as follows: + And for different pooltype, the value of Out is as follows: - AVERAGE: [2, 4, 3], where 2=(1+3)/2, 4=(2+4+6)/3, 3=(5+1)/2 - SUM: [4, 12, 6], where 4=1+3, 12=2+4+6, 6=5+1 diff --git a/paddle/operators/sequence_pool_op.h b/paddle/operators/sequence_pool_op.h index 07bf61df45bf51c8648180ffc9eb97306865fab6..e0e0493fe0ef7e1963ce5c2e3f37c164a605809b 100644 --- a/paddle/operators/sequence_pool_op.h +++ b/paddle/operators/sequence_pool_op.h @@ -29,22 +29,13 @@ template using EigenMatrix = framework::EigenMatrix; -enum SeqPoolType { - AVERAGE = 0, - SUM = 1, - SQRT = 2, // square_root_n - MAX = 3, - LAST = 4, - FIRST = 5 -}; - template class SequencePoolKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* in = context.Input("X"); auto* out = context.Output("Out"); - int strategy = context.Attr("strategy"); + std::string pooltype = context.Attr("pooltype"); auto dims = in->dims(); auto lod = in->lod(); @@ -71,28 +62,21 @@ class SequencePoolKernel : public framework::OpKernel { auto in_e = EigenMatrix::From(in_t, framework::make_ddim({h, w})); auto out_e = EigenVector::Flatten(out_t); - switch (strategy) { - case AVERAGE: - out_e.device(place) = in_e.mean(Eigen::array({{0}})); - break; - case SUM: - out_e.device(place) = in_e.sum(Eigen::array({{0}})); - break; - case SQRT: - out_e.device(place) = in_e.sum(Eigen::array({{0}})) / - std::sqrt(static_cast(h)); - break; - case MAX: - out_e.device(place) = in_e.maximum(Eigen::array({{0}})); - break; - case LAST: - out_e.device(place) = in_e.chip(h - 1, 0); - break; - case FIRST: - out_e.device(place) = in_e.chip(0, 0); - break; - default: - PADDLE_THROW("unsupported pooling strategy"); + if (pooltype == "AVERAGE") { + out_e.device(place) = in_e.mean(Eigen::array({{0}})); + } else if (pooltype == "SUM") { + out_e.device(place) = in_e.sum(Eigen::array({{0}})); + } else if (pooltype == "SQRT") { + out_e.device(place) = in_e.sum(Eigen::array({{0}})) / + std::sqrt(static_cast(h)); + } else if (pooltype == "MAX") { + out_e.device(place) = in_e.maximum(Eigen::array({{0}})); + } else if (pooltype == "LAST") { + out_e.device(place) = in_e.chip(h - 1, 0); + } else if (pooltype == "FIRST") { + out_e.device(place) = in_e.chip(0, 0); + } else { + PADDLE_THROW("unsupported pooling pooltype"); } } } @@ -105,15 +89,15 @@ class SequencePoolGradKernel : public framework::OpKernel { auto* in = context.Input("X"); auto* in_g = context.Output(framework::GradVarName("X")); auto* out_g = context.Input(framework::GradVarName("Out")); - int strategy = context.Attr("strategy"); + std::string pooltype = context.Attr("pooltype"); auto dims = in->dims(); auto lod = in->lod()[0]; int64_t w = in->numel() / dims[0]; in_g->mutable_data(context.GetPlace()); - if (strategy == LAST || strategy == FIRST) { - // set X@Grad be zero at first when strategy is LAST/FIRST + if (pooltype == "LAST" || pooltype == "FIRST") { + // set X@Grad be zero at first when pooltype is LAST/FIRST math::SetConstant functor; functor(context.device_context(), in_g, 0); } @@ -127,41 +111,33 @@ class SequencePoolGradKernel : public framework::OpKernel { auto out_g_e = EigenMatrix::From(out_g_t, {1, w}); Eigen::DSizes bcast(h, 1); - switch (strategy) { - case AVERAGE: - in_g_e.device(place) = (out_g_e / static_cast(h)).broadcast(bcast); - break; - case SUM: - in_g_e.device(place) = (out_g_e).broadcast(bcast); - break; - case SQRT: - in_g_e.device(place) = - (out_g_e / std::sqrt(static_cast(h))).broadcast(bcast); - break; - case MAX: { - auto in_t = - in->Slice(static_cast(lod[i]), static_cast(lod[i + 1])); - Eigen::Map> - in_t_map(in_t.data(), h, w); - int row_id; - Eigen::array extents{{1, 1}}; - for (int col_id = 0; col_id < w; col_id++) { - in_t_map.col(col_id).maxCoeff(&row_id); - Eigen::array in_offsets{{row_id, col_id}}; - Eigen::array out_offsets{{0, col_id}}; - in_g_e.slice(in_offsets, extents).device(place) = - out_g_e.slice(out_offsets, extents); - } - break; + if (pooltype == "AVERAGE") { + in_g_e.device(place) = (out_g_e / static_cast(h)).broadcast(bcast); + } else if (pooltype == "SUM") { + in_g_e.device(place) = (out_g_e).broadcast(bcast); + } else if (pooltype == "SQRT") { + in_g_e.device(place) = + (out_g_e / std::sqrt(static_cast(h))).broadcast(bcast); + } else if (pooltype == "MAX") { + auto in_t = + in->Slice(static_cast(lod[i]), static_cast(lod[i + 1])); + Eigen::Map> + in_t_map(in_t.data(), h, w); + int row_id; + Eigen::array extents{{1, 1}}; + for (int col_id = 0; col_id < w; col_id++) { + in_t_map.col(col_id).maxCoeff(&row_id); + Eigen::array in_offsets{{row_id, col_id}}; + Eigen::array out_offsets{{0, col_id}}; + in_g_e.slice(in_offsets, extents).device(place) = + out_g_e.slice(out_offsets, extents); } - case LAST: - in_g_e.chip(h - 1, 0).device(place) = out_g_e; - break; - case FIRST: - in_g_e.chip(0, 0).device(place) = out_g_e; - break; - default: - PADDLE_THROW("unsupported pooling strategy"); + } else if (pooltype == "LAST") { + in_g_e.chip(h - 1, 0).device(place) = out_g_e; + } else if (pooltype == "FIRST") { + in_g_e.chip(0, 0).device(place) = out_g_e; + } else { + PADDLE_THROW("unsupported pooling pooltype"); } } } diff --git a/paddle/operators/softmax_with_cross_entropy_op.cc b/paddle/operators/softmax_with_cross_entropy_op.cc index 942fbb42df8bb90b86bd097832a15b320a857750..50497da1b70d39d2638240dd91035c9181124af9 100644 --- a/paddle/operators/softmax_with_cross_entropy_op.cc +++ b/paddle/operators/softmax_with_cross_entropy_op.cc @@ -32,9 +32,9 @@ class SoftmaxWithCrossEntropyOpMaker AddInput("Label", "(Tensor, default: Tensor), The ground truth which is a 2-D " "tensor. " - "If softLable is set to 0, Label is a Tensor with shape [N x " - "1]. " - "If softLable is set to 1, Label is a Tensor " + "If softLabel is set to false, Label is a Tensor with shape " + "[N x 1]." + "If softLabel is set to true, Label is a Tensor " "with shape [N x K]."); AddOutput( "Softmax", @@ -60,19 +60,23 @@ Because this operators performs a softmax on logits internally, it expects unscaled logits. Please do not call this op with the output of softmax operator, which will produce incorrect results. -This operators expects mutually exclusive hard labels, each sample in a batch -is in exactly one class with probabilities 1. Each sample in the batch with one -and only one label. +When the attribute softLabel is set false, this operators expects mutually +exclusive hard labels, each sample in a batch is in exactly one class with +probabilities 1. Each sample in the batch with one and only one label. Equation: 1) hard label (one-hot label) -Loss_j = -\text{Logit}_{Label_j} + \log\left(\sum_{i=0}^{K}\exp(\text{Logit}_i)\right), j = 1, ..., K +Loss_j = \f$ -\text{Logit}_{Label_j} + +\log\left(\sum_{i=0}^{K}\exp(\text{Logit}_i)\right), +j = 1, ..., K $\f 2) soft label (a distribution over all classes) -Loss_j = -\sum_{i=0}^{K}\text{Label}_i\left(\text{Logit}_i-\log\left(\sum_{i=0}^{K}\exp(\text{Logit}_i)\right)\right), j = 1,...,K +Loss_j = \f$ -\sum_{i=0}^{K}\text{Label}_i\left(\text{Logit}_i - +\log\left(\sum_{i=0}^{K}\exp(\text{Logit}_i)\right)\right), +j = 1,...,K $\f )DOC"); } diff --git a/paddle/scripts/docker/build_android.sh b/paddle/scripts/docker/build_android.sh index 11612ad4bed0afa8496087605afaefbd0420d5ce..6ef45d33d8c9e32e564555854c10a6fe15e4fd9f 100644 --- a/paddle/scripts/docker/build_android.sh +++ b/paddle/scripts/docker/build_android.sh @@ -4,6 +4,10 @@ set -xe if [ $ANDROID_ABI == "arm64-v8a" ]; then ANDROID_ARCH=arm64 + if [ $ANDROID_API -lt 21 ]; then + echo "Warning: arm64-v8a requires ANDROID_API >= 21." + ANDROID_API=21 + fi else # armeabi, armeabi-v7a ANDROID_ARCH=arm fi diff --git a/paddle/trainer/MergeModel.cpp b/paddle/trainer/MergeModel.cpp index a70673ffec8812d86b9a0c13f15ef0b378dcf3ce..f3cfd9f97fea837e8f666f2eabee5a75659a4e42 100644 --- a/paddle/trainer/MergeModel.cpp +++ b/paddle/trainer/MergeModel.cpp @@ -27,6 +27,13 @@ using namespace paddle; // NOLINT using namespace std; // NOLINT int main(int argc, char** argv) { + if (FLAGS_model_dir.empty() || FLAGS_config_file.empty() || + FLAGS_model_file.empty()) { + LOG(INFO) << "Usage: ./paddle_merge_model --model_dir=pass-00000 " + "--config_file=config.py --model_file=out.paddle"; + return 0; + } + initMain(argc, argv); initPython(argc, argv); diff --git a/python/paddle/v2/framework/layers.py b/python/paddle/v2/framework/layers.py index dab72f0195fd9f7ec8219be7867e8ddde9036db9..86a2c7bf08b09638fd065f96300fc4f7ffc332d5 100644 --- a/python/paddle/v2/framework/layers.py +++ b/python/paddle/v2/framework/layers.py @@ -351,32 +351,21 @@ def conv2d(input, return helper.append_activation(pre_act) -def sequence_pool(input, pool_type, program=None, init_program=None): - # FIXME(dzh) : want to unify the argument of python layer - # function. So we ignore some unecessary attributes - - ENUM_POOL_TYPE = dict({ - "AVERAGE": 0, - "SUM": 1, - "SQRT": 2, - "MAX": 3, - "LAST": 4, - "FIRST": 5 - }) +def sequence_pool(input, pool_type, **kwargs): + ENUM_POOL_TYPE = set(["MAX", "AVG", "SQRT", "LAST", "FIRST"]) if pool_type.upper() not in ENUM_POOL_TYPE: raise ValueError("Unknown pool_type: '%s'. It can only be %s.", - str(pool_type), " ".join(ENUM_POOL_TYPE.keys())) + str(pool_type), " ".join(ENUM_POOL_TYPE)) - helper = LayerHelper('sequence_pool', **locals()) + helper = LayerHelper('sequence_pool', **kwargs) dtype = helper.input_dtype() pool_out = helper.create_tmp_variable(dtype) - # FIXME(dzh): strategy helper.append_op( type="sequence_pool", inputs={"X": [input]}, outputs={"Out": [pool_out]}, - attrs={"strategy": ENUM_POOL_TYPE[pool_type.upper()]}) + attrs={"pooltype": pool_type.upper()}) return pool_out diff --git a/python/paddle/v2/framework/tests/test_linear_chain_crf_op.py b/python/paddle/v2/framework/tests/test_linear_chain_crf_op.py new file mode 100644 index 0000000000000000000000000000000000000000..6f06a66c825b37ee91214efc0a29a58f0b9057f9 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_linear_chain_crf_op.py @@ -0,0 +1,142 @@ +import unittest +import random +import numpy as np + +from op_test import OpTest + + +class LinearChainCrfForward(object): + def __init__(self, seq_start_positions, emission_weights, emission_row_max, + emission_exps, transition_weights, transition_exps, labels): + self.tag_num = emission_weights.shape[1] + self.seq_num = len(seq_start_positions) - 1 + + self.seq_start_positions = seq_start_positions + self.labels = labels + self.x = emission_weights + + self.x_row_max = emission_row_max + self.x_exps = emission_exps + + # unnormalized logits of the transition weights for the start mark. + self.a = transition_weights[0, :] + self.a_exps = transition_exps[0, :] + # unnormalized logits of the transition weights for the end mark. + self.b = transition_weights[1, :] + self.b_exps = transition_exps[1, :] + # unnormalized logits of the transition weights for all the other tags. + self.w = transition_weights[2:, :] + self.w_exps = transition_exps[2:, :] + + # The output of linear chain crf operator. + # alpha is a memo table in dynamic programming to caculate + # nomalization factor. + self.alpha = np.zeros( + (seq_start_positions[-1], self.tag_num), dtype="float64") + self.log_likelihood = np.zeros((self.seq_num, 1)) + + def _l1_norm(self, x): + s = np.sum(x) + x /= s + return s + + def _forward_a_sequence(self, x, x_row_max, x_exps, label, alpha): + seq_len = x_row_max.shape[0] + log_likelihood = 0. + + for i in range(self.tag_num): + alpha[0, i] = self.a_exps[i] * x_exps[0, i] + log_likelihood = -x_row_max[0] - np.log(self._l1_norm(alpha[0, :])) + + # calculate the unnormalized logits of the normalization factor. + for k in range(1, seq_len): + for i in range(self.tag_num): + s = 0. + for j in range(self.tag_num): + s += alpha[k - 1, j] * self.w_exps[j, i] + alpha[k, i] = x_exps[k, i] * s + log_likelihood -= x_row_max[k] + np.log(self._l1_norm(alpha[k, :])) + s = 0. + for i in range(self.tag_num): + s += alpha[-1, i] * self.b_exps[i] + log_likelihood -= np.log(s) + + # calculate the nominator part. + log_likelihood += ( + self.a[label[0]] + x[0, label[0]] + self.b[label[-1]]) + + for k in range(1, seq_len): + log_likelihood += (x[k, label[k]] + self.w[label[k - 1], label[k]]) + return -log_likelihood + + def crf_forward_compute(self): + for i in range(self.seq_num): + start = self.seq_start_positions[i] + end = self.seq_start_positions[i + 1] + + self.log_likelihood[i] = self._forward_a_sequence( + self.x[start:end, :], self.x_row_max[start:end, :], + self.x_exps[start:end, :], self.labels[start:end, :], + self.alpha[start:end, :]) + return self.alpha, self.log_likelihood + + +class TestLinearChainCrfOp(OpTest): + def set_test_data(self): + # TODO(caoying) Fix the unittest by: add the boundary cases when + # sequence lengths are 1, 2, and 3. + + SEQ_NUM = 3 + TAG_NUM = 17 + MAX_SEQ_LEN = 5 + + # the linear_chain_crf operator only supports sequence (LoD level = 1) + lod = [[0]] + for i in range(SEQ_NUM): + lod[-1].append(lod[-1][-1] + random.randint(1, MAX_SEQ_LEN)) + emission = np.random.uniform(-1, 1, + [lod[-1][-1], TAG_NUM]).astype("float64") + emission_row_max = np.amax(emission, axis=1, keepdims=True) + emission_exps = np.exp(emission - emission_row_max) + + transition = np.random.uniform(-0.5, 0.5, + [TAG_NUM + 2, TAG_NUM]).astype("float64") + transition_exps = np.exp(transition) + + labels = np.random.randint( + low=0, high=TAG_NUM, size=(lod[-1][-1], 1), dtype="int32") + + self.inputs = { + "Emission": (emission, lod), + "Transition": transition, + "Label": (labels, lod) + } + crf = LinearChainCrfForward(lod[0], emission, emission_row_max, + emission_exps, transition, transition_exps, + labels) + alpha, log_likelihood = crf.crf_forward_compute() + + self.outputs = { + "Alpha": alpha, + "EmissionExps": emission_exps, + "TransitionExps": transition_exps, + "LogLikelihood": log_likelihood + } + + def setUp(self): + self.op_type = "linear_chain_crf" + self.set_test_data() + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["Emission", "Transition"], "LogLikelihood") + + def test_check_grad_ignore_transition(self): + self.check_grad( + ["Emission"], "LogLikelihood", no_grad_set=set("Transition")) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/v2/framework/tests/test_lstm_op.py b/python/paddle/v2/framework/tests/test_lstm_op.py index 93a4e450e916716e27573d192bace73f271733de..ff75160083f2936dd653a8396254bf16d1752ffa 100644 --- a/python/paddle/v2/framework/tests/test_lstm_op.py +++ b/python/paddle/v2/framework/tests/test_lstm_op.py @@ -52,7 +52,7 @@ def lstm( g = np.dot(h_pre, w_h) # 1 x 4D g = g + x g = np.reshape(g, (1, g.size)) - c_tmp, g_i, g_f, g_o = np.split(g, 4, axis=1) + c, g_i, g_f, g_o = np.split(g, 4, axis=1) if w_c is None: g_i = act_gate(g_i) # 1 x D g_f = act_gate(g_f) # 1 x D @@ -60,7 +60,7 @@ def lstm( w_ic, w_fc, w_oc = np.split(w_c, 3, axis=1) g_i = act_gate(g_i + w_ic * c_pre) # 1 x D g_f = act_gate(g_f + w_fc * c_pre) # 1 x D - c = g_f * c_pre + g_i * act_cand(c_tmp) # 1 x D + c = g_f * c_pre + g_i * act_cand(c) # 1 x D if w_c is None: g_o = act_gate(g_o) # 1 x D @@ -68,8 +68,7 @@ def lstm( _, _, w_oc = np.split(w_c, 3, axis=1) g_o = act_gate(g_o + w_oc * c) # 1 x D h = g_o * act_cell(c) - bg = np.concatenate((act_cand(c_tmp), g_i, g_f, g_o), axis=1) - return h, c, bg + return h, c def _reverse(x, lod): y = np.zeros_like(x) @@ -82,7 +81,6 @@ def lstm( batch_size = len(offset) - 1 hidden = [] cell = [] - gate = [] input = _reverse(input, offset) if is_reverse else input if w_b is not None: input = input + np.tile(w_b, (offset[-1], 1)) @@ -94,96 +92,109 @@ def lstm( c_pre = c0[i] # 1 x D for j in range(seq_len): # compute one step - h_pre, c_pre, g_pre = _step(x[j], w_h, w_c, h_pre, c_pre, act_gate, - act_cell, act_cand) + h_pre, c_pre = _step(x[j], w_h, w_c, h_pre, c_pre, act_gate, + act_cell, act_cand) hidden.append(h_pre.flatten()) cell.append(c_pre.flatten()) - gate.append(g_pre.flatten()) - hidden = np.array(hidden).astype("float64") - cell = np.array(cell).astype("float64") - gate = np.array(gate).astype("float64") + hidden = np.array(hidden).astype('float64') + cell = np.array(cell).astype('float64') hidden = _reverse(hidden, offset) if is_reverse else hidden cell = _reverse(cell, offset) if is_reverse else cell - assert gate.shape == input.shape assert hidden.shape == (input.shape[0], input.shape[1] / 4) assert cell.shape == (input.shape[0], input.shape[1] / 4) - return hidden, cell, gate + return hidden, cell class TestLstmOp(OpTest): - def set_data(self): - self.lod = [[0, 2, 6, 9]] - self.D = 64 - self.sort_idx = [2, 6, 0, 3, 7, 1, 4, 8, 5] + def set_argument(self): + self.lod = [[0, 2, 5, 7]] + self.D = 16 - self.act_gate = "sigmoid" - self.act_cell = "tanh" - self.act_cand = "tanh" + self.act_gate = 'sigmoid' + self.act_cell = 'tanh' + self.act_cand = 'tanh' + self.has_initial_state = True self.is_reverse = False def setUp(self): - self.set_data() - self.op_type = "lstm" + self.set_argument() + self.op_type = 'lstm' T = self.lod[0][-1] N = len(self.lod[0]) - 1 - x = np.random.normal(size=(T, 4 * self.D)).astype("float64") - h0 = np.zeros((N, self.D)).astype("float64") - c0 = np.zeros((N, self.D)).astype("float64") - w = np.random.normal(size=(self.D, 4 * self.D)).astype("float64") - b = np.random.normal(size=(1, 7 * self.D)).astype("float64") + x = np.random.normal(size=(T, 4 * self.D)).astype('float64') + h0 = np.zeros((N, self.D)).astype('float64') + c0 = np.zeros((N, self.D)).astype('float64') + w = np.random.normal(size=(self.D, 4 * self.D)).astype('float64') + b = np.random.normal(size=(1, 7 * self.D)).astype('float64') w_b = b[:, 0:4 * self.D] w_c = b[:, 4 * self.D:] - h, c, g = lstm(x, self.lod, h0, c0, w, w_b, w_c, self.is_reverse, - ACTVATION[self.act_gate], ACTVATION[self.act_cell], - ACTVATION[self.act_cand]) - - g_sort = np.zeros_like(x) - for i, j in enumerate(self.sort_idx): - g_sort[i, :] = g[j, :] - - self.inputs = { - 'Input': (x, self.lod), - 'H0': h0, - 'C0': c0, - 'Weight': w, - 'Bias': b - } + h, c = lstm(x, self.lod, h0, c0, w, w_b, w_c, self.is_reverse, + ACTVATION[self.act_gate], ACTVATION[self.act_cell], + ACTVATION[self.act_cand]) + + self.inputs = {'Input': (x, self.lod), 'Weight': w, 'Bias': b} + if self.has_initial_state: + self.inputs['H0'] = h0 + self.inputs['C0'] = c0 + self.outputs = { 'Hidden': (h, self.lod), 'Cell': (c, self.lod), - 'BatchGate': g_sort } self.attrs = { 'usePeepholes': True, 'isReverse': self.is_reverse, - 'gateActivation': 'sigmoid', - 'cellActivation': 'tanh', - 'candidateActivation': 'tanh' + 'gateActivation': self.act_gate, + 'cellActivation': self.act_cell, + 'candidateActivation': self.act_cand } def test_check_output(self): - self.check_output() + self.check_output(atol=1e-8) + + #TODO(qingqing) add more unit testing case + def test_check_grad(self): + # TODO(qingqing) remove folowing lines after the check_grad is refined. + N = len(self.lod[0]) - 1 + self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64') + self.outputs['BatchCellPreAct'] = np.zeros( + (N, self.D)).astype('float64') + self.check_grad( + ['Input', 'Weight', 'Bias'], ['Hidden'], max_relative_error=5e-4) + + +class TestLstmOpHasNoInitial(TestLstmOp): + def set_argument(self): + self.lod = [[0, 2, 5, 7]] + self.D = 16 + + self.act_gate = 'sigmoid' + self.act_cell = 'tanh' + self.act_cand = 'tanh' + + self.has_initial_state = False + self.is_reverse = True class TestLstmOpRerverse(TestLstmOp): - def set_data(self): - self.lod = [[0, 2, 6, 9]] - self.D = 64 - self.sort_idx = [2, 6, 0, 3, 7, 1, 4, 8, 5] + def set_argument(self): + self.lod = [[0, 2, 5, 7]] + self.D = 16 - self.act_gate = "sigmoid" - self.act_cell = "tanh" - self.act_cand = "tanh" + self.act_gate = 'sigmoid' + self.act_cell = 'tanh' + self.act_cand = 'tanh' + self.has_initial_state = True self.is_reverse = True -if __name__ == "__main__": +if __name__ == '__main__': unittest.main() diff --git a/python/paddle/v2/framework/tests/test_seq_pool.py b/python/paddle/v2/framework/tests/test_seq_pool.py index 56602c57e6b63b71d6b089e774a876ad6164040e..efc4920124afb539017a3b3f211c7320da68ffef 100644 --- a/python/paddle/v2/framework/tests/test_seq_pool.py +++ b/python/paddle/v2/framework/tests/test_seq_pool.py @@ -3,15 +3,6 @@ import numpy as np from op_test import OpTest -class SeqPoolType(OpTest): - AVERAGE = 0 - SUM = 1 - SQRT = 2 - MAX = 3 - LAST = 4 - FIRST = 5 - - class TestSeqAvgPool(OpTest): def set_data(self): self.op_type = 'sequence_pool' @@ -25,7 +16,7 @@ class TestSeqAvgPool(OpTest): return x, lod, out def compute(self, x, lod, out): - self.attrs = {'strategy': SeqPoolType.AVERAGE} + self.attrs = {'pooltype': "AVERAGE"} for i in range(4): sub_x = x[lod[0][i]:lod[0][i + 1], :] out[i] = sub_x.mean(axis=0) @@ -54,7 +45,7 @@ class TestSeqAvgPool2D(TestSeqAvgPool): return x, lod, out def compute(self, x, lod, out): - self.attrs = {'strategy': SeqPoolType.AVERAGE} + self.attrs = {'pooltype': "AVERAGE"} for i in range(4): sub_x = np.reshape(x[lod[0][i]:lod[0][i + 1], :], (-1, 3 * 17)) out[i] = np.reshape(sub_x.mean(axis=0), (3, 17)) @@ -62,7 +53,7 @@ class TestSeqAvgPool2D(TestSeqAvgPool): class TestSeqSumPool(TestSeqAvgPool): def compute(self, x, lod, out): - self.attrs = {'strategy': SeqPoolType.SUM} + self.attrs = {'pooltype': "SUM"} for i in range(4): sub_x = x[lod[0][i]:lod[0][i + 1], :] out[i] = sub_x.sum(axis=0) @@ -70,7 +61,7 @@ class TestSeqSumPool(TestSeqAvgPool): class TestSeqSumPool2D(TestSeqAvgPool2D): def compute(self, x, lod, out): - self.attrs = {'strategy': SeqPoolType.SUM} + self.attrs = {'pooltype': "SUM"} for i in range(4): sub_x = np.reshape(x[lod[0][i]:lod[0][i + 1], :], (-1, 3 * 17)) out[i] = np.reshape(sub_x.sum(axis=0), (3, 17)) @@ -78,7 +69,7 @@ class TestSeqSumPool2D(TestSeqAvgPool2D): class TestSeqSqrtPool(TestSeqAvgPool): def compute(self, x, lod, out): - self.attrs = {'strategy': SeqPoolType.SQRT} + self.attrs = {'pooltype': "SQRT"} for i in range(4): sub_x = x[lod[0][i]:lod[0][i + 1], :] len = lod[0][i + 1] - lod[0][i] @@ -87,7 +78,7 @@ class TestSeqSqrtPool(TestSeqAvgPool): class TestSeqSqrtPool2D(TestSeqAvgPool2D): def compute(self, x, lod, out): - self.attrs = {'strategy': SeqPoolType.SQRT} + self.attrs = {'pooltype': "SQRT"} for i in range(4): sub_x = np.reshape(x[lod[0][i]:lod[0][i + 1], :], (-1, 3 * 17)) len = lod[0][i + 1] - lod[0][i] @@ -99,7 +90,7 @@ class TestSeqSqrtPool2D(TestSeqAvgPool2D): class TestSeqMaxPool(TestSeqAvgPool): def compute(self, x, lod, out): - self.attrs = {'strategy': SeqPoolType.MAX} + self.attrs = {'pooltype': "MAX"} for i in range(4): sub_x = x[lod[0][i]:lod[0][i + 1], :] out[i] = np.amax(sub_x, axis=0) @@ -111,7 +102,7 @@ class TestSeqMaxPool(TestSeqAvgPool): class TestSeqMaxPool2D(TestSeqAvgPool2D): def compute(self, x, lod, out): - self.attrs = {'strategy': SeqPoolType.MAX} + self.attrs = {'pooltype': "MAX"} for i in range(4): sub_x = np.reshape(x[lod[0][i]:lod[0][i + 1], :], (-1, 3 * 17)) out[i] = np.reshape(np.amax(sub_x, axis=0), (3, 17)) @@ -123,7 +114,7 @@ class TestSeqMaxPool2D(TestSeqAvgPool2D): class TestSeqLastPool(TestSeqAvgPool): def compute(self, x, lod, out): - self.attrs = {'strategy': SeqPoolType.LAST} + self.attrs = {'pooltype': "LAST"} for i in range(4): sub_x = x[lod[0][i]:lod[0][i + 1], :] out[i] = sub_x[-1, :] @@ -131,7 +122,7 @@ class TestSeqLastPool(TestSeqAvgPool): class TestSeqLastPool2D(TestSeqAvgPool2D): def compute(self, x, lod, out): - self.attrs = {'strategy': SeqPoolType.LAST} + self.attrs = {'pooltype': "LAST"} for i in range(4): sub_x = np.reshape(x[lod[0][i]:lod[0][i + 1], :], (-1, 3 * 17)) out[i] = np.reshape(sub_x[-1, :], (3, 17)) @@ -139,7 +130,7 @@ class TestSeqLastPool2D(TestSeqAvgPool2D): class TestSeqFirstPool(TestSeqAvgPool): def compute(self, x, lod, out): - self.attrs = {'strategy': SeqPoolType.FIRST} + self.attrs = {'pooltype': "FIRST"} for i in range(4): sub_x = x[lod[0][i]:lod[0][i + 1], :] out[i] = sub_x[0, :] @@ -147,7 +138,7 @@ class TestSeqFirstPool(TestSeqAvgPool): class TestSeqFirstPool2D(TestSeqAvgPool2D): def compute(self, x, lod, out): - self.attrs = {'strategy': SeqPoolType.FIRST} + self.attrs = {'pooltype': "FIRST"} for i in range(4): sub_x = np.reshape(x[lod[0][i]:lod[0][i + 1], :], (-1, 3 * 17)) out[i] = np.reshape(sub_x[0, :], (3, 17))