diff --git a/doc/getstarted/build_and_install/docker_install_en.rst b/doc/getstarted/build_and_install/docker_install_en.rst index f43e83d1297637a84f8a8bd581d1ab94089efc28..8fb9369e0e8e31e620169fa2856094c414efe23e 100644 --- a/doc/getstarted/build_and_install/docker_install_en.rst +++ b/doc/getstarted/build_and_install/docker_install_en.rst @@ -8,199 +8,255 @@ Please be aware that you will need to change `Dockers settings `_ to make full use of your hardware resource on Mac OS X and Windows. +Working With Docker +------------------- + +Docker is simple as long as we understand a few basic concepts: + +- *image*: A Docker image is a pack of software. It could contain one or more programs and all their dependencies. For example, the PaddlePaddle's Docker image includes pre-built PaddlePaddle and Python and many Python packages. We can run a Docker image directly, other than installing all these software. We can type + + .. code-block:: bash + + docker images + + to list all images in the system. We can also run + + .. code-block:: bash + + docker pull paddlepaddle/paddle:0.10.0rc2 + + to download a Docker image, paddlepaddle/paddle in this example, + from Dockerhub.com. + +- *container*: considering a Docker image a program, a container is a + "process" that runs the image. Indeed, a container is exactly an + operating system process, but with a virtualized filesystem, network + port space, and other virtualized environment. We can type + + .. code-block:: bash + + docker run paddlepaddle/paddle:0.10.0rc2 + + to start a container to run a Docker image, paddlepaddle/paddle in this example. + +- By default docker container have an isolated file system namespace, + we can not see the files in the host file system. By using *volume*, + mounted files in host will be visible inside docker container. + Following command will mount current dirctory into /data inside + docker container, run docker container from debian image with + command :code:`ls /data`. + + .. code-block:: bash + + docker run --rm -v $(pwd):/data debian ls /data Usage of CPU-only and GPU Images ---------------------------------- -For each version of PaddlePaddle, we release 2 types of Docker images: development -image and production image. Production image includes CPU-only version and a CUDA -GPU version and their no-AVX versions. We put the docker images on -`dockerhub.com `_. You can find the -latest versions under "tags" tab at dockerhub.com. -1. development image :code:`paddlepaddle/paddle:-dev` +For each version of PaddlePaddle, we release two types of Docker images: +development image and production image. Production image includes +CPU-only version and a CUDA GPU version and their no-AVX versions. We +put the docker images on `dockerhub.com +`_. You can find the +latest versions under "tags" tab at dockerhub.com - This image has packed related develop tools and runtime environment. Users and - developers can use this image instead of their own local computer to accomplish - development, build, releasing, document writing etc. While different version of - paddle may depends on different version of libraries and tools, if you want to - setup a local environment, you must pay attention to the versions. - The development image contains: - - gcc/clang - - nvcc - - Python - - sphinx - - woboq - - sshd - Many developers use servers with GPUs, they can use ssh to login to the server - and run :code:`docker exec` to enter the docker container and start their work. - Also they can start a development docker image with SSHD service, so they can login to - the container and start work. +1. Production images, this image might have multiple variants: - To run the CPU-only image as an interactive container: + - GPU/AVX::code:`paddlepaddle/paddle:-gpu` + - GPU/no-AVX::code:`paddlepaddle/paddle:-gpu-noavx` + - CPU/AVX::code:`paddlepaddle/paddle:` + - CPU/no-AVX::code:`paddlepaddle/paddle:-noavx` - .. code-block:: bash + Please be aware that the CPU-only and the GPU images both use the + AVX instruction set, but old computers produced before 2008 do not + support AVX. The following command checks if your Linux computer + supports AVX: - docker run -it --rm paddledev/paddle: /bin/bash + .. code-block:: bash - or, we can run it as a daemon container + if cat /proc/cpuinfo | grep -i avx; then echo Yes; else echo No; fi - .. code-block:: bash + + To run the CPU-only image as an interactive container: - docker run -d -p 2202:22 -p 8888:8888 paddledev/paddle: + .. code-block:: bash - and SSH to this container using password :code:`root`: + docker run -it --rm paddlepaddle/paddle:0.10.0rc2 /bin/bash - .. code-block:: bash + Above method work with the GPU image too -- the recommended way is + using `nvidia-docker `_. - ssh -p 2202 root@localhost + Please install nvidia-docker first following this `tutorial + `_. - An advantage of using SSH is that we can connect to PaddlePaddle from - more than one terminals. For example, one terminal running vi and - another one running Python interpreter. Another advantage is that we - can run the PaddlePaddle container on a remote server and SSH to it - from a laptop. + Now you can run a GPU image: + .. code-block:: bash -2. Production images, this image might have multiple variants: - - GPU/AVX::code:`paddlepaddle/paddle:-gpu` - - GPU/no-AVX::code:`paddlepaddle/paddle:-gpu-noavx` - - CPU/AVX::code:`paddlepaddle/paddle:` - - CPU/no-AVX::code:`paddlepaddle/paddle:-noavx` + nvidia-docker run -it --rm paddlepaddle/paddle:0.10.0rc2-gpu /bin/bash - Please be aware that the CPU-only and the GPU images both use the AVX - instruction set, but old computers produced before 2008 do not support - AVX. The following command checks if your Linux computer supports - AVX: +2. development image :code:`paddlepaddle/paddle:-dev` - .. code-block:: bash + This image has packed related develop tools and runtime + environment. Users and developers can use this image instead of + their own local computer to accomplish development, build, + releasing, document writing etc. While different version of paddle + may depends on different version of libraries and tools, if you + want to setup a local environment, you must pay attention to the + versions. The development image contains: + + - gcc/clang + - nvcc + - Python + - sphinx + - woboq + - sshd + + Many developers use servers with GPUs, they can use ssh to login to + the server and run :code:`docker exec` to enter the docker + container and start their work. Also they can start a development + docker image with SSHD service, so they can login to the container + and start work. - if cat /proc/cpuinfo | grep -i avx; then echo Yes; else echo No; fi +Train Model Using Python API +---------------------------- - If it doesn't, we will use the non-AVX images. +Our official docker image provides a runtime for PaddlePaddle +programs. The typical workflow will be as follows: - Above methods work with the GPU image too -- just please don't forget - to install GPU driver. To support GPU driver, we recommend to use - [nvidia-docker](https://github.com/NVIDIA/nvidia-docker). Run using +Create a directory as workspace: - .. code-block:: bash +.. code-block:: bash - nvidia-docker run -it --rm paddledev/paddle:0.10.0rc1-gpu /bin/bash + mkdir ~/workspace - Note: If you would have a problem running nvidia-docker, you may try the old method we have used (not recommended). +Edit a PaddlePaddle python program using your favourite editor - .. code-block:: bash +.. code-block:: bash - export CUDA_SO="$(\ls /usr/lib64/libcuda* | xargs -I{} echo '-v {}:{}') $(\ls /usr/lib64/libnvidia* | xargs -I{} echo '-v {}:{}')" - export DEVICES=$(\ls /dev/nvidia* | xargs -I{} echo '--device {}:{}') - docker run ${CUDA_SO} ${DEVICES} -it paddledev/paddle:-gpu + emacs ~/workspace/example.py +Run the program using docker: -3. Use production image to release you AI application - Suppose that we have a simple application program in :code:`a.py`, we can test and run it using the production image: +.. code-block:: bash - ```bash - docker run -it -v $PWD:/work paddle /work/a.py - ``` + docker run --rm -v ~/workspace:/workspace paddlepaddle/paddle:0.10.0rc2 python /workspace/example.py - But this works only if all dependencies of :code:`a.py` are in the production image. If this is not the case, we need to build a new Docker image from the production image and with more dependencies installs. +Or if you are using GPU for training: +.. code-block:: bash -PaddlePaddle Book ------------------- + nvidia-docker run --rm -v ~/workspace:/workspace paddlepaddle/paddle:0.10.0rc2-gpu python /workspace/example.py -The Jupyter Notebook is an open-source web application that allows -you to create and share documents that contain live code, equations, -visualizations and explanatory text in a single browser. +Above commands will start a docker container by running :code:`python +/workspace/example.py`. It will stop once :code:`python +/workspace/example.py` finishes. -PaddlePaddle Book is an interactive Jupyter Notebook for users and developers. -We already exposed port 8888 for this book. If you want to -dig deeper into deep learning, PaddlePaddle Book definitely is your best choice. +Another way is to tell docker to start a :code:`/bin/bash` session and +run PaddlePaddle program interactively: -We provide a packaged book image, simply issue the command: +.. code-block:: bash + + docker run -it -v ~/workspace:/workspace paddlepaddle/paddle:0.10.0rc2 /bin/bash + # now we are inside docker container + cd /workspace + python example.py + +Running with GPU is identical: .. code-block:: bash - docker run -p 8888:8888 paddlepaddle/book + nvidia-docker run -it -v ~/workspace:/workspace paddlepaddle/paddle:0.10.0rc2-gpu /bin/bash + # now we are inside docker container + cd /workspace + python example.py -Then, you would back and paste the address into the local browser: -.. code-block:: text +Develop PaddlePaddle or Train Model Using C++ API +--------------------------------------------------- - http://localhost:8888/ +We will be using PaddlePaddle development image since it contains all +compiling tools and dependencies. -That's all. Enjoy your journey! +Let's clone PaddlePaddle repo first: -Development Using Docker ------------------------- +.. code-block:: bash -Developers can work on PaddlePaddle using Docker. This allows -developers to work on different platforms -- Linux, Mac OS X, and -Windows -- in a consistent way. + git clone https://github.com/PaddlePaddle/Paddle.git && cd Paddle -1. Build the Development Docker Image +Mount both workspace folder and paddle code folder into docker +container, so we can access them inside docker container. There are +two ways of using PaddlePaddle development docker image: - .. code-block:: bash +- run interactive bash directly - git clone --recursive https://github.com/PaddlePaddle/Paddle - cd Paddle - docker build -t paddle:dev . + .. code-block:: bash - Note that by default :code:`docker build` wouldn't import source - tree into the image and build it. If we want to do that, we need docker the - development docker image and then run the following command: + # use nvidia-docker instead of docker if you need to use GPU + docker run -it -v ~/workspace:/workspace -v $(pwd):/paddle paddlepaddle/paddle:0.10.0rc2-dev /bin/bash + # now we are inside docker container - .. code-block:: bash +- or, we can run it as a daemon container - docker run -v $PWD:/paddle -e "WITH_GPU=OFF" -e "WITH_AVX=ON" -e "TEST=OFF" paddle:dev + .. code-block:: bash + # use nvidia-docker instead of docker if you need to use GPU + docker run -d -p 2202:22 -p 8888:8888 -v ~/workspace:/workspace -v $(pwd):/paddle paddlepaddle/paddle:0.10.0rc2-dev /usr/sbin/sshd -D -2. Run the Development Environment + and SSH to this container using password :code:`root`: - Once we got the image :code:`paddle:dev`, we can use it to develop - Paddle by mounting the local source code tree into a container that - runs the image: + .. code-block:: bash - .. code-block:: bash + ssh -p 2202 root@localhost - docker run -d -p 2202:22 -p 8888:8888 -v $PWD:/paddle paddle:dev sshd + An advantage is that we can run the PaddlePaddle container on a + remote server and SSH to it from a laptop. - This runs a container of the development environment Docker image - with the local source tree mounted to :code:`/paddle` of the - container. +When developing PaddlePaddle, you can edit PaddlePaddle source code +from outside of docker container using your favoriate editor. To +compile PaddlePaddle, run inside container: - The above :code:`docker run` commands actually starts - an SSHD server listening on port 2202. This allows us to log into - this container with: +.. code-block:: bash - .. code-block:: bash + WITH_GPU=OFF WITH_AVX=ON WITH_TEST=ON bash /paddle/paddle/scripts/docker/build.sh - ssh root@localhost -p 2202 +This builds everything about Paddle in :code:`/paddle/build`. And we +can run unit tests there: - Usually, I run above commands on my Mac. I can also run them on a - GPU server :code:`xxx.yyy.zzz.www` and ssh from my Mac to it: +.. code-block:: bash - .. code-block:: bash + cd /paddle/build + ctest - my-mac$ ssh root@xxx.yyy.zzz.www -p 2202 +When training model using C++ API, we can edit paddle program in +~/workspace outside of docker. And build from /workspace inside of +docker. -3. Build and Install Using the Development Environment +PaddlePaddle Book +------------------ - Once I am in the container, I can use - :code:`paddle/scripts/docker/build.sh` to build, install, and test - Paddle: +The Jupyter Notebook is an open-source web application that allows +you to create and share documents that contain live code, equations, +visualizations and explanatory text in a single browser. - .. code-block:: bash +PaddlePaddle Book is an interactive Jupyter Notebook for users and developers. +We already exposed port 8888 for this book. If you want to +dig deeper into deep learning, PaddlePaddle Book definitely is your best choice. - /paddle/paddle/scripts/docker/build.sh +We provide a packaged book image, simply issue the command: - This builds everything about Paddle in :code:`/paddle/build`. And - we can run unit tests there: +.. code-block:: bash - .. code-block:: bash + docker run -p 8888:8888 paddlepaddle/book + +Then, you would back and paste the address into the local browser: + +.. code-block:: text + + http://localhost:8888/ - cd /paddle/build - ctest +That's all. Enjoy your journey! Documentation diff --git a/paddle/cuda/include/hl_sequence.h b/paddle/cuda/include/hl_sequence.h index 9f9d8f972e3a4c62e5caedcf85054be5681b96c1..973ddcceed99ba4177b3db277e664611d42ac51b 100644 --- a/paddle/cuda/include/hl_sequence.h +++ b/paddle/cuda/include/hl_sequence.h @@ -159,4 +159,10 @@ extern void hl_sequence_avg_forward(real* dst, int width, const int mode); +extern void hl_sequence_avg_backward(real* dst, + real* src, + const int* starts, + int height, + int width, + const int mode); #endif /* HL_SEQUENCE_H_ */ diff --git a/paddle/cuda/include/stub/hl_sequence_stub.h b/paddle/cuda/include/stub/hl_sequence_stub.h index 05e51bce9e1df6fc6ef1cad891b44a9172da185d..920b417b1c717efaff75f70f1b9d2b574469e425 100644 --- a/paddle/cuda/include/stub/hl_sequence_stub.h +++ b/paddle/cuda/include/stub/hl_sequence_stub.h @@ -57,4 +57,10 @@ inline void hl_sequence_avg_forward(real* dst, int width, const int mode) {} +inline void hl_sequence_avg_backward(real* dst, + real* src, + const int* starts, + int height, + int width, + const int mode) {} #endif // HL_SEQUENCE_STUB_H_ diff --git a/paddle/cuda/src/hl_cuda_sequence.cu b/paddle/cuda/src/hl_cuda_sequence.cu index ba823de2720336851bf9c49d8162360af93e8601..0fe2877f89f8d0fbc4db40c400037be30bb87ff7 100644 --- a/paddle/cuda/src/hl_cuda_sequence.cu +++ b/paddle/cuda/src/hl_cuda_sequence.cu @@ -325,12 +325,12 @@ __global__ void KeSequenceAvgForward(real* dst, int seqLength = end - start; if (seqLength == 0) return; real sum = 0.0; - for (int i = 0; i < seqLength; i++) { - sum += src[(start + i) * width + col]; + for (int i = start; i < end; i++) { + sum += src[i * width + col]; } sum = mode == 1 ? sum : (mode == 0 ? sum / seqLength : sum * my_rsqrt((real)seqLength)); - dst[row * width + col] = sum; + dst[gid] = sum; } } @@ -354,3 +354,48 @@ void hl_sequence_avg_forward(real* dst, (dst, src, starts, height, width, mode); CHECK_SYNC("hl_sequence_avg_forward failed"); } + +__global__ void KeSequenceAvgBackward(real* dst, + real* src, + const int* starts, + int height, + int width, + const int mode) { + int gid = blockIdx.x * blockDim.x + threadIdx.x; + int row = gid / width; + int col = gid % width; + + if (gid < height * width) { + int start = starts[row]; + int end = starts[row + 1]; + int seqLength = end - start; + if (seqLength == 0) return; + real grad = src[gid]; + grad = mode == 1 ? grad : + (mode == 0 ? grad / seqLength : grad * my_rsqrt((real)seqLength)); + for (int i = start; i < end; i++) { + dst[i * width + col] += grad; + } + } +} + +void hl_sequence_avg_backward(real* dst, + real* src, + const int* starts, + int height, + int width, + const int mode) { + CHECK_NOTNULL(dst); + CHECK_NOTNULL(src); + CHECK_NOTNULL(starts); + + int block = 512; + int grid = DIVUP(width * height, 512); + + CHECK(mode == 0 || mode == 1 || mode == 2) + << "mode error in hl_sequence_avg_backward!"; + + KeSequenceAvgBackward<<< grid, block, 0, STREAM_DEFAULT >>> + (dst, src, starts, height, width, mode); + CHECK_SYNC("hl_sequence_avg_backward failed"); +} diff --git a/paddle/gserver/layers/AverageLayer.cpp b/paddle/gserver/layers/AverageLayer.cpp index b8955ab04f209629c855ed66f8e8e9701b7224a3..96cc4288c6faad4b80c790ed2ce6f5128ea83b6d 100644 --- a/paddle/gserver/layers/AverageLayer.cpp +++ b/paddle/gserver/layers/AverageLayer.cpp @@ -26,8 +26,6 @@ bool AverageLayer::init(const LayerMap& layerMap, const ParameterMap& parameterMap) { SequencePoolLayer::init(layerMap, parameterMap); - dataMtx_ = Matrix::create(nullptr, 1, 1, false, useGpu_); - outMtx_ = Matrix::create(nullptr, 1, getSize(), false, useGpu_); // average strategy if (config_.average_strategy() == "average") { mode_ = kAverage; @@ -60,43 +58,9 @@ void AverageLayer::forward(PassType passType) { void AverageLayer::backward(const UpdateCallback& callback) { SequencePoolLayer::backward(callback); - const int* starts = startPositions_->getData(false); - MatrixPtr grad = getInputGrad(0); - - if (grad) { - size_t dim = getSize(); - real* gradientData = getInputGrad(0)->getData(); - real* gradient = getOutputGrad()->getData(); - size_t numSequences = startPositions_->getSize() - 1; - for (size_t sequenceId = 0; sequenceId < numSequences; ++sequenceId) { - // TODO(Dangqingqing) optimization for GPU - int sequenceLength = starts[sequenceId + 1] - starts[sequenceId]; - if (0 == sequenceLength) { - // empty sequence - continue; - } - dataMtx_->setData( - gradientData + starts[sequenceId] * dim, sequenceLength, dim); - outMtx_->setData(gradient + sequenceId * dim); - switch (mode_) { - case kAverage: { - // plain average - dataMtx_->addBias(*outMtx_, 1.0f / sequenceLength); - break; - } - case kSum: { - // sum instead of average - dataMtx_->addBias(*outMtx_, 1.0f); - break; - } - case kAverageSquareRootN: { - // divide by square root of sequenceLength - dataMtx_->addBias(*outMtx_, 1.0f / sqrt(sequenceLength)); - break; - } - default: { LOG(FATAL) << "should not reach here"; } - } - } + if (getInputGrad(0)) { + getInputGrad(0)->sequenceAvgBackward( + *getOutputGrad(), *startPositions_->getVector(useGpu_), mode_); } } diff --git a/paddle/gserver/layers/AverageLayer.h b/paddle/gserver/layers/AverageLayer.h index 621e1d7bb12ec5b8c7a6173bd601835d9406e814..332552a30479a368c24db10e5ef3a9d59408c8ef 100644 --- a/paddle/gserver/layers/AverageLayer.h +++ b/paddle/gserver/layers/AverageLayer.h @@ -45,8 +45,6 @@ public: void backward(const UpdateCallback& callback = nullptr) override; protected: - MatrixPtr outMtx_; - MatrixPtr dataMtx_; int mode_; }; } // namespace paddle diff --git a/paddle/math/Matrix.cpp b/paddle/math/Matrix.cpp index 9eead5b62c690b0a3310d8b68bfa3f1870be17c2..5f30a15f2eb913d57d01479cf132e188b9e7c813 100644 --- a/paddle/math/Matrix.cpp +++ b/paddle/math/Matrix.cpp @@ -483,6 +483,20 @@ void GpuMatrix::sequenceAvgForward(Matrix& a, hl_sequence_avg_forward(dst, src, starts, height, width, mode); } +void GpuMatrix::sequenceAvgBackward(Matrix& a, + const IVector& startsPos, + int mode) { + size_t height = a.getHeight(); + size_t width = getWidth(); + CHECK_EQ(height, startsPos.getSize() - 1); + CHECK_EQ(width, a.getWidth()); + real* dst = getData(); + real* src = a.getData(); + const int* starts = startsPos.getData(); + + hl_sequence_avg_backward(dst, src, starts, height, width, mode); +} + /* this = scaleAB*(a*b) + scaleT*this */ void GpuMatrix::mul(const GpuMatrix& a, const GpuMatrix& b, @@ -2304,6 +2318,41 @@ void CpuMatrix::sequenceAvgForward(Matrix& a, } } +void CpuMatrix::sequenceAvgBackward(Matrix& a, + const IVector& startsPos, + int mode) { + size_t height = a.getHeight(); + size_t width = getWidth(); + CHECK_EQ(height, startsPos.getSize() - 1); + CHECK_EQ(width, a.getWidth()); + real* dst = getData(); + real* src = a.getData(); + const int* starts = startsPos.getData(); + MatrixPtr outMtx = Matrix::create(nullptr, 1, width, false, false); + MatrixPtr dataMtx = Matrix::create(nullptr, 1, width, false, false); + for (size_t i = 0; i < height; ++i) { + int sequenceLength = starts[i + 1] - starts[i]; + if (0 == sequenceLength) { + // empty sequence + continue; + } + outMtx->setData(dst + starts[i] * width, sequenceLength, width); + dataMtx->setData(src + i * width); + if (mode == 0) { + // plain average + outMtx->addBias(*dataMtx, 1.0f / sequenceLength); + } else if (mode == 1) { + // sum instead of average + outMtx->addBias(*dataMtx, 1.0f); + } else if (mode == 2) { + // divide by square root of sequenceLength + outMtx->addBias(*dataMtx, 1.0f / std::sqrt(sequenceLength)); + } else { + LOG(FATAL) << "should not reach here"; + } + } +} + /* this = scaleAB*(a*b) + scaleT*this*/ void CpuMatrix::mul(const Matrix& a, const Matrix& b, diff --git a/paddle/math/Matrix.h b/paddle/math/Matrix.h index dbdb629614546b7c7b569d7473d96a06d0c5a9c7..3252adb19e4c2e48f86c3c811bfc7d75fd06a8f7 100644 --- a/paddle/math/Matrix.h +++ b/paddle/math/Matrix.h @@ -461,6 +461,12 @@ public: LOG(FATAL) << "Not implemented"; } + virtual void sequenceAvgBackward(Matrix& a, + const IVector& startsPos, + int mode) { + LOG(FATAL) << "Not implemented"; + } + /** * @code * this = scaleAB*(a*b) + scaleT*this @@ -1203,6 +1209,7 @@ public: void collectSharedBias(Matrix& a, real scale); void sequenceAvgForward(Matrix& a, const IVector& startsPos, int mode); + void sequenceAvgBackward(Matrix& a, const IVector& startsPos, int mode); /** * @code @@ -1619,6 +1626,7 @@ public: void collectSharedBias(Matrix& a, real scale); void sequenceAvgForward(Matrix& a, const IVector& startsPos, int mode); + void sequenceAvgBackward(Matrix& a, const IVector& startsPos, int mode); /** * @code diff --git a/paddle/math/tests/test_matrixCompare.cpp b/paddle/math/tests/test_matrixCompare.cpp index 08b64c1bb6f5d359a2d2164e723a76c5360168ee..dd19fe516fbf724a86479e6f27032614ab4c6106 100644 --- a/paddle/math/tests/test_matrixCompare.cpp +++ b/paddle/math/tests/test_matrixCompare.cpp @@ -685,7 +685,7 @@ TEST(SMatrix, topK) { } } -void testMatrixSequenceAvgForward(int batchSize, int inputDim, int mode) { +void testMatrixSequenceAvg(int batchSize, int inputDim, int mode) { MatrixPtr cpuInput = std::make_shared(batchSize, inputDim); MatrixPtr gpuInput = std::make_shared(batchSize, inputDim); cpuInput->randomizeUniform(); @@ -706,15 +706,25 @@ void testMatrixSequenceAvgForward(int batchSize, int inputDim, int mode) { gpuOutput->sequenceAvgForward(*gpuInput, *gpuSequence, mode); TensorCheckErr(*cpuOutput, *gpuOutput); + + MatrixPtr cpuInGrad = std::make_shared(batchSize, inputDim); + MatrixPtr gpuInGrad = std::make_shared(batchSize, inputDim); + cpuInGrad->randomizeUniform(); + gpuInGrad->copyFrom(*cpuInGrad); + + cpuInGrad->sequenceAvgBackward(*cpuOutput, *cpuSequence, mode); + gpuInGrad->sequenceAvgBackward(*gpuOutput, *gpuSequence, mode); + + TensorCheckErr(*cpuInGrad, *gpuInGrad); } -TEST(Matrix, sequenceAvgForward) { +TEST(Matrix, sequenceAvg) { for (auto batchSize : {10, 128, 6000}) { for (auto inputDim : {32, 100, 512}) { for (auto mode : {0, 1, 2}) { VLOG(3) << " batchSize=" << batchSize << " inputDim=" << inputDim << " mode=" << mode; - testMatrixSequenceAvgForward(batchSize, inputDim, mode); + testMatrixSequenceAvg(batchSize, inputDim, mode); } } } diff --git a/python/paddle/v2/plot/plot_curve.py b/python/paddle/v2/plot/plot_curve.py new file mode 100644 index 0000000000000000000000000000000000000000..0f62674cb2baad9e4ecd9f6655f7e2dc00173dc6 --- /dev/null +++ b/python/paddle/v2/plot/plot_curve.py @@ -0,0 +1,48 @@ +from IPython import display +import os + + +class PlotCost(object): + """ + append train and test cost in event_handle and then call plot. + """ + + def __init__(self): + self.train_costs = ([], []) + self.test_costs = ([], []) + + self.__disable_plot__ = os.environ.get("DISABLE_PLOT") + if not self.__plot_is_disabled__(): + import matplotlib.pyplot as plt + self.plt = plt + + def __plot_is_disabled__(self): + return self.__disable_plot__ == "True" + + def plot(self): + if self.__plot_is_disabled__(): + return + + self.plt.plot(*self.train_costs) + self.plt.plot(*self.test_costs) + title = [] + if len(self.train_costs[0]) > 0: + title.append('Train Cost') + if len(self.test_costs[0]) > 0: + title.append('Test Cost') + self.plt.legend(title, loc='upper left') + display.clear_output(wait=True) + display.display(self.plt.gcf()) + self.plt.gcf().clear() + + def append_train_cost(self, step, cost): + self.train_costs[0].append(step) + self.train_costs[1].append(cost) + + def append_test_cost(self, step, cost): + self.test_costs[0].append(step) + self.test_costs[1].append(cost) + + def reset(self): + self.train_costs = ([], []) + self.test_costs = ([], [])