diff --git a/doc/howto/dev/write_docs_cn.rst b/doc/howto/dev/write_docs_cn.rst index 36e5d420c986fc8d88eefee4aa221dba0a0480f2..731a63f945c29ba78538b3d71289b234e569354d 100644 --- a/doc/howto/dev/write_docs_cn.rst +++ b/doc/howto/dev/write_docs_cn.rst @@ -5,15 +5,13 @@ PaddlePaddle的文档包括英文文档 ``doc`` 和中文文档 ``doc_cn`` 两个部分。文档都是通过 `cmake`_ 驱动 `sphinx`_ 编译生成,生成后的文档分别存储在编译目录的 ``doc`` 和 ``doc_cn`` 两个子目录下。 -如何构建PaddlePaddle的文档 -========================== +如何构建文档 +============ -PaddlePaddle的文档构建有直接构建和基于Docker构建两种方式,我们提供了一个构建脚本build_docs.sh来进行构建。 -PaddlePaddle文档需要准备的环境相对较复杂,所以我们推荐使用基于Docker来构建PaddlePaddle的文档。 +PaddlePaddle的文档构建有两种方式。 - -使用Docker构建PaddlePaddle的文档 --------------------------------- +使用Docker构建 +-------------- 使用Docker构建PaddlePaddle的文档,需要在系统里先安装好Docker工具包。Docker安装请参考 `Docker的官网 `_ 。安装好Docker之后可以使用源码目录下的脚本构建文档,即 @@ -21,58 +19,46 @@ PaddlePaddle文档需要准备的环境相对较复杂,所以我们推荐使 cd TO_YOUR_PADDLE_CLONE_PATH cd paddle/scripts/tools/build_docs - bash build_docs.sh with_docker - -编译完成后,会在当前目录生成两个子目录\: - -* doc 英文文档目录 -* doc_cn 中文文档目录 + sh build_docs.sh +编译完成之后,会在当前目录生成两个子目录\: doc(英文文档目录)和 doc_cn(中文文档目录)。 打开浏览器访问对应目录下的index.html即可访问本地文档。 - - -直接构建PaddlePaddle的文档 --------------------------- - -因为PaddlePaddle的v2 api文档生成过程依赖于py_paddle Python包,用户需要首先确认py_paddle包已经安装。 - -.. code-block:: bash - - python -c "import py_paddle" - -如果提示错误,那么用户需要在本地编译安装PaddlePaddle,请参考 `源码编译文档 `_ 。 -注意,用户在首次编译安装PaddlePaddle时,请将WITH_DOC选项关闭。在编译安装正确之后,请再次确认py_paddle包已经安装,即可进行下一步操作。 +直接构建 +-------- 如果提示正确,可以执行以下命令编译生成文档,即 .. code-block:: bash cd TO_YOUR_PADDLE_CLONE_PATH - cd paddle/scripts/tools/build_docs - bash build_docs.sh local - -编译完成之后,会在当前目录生成两个子目录\: - -* doc 英文文档目录 -* doc_cn 中文文档目录 + mkdir -p build + cd build + cmake .. -DCMAKE_BUILD_TYPE=Debug -DWITH_GPU=OFF -DWITH_MKLDNN=OFF -DWITH_MKLML=OFF -DWITH_DOC=ON + make gen_proto_py + make paddle_docs paddle_docs_cn +编译完成之后,会在当前目录生成两个子目录\: doc(英文文档目录)和 doc_cn(中文文档目录)。 打开浏览器访问对应目录下的index.html即可访问本地文档。 -如何书写PaddlePaddle的文档 -========================== +如何书写文档 +============ PaddlePaddle文档使用 `sphinx`_ 自动生成,用户可以参考sphinx教程进行书写。 -如何更新www.paddlepaddle.org文档 -================================ +如何更新文档主题 +================ + +PaddlePaddle文档主题在 `TO_YOUR_PADDLE_CLONE_PATH/doc_theme` 文件夹下,包含所有和前端网页设计相关的文件。 -开发者给PaddlePaddle代码增加的注释以PR的形式提交到github中,提交方式可参见 `贡献文档 `_ 。 +如何更新doc.paddlepaddle.org +============================ + +更新的文档以PR的形式提交到github中,提交方式参见 `贡献文档 `_ 。 目前PaddlePaddle的develop分支的文档是自动触发更新的,用户可以分别查看最新的 `中文文档 `_ 和 `英文文档 `_ 。 - .. _cmake: https://cmake.org/ .. _sphinx: http://www.sphinx-doc.org/en/1.4.8/ diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index c0838d9b759110fd706577386d2c81bda6876223..3371962c635c3731f00a6af2a6e287ece33397cd 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -9,6 +9,7 @@ cc_test(eigen_test SRCS eigen_test.cc DEPS tensor) cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor) cc_test(lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor) +nv_test(lod_tensor_gpu_test SRCS lod_tensor_test.cu DEPS lod_tensor) cc_test(variable_test SRCS variable_test.cc) diff --git a/paddle/framework/lod_tensor.h b/paddle/framework/lod_tensor.h index 154068fef69bc96edbd85b731fe8091b3b1ff823..568f4e89819c8345d8908634f6fa56f09483a763 100644 --- a/paddle/framework/lod_tensor.h +++ b/paddle/framework/lod_tensor.h @@ -18,8 +18,10 @@ #ifndef PADDLE_ONLY_CPU #include #include +#include #endif +#include #include "paddle/framework/ddim.h" #include "paddle/framework/tensor.h" #include "paddle/platform/enforce.h" @@ -32,7 +34,8 @@ template using Vector = std::vector; #else template -using Vector = thrust::host_vector; +using Vector = thrust::host_vector< + T, thrust::system::cuda::experimental::pinned_allocator>; #endif using LoD = std::vector>; diff --git a/paddle/framework/lod_tensor_test.cu b/paddle/framework/lod_tensor_test.cu new file mode 100644 index 0000000000000000000000000000000000000000..1079a36a2e7b24f6f8a5bcbb296355567305a765 --- /dev/null +++ b/paddle/framework/lod_tensor_test.cu @@ -0,0 +1,52 @@ +/* + Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +#include +#include +#include "paddle/framework/lod_tensor.h" +#include "paddle/platform/assert.h" + +#include + +__global__ void test(size_t* a, int size) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < size; + i += blockDim.x * gridDim.x) { + a[i] *= 2; + } +} + +TEST(LoDTensor, LoDInGPU) { + paddle::framework::Tensor tensor; + paddle::framework::LoDTensor lod_tensor; + paddle::platform::GPUPlace place(0); + + paddle::framework::LoD src_lod; + src_lod.push_back(std::vector{0, 2, 4, 6, 8, 10, 12, 14}); + + tensor.Resize({14, 16}); + tensor.mutable_data(place); + + lod_tensor.set_lod(src_lod); + lod_tensor.set_tensor(&tensor); + CHECK_EQ(lod_tensor.lod_element(0, 2), 4); + CHECK_EQ(lod_tensor.lod_element(0, 4), 8); + + auto lod = lod_tensor.lod(); + + test<<<1, 8>>>(lod[0].data(), lod[0].size()); + cudaDeviceSynchronize(); + + for (size_t i = 0; i < src_lod[0].size(); ++i) { + CHECK_EQ(lod[0].data()[i], src_lod[0].data()[i] * 2); + } +} diff --git a/paddle/gserver/layers/Layer.h b/paddle/gserver/layers/Layer.h index edef36194aabdb9c122ec3423deb036169a34d7c..4002a3d0747a86ab7b495ffe52247521831b71b8 100644 --- a/paddle/gserver/layers/Layer.h +++ b/paddle/gserver/layers/Layer.h @@ -49,6 +49,12 @@ struct LayerState { }; typedef std::shared_ptr LayerStatePtr; +/// Paddle device ID, MKLDNN is -2, CPU is -1 +enum PADDLE_DEVICE_ID { + MKLDNN_DEVICE = -2, + CPU_DEVICE = -1, +}; + /** * @brief Base class for layer. * Define necessary variables and functions for every layer. @@ -59,11 +65,6 @@ protected: LayerConfig config_; /// whether to use GPU bool useGpu_; - /// Paddle device ID, MKLDNN is -2, CPU is -1 - enum PADDLE_DEVICE_ID { - MKLDNN_DEVICE = -2, - CPU_DEVICE = -1, - }; /// Device Id. MKLDNN is -2, CPU is -1, and GPU is 0, 1, 2 ... int deviceId_; /// Input layers diff --git a/paddle/gserver/layers/MKLDNNFcLayer.cpp b/paddle/gserver/layers/MKLDNNFcLayer.cpp index 53433cef35a377a73f87b041fdcfadd848dd2ec9..f70343251ad4fbb99f9614618f6d1bff1174f15e 100644 --- a/paddle/gserver/layers/MKLDNNFcLayer.cpp +++ b/paddle/gserver/layers/MKLDNNFcLayer.cpp @@ -14,7 +14,6 @@ limitations under the License. */ #include "MKLDNNFcLayer.h" #include "paddle/utils/Logging.h" -#include "paddle/utils/Stat.h" using namespace mkldnn; // NOLINT typedef memory::format format; @@ -40,6 +39,8 @@ bool MKLDNNFcLayer::init(const LayerMap& layerMap, oc_ = getSize(); oh_ = 1; ow_ = 1; + ih_ = 1; + iw_ = 1; // input size can not change in FC iLayerSize_ = inputLayers_[0]->getSize(); @@ -77,67 +78,53 @@ void MKLDNNFcLayer::convertWeightsToPaddle() { wgtVal_->reorderDataTo(wgtVal_, dstFmt, targetDim); } -void MKLDNNFcLayer::reshape() { - const Argument& input = getInput(0, getPrev(0)->getDeviceId()); - int batchSize = input.getBatchSize(); - if (bs_ == batchSize) { - return; - } - bs_ = batchSize; - ih_ = input.getFrameHeight(); - iw_ = input.getFrameWidth(); - if (ih_ == 0) { - ih_ = 1; - } - if (iw_ == 0) { - iw_ = 1; - } - CHECK_EQ(iLayerSize_, inputLayers_[0]->getSize()); - ic_ = iLayerSize_ / (ih_ * iw_); - CHECK_EQ(size_t(ic_ * ih_ * iw_), iLayerSize_) << "not divisible"; - CHECK_EQ(size_t(oc_), getSize()); - printSizeInfo(); +void MKLDNNFcLayer::reshape( + int& bs, int& ic, int& ih, int& iw, int oc, int& oh, int& ow) { + reshapeInput(bs, ih, iw); - // reset output - output_.setFrameHeight(oh_); - output_.setFrameWidth(ow_); - resetOutput(bs_, oc_); + CHECK_EQ(iLayerSize_, inputLayers_[0]->getSize()); + ic = iLayerSize_ / (ih * iw); + CHECK_EQ(size_t(ic * ih * iw), iLayerSize_) << "not divisible"; + CHECK_EQ(size_t(oc), getSize()); - // reset mkldnn forward - resetFwd(); - needResetBwd_ = true; + reshapeOutput(oh, ow); + resizeOutput(bs, oc); - convertWeightsFromPaddle(); + printSizeInfo(); } -void MKLDNNFcLayer::resetFwd() { +void MKLDNNFcLayer::resetFwd(std::vector& pipeline, + MKLDNNMatrixPtr& in, + MKLDNNMatrixPtr& wgt, + MKLDNNMatrixPtr& bias, + MKLDNNMatrixPtr& out) { + pipeline.clear(); bool hasBias = biases_ && biases_->getW(); - const MatrixPtr& wgt = weight_->getW(); - const MatrixPtr& bias = hasBias ? biases_->getW() : nullptr; - const MatrixPtr& out = output_.value; + const MatrixPtr& wgtVal = weight_->getW(); + const MatrixPtr& biasVal = hasBias ? biases_->getW() : nullptr; + const MatrixPtr& outVal = output_.value; if (inputIsOnlyMKLDNN()) { - const MatrixPtr& in = getInputValue(0); - inVal_ = std::dynamic_pointer_cast(in); - CHECK(inVal_) << "Input should be MKLDNNMatrix"; + const MatrixPtr& inVal = getInputValue(0); + in = std::dynamic_pointer_cast(inVal); + CHECK(in) << "Input should be MKLDNNMatrix"; } else { CHECK_EQ(getPrev(0)->getDeviceId(), CPU_DEVICE) << "Only support CPU yet"; - const MatrixPtr& in = getInputValue(0, CPU_DEVICE); - inVal_ = MKLDNNMatrix::create( - in, memory::dims{bs_, ic_, ih_, iw_}, format::nchw, engine_); - } - inVal_->downSpatial(); - wgtVal_ = MKLDNNMatrix::create( - wgt, memory::dims{oc_, ic_, ih_, iw_}, format::oihw, engine_); - wgtVal_->downSpatial(); - biasVal_ = - hasBias ? MKLDNNMatrix::create(bias, {oc_}, format::x, engine_) : nullptr; - outVal_ = MKLDNNMatrix::create(out, {bs_, oc_}, format::nc, engine_); + const MatrixPtr& inVal = getInputValue(0, CPU_DEVICE); + in = MKLDNNMatrix::create( + inVal, memory::dims{bs_, ic_, ih_, iw_}, format::nchw, engine_); + } + in->downSpatial(); + wgt = MKLDNNMatrix::create( + wgtVal, memory::dims{oc_, ic_, ih_, iw_}, format::oihw, engine_); + wgt->downSpatial(); + bias = hasBias ? MKLDNNMatrix::create(biasVal, {oc_}, format::x, engine_) + : nullptr; + out = MKLDNNMatrix::create(outVal, {bs_, oc_}, format::nc, engine_); // change original output value to mkldnn output value - output_.value = std::dynamic_pointer_cast(outVal_); + output_.value = std::dynamic_pointer_cast(out); if (!outputIsOnlyMKLDNN()) { - copyOutputInfoToOtherDevice(); // fc cpu output value do not need create convert // just share point getOutput(CPU_DEVICE).value->setData(output_.value->getData()); @@ -146,27 +133,31 @@ void MKLDNNFcLayer::resetFwd() { // create forward handle prop_kind pk = prop_kind::forward; fc_fwd::desc fwdDesc = hasBias ? fc_fwd::desc(pk, - inVal_->getMemoryDesc(), - wgtVal_->getMemoryDesc(), - biasVal_->getMemoryDesc(), - outVal_->getMemoryDesc()) + in->getMemoryDesc(), + wgt->getMemoryDesc(), + bias->getMemoryDesc(), + out->getMemoryDesc()) : fc_fwd::desc(pk, - inVal_->getMemoryDesc(), - wgtVal_->getMemoryDesc(), - outVal_->getMemoryDesc()); + in->getMemoryDesc(), + wgt->getMemoryDesc(), + out->getMemoryDesc()); fc_fwd::primitive_desc fwdPD = fc_fwd::primitive_desc(fwdDesc, engine_); if (hasBias) { - fwd_.reset(new fc_fwd(fwdPD, *inVal_, *wgtVal_, *biasVal_, *outVal_)); + fwd_.reset(new fc_fwd(fwdPD, *in, *wgt, *bias, *out)); } else { - fwd_.reset(new fc_fwd(fwdPD, *inVal_, *wgtVal_, *outVal_)); + fwd_.reset(new fc_fwd(fwdPD, *in, *wgt, *out)); } printValueFormatFlow(); - pipelineFwd_.clear(); - pipelineFwd_.push_back(*fwd_); + pipeline.push_back(*fwd_); } -void MKLDNNFcLayer::resetBwd() { +void MKLDNNFcLayer::resetBwd(std::vector& pipeline, + MKLDNNMatrixPtr& in, + MKLDNNMatrixPtr& wgt, + MKLDNNMatrixPtr& bias, + MKLDNNMatrixPtr& out) { + pipeline.clear(); if (!needResetBwd_) { return; } @@ -175,8 +166,8 @@ void MKLDNNFcLayer::resetBwd() { /// backward weight CHECK(inVal_) << "Should have input value"; - const MatrixPtr& wgt = weight_->getWGrad(); - const MatrixPtr& bias = hasBias ? biases_->getWGrad() : nullptr; + const MatrixPtr& wgtGrad = weight_->getWGrad(); + const MatrixPtr& biasGrad = hasBias ? biases_->getWGrad() : nullptr; // TODO(TJ): merge outgrad int device = outputIsOnlyMKLDNN() ? MKLDNN_DEVICE : CPU_DEVICE; @@ -187,107 +178,66 @@ void MKLDNNFcLayer::resetBwd() { // for CPU device: // fc do not need to convert from cpu device since output is always nc format // only need create from cpu device - const MatrixPtr& out = getOutput(device).grad; - outGrad_ = MKLDNNMatrix::create(out, outVal_->getPrimitiveDesc()); - wgtGrad_ = MKLDNNMatrix::create(wgt, wgtVal_->getPrimitiveDesc()); - biasGrad_ = hasBias ? MKLDNNMatrix::create(bias, biasVal_->getPrimitiveDesc()) - : nullptr; + const MatrixPtr& outGrad = getOutput(device).grad; + out = MKLDNNMatrix::create(outGrad, outVal_->getPrimitiveDesc()); + wgt = MKLDNNMatrix::create(wgtGrad, wgtVal_->getPrimitiveDesc()); + bias = hasBias ? MKLDNNMatrix::create(biasGrad, biasVal_->getPrimitiveDesc()) + : nullptr; // create memory primitive desc fc_fwd::desc fwdDesc = fc_fwd::desc(prop_kind::forward, inVal_->getMemoryDesc(), - wgtGrad_->getMemoryDesc(), - outGrad_->getMemoryDesc()); + wgt->getMemoryDesc(), + out->getMemoryDesc()); fc_fwd::primitive_desc fwdPD = fc_fwd::primitive_desc(fwdDesc, engine_); fc_bwdWgt::desc bwdWgtDesc = hasBias ? fc_bwdWgt::desc(inVal_->getMemoryDesc(), - wgtGrad_->getMemoryDesc(), - biasGrad_->getMemoryDesc(), - outGrad_->getMemoryDesc()) + wgt->getMemoryDesc(), + bias->getMemoryDesc(), + out->getMemoryDesc()) : fc_bwdWgt::desc(inVal_->getMemoryDesc(), - wgtGrad_->getMemoryDesc(), - outGrad_->getMemoryDesc()); + wgt->getMemoryDesc(), + out->getMemoryDesc()); fc_bwdWgt::primitive_desc bwdWgtPD = fc_bwdWgt::primitive_desc(bwdWgtDesc, engine_, fwdPD); if (hasBias) { - bwdWgt_.reset( - new fc_bwdWgt(bwdWgtPD, *inVal_, *outGrad_, *wgtGrad_, *biasGrad_)); + bwdWgt_.reset(new fc_bwdWgt(bwdWgtPD, *inVal_, *out, *wgt, *bias)); } else { - bwdWgt_.reset(new fc_bwdWgt(bwdWgtPD, *inVal_, *outGrad_, *wgtGrad_)); + bwdWgt_.reset(new fc_bwdWgt(bwdWgtPD, *inVal_, *out, *wgt)); } - pipelineBwd_.clear(); - pipelineBwd_.push_back(*bwdWgt_); + pipeline.push_back(*bwdWgt_); /// backward data - const MatrixPtr& in = inputLayers_[0]->getOutput().grad; - if (in == nullptr) { + const MatrixPtr& inGrad = inputLayers_[0]->getOutput().grad; + if (inGrad == nullptr) { return; } if (getInput(0, MKLDNN_DEVICE).getAllCount() > 1) { // TODO(TJ): use outputMaps_ ways to get the inGrad_ when merge outgrad done } else { - inGrad_ = MKLDNNMatrix::create(in, inVal_->getPrimitiveDesc()); + in = MKLDNNMatrix::create(inGrad, inVal_->getPrimitiveDesc()); } - fc_bwdData::desc bwdDataDesc = fc_bwdData::desc(inVal_->getMemoryDesc(), - wgtGrad_->getMemoryDesc(), - outGrad_->getMemoryDesc()); + fc_bwdData::desc bwdDataDesc = fc_bwdData::desc( + inVal_->getMemoryDesc(), wgt->getMemoryDesc(), out->getMemoryDesc()); fc_bwdData::primitive_desc bwdDataPD = fc_bwdData::primitive_desc(bwdDataDesc, engine_, fwdPD); CHECK(wgtVal_) << "Should have weight memory"; - bwdData_.reset(new fc_bwdData(bwdDataPD, *outGrad_, *wgtVal_, *inGrad_)); + bwdData_.reset(new fc_bwdData(bwdDataPD, *out, *wgtVal_, *in)); printGradFormatFlow(); - pipelineBwd_.push_back(*bwdData_); + pipeline.push_back(*bwdData_); } void MKLDNNFcLayer::updateInputData() { - if (inputLayers_[0]->getType() != "data") { - return; - } - real* iData = getInputValue(0, CPU_DEVICE)->getData(); - inVal_->setData(iData); + inVal_->setData(getInputValue(0, CPU_DEVICE)->getData()); } -void MKLDNNFcLayer::forward(PassType passType) { - Layer::forward(passType); - reshape(); - - { - REGISTER_TIMER_INFO("mkldnn_FwdTimer", getName().c_str()); - updateInputData(); - - // just submit forward pipeline - stream_->submit(pipelineFwd_); - } - - /* activation */ { - REGISTER_TIMER_INFO("FwActTimer", getName().c_str()); - forwardActivation(); - } -} - -void MKLDNNFcLayer::backward(const UpdateCallback& callback) { - /* Do derivation */ { - REGISTER_TIMER_INFO("BpActTimer", getName().c_str()); - backwardActivation(); - } - - { - REGISTER_TIMER_INFO("mkldnn_bwdTimer", getName().c_str()); - resetBwd(); - - // just sumbmit backward pipeline - stream_->submit(pipelineBwd_); - } - - { - REGISTER_TIMER_INFO("WeightUpdate", getName().c_str()); - weight_->getParameterPtr()->incUpdate(callback); - if (biases_ && biases_->getWGrad()) { - biases_->getParameterPtr()->incUpdate(callback); - } +void MKLDNNFcLayer::updateWeights(const UpdateCallback& callback) { + weight_->getParameterPtr()->incUpdate(callback); + if (biases_ && biases_->getWGrad()) { + biases_->getParameterPtr()->incUpdate(callback); } } } // namespace paddle diff --git a/paddle/gserver/layers/MKLDNNFcLayer.h b/paddle/gserver/layers/MKLDNNFcLayer.h index 4ad67a16e056a718c45a28babcf22a7cd571b15c..3119f863496df092da13c08bf733f13c42e53780 100644 --- a/paddle/gserver/layers/MKLDNNFcLayer.h +++ b/paddle/gserver/layers/MKLDNNFcLayer.h @@ -45,35 +45,28 @@ public: bool init(const LayerMap& layerMap, const ParameterMap& parameterMap) override; - void convertWeightsFromPaddle() override; - - void convertWeightsToPaddle() override; + void reshape( + int& bs, int& ic, int& ih, int& iw, int oc, int& oh, int& ow) override; - void forward(PassType passType) override; + void resetFwd(std::vector& pipeline, + MKLDNNMatrixPtr& in, + MKLDNNMatrixPtr& wgt, + MKLDNNMatrixPtr& bias, + MKLDNNMatrixPtr& out) override; - void backward(const UpdateCallback& callback) override; + void resetBwd(std::vector& pipeline, + MKLDNNMatrixPtr& in, + MKLDNNMatrixPtr& wgt, + MKLDNNMatrixPtr& bias, + MKLDNNMatrixPtr& out) override; void updateInputData() override; -protected: - /** - * reshape the input image sizes - * and reset output buffer size - * and reset mkldnn forward - */ - void reshape(); - - /** - * reset the forward primitve and memory - * only would be called when input size changes - */ - void resetFwd(); - - /** - * reset the backward primitve and memory for mkldnn fc - * only would be called when needed - */ - void resetBwd(); + void updateWeights(const UpdateCallback& callback) override; + + void convertWeightsFromPaddle() override; + + void convertWeightsToPaddle() override; }; } // namespace paddle diff --git a/paddle/gserver/layers/MKLDNNLayer.h b/paddle/gserver/layers/MKLDNNLayer.h index 543364edceff684bdcd002a8f4f10e7ce5e6953b..169679c8297542cac4a43f5a8e1af311ad9282df 100644 --- a/paddle/gserver/layers/MKLDNNLayer.h +++ b/paddle/gserver/layers/MKLDNNLayer.h @@ -19,6 +19,7 @@ limitations under the License. */ #include "MKLDNNBase.h" #include "mkldnn.hpp" #include "paddle/math/MKLDNNMatrix.h" +#include "paddle/utils/Stat.h" DECLARE_bool(use_mkldnn); @@ -33,6 +34,8 @@ typedef std::shared_ptr MKLDNNLayerPtr; */ class MKLDNNLayer : public Layer { protected: + // input value element count + size_t inputElemenCnt_; // batch size int bs_; // input image channel, height and width @@ -52,7 +55,7 @@ protected: std::vector pipelineFwd_; std::vector pipelineBwd_; - // MKLDNNMatrixPtr + // MKLDNNMatrixPtr with internal format MKLDNNMatrixPtr inVal_; MKLDNNMatrixPtr inGrad_; MKLDNNMatrixPtr outVal_; @@ -65,6 +68,7 @@ protected: public: explicit MKLDNNLayer(const LayerConfig& config) : Layer(config), + inputElemenCnt_(0), bs_(0), ic_(0), ih_(0), @@ -95,12 +99,104 @@ public: if (!Layer::init(layerMap, parameterMap)) { return false; } + checkCPUOutputsNumber(); stream_.reset(new MKLDNNStream()); engine_ = CPUEngine::Instance().getEngine(); return true; } + void forward(PassType passType) override { + passType_ = passType; + + { + REGISTER_TIMER_INFO("mkldnn_FwdTimer", getName().c_str()); + CHECK(!inputLayers_.empty()); + copySeqInfoToOutputs(); + size_t elemenCnt = inputLayers_[0]->getOutput().value->getElementCnt(); + if (inputElemenCnt_ != elemenCnt) { + // reset when input total sizes changed, not only the batchsize + inputElemenCnt_ = elemenCnt; + reshape(bs_, ic_, ih_, iw_, oc_, oh_, ow_); + resetFwd(pipelineFwd_, inVal_, wgtVal_, biasVal_, outVal_); + convertWeightsFromPaddle(); + needResetBwd_ = true; + } + + if (inputLayers_[0]->getType() == "data") { + updateInputData(); + } + + stream_->submit(pipelineFwd_); + } + + /* activation */ { + REGISTER_TIMER_INFO("FwActTimer", getName().c_str()); + forwardActivation(); + } + } + + void backward(const UpdateCallback& callback) override { + /* Do derivation */ { + REGISTER_TIMER_INFO("BpActTimer", getName().c_str()); + backwardActivation(); + } + + { + REGISTER_TIMER_INFO("mkldnn_bwdTimer", getName().c_str()); + if (needResetBwd_) { + resetBwd(pipelineBwd_, inGrad_, wgtGrad_, biasGrad_, outGrad_); + needResetBwd_ = false; + } + + stream_->submit(pipelineBwd_); + } + + { + REGISTER_TIMER_INFO("WeightUpdate", getName().c_str()); + updateWeights(callback); + } + } + + /** + * reshape the input image sizes + * and reset output image and buffer size + * output channel can not be changed + */ + virtual void reshape( + int& bs, int& ic, int& ih, int& iw, int oc, int& oh, int& ow) = 0; + + /** + * reset the mkldnn forward primitve and memory + * only would be called when input size changes + */ + virtual void resetFwd(std::vector& pipeline, + MKLDNNMatrixPtr& in, + MKLDNNMatrixPtr& wgt, + MKLDNNMatrixPtr& bias, + MKLDNNMatrixPtr& out) = 0; + + /** + * reset the mkldnn backward primitve and memory for mkldnn fc + * only would be called when needed + */ + virtual void resetBwd(std::vector& pipeline, + MKLDNNMatrixPtr& in, + MKLDNNMatrixPtr& wgt, + MKLDNNMatrixPtr& bias, + MKLDNNMatrixPtr& out) = 0; + + /** + * Update input value data when input layer is "data" type. + * Since the input value data address might be changed. + */ + virtual void updateInputData() {} + + /** + * Update weights and biases if necessary. + */ + virtual void updateWeights(const UpdateCallback& callback) {} + /** * convert weight from paddle format to mkldnn format * weight_ will be override @@ -114,10 +210,38 @@ public: virtual void convertWeightsToPaddle() {} /** - * Update input value data when input layer is "data" type. - * Since the input value data address might be changed. + * add this interface as public for unit test */ - virtual void updateInputData() {} + void addOutputArgument(int deviceId) { Layer::addOutputArgument(deviceId); } + +protected: + /** + * reshape the input image sizes and input batchsize + */ + virtual void reshapeInput(int& batchsize, int& height, int& width) { + const Argument& input = inputLayers_[0]->getOutput(); + batchsize = input.getBatchSize(); + int h = input.getFrameHeight(); + int w = input.getFrameWidth(); + if (h != 0) { + height = h; + } + if (w != 0) { + width = w; + } + } + + /** + * reshape output image sizes + */ + virtual void reshapeOutput(size_t height, size_t width) { + output_.setFrameHeight(height); + output_.setFrameWidth(width); + for (size_t i = 0; i < outputOtherDevice_.size(); i++) { + outputOtherDevice_[i].setFrameHeight(height); + outputOtherDevice_[i].setFrameWidth(width); + } + } /** * print info about sizes @@ -133,8 +257,8 @@ public: */ virtual void printValueFormatFlow() { if (inVal_ && outVal_) { - VLOG(MKLDNN_FMTS) << "value format flow --- " << inVal_->getFormat() - << " >>> " << outVal_->getFormat(); + VLOG(MKLDNN_FMTS) << inVal_->getFormat() << " >>> " + << outVal_->getFormat(); } } @@ -143,36 +267,12 @@ public: */ virtual void printGradFormatFlow() { if (inGrad_ && outGrad_) { - VLOG(MKLDNN_FMTS) << "grad format flow --- " << inGrad_->getFormat() - << " <<< " << outGrad_->getFormat(); + VLOG(MKLDNN_FMTS) << inGrad_->getFormat() << " <<< " + << outGrad_->getFormat(); } } protected: - /** - * copy image size and sequence info to other device - * @note: can not directly use Layer::copyOutputToOtherDevice since here only - * copy base info and do not copy data value - */ - void copyOutputInfoToOtherDevice() { - int cnt = 0; - for (size_t i = 0; i < outputOtherDevice_.size(); i++) { - outputOtherDevice_[i].setFrameHeight(output_.getFrameHeight()); - outputOtherDevice_[i].setFrameWidth(output_.getFrameWidth()); - outputOtherDevice_[i].sequenceStartPositions = - output_.sequenceStartPositions; - outputOtherDevice_[i].subSequenceStartPositions = - output_.subSequenceStartPositions; - outputOtherDevice_[i].cpuSequenceDims = output_.cpuSequenceDims; - if (outputOtherDevice_[i].deviceId == CPU_DEVICE) { - ++cnt; - } - } - if (cnt > 1) { - LOG(WARNING) << "should not have more than one CPU devie"; - } - } - /** * If input only has MKLDNN device. * Otherwise, only support the previous layer using CPU device. @@ -205,6 +305,7 @@ protected: */ void setDevice(int id) { deviceId_ = id; } +private: /** * Set deviceId of the params used in this layer. */ @@ -228,6 +329,42 @@ protected: parameter->setDevice(id); } } + + /** + * Check the cpu device number of outputOtherDevice_. + * should have only one at most. + */ + void checkCPUOutputsNumber(int max = 1) { + int cnt = 0; + for (size_t i = 0; i < outputOtherDevice_.size(); i++) { + if (outputOtherDevice_[i].deviceId == CPU_DEVICE) { + ++cnt; + } + } + CHECK_LE(cnt, max) << "too much CPU devies"; + } + + /** + * copy SeqInfo from input layer to this output and other output devices. + * @note: do not use getInput(0) since it used this deviceId_, + * use "inputLayers_[0]->getOutput()" instead. + */ + void copySeqInfoToOutputs() { + if (inputLayers_.empty() || !needSequenceInfo_) { + return; + } + const Argument& input = inputLayers_[0]->getOutput(); + output_.sequenceStartPositions = input.sequenceStartPositions; + output_.subSequenceStartPositions = input.subSequenceStartPositions; + output_.cpuSequenceDims = input.cpuSequenceDims; + for (size_t i = 0; i < outputOtherDevice_.size(); i++) { + outputOtherDevice_[i].sequenceStartPositions = + output_.sequenceStartPositions; + outputOtherDevice_[i].subSequenceStartPositions = + output_.subSequenceStartPositions; + outputOtherDevice_[i].cpuSequenceDims = output_.cpuSequenceDims; + } + } }; } // namespace paddle diff --git a/paddle/gserver/tests/MKLDNNTester.cpp b/paddle/gserver/tests/MKLDNNTester.cpp index de1635be2af37cd0ba49010199a417090865b0e4..2f48e5b2d3ffc9337ed1314f6db6549e56263fdd 100644 --- a/paddle/gserver/tests/MKLDNNTester.cpp +++ b/paddle/gserver/tests/MKLDNNTester.cpp @@ -63,8 +63,12 @@ void MKLDNNTester::reset(const TestConfig& dnn, initTestLayer( configs_[i], &(layerMaps_[i]), &(parameters_[i]), &(testLayers_[i])); } - dnnLayer_ = testLayers_[DNN]; refLayer_ = testLayers_[REF]; + dnnLayer_ = std::dynamic_pointer_cast(testLayers_[DNN]); + CHECK(dnnLayer_); + // for comparison with Paddle reference results, + // need manually add cpu device output for test + dnnLayer_->addOutputArgument(CPU_DEVICE); EXPECT_EQ(dataLayers_[DNN].size(), dataLayers_[REF].size()); EXPECT_EQ(parameters_[DNN].size(), parameters_[REF].size()); @@ -109,20 +113,22 @@ void MKLDNNTester::randomBotDatas() { void MKLDNNTester::randomTopDiffs() { refLayer_->getOutputGrad()->randomizeUniform(); - dnnLayer_->getOutputGrad()->copyFrom(*(refLayer_->getOutputGrad())); - VLOG(lvl_) << "Random dom Backward Input, TopDiff: "; + dnnLayer_->getOutput(CPU_DEVICE) + .grad->copyFrom(*(refLayer_->getOutputGrad())); + VLOG(lvl_) << "Random Backward Input, TopDiff: "; printMatrix(refLayer_->getOutputGrad()); } void MKLDNNTester::checkForward() { - printTopDatas(); - double delta = compareMatrix(testLayers_[DNN]->getOutputValue(), - testLayers_[REF]->getOutputValue()); VLOG(MKLDNN_ALL) << "Check Forward"; + printTopDatas(); + double delta = compareMatrix(dnnLayer_->getOutput(-1).value, + refLayer_->getOutputValue()); EXPECT_LE(fabs(delta), eps_); } void MKLDNNTester::checkBackwardData() { + VLOG(MKLDNN_ALL) << "Check Backward Data"; // TODO(TJ): uncomment me when batch norm ready // const bool isBN = dnnLayer_->getType() == "mkldnn_batch_norm"; for (size_t i = 0; i < dataLayers_[DNN].size(); ++i) { @@ -144,14 +150,12 @@ void MKLDNNTester::checkBackwardData() { } void MKLDNNTester::checkBackwardWgts() { + VLOG(MKLDNN_ALL) << "Check Backward Weight"; CHECK_EQ(parameters_[DNN].size(), parameters_[REF].size()); vector dnnWgts; // used to temply save mkldnn weights saveWgt(parameters_[DNN], dnnWgts); - const MKLDNNLayerPtr dnnlayer = - std::dynamic_pointer_cast(dnnLayer_); - CHECK(dnnlayer); - dnnlayer->convertWeightsToPaddle(); + dnnLayer_->convertWeightsToPaddle(); for (size_t i = 0; i < parameters_[DNN].size(); ++i) { const VectorPtr& dnn = parameters_[DNN][i]->getBuf(PARAMETER_VALUE); const VectorPtr& ref = parameters_[REF][i]->getBuf(PARAMETER_VALUE); @@ -189,38 +193,38 @@ void MKLDNNTester::restoreWgt(const vector& from, } // clear parameters grad -void MKLDNNTester::clearWgtDiffs() { +void MKLDNNTester::clearWgtDiffs(size_t id) { + CHECK_LE(id, parameters_.size()); for (size_t n = 0; n < parameters_.size(); ++n) { - for (size_t i = 0; i < parameters_[n].size(); ++i) { - const VectorPtr& grad = parameters_[n][i]->getBuf(PARAMETER_GRADIENT); - if (grad) { - grad->zeroMem(); + if (id == n || id == parameters_.size()) { + for (size_t i = 0; i < parameters_[n].size(); ++i) { + const VectorPtr& grad = parameters_[n][i]->getBuf(PARAMETER_GRADIENT); + if (grad) { + grad->zeroMem(); + } } } } } -void MKLDNNTester::clearBotDiffs() { - // dnn and ref +void MKLDNNTester::clearBotDiffs(size_t id) { + CHECK_LE(id, dataLayers_.size()); for (size_t n = 0; n < dataLayers_.size(); ++n) { - // all inputs layers - for (size_t i = 0; i < dataLayers_[n].size(); ++i) { - dataLayers_[n][i]->getOutputGrad()->zeroMem(); + if (id == n || id == dataLayers_.size()) { + // clear inputs layers of this specific layer + for (size_t i = 0; i < dataLayers_[n].size(); ++i) { + dataLayers_[n][i]->getOutputGrad()->zeroMem(); + } } } } -void MKLDNNTester::clearBotDiffs(int n) { - CHECK_LT(n, NUM); - // all inputs layers - for (size_t i = 0; i < dataLayers_[n].size(); ++i) { - dataLayers_[n][i]->getOutputGrad()->zeroMem(); - } -} - -void MKLDNNTester::clearTopDatas() { +void MKLDNNTester::clearTopDatas(size_t id) { + CHECK_LE(id, testLayers_.size()); for (size_t i = 0; i < testLayers_.size(); ++i) { - testLayers_[i]->getOutputValue()->zeroMem(); + if (id == i || id == testLayers_.size()) { + testLayers_[i]->getOutputValue()->zeroMem(); + } } } @@ -300,16 +304,24 @@ void MKLDNNTester::runOnce() { checkForward(); // test backward + // simple updater + UpdateCallback updateCallback = [](Parameter* para) { + auto& grad = para->getBuf(PARAMETER_GRADIENT); + auto& value = para->getBuf(PARAMETER_VALUE); + real lr = 1e-3; + value->add(*grad, lr); + }; randomTopDiffs(); - dnnLayer_->backward(nullptr); - refLayer_->backward(nullptr); + dnnLayer_->backward(updateCallback); + refLayer_->backward(updateCallback); checkBackwardData(); checkBackwardWgts(); // clear buffers // ref code will addto the diff, dnn code will writeto it - // and clearTopDatas() and clearWgtDiffs() should be coverd by test layers + // and clearTopDatas(REF) should be coverd by ref layers clearBotDiffs(REF); + clearWgtDiffs(REF); } void MKLDNNTester::run(const TestConfig& dnn, diff --git a/paddle/gserver/tests/MKLDNNTester.h b/paddle/gserver/tests/MKLDNNTester.h index e55e4493ffdfe45b8cfdee423febd1878b8b3d8a..5ac885638cde7693a0c847733e7a6149c1b7e6c2 100644 --- a/paddle/gserver/tests/MKLDNNTester.h +++ b/paddle/gserver/tests/MKLDNNTester.h @@ -18,6 +18,7 @@ limitations under the License. */ #include #include "LayerGradUtil.h" #include "paddle/gserver/layers/MKLDNNBase.h" +#include "paddle/gserver/layers/MKLDNNLayer.h" namespace paddle { @@ -40,7 +41,8 @@ protected: vector layerMaps_; vector> parameters_; vector testLayers_; - LayerPtr dnnLayer_, refLayer_; + LayerPtr refLayer_; + MKLDNNLayerPtr dnnLayer_; /// run some iterations, all the result should pass size_t iter_; @@ -88,10 +90,10 @@ private: void checkBackwardData(); void checkBackwardWgts(); - void clearWgtDiffs(); - void clearBotDiffs(); - void clearBotDiffs(int n); // clear specific layer - void clearTopDatas(); + // clear specific layer, clear all when id equals NUM + void clearWgtDiffs(size_t id = NUM); + void clearBotDiffs(size_t id = NUM); + void clearTopDatas(size_t id = NUM); void printTopDatas(); void printMatrix(const MatrixPtr& m); diff --git a/paddle/operators/math/im2col_test.cc b/paddle/operators/math/im2col_test.cc index 186a33edcec88bd5e51091a524a778eeb27ad526..4f380388b108dc173d847f027ba5c9db387a87f8 100644 --- a/paddle/operators/math/im2col_test.cc +++ b/paddle/operators/math/im2col_test.cc @@ -119,4 +119,4 @@ TEST(math, im2col) { #ifndef PADDLE_ONLY_CPU testIm2col(); #endif -} \ No newline at end of file +} diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index 227b75aff86089d0b21bdae7e6e402292bce67d9..3958b53c22c383e5e2298bfdc4e8490d4148118f 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -17,6 +17,7 @@ limitations under the License. */ #include #include "paddle/framework/backward.h" +#include "paddle/framework/lod_tensor.h" #include "paddle/framework/op_registry.h" #include "paddle/operators/net_op.h" #include "paddle/operators/recurrent_op.h" @@ -58,6 +59,8 @@ namespace paddle { namespace framework { using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; +using LoD = framework::LoD; static size_t UniqueIntegerGenerator() { static std::atomic generator; @@ -117,6 +120,60 @@ PYBIND11_PLUGIN(core) { return self.data()[offset]; }); + py::class_(m, "LoDTensor", R"DOC(LoD(Leval of Ddetails) Tensor. + +The tensor and LoD info should be created before creating the LoDTensor, then +call the set_tensor and set_lod functions to set them. + +)DOC") + .def("__init__", + [](LoDTensor &instance, + const std::vector> &lod, + Tensor *t) { +#ifdef PADDLE_ONLY_CPU + new (&instance) LoDTensor(lod, t); +#else + paddle::framework::LoD new_lod; + new_lod.reserve(lod.size()); + std::copy(lod.begin(), lod.end(), std::back_inserter(new_lod)); + new (&instance) LoDTensor(new_lod, t); +#endif + }) + .def("set_tensor", + [](LoDTensor &self, Tensor *tensor) { self.set_tensor(tensor); }) + .def("set_lod", + [](LoDTensor &self, const std::vector> &lod) { +#ifdef PADDLE_ONLY_CPU + self.set_lod(lod); +#else + paddle::framework::LoD new_lod; + new_lod.reserve(lod.size()); + std::copy(lod.begin(), lod.end(), std::back_inserter(new_lod)); + self.set_lod(new_lod); +#endif + }) + .def("tensor", + [](LoDTensor &self) -> Tensor & { return self.tensor(); }, + py::return_value_policy::reference) + .def("lod", [](LoDTensor &self) -> std::vector> { +#ifdef PADDLE_ONLY_CPU + return self.lod(); +#else + auto lod = self.lod(); + std::vector> new_lod; + new_lod.reserve(lod.size()); + std::transform(lod.begin(), lod.end(), std::back_inserter(new_lod), + [](paddle::framework::Vector item) -> + std::vector { + std::vector v; + v.reserve(item.size()); + std::copy(item.begin(), item.end(), std::back_inserter(v)); + return v; + }); + return new_lod; +#endif + }); + py::class_(m, "Variable", R"DOC(Variable Class. All parameter, weight, gradient are variables in Paddle. @@ -128,6 +185,11 @@ All parameter, weight, gradient are variables in Paddle. .def("get_tensor", [](Variable &self) -> Tensor * { return self.GetMutable(); }, py::return_value_policy::reference) + .def("get_lod_tensor", + [](Variable &self) -> LoDTensor * { + return self.GetMutable(); + }, + py::return_value_policy::reference) .def("get_net", [](Variable &self) -> operators::NetOp * { return self.GetMutable(); diff --git a/python/paddle/v2/framework/tests/test_tensor.py b/python/paddle/v2/framework/tests/test_tensor.py index 1af39818a305215b45219b8c5f0a10630fd64279..f26ed4964c521be1cd839b39d7244f96c653cb1a 100644 --- a/python/paddle/v2/framework/tests/test_tensor.py +++ b/python/paddle/v2/framework/tests/test_tensor.py @@ -3,7 +3,7 @@ import unittest import numpy -class TestScope(unittest.TestCase): +class TestTensor(unittest.TestCase): def test_int_tensor(self): scope = core.Scope() var = scope.new_var("test_tensor") @@ -20,8 +20,8 @@ class TestScope(unittest.TestCase): tensor.set(tensor_array, place) tensor_array_2 = numpy.array(tensor) - self.assertEqual(1.0, tensor_array_2[3, 9]) - self.assertEqual(2.0, tensor_array_2[19, 11]) + self.assertEqual(1, tensor_array_2[3, 9]) + self.assertEqual(2, tensor_array_2[19, 11]) def test_float_tensor(self): scope = core.Scope() @@ -43,6 +43,84 @@ class TestScope(unittest.TestCase): self.assertAlmostEqual(1.0, tensor_array_2[3, 9]) self.assertAlmostEqual(2.0, tensor_array_2[19, 11]) + def test_int_lod_tensor(self): + places = [core.CPUPlace(), core.GPUPlace(0)] + for place in places: + scope = core.Scope() + var = scope.new_var("test_tensor") + var_lod = scope.new_var("test_lod_tensor") + + tensor = var.get_tensor() + lod_tensor = var_lod.get_lod_tensor() + + tensor.set_dims([4, 4, 6]) + tensor.alloc_int(place) + array = numpy.array(tensor) + array[0, 0, 0] = 3 + array[3, 3, 5] = 10 + tensor.set(array, place) + + lod_tensor.set_tensor(tensor) + lod_tensor.set_lod([[0, 2, 4]]) + + lod_v = numpy.array(lod_tensor.tensor()) + self.assertTrue(numpy.alltrue(array == lod_v)) + + lod = lod_tensor.lod() + self.assertEqual(0, lod[0][0]) + self.assertEqual(2, lod[0][1]) + self.assertEqual(4, lod[0][2]) + + def test_float_lod_tensor(self): + places = [core.CPUPlace(), core.GPUPlace(0)] + for place in places: + scope = core.Scope() + var = scope.new_var("test_tensor") + var_lod = scope.new_var("test_lod_tensor") + + tensor = var.get_tensor() + lod_tensor = var_lod.get_lod_tensor() + + tensor.set_dims([5, 2, 3, 4]) + tensor.alloc_float(place) + + tensor_array = numpy.array(tensor) + self.assertEqual((5, 2, 3, 4), tensor_array.shape) + tensor_array[0, 0, 0, 0] = 1.0 + tensor_array[0, 0, 0, 1] = 2.0 + tensor.set(tensor_array, place) + + lod_tensor.set_tensor(tensor) + + lod_v = numpy.array(lod_tensor.tensor()) + self.assertAlmostEqual(1.0, lod_v[0, 0, 0, 0]) + self.assertAlmostEqual(2.0, lod_v[0, 0, 0, 1]) + self.assertEqual(len(lod_tensor.lod()), 0) + + lod_py = [[0, 2, 5], [0, 2, 4, 5]] + lod_tensor.set_lod(lod_py) + lod = lod_tensor.lod() + self.assertListEqual(lod_py, lod) + + def test_lod_tensor_init(self): + scope = core.Scope() + var = scope.new_var("test_tensor") + place = core.CPUPlace() + tensor = var.get_tensor() + tensor.set_dims([5, 2, 3, 4]) + tensor.alloc_float(place) + tensor_array = numpy.array(tensor) + tensor_array[0, 0, 0, 0] = 1.0 + tensor_array[0, 0, 0, 1] = 2.0 + tensor.set(tensor_array, place) + lod_py = [[0, 2, 5], [0, 2, 4, 5]] + + lod_tensor = core.LoDTensor(lod_py, tensor) + lod_v = numpy.array(lod_tensor.tensor()) + self.assertAlmostEqual(1.0, lod_v[0, 0, 0, 0]) + self.assertAlmostEqual(2.0, lod_v[0, 0, 0, 1]) + self.assertListEqual(lod_py, lod_tensor.lod()) + if __name__ == '__main__': unittest.main()