提交 daa50117 编写于 作者: Y Yu Yang

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into feature/polish_reshape_op

......@@ -139,9 +139,6 @@ def run_benchmark(model, args):
# inference program
inference_program = fluid.default_main_program().clone()
with fluid.program_guard(inference_program):
inference_program = fluid.io.get_inference_program(
target_vars=[batch_acc, batch_size_tensor])
# Optimization
opt = fluid.optimizer.AdamOptimizer(
......@@ -161,7 +158,7 @@ def run_benchmark(model, args):
train_reader = paddle.batch(
paddle.dataset.mnist.train(), batch_size=args.batch_size)
accuracy = fluid.average.WeightedAverage()
accuracy = fluid.metrics.Accuracy()
iters, num_samples, start_time = 0, 0, time.time()
for pass_id in range(args.pass_num):
accuracy.reset()
......@@ -184,7 +181,7 @@ def run_benchmark(model, args):
"label": y_data},
fetch_list=[avg_cost, batch_acc, batch_size_tensor]
) # The accuracy is the accumulation of batches, but not the current batch.
accuracy.add(value=outs[1], weight=outs[2])
accuracy.update(value=outs[1], weight=outs[2])
iters += 1
num_samples += len(y_data)
loss = np.array(outs[0])
......
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
if(NOT WITH_GPU)
return()
endif()
include(ExternalProject)
set(NCCL_SOURCE_DIR ${THIRD_PARTY_PATH}/nccl)
include_directories(${NCCL_SOURCE_DIR}/src/extern_nccl/src)
if(WITH_DSO)
# If we use DSO, we do not build nccl, just download the dependencies
set(NCCL_BUILD_COMMAND "")
set(NCCL_INSTALL_COMMAND "")
set(NCCL_INSTALL_DIR "")
else()
# otherwise, we build nccl and link it.
set(NCCL_INSTALL_DIR ${THIRD_PARTY_PATH}/install/nccl)
# Note: cuda 8.0 is needed to make nccl
# When cuda is not installed on the system directory, need to set CUDA_HOME to your cuda root
set(NCCL_BUILD_COMMAND "make -j 8")
set(NCCL_INSTALL_COMMAND "make install PREFIX=${NCCL_INSTALL_DIR}")
endif()
ExternalProject_Add(
extern_nccl
${EXTERNAL_PROJECT_LOG_ARGS}
GIT_REPOSITORY "https://github.com/NVIDIA/nccl.git"
GIT_TAG "v1.3.4-1"
PREFIX "${NCCL_SOURCE_DIR}"
UPDATE_COMMAND ""
CONFIGURE_COMMAND ""
BUILD_COMMAND "${NCCL_BUILD_COMMAND}"
INSTALL_COMMAND "${NCCL_INSTALL_COMMAND}"
INSTALL_DIR "${NCCL_INSTALL_DIR}"
TEST_COMMAND ""
)
if(WITH_DSO)
if(${CMAKE_VERSION} VERSION_LESS "3.3.0")
set(dummyfile ${CMAKE_CURRENT_BINARY_DIR}/lib_nccl_dummy.c)
file(WRITE ${dummyfile} "const char * dummy_nccl = \"${dummyfile}\";")
add_library(nccl STATIC ${dummyfile})
else()
add_library(nccl INTERFACE)
endif()
else()
add_library(nccl STATIC IMPORTED GLOBAL)
set_property(TARGET nccl PROPERTY IMPORTED_LOCATION
${NCCL_INSTALL_DIR}/lib/libnccl_static.a)
endif()
add_dependencies(nccl extern_nccl)
......@@ -65,39 +65,55 @@ PaddlePaddle.org工具可以配合Docker使用,需要在系统里先安装好D
不使用PaddlePaddle.org工具
--------------------------
使用Docker构建PaddlePaddle的文档,需要在系统里先安装好Docker工具包。Docker安装请参考 `Docker的官网 <https://docs.docker.com/>`_ 。安装好Docker之后可以使用源码目录下的脚本构建文档,即
使用Docker构建PaddlePaddle的文档,需要在系统里先安装好Docker工具包。Docker安装请参考 `Docker的官网 <https://docs.docker.com/>`_ 。该方法与 `从源码编译PaddlePaddle <http://paddlepaddle.org/docs/develop/documentation/zh/build_and_install/build_from_source_cn.html>`_ 相似,通过从源码中构建可用于编译PaddlePaddle文档的Docker镜像并运行,在进入Docker容器后使用源码中的脚本构建PaddlePaddle文档,具体步骤如下:
[TBD]
.. code-block:: bash
git clone https://github.com/PaddlePaddle/Paddle.git
cd Paddle
# 从源码中构建可用于编译PaddlePaddle文档的Docker镜像
docker build -t paddle:dev .
docker run -it -v $PWD:/paddle -e "WITH_GPU=OFF" -e "WITH_TESTING=OFF" -e "WITH_DOC=ON" paddle:dev /bin/bash
# 进入Docker容器后使用build.sh脚本构建PaddlePaddle文档
bash -x /paddle/paddle/scripts/docker/build.sh
注:上述命令把当前目录(源码根目录)映射为 container 里的 :code:`/paddle` 目录。
编译完成后,会产生 ``doc/v2`` 和 ``doc/fluid`` 两个目录,在这两个目录下分别都生成 ``cn/html/`` 、 ``en/html`` 、 ``api/en/html`` 共三个子目录,分别进入这些目录下,执行以下命令:
.. code-block:: bash
python -m SimpleHTTPServer 8088
在浏览器中输入 http://localhost:8088 就可以看到编译生成的 ``v2`` 和 ``fluid`` 两种版本的中/英文的文档页面和英文的API页面。
如果不想使用Docker,也可以使用以下命令直接构建PaddlePaddle文档,即
.. code-block:: bash
mkdir paddle
cd paddle
git clone https://github.com/PaddlePaddle/Paddle.git
cd Paddle
mkdir -p build
cd build
cmake .. -DCMAKE_BUILD_TYPE=Release -DWITH_GPU=OFF -DWITH_MKL=OFF -DWITH_DOC=ON
# 如果只需要构建使用文档,则执行以下命令
make -j $processors gen_proto_py
make -j $processors paddle_docs paddle_docs_cn
make -j $processors paddle_docs
# 如果只需要构建API,则执行以下命令
make -j $processors gen_proto_py framework_py_proto
make -j $processors copy_paddle_pybind
make -j $processors paddle_api_docs
make -j $processors paddle_apis
其中$processors代表启动和CPU核一样多的进程来并行编译,可以根据本机的CPU核数设置相应的值。
编译完成后,进入 ``doc/v2`` 目录,如果选择构建文档则会在该目录下生成 ``cn/html/`` 、 ``en/html`` 两个子目录,选择构建API则会生成 ``api/en/html`` 目录,分别进入这些目录下,执行以下命令:
编译完成后,同样会产生 ``doc/v2`` 和 ``doc/fluid`` 两个目录,如果选择构建文档则会在这两个目录下分别都生成 ``cn/html/`` 、 ``en/html`` 两个子目录,选择构建API则会在这两个目录下分别生成 ``api/en/html`` 目录,分别进入这些子目录下,执行以下命令:
.. code-block:: bash
python -m SimpleHTTPServer 8088
在浏览器中输入 http://localhost:8088 就可以看到编译生成的中/英文的文档页面和英文的API页面,下图为生成的英文文档首页示例。注意,示例中由于使用了sphinx的原始主题,所以页面的风格与官网并不一致,但这并不影响开发者进行调试。
在浏览器中输入 http://localhost:8088 就可以看到编译生成的 ``v2`` 和 ``fluid`` 两种版本的中/英文的文档页面和英文的API页面。下图为生成的 ``v2`` 英文文档首页示例。注意,示例中由于使用了sphinx的原始主题,所以页面的风格与官网并不一致,但这并不影响开发者进行调试。
.. image:: src/doc_en.png
:align: center
......
......@@ -68,39 +68,56 @@ Please `click here <https://github.com/PaddlePaddle/PaddlePaddle.org/blob/develo
Manually Building the Documentation
-------------------------------------
Build PaddlePaddle's documentation with Docker,you need to install Docker first. Please refer to `Docker's official website <https://docs.docker.com/>`_ on how to install Docker. After Docker is installed, you could use the scripts in the source directory to build the documentation.
Build PaddlePaddle's documentation with Docker,you need to install Docker first. Please refer to `Docker's official website <https://docs.docker.com/>`_ on how to install Docker. This method is quite similar to ` Build From Sources <http://paddlepaddle.org/docs/develop/documentation/en/build_and_install/build_from_source_en.html>`_ , by constructing, from source code, a docker image that can be used to build PaddlePaddle documentation. Enter the Docker container and use the script ``build.sh`` in the source directory to build the PaddlePaddle documentation. The specific steps are as follows:
[TBD]
.. code-block:: bash
git clone https://github.com/PaddlePaddle/Paddle.git
cd Paddle
# Construct a docker image from source code
docker build -t paddle:dev .
docker run -it -v $PWD:/paddle -e "WITH_GPU=OFF" -e "WITH_TESTING=OFF" -e "WITH_DOC=ON" paddle:dev /bin/bash
# Use build.sh to build PaddlePaddle documentation
bash -x /paddle/paddle/scripts/docker/build.sh
Note: The above commands maps the current directory (source root directory) to the :code:`/paddle` directory in the container.
After compiling, there should be two generated directories: ``doc/v2`` and ``doc/fluid``, where three subdirectories ``cn/html/``, ``en/html`` and ``api/en/html`` are generated. Please enter these directories respectively and execute the following commands:
.. code-block:: bash
python -m SimpleHTTPServer 8088
Use a web browser and navigate to http://localhost:8000, you could see the compiled ``v2`` 's and ``fluid`` 's Chinese/English documents page and English APIs page.
If you do not wish to use Docker, you can also use the following commands to directly build the PaddlePaddle documentation.
.. code-block:: bash
mkdir paddle
cd paddle
git clone https://github.com/PaddlePaddle/Paddle.git
cd Paddle
mkdir -p build
cd build
cmake .. -DCMAKE_BUILD_TYPE=Release -DWITH_GPU=OFF -DWITH_MKL=OFF -DWITH_DOC=ON
# If you only need to build documents, use the following commands
make -j $processors gen_proto_py
make -j $processors paddle_docs paddle_docs_cn
make -j $processors paddle_docs
# If you only need to build APIs, use the following commands
make -j $processors gen_proto_py framework_py_proto
make -j $processors copy_paddle_pybind
make -j $processors paddle_api_docs
make -j $processors paddle_apis
$processors indicates that as many processes as the CPU cores are started to compile in parallel. It should be set according to the number of CPU cores of your machine.
After the compilation is complete, enter the ``doc/v2`` directory. If you chose to build documents, it will generate ``cn/html/`` and ``en/html`` subdirectories under this directory. If you chose to build APIs,it will generate``api/en/html`` subdirectory. Please enter these directories respectively and execute the following commands:
After compiling, there also should be two generated directories: ``doc/v2`` and ``doc/fluid`` . If you chose to build documents, two subdirectories ``cn/html/`` and ``en/html`` will be generated in both two directories. If you chose to build APIs,a subdirectory ``api/en/html`` will be generated. Please enter these directories respectively and execute the following commands:
.. code-block:: bash
python -m SimpleHTTPServer 8088
Use a web browser and navigate to http://localhost:8000, you could see the compiled Chinese/English documents page and the English APIs page. The following figure is an example of the built English documents home page. Note that due to the sphinx's original theme used in the example, the style of the page is not consistent with the official website, but this does not affect the developer's debugging.
Use a web browser and navigate to http://localhost:8000, you could see the compiled ``v2`` 's and ``fluid`` 's Chinese/English documents page and English APIs page. The following figure is an example of the built ``v2`` 's English documents home page. Note that due to the sphinx's original theme used in the example, the style of the page is not consistent with the official website, but this does not affect the developer's debugging.
.. image:: src/doc_en.png
:align: center
......
......@@ -5,6 +5,7 @@ cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod
nv_library(nccl_all_reduce_op_handle SRCS nccl_all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
dynload_cuda)
cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry)
cc_library(send_op_handle SRCS send_op_handle.cc DEPS framework_proto scope place operator op_registry)
cc_library(ssa_graph SRCS ssa_graph.cc DEPS var_handle op_handle_base)
cc_library(ssa_graph_builder SRCS ssa_graph_builder.cc DEPS ssa_graph)
......@@ -15,7 +16,7 @@ else()
set(multi_devices_graph_builder_deps)
endif()
cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle
scale_loss_grad_op_handle ${multi_devices_graph_builder_deps})
scale_loss_grad_op_handle send_op_handle ${multi_devices_graph_builder_deps})
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph framework_proto)
cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
simple_threadpool device_context)
......@@ -15,6 +15,7 @@
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
#include "paddle/fluid/framework/details/send_op_handle.h"
#include "paddle/fluid/framework/scope.h"
#ifdef PADDLE_WITH_CUDA
......@@ -54,6 +55,27 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
}
}
void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result, OpDesc *op,
const platform::Place &p,
const size_t &i) const {
auto *op_handle = result->ops_.back().get();
op_handle->dev_ctxes_[p] = const_cast<platform::DeviceContext *>(
platform::DeviceContextPool::Instance().Get(p));
auto var_names = op->InputArgumentNames();
for (auto &each_var_name : var_names) {
VarHandle *var = CreateOrGetLatestVarHandle(result, each_var_name, p, i);
op_handle->AddInput(var);
}
var_names = op->OutputArgumentNames();
for (auto &each_var_name : var_names) {
CreateOpOutput(result, op_handle, each_var_name, p, i);
}
}
std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
const ProgramDesc &program) const {
auto graph = new SSAGraph();
......@@ -76,27 +98,28 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
}
}
// append send op if program is distributed trainer main program.
// always use the first device
if (!is_forwarding && op->Type() == "send") {
auto &p = places_[0];
auto *s = local_scopes_[0];
// FIXME(wuyi): send op always copy from GPU 0
result.ops_.emplace_back(new SendOpHandle(*op, s, p));
// Create inputs for output on original place and no ssa output
// is created for send op.
CreateOpHandleIOs(&result, op, p, 0);
continue;
}
for (size_t i = 0; i < places_.size(); ++i) {
auto &p = places_[i];
auto *s = local_scopes_[i];
result.ops_.emplace_back(new ComputationOpHandle(*op, s, p));
auto *op_handle = result.ops_.back().get();
op_handle->dev_ctxes_[p] = const_cast<platform::DeviceContext *>(
platform::DeviceContextPool::Instance().Get(p));
CreateOpHandleIOs(&result, op, p, i);
auto var_names = op->InputArgumentNames();
for (auto &each_var_name : var_names) {
VarHandle *var =
CreateOrGetLatestVarHandle(&result, each_var_name, p, i);
op_handle->AddInput(var);
}
var_names = op->OutputArgumentNames();
for (auto &each_var_name : var_names) {
CreateOpOutput(&result, op_handle, each_var_name, p, i);
}
auto var_names = op->OutputArgumentNames();
if (is_forwarding) {
if (var_names.size() == 1 && var_names[0] == loss_var_name_) {
......
......@@ -14,6 +14,9 @@
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/ssa_graph_builder.h"
namespace paddle {
......@@ -41,6 +44,10 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const override;
private:
void CreateOpHandleIOs(SSAGraph *result, OpDesc *op, const platform::Place &p,
const size_t &i) const;
private:
std::string loss_var_name_;
const std::vector<platform::Place> &places_;
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/send_op_handle.h"
namespace paddle {
namespace framework {
namespace details {
SendOpHandle::SendOpHandle(const framework::OpDesc &op_desc,
const Scope *local_scope,
const platform::Place &place)
: op_(framework::OpRegistry::CreateOp(op_desc)),
local_scope_(local_scope),
place_(place) {}
void SendOpHandle::RunImpl() {
// Wait input done
for (auto *in : inputs_) {
auto &p = static_cast<VarHandle *>(in)->place_;
if (in->DebugString() == "dummy") { // HACK
continue;
}
in->generated_op_->Wait(dev_ctxes_[p]);
}
op_->Run(*local_scope_, place_);
}
std::string SendOpHandle::Name() const { return "send"; }
} // namespace details
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/scope.h"
namespace paddle {
namespace framework {
namespace details {
struct SendOpHandle : public OpHandleBase {
std::unique_ptr<OperatorBase> op_;
const Scope* local_scope_;
const platform::Place& place_;
SendOpHandle(const framework::OpDesc& op_desc, const Scope* local_scope,
const platform::Place& place);
std::string Name() const override;
// Delay and buffer nccl_all_reduce together can significantly increase
// performance. Disable this feature by returning false.
bool IsMultiDeviceTransfer() override { return false; };
protected:
void RunImpl() override;
};
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -12,9 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h"
#include <stdint.h>
#include <string.h>
#include <algorithm>
#include <iterator>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/memory/memory.h"
......@@ -22,11 +27,6 @@ limitations under the License. */
#include "paddle/fluid/recordio/scanner.h"
#include "paddle/fluid/recordio/writer.h"
#include <stdint.h>
#include <string.h>
#include <algorithm>
#include <iterator>
namespace paddle {
namespace framework {
......@@ -294,7 +294,7 @@ void DeserializeFromStream(std::istream &is, LoDTensor *tensor,
TensorFromStream(is, static_cast<Tensor *>(tensor), dev_ctx);
}
void WriteToRecordIO(recordio::Writer &writer,
void WriteToRecordIO(recordio::Writer *writer,
const std::vector<LoDTensor> &tensor,
const platform::DeviceContext &dev_ctx) {
std::stringstream buffer;
......@@ -303,18 +303,20 @@ void WriteToRecordIO(recordio::Writer &writer,
for (auto &each : tensor) {
SerializeToStream(buffer, each, dev_ctx);
}
writer.Write(buffer.str());
writer->Write(buffer.str());
}
std::vector<LoDTensor> ReadFromRecordIO(
recordio::Scanner &scanner, const platform::DeviceContext &dev_ctx) {
std::istringstream sin(scanner.Next());
uint32_t sz;
sin.read(reinterpret_cast<char *>(&sz), sizeof(uint32_t));
recordio::Scanner *scanner, const platform::DeviceContext &dev_ctx) {
std::vector<LoDTensor> result;
result.resize(sz);
for (uint32_t i = 0; i < sz; ++i) {
DeserializeFromStream(sin, &result[i], dev_ctx);
if (scanner->HasNext()) {
std::istringstream sin(scanner->Next());
uint32_t sz;
sin.read(reinterpret_cast<char *>(&sz), sizeof(uint32_t));
result.resize(sz);
for (uint32_t i = 0; i < sz; ++i) {
DeserializeFromStream(sin, &result[i], dev_ctx);
}
}
return result;
}
......
......@@ -15,6 +15,9 @@ limitations under the License. */
#pragma once
#include <memory>
#include <string>
#include <utility>
#include <vector>
#ifdef PADDLE_WITH_CUDA
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
......@@ -216,12 +219,12 @@ void SerializeToStream(std::ostream& os, const LoDTensor& tensor,
void DeserializeFromStream(std::istream& is, LoDTensor* tensor,
const platform::DeviceContext& dev_ctx);
extern void WriteToRecordIO(recordio::Writer& writer,
extern void WriteToRecordIO(recordio::Writer* writer,
const std::vector<LoDTensor>& tensor,
const platform::DeviceContext& dev_ctx);
extern std::vector<LoDTensor> ReadFromRecordIO(
recordio::Scanner& scanner, const platform::DeviceContext& dev_ctx);
recordio::Scanner* scanner, const platform::DeviceContext& dev_ctx);
} // namespace framework
} // namespace paddle
......@@ -12,17 +12,17 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/recordio/scanner.h"
#include "paddle/fluid/recordio/writer.h"
#include <glog/logging.h>
#include <gtest/gtest.h>
#include <algorithm>
#include <memory>
#include <vector>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/recordio/scanner.h"
#include "paddle/fluid/recordio/writer.h"
namespace paddle {
namespace framework {
......@@ -240,8 +240,8 @@ TEST(LoDTensor, RecordIO) {
*platform::DeviceContextPool::Instance().Get(platform::CPUPlace());
{
recordio::Writer writer(stream, recordio::Compressor::kSnappy);
WriteToRecordIO(writer, {tensor, tensor}, ctx);
WriteToRecordIO(writer, {tensor, tensor}, ctx);
WriteToRecordIO(&writer, {tensor, tensor}, ctx);
WriteToRecordIO(&writer, {tensor, tensor}, ctx);
writer.Flush();
}
......@@ -254,11 +254,11 @@ TEST(LoDTensor, RecordIO) {
{
std::unique_ptr<std::istream> stream_ptr(stream);
recordio::Scanner scanner(std::move(stream_ptr));
auto tensors = ReadFromRecordIO(scanner, ctx);
auto tensors = ReadFromRecordIO(&scanner, ctx);
ASSERT_EQ(tensors.size(), 2);
assert_tensor_ok(tensors[0]);
assert_tensor_ok(tensors[1]);
tensors = ReadFromRecordIO(scanner, ctx);
tensors = ReadFromRecordIO(&scanner, ctx);
ASSERT_EQ(tensors.size(), 2);
assert_tensor_ok(tensors[0]);
assert_tensor_ok(tensors[1]);
......
......@@ -115,14 +115,12 @@ void ParallelExecutor::BCastParamsToGPUs(
for (auto &var : vars) {
auto *main_var = main_scope->FindVar(var);
if (!main_var->IsType<LoDTensor>()) {
if (main_var == nullptr || !main_var->IsType<LoDTensor>()) {
continue;
}
auto &main_tensor = main_var->Get<LoDTensor>();
auto &dims = main_tensor.dims();
if (paddle::platform::is_gpu_place(main_tensor.place())) {
size_t numel = main_tensor.numel();
ncclDataType_t data_type = platform::ToNCCLDataType(main_tensor.type());
......
......@@ -48,13 +48,13 @@ class ParallelExecutor {
const std::string& fetched_var_name,
const std::unordered_map<std::string, LoDTensor>& feed_tensors);
void BCastParamsToGPUs(const std::unordered_set<std::string>& vars) const;
private:
void SplitTensorToPlaces(
const std::unordered_map<std::string, LoDTensor>& feed_tensors);
ParallelExecutorPrivate* member_;
void BCastParamsToGPUs(const std::unordered_set<std::string>& vars) const;
};
} // namespace framework
......
......@@ -22,7 +22,9 @@ FileReader::FileReader(const std::vector<DDim> &dims) : dims_(dims) {}
void FileReader::ReadNext(std::vector<LoDTensor> *out) {
ReadNextImpl(out);
PADDLE_ENFORCE_EQ(out->size(), dims_.size());
if (out->empty()) {
return;
}
for (size_t i = 0; i < dims_.size(); ++i) {
auto &actual = out->at(i).dims();
auto &expect = dims_[i];
......
......@@ -14,14 +14,13 @@
#pragma once
#include <memory>
#include <vector>
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/platform/place.h"
#include <memory>
#include <thread>
#include <vector>
namespace paddle {
namespace framework {
......@@ -31,8 +30,6 @@ class ReaderBase {
virtual void ReInit() = 0;
virtual bool HasNext() const = 0;
virtual ~ReaderBase();
};
......@@ -44,8 +41,6 @@ class DecoratedReader : public ReaderBase {
void ReInit() override { reader_->ReInit(); }
bool HasNext() const override { return reader_->HasNext(); }
protected:
ReaderBase* reader_;
};
......@@ -80,8 +75,6 @@ class ReaderHolder {
reader_->ReInit();
}
bool HasNext() const { return reader_->HasNext(); }
private:
std::unique_ptr<ReaderBase> reader_;
};
......
......@@ -24,7 +24,8 @@ function(inference_test TARGET_NAME)
endforeach()
endfunction(inference_test)
inference_test(fit_a_line)
# This unittest is buggy!
#inference_test(fit_a_line)
inference_test(image_classification ARGS vgg resnet)
inference_test(label_semantic_roles)
inference_test(recognize_digits ARGS mlp conv)
......
......@@ -65,9 +65,8 @@ bool RPCClient::AsyncSendVariable(const std::string& ep,
}
void ProcGetResponse(const VarHandle& var_h,
// const sendrecv::VariableMessage& ret_msg) {
const ::grpc::ByteBuffer& ret_msg) {
framework::Variable* outvar = NULL;
framework::Variable* outvar = nullptr;
DeserializeFromByteBuffer(ret_msg, *var_h.ctx, var_h.scope, &outvar);
}
......
......@@ -107,7 +107,7 @@ void RunSerdeTestSelectedRows(platform::Place place) {
for (int i = 0; i < tensor_numel; ++i) {
EXPECT_FLOAT_EQ(tensor_data2[i], 32.7);
}
for (int64_t i = 0; i < rows2->size(); ++i) {
for (size_t i = 0; i < rows2->size(); ++i) {
EXPECT_EQ(rows_data2[i], i);
}
EXPECT_EQ(slr2->height(), 1000);
......
......@@ -66,13 +66,7 @@ class ReadOp : public framework::OperatorBase {
std::vector<std::string> out_arg_names = Outputs("Out");
std::vector<framework::LoDTensor> ins;
reader->ReadNext(&ins);
if (ins.empty()) {
reader->ReInit();
reader->ReadNext(&ins);
PADDLE_ENFORCE(
!ins.empty(),
"Reader can not read the next data even it has been re-initialized.");
}
PADDLE_ENFORCE(!ins.empty(), "There is no next data.");
PADDLE_ENFORCE_EQ(ins.size(), out_arg_names.size());
for (size_t i = 0; i < ins.size(); ++i) {
auto* out =
......
......@@ -22,5 +22,6 @@ reader_library(create_batch_reader_op SRCS create_batch_reader_op.cc)
reader_library(create_recordio_file_reader_op SRCS create_recordio_file_reader_op.cc)
reader_library(create_double_buffer_reader_op SRCS create_double_buffer_reader_op.cc)
reader_library(create_multi_pass_reader_op SRCS create_multi_pass_reader_op.cc)
reader_library(create_threaded_reader_op SRCS create_threaded_reader_op.cc)
# Export local libraries to parent
set(READER_LIBRARY ${LOCAL_READER_LIBS} PARENT_SCOPE)
......@@ -63,13 +63,14 @@ class DoubleBufferReader : public framework::DecoratedReader {
StartPrefetcher();
}
bool HasNext() const override;
void ReadNext(std::vector<framework::LoDTensor>* out) override;
void ReInit() override;
~DoubleBufferReader() { EndPrefetcher(); }
private:
bool HasNext() const;
void StartPrefetcher() {
channel_ = framework::MakeChannel<Item>(kChannelSize);
prefetcher_ = std::thread([this] { PrefetchThreadFunc(); });
......@@ -109,7 +110,9 @@ class CreateDoubleBufferReaderOp : public framework::OperatorBase {
auto place_str = Attr<std::string>("place");
platform::Place place;
if (place_str == "CPU") {
if (place_str == "AUTO") {
place = dev_place;
} else if (place_str == "CPU") {
place = platform::CPUPlace();
} else {
std::istringstream sin(place_str);
......@@ -140,28 +143,22 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase {
enum_range.insert(string::Sprintf("CUDA:%d", i));
}
enum_range.insert("CPU");
AddAttr<std::string>("place", "The double buffer place, default is CPU")
.SetDefault("CPU")
enum_range.insert("AUTO");
AddAttr<std::string>("place", "The double buffer place")
.SetDefault("AUTO")
.InEnum({enum_range});
}
};
bool DoubleBufferReader::HasNext() const {
while (!channel_->IsClosed() && !channel_->CanReceive()) {
}
return channel_->CanReceive();
}
void DoubleBufferReader::ReadNext(std::vector<framework::LoDTensor>* out) {
if (!HasNext()) {
PADDLE_THROW("There is no next data!");
}
Item batch;
channel_->Receive(&batch);
*out = batch.payloads_;
if (batch.ctx_) {
batch.ctx_->Wait();
out->clear();
if (HasNext()) {
Item batch;
channel_->Receive(&batch);
*out = batch.payloads_;
if (batch.ctx_) {
batch.ctx_->Wait();
}
}
}
......@@ -171,16 +168,26 @@ void DoubleBufferReader::ReInit() {
StartPrefetcher();
}
bool DoubleBufferReader::HasNext() const {
while (!channel_->IsClosed() && !channel_->CanReceive()) {
}
return channel_->CanReceive();
}
void DoubleBufferReader::PrefetchThreadFunc() {
VLOG(5) << "A new prefetch thread starts.";
std::vector<std::vector<framework::LoDTensor>> cpu_tensor_cache(kCacheSize);
std::vector<std::vector<framework::LoDTensor>> gpu_tensor_cache(kCacheSize);
size_t cached_tensor_id = 0;
while (reader_->HasNext()) {
while (true) {
Item batch;
auto& cpu_batch = cpu_tensor_cache[cached_tensor_id];
reader_->ReadNext(&cpu_batch);
if (cpu_batch.empty()) {
// The underlying reader have no next data.
break;
}
if (platform::is_gpu_place(place_)) {
auto& gpu_batch = gpu_tensor_cache[cached_tensor_id];
auto* gpu_ctx = ctxs_[cached_tensor_id].get();
......
......@@ -25,22 +25,12 @@ class MultiPassReader : public framework::DecoratedReader {
: DecoratedReader(reader), pass_num_(pass_num), pass_count_(0) {}
void ReadNext(std::vector<framework::LoDTensor>* out) override {
if (!HasNext()) {
PADDLE_THROW("There is no next data!");
}
reader_->ReadNext(out);
}
bool HasNext() const override {
if (reader_->HasNext()) {
return true;
} else {
if (out->empty()) {
++pass_count_;
if (pass_count_ >= pass_num_) {
return false;
} else {
if (pass_count_ < pass_num_) {
reader_->ReInit();
return true;
reader_->ReadNext(out);
}
}
}
......
......@@ -52,8 +52,6 @@ class RandomDataGenerator : public framework::ReaderBase {
void ReInit() override { return; }
bool HasNext() const override { return true; }
private:
float min_;
float max_;
......@@ -74,7 +72,7 @@ class CreateRandomDataGeneratorOp : public framework::OperatorBase {
const auto& ranks = Attr<std::vector<int>>("ranks");
PADDLE_ENFORCE(!shape_concat.empty() && !ranks.empty());
PADDLE_ENFORCE_EQ(std::accumulate(ranks.begin(), ranks.end(), 0),
int(shape_concat.size()),
static_cast<int>(shape_concat.size()),
"The accumulate of all ranks should be equal to the "
"shape concat's length.");
std::vector<framework::DDim> shapes = RestoreShapes(shape_concat, ranks);
......
......@@ -12,8 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <mutex>
#include <thread>
#include "paddle/fluid/operators/reader/reader_op_registry.h"
#include "paddle/fluid/recordio/scanner.h"
......@@ -35,17 +33,15 @@ class RecordIOFileReader : public framework::FileReader {
LOG(INFO) << "Creating file reader" << filename;
}
bool HasNext() const override { return scanner_.HasNext(); }
void ReInit() override { scanner_.Reset(); }
protected:
void ReadNextImpl(std::vector<framework::LoDTensor>* out) override {
if (ThreadSafe) {
std::lock_guard<std::mutex> guard(*mutex_);
*out = framework::ReadFromRecordIO(scanner_, dev_ctx_);
*out = framework::ReadFromRecordIO(&scanner_, dev_ctx_);
} else {
*out = framework::ReadFromRecordIO(scanner_, dev_ctx_);
*out = framework::ReadFromRecordIO(&scanner_, dev_ctx_);
}
}
......@@ -66,7 +62,7 @@ class CreateRecordIOReaderOp : public framework::OperatorBase {
const auto& ranks = Attr<std::vector<int>>("ranks");
PADDLE_ENFORCE(!shape_concat.empty() && !ranks.empty());
PADDLE_ENFORCE_EQ(std::accumulate(ranks.begin(), ranks.end(), 0),
int(shape_concat.size()),
static_cast<int>(shape_concat.size()),
"The accumulate of all ranks should be equal to the "
"shape concat's length.");
std::string filename = Attr<std::string>("filename");
......
......@@ -30,35 +30,33 @@ class ShuffleReader : public framework::DecoratedReader {
std::random_device device;
seed_ = device();
}
ReadIntoBuffers();
ReloadBuffer();
}
void ReadNext(std::vector<framework::LoDTensor>* out) override {
if (!HasNext()) {
PADDLE_THROW("There is no next data!");
}
out->clear();
if (iteration_pos_ >= buffer_.size()) {
VLOG(10) << "Resetting shuffle buffer";
ReadIntoBuffers();
ReloadBuffer();
if (buffer_.empty()) {
return;
}
}
*out = buffer_[iteration_pos_++];
}
bool HasNext() const override {
return iteration_pos_ < buffer_.size() || reader_->HasNext();
}
private:
void ReadIntoBuffers() {
void ReloadBuffer() {
buffer_.clear();
buffer_.reserve(buffer_size_);
iteration_pos_ = 0;
for (size_t i = 0; i < buffer_size_; ++i) {
if (!reader_->HasNext()) {
std::vector<framework::LoDTensor> ins;
reader_->ReadNext(&ins);
if (ins.empty()) {
break;
}
buffer_.emplace_back();
reader_->ReadNext(&buffer_.back());
buffer_.emplace_back(ins);
}
std::mt19937 g(seed_);
std::shuffle(buffer_.begin(), buffer_.end(), g);
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/reader/reader_op_registry.h"
namespace paddle {
namespace operators {
namespace reader {
class ThreadedReader : public framework::DecoratedReader {
public:
ThreadedReader(ReaderBase* reader, bool safe_mode)
: DecoratedReader(reader), safe_mode_(safe_mode) {}
void ReadNext(std::vector<framework::LoDTensor>* out) override {
std::lock_guard<std::mutex> lock(mutex_);
reader_->ReadNext(out);
}
void ReInit() override {
if (safe_mode_) {
PADDLE_THROW(
"ThreadedReader::ReInit() is disabled when 'safe_mode' is true.");
}
VLOG(5) << "ThreadedReader::ReInit() is invoked! It might be buggy in "
"multi-thread environment.";
reader_->ReInit();
}
private:
bool safe_mode_;
std::mutex mutex_;
};
class CreateThreadedReaderOp : public framework::OperatorBase {
public:
using framework::OperatorBase::OperatorBase;
private:
void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override {
auto* out = detail::Ref(scope.FindVar(Output("Out")))
.GetMutable<framework::ReaderHolder>();
if (out->Get() != nullptr) {
return;
}
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
->Get<framework::ReaderHolder>();
bool safe_mode = Attr<bool>("safe_mode");
out->Reset(new ThreadedReader(underlying_reader.Get(), safe_mode));
}
};
class CreateThreadedReaderOpMaker : public DecoratedReaderMakerBase {
public:
CreateThreadedReaderOpMaker(OpProto* op_proto, OpAttrChecker* op_checker)
: DecoratedReaderMakerBase(op_proto, op_checker) {
AddAttr<bool>("safe_mode",
"When 'safe_mode' is true, 'ReInit()' is disabled to avoid "
"unexpected bugs in multi-thread environment.")
.SetDefault(true);
AddComment(R"DOC(
CreateThreadedReader Operator
This operator creates a threaded reader. A threaded reader's
'ReadNext()' can be invoked by several threads at the same
time.
When the attribute 'safe_mode' is true, the threaded reader's
'ReInit()' is disabled to avoid unexpected bugs in multi-thread
environment.
)DOC");
}
};
} // namespace reader
} // namespace operators
} // namespace paddle
namespace reader = paddle::operators::reader;
REGISTER_DECORATED_READER_OPERATOR(create_threaded_reader,
reader::CreateThreadedReaderOp,
reader::CreateThreadedReaderOpMaker);
......@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <thread> // NOLINT
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/operators/reader/reader_op_registry.h"
......@@ -19,38 +21,23 @@ namespace paddle {
namespace operators {
namespace reader {
class MultipleReader : public framework::ReaderBase {
class MultiFileReader : public framework::ReaderBase {
public:
class ThreadBufferMap {
public:
std::vector<framework::LoDTensor>& operator[](
const std::thread::id& thread_id) {
std::lock_guard<std::mutex> lock(mutex_);
return buffer_[thread_id];
}
void Clear() { buffer_.clear(); }
private:
std::mutex mutex_;
std::unordered_map<std::thread::id, std::vector<framework::LoDTensor>>
buffer_;
};
MultipleReader(const std::vector<std::string>& file_names,
const std::vector<framework::DDim>& dims, size_t thread_num)
: file_names_(file_names), dims_(dims) {
MultiFileReader(const std::vector<std::string>& file_names,
const std::vector<framework::DDim>& dims, size_t thread_num,
size_t buffer_size)
: file_names_(file_names), dims_(dims), buffer_size_(buffer_size) {
prefetchers_.resize(thread_num);
StartNewScheduler();
}
void ReadNext(std::vector<framework::LoDTensor>* out) override;
bool HasNext() const override;
void ReInit() override;
~MultipleReader() { EndScheduler(); }
~MultiFileReader() { EndScheduler(); }
private:
bool HasNext();
void StartNewScheduler();
void EndScheduler();
void ScheduleThreadFunc();
......@@ -60,39 +47,36 @@ class MultipleReader : public framework::ReaderBase {
std::vector<framework::DDim> dims_;
std::thread scheduler_;
std::vector<std::thread> prefetchers_;
size_t buffer_size_;
framework::Channel<size_t>* waiting_file_idx_;
framework::Channel<size_t>* available_thread_idx_;
framework::Channel<std::vector<framework::LoDTensor>>* buffer_;
mutable ThreadBufferMap thread_buffer_map_;
};
void MultipleReader::ReadNext(std::vector<framework::LoDTensor>* out) {
if (!HasNext()) {
PADDLE_THROW("There is no next data!");
void MultiFileReader::ReadNext(std::vector<framework::LoDTensor>* out) {
out->clear();
if (HasNext()) {
buffer_->Receive(out);
}
auto& thread_local_buffer = thread_buffer_map_[std::this_thread::get_id()];
*out = thread_local_buffer;
thread_local_buffer.clear();
}
bool MultipleReader::HasNext() const {
auto& thread_local_buffer = thread_buffer_map_[std::this_thread::get_id()];
return thread_local_buffer.empty() ? buffer_->Receive(&thread_local_buffer)
: true;
}
void MultipleReader::ReInit() {
void MultiFileReader::ReInit() {
EndScheduler();
thread_buffer_map_.Clear();
StartNewScheduler();
}
void MultipleReader::StartNewScheduler() {
bool MultiFileReader::HasNext() {
while (!buffer_->IsClosed() && !buffer_->CanReceive()) {
}
return buffer_->CanReceive();
}
void MultiFileReader::StartNewScheduler() {
size_t thread_num = prefetchers_.size();
waiting_file_idx_ = framework::MakeChannel<size_t>(file_names_.size());
available_thread_idx_ = framework::MakeChannel<size_t>(thread_num);
buffer_ =
framework::MakeChannel<std::vector<framework::LoDTensor>>(thread_num);
framework::MakeChannel<std::vector<framework::LoDTensor>>(buffer_size_);
for (size_t i = 0; i < file_names_.size(); ++i) {
waiting_file_idx_->Send(&i);
......@@ -105,7 +89,7 @@ void MultipleReader::StartNewScheduler() {
scheduler_ = std::thread([this] { ScheduleThreadFunc(); });
}
void MultipleReader::EndScheduler() {
void MultiFileReader::EndScheduler() {
available_thread_idx_->Close();
buffer_->Close();
waiting_file_idx_->Close();
......@@ -117,8 +101,8 @@ void MultipleReader::EndScheduler() {
delete waiting_file_idx_;
}
void MultipleReader::ScheduleThreadFunc() {
VLOG(5) << "MultipleReader schedule thread starts.";
void MultiFileReader::ScheduleThreadFunc() {
VLOG(5) << "MultiFileReader schedule thread starts.";
size_t completed_thread_num = 0;
size_t thread_idx;
while (available_thread_idx_->Receive(&thread_idx)) {
......@@ -150,17 +134,20 @@ void MultipleReader::ScheduleThreadFunc() {
p.join();
}
}
VLOG(5) << "MultipleReader schedule thread terminates.";
VLOG(5) << "MultiFileReader schedule thread terminates.";
}
void MultipleReader::PrefetchThreadFunc(std::string file_name,
size_t thread_idx) {
void MultiFileReader::PrefetchThreadFunc(std::string file_name,
size_t thread_idx) {
VLOG(5) << "The prefetch thread of file '" << file_name << "' starts.";
std::unique_ptr<framework::ReaderBase> reader =
CreateReaderByFileName(file_name, dims_);
while (reader->HasNext()) {
while (true) {
std::vector<framework::LoDTensor> ins;
reader->ReadNext(&ins);
if (ins.empty()) {
break;
}
try {
buffer_->Send(&ins);
} catch (paddle::platform::EnforceNotMet e) {
......@@ -197,11 +184,13 @@ class OpenFilesOp : public framework::OperatorBase {
const auto& file_names = Attr<std::vector<std::string>>("file_names");
PADDLE_ENFORCE(!file_names.empty(), "No file to be read!");
const size_t thread_num = Attr<int>("thread_num");
const size_t buffer_size = Attr<int>("buffer_size");
auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::ReaderHolder>();
out->Reset(new MultipleReader(
file_names, RestoreShapes(shape_concat, ranks), thread_num));
out->Reset(new MultiFileReader(file_names,
RestoreShapes(shape_concat, ranks),
thread_num, buffer_size));
}
};
......@@ -212,11 +201,12 @@ class OpenFilesOpMaker : public FileReaderMakerBase {
AddAttr<std::vector<std::string>>("file_names", "Files to be read.");
AddAttr<int>("thread_num", "The maximal concurrent prefetch thread number.")
.GreaterThan(0);
AddAttr<int>("buffer_size", "The size of prefetch buffer.").GreaterThan(0);
AddComment(R"DOC(
OpenFiles Operator
An OpenFilesOp creates a MultipleReader, which is able to
An OpenFilesOp creates a MultiFileReader, which is able to
read data multi-threaded from multiple files.
)DOC");
}
......
......@@ -33,22 +33,26 @@ constexpr int PADDLE_CUDA_NUM_THREADS = 512;
USE_CUDA_ATOMIC(Add, float);
USE_CUDA_ATOMIC(Add, int);
USE_CUDA_ATOMIC(Add, unsigned int);
USE_CUDA_ATOMIC(Add, unsigned long long int);
// CUDA API uses unsigned long long int, we cannot use uint64_t here.
// It because unsigned long long int is not necessarily uint64_t
USE_CUDA_ATOMIC(Add, unsigned long long int); // NOLINT
CUDA_ATOMIC_WRAPPER(Add, int64_t) {
static_assert(sizeof(int64_t) == sizeof(long long int),
// Here, we check long long int must be int64_t.
static_assert(sizeof(int64_t) == sizeof(long long int), // NOLINT
"long long should be int64");
return CudaAtomicAdd(reinterpret_cast<unsigned long long int*>(address),
static_cast<unsigned long long int>(val));
return CudaAtomicAdd(
reinterpret_cast<unsigned long long int*>(address), // NOLINT
static_cast<unsigned long long int>(val)); // NOLINT
}
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600
USE_CUDA_ATOMIC(Add, double);
#else
CUDA_ATOMIC_WRAPPER(Add, double) {
unsigned long long int* address_as_ull =
reinterpret_cast<unsigned long long int*>(address);
unsigned long long int old = *address_as_ull, assumed;
unsigned long long int* address_as_ull = // NOLINT
reinterpret_cast<unsigned long long int*>(address); // NOLINT
unsigned long long int old = *address_as_ull, assumed; // NOLINT
do {
assumed = old;
......
......@@ -61,7 +61,7 @@ struct NCCLContext {
ncclComm_t comm_;
explicit NCCLContext(int dev_id)
: ctx_(new CUDADeviceContext(CUDAPlace(dev_id))) {}
: ctx_(new CUDADeviceContext(CUDAPlace(dev_id))), comm_{nullptr} {}
cudaStream_t stream() const { return ctx_->stream(); }
......@@ -95,6 +95,7 @@ struct NCCLContextMap {
std::vector<int> order_;
explicit NCCLContextMap(const std::vector<platform::Place> &places) {
PADDLE_ENFORCE(!places.empty());
order_.reserve(places.size());
for (auto &p : places) {
int dev_id = boost::get<CUDAPlace>(p).device;
......@@ -105,15 +106,17 @@ struct NCCLContextMap {
order_.size(), contexts_.size(),
"NCCL Context Map does not support contain two or more same device");
std::vector<ncclComm_t> comms;
comms.resize(order_.size());
if (places.size() > 1) {
std::vector<ncclComm_t> comms;
comms.resize(order_.size());
PADDLE_ENFORCE(platform::dynload::ncclCommInitAll(
&comms[0], static_cast<int>(order_.size()), &order_[0]));
PADDLE_ENFORCE(platform::dynload::ncclCommInitAll(
&comms[0], static_cast<int>(order_.size()), &order_[0]));
int i = 0;
for (auto &dev_id : order_) {
contexts_.at(dev_id).comm_ = comms[i++];
int i = 0;
for (auto &dev_id : order_) {
contexts_.at(dev_id).comm_ = comms[i++];
}
}
}
......
......@@ -252,7 +252,6 @@ All parameter, weight, gradient are variables in Paddle.
py::return_value_policy::reference);
py::class_<framework::ReaderHolder>(m, "Reader", "")
.def("has_next", &framework::ReaderHolder::HasNext)
.def("reset", &framework::ReaderHolder::ReInit);
py::class_<Scope>(m, "Scope", "")
......@@ -554,6 +553,7 @@ All parameter, weight, gradient are variables in Paddle.
bcast_vars, main_program, loss_var_name,
scope, local_scopes, allow_op_delay);
})
.def("bcast_params", &ParallelExecutor::BCastParamsToGPUs)
.def("local_scopes",
[](ParallelExecutor &self) -> std::vector<Scope *> * {
return &self.GetLocalScopes();
......
......@@ -39,7 +39,7 @@ class RecordIOWriter {
void CompleteAppendTensor() {
auto& ctx =
*platform::DeviceContextPool::Instance().Get(platform::CPUPlace());
framework::WriteToRecordIO(writer_, tensors_, ctx);
framework::WriteToRecordIO(&writer_, tensors_, ctx);
tensors_.clear();
}
......
......@@ -29,6 +29,7 @@ import optimizer
import backward
import regularizer
import average
import metrics
from param_attr import ParamAttr, WeightNormParamAttr
from data_feeder import DataFeeder
from core import LoDTensor, CPUPlace, CUDAPlace, CUDAPinnedPlace
......
......@@ -13,6 +13,7 @@
# limitations under the License.
import numpy as np
import warnings
"""
Class of all kinds of Average.
......@@ -22,6 +23,8 @@ import numpy as np
wrappers of Python functions.
"""
__all__ = ["WeightedAverage"]
def _is_number_(var):
return isinstance(var, int) or isinstance(var, float) or (isinstance(
......@@ -34,6 +37,9 @@ def _is_number_or_matrix_(var):
class WeightedAverage(object):
def __init__(self):
warnings.warn(
"The %s is deprecated, please use fluid.metrics.Accuracy instead." %
(self.__class__.__name__), Warning)
self.reset()
def reset(self):
......
......@@ -255,6 +255,7 @@ class DistributeTranspiler:
def get_trainer_program(self):
# remove optimize ops and add a send op to main_program
self.program.global_block().delete_ops(self.optimize_ops)
self.program.sync_with_cpp()
# FIXME(typhoonzero): serialize once will fix error occurs when clone.
self.program.__str__()
return self.program
......
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
import numpy as np
import layers
......@@ -59,6 +60,9 @@ class Evaluator(object):
"""
def __init__(self, name, **kwargs):
warnings.warn(
"The %s is deprecated, because maintain a modified program inside evaluator cause bug easily, please use fluid.metrics.%s instead."
% (self.__class__.__name__, self.__class__.__name__), Warning)
self.states = []
self.metrics = []
self.helper = LayerHelper(name, **kwargs)
......
......@@ -18,7 +18,8 @@ import contextlib
__all__ = [
'Constant', 'Uniform', 'Normal', 'Xavier', 'force_init_on_cpu',
'init_on_cpu'
'init_on_cpu', 'ConstantInitializer', 'UniformInitializer',
'NormalInitializer', 'XavierInitializer'
]
_force_init_on_cpu_ = False
......
......@@ -21,8 +21,7 @@ from ..executor import global_scope
__all__ = [
'data', 'BlockGuardServ', 'ListenAndServ', 'Send', 'open_recordio_file',
'open_files', 'read_file', 'create_shuffle_reader',
'create_double_buffer_reader', 'create_multi_pass_reader'
'open_files', 'read_file', 'shuffle', 'double_buffer'
]
......@@ -237,13 +236,9 @@ def monkey_patch_reader_methods(reader):
var = scope.find_var(reader.name)
return var.get_reader()
def eof():
return not __get_reader__().has_next()
def reset():
return __get_reader__().reset()
reader.eof = eof
reader.reset = reset
reader.stop_gradient = True
reader.persistable = True
......@@ -283,7 +278,42 @@ def _copy_reader_create_op_(block, op):
return new_op
def open_recordio_file(filename, shapes, lod_levels, dtypes):
def open_recordio_file(filename,
shapes,
lod_levels,
dtypes,
pass_num=1,
for_parallel=False):
"""
Open a RecordIO file
This layer takes a RecordIO file to read from and returns a Reader Variable.
Via the Reader Variable, we can get data from the given RecordIO file.
Args:
filename(str): The RecordIO file's name.
shapes(list): List of tuples which declaring data shapes.
lod_levels(list): List of ints which declaring data lod_level.
dtypes(list): List of strs which declaring data type.
pass_num(int): Number of passes to run.
for_parallel(Bool): Set it as True if you are going to run
subsequent operators in parallel.
Returns:
Variable: A Reader Variable via which we can get RecordIO file data.
Examples:
.. code-block:: python
reader = fluid.layers.io.open_recordio_file(
filename='./data.recordio',
shapes=[(3,224,224), (1)],
lod_levels=[0, 0],
dtypes=['float32', 'int64'])
# Via the reader, we can use 'read_file' layer to get data:
image, label = fluid.layers.read_file(reader)
"""
dtypes = [convert_np_dtype_to_dtype_(dt) for dt in dtypes]
shape_concat = []
ranks = []
......@@ -310,10 +340,63 @@ def open_recordio_file(filename, shapes, lod_levels, dtypes):
startup_var.persistable = True
main_prog_var = _copy_reader_var_(default_main_program().current_block(),
startup_var)
if pass_num > 1:
main_prog_var = multi_pass(reader=main_prog_var, pass_num=pass_num)
if for_parallel:
main_prog_var = parallel(reader=main_prog_var)
return monkey_patch_reader_methods(main_prog_var)
def open_files(filenames, thread_num, shapes, lod_levels, dtypes):
def open_files(filenames,
shapes,
lod_levels,
dtypes,
thread_num,
buffer_size=None,
pass_num=1,
for_parallel=False):
"""
Open files
This layer takes a list of files to read from and returns a Reader Variable.
Via the Reader Variable, we can get data from given files. All files must
have name suffixs to indicate their formats, e.g., '*.recordio'.
Args:
filenames(list): The list of file names.
shapes(list): List of tuples which declaring data shapes.
lod_levels(list): List of ints which declaring data lod_level.
dtypes(list): List of strs which declaring data type.
thread_num(int): The maximal concurrent prefetch thread number.
buffer_size(int): The size of prefetch buffer.
pass_num(int): Number of passes to run.
for_parallel(Bool): Set it as True if you are going to run
subsequent operators in parallel.
Returns:
Variable: A Reader Variable via which we can get file data.
Examples:
.. code-block:: python
reader = fluid.layers.io.open_files(filenames=['./data1.recordio',
'./data2.recordio'],
shapes=[(3,224,224), (1)],
lod_levels=[0, 0],
dtypes=['float32', 'int64'],
thread_num=2,
buffer_size=2)
# Via the reader, we can use 'read_file' layer to get data:
image, label = fluid.layers.io.read_file(reader)
"""
if buffer_size is None:
buffer_size = thread_num
if isinstance(filenames, basestring):
filenames = [filenames]
dtypes = [convert_np_dtype_to_dtype_(dt) for dt in dtypes]
shape_concat = []
ranks = []
......@@ -322,29 +405,36 @@ def open_files(filenames, thread_num, shapes, lod_levels, dtypes):
shape_concat.extend(shape)
ranks.append(len(shape))
var_name = unique_name('multiple_reader')
multi_file_reader_name = unique_name('multi_file_reader')
startup_blk = default_startup_program().current_block()
startup_var = startup_blk.create_var(name=var_name)
startup_reader = startup_blk.create_var(name=multi_file_reader_name)
startup_blk.append_op(
type='open_files',
outputs={'Out': [startup_var]},
outputs={'Out': [startup_reader]},
attrs={
'shape_concat': shape_concat,
'lod_levels': lod_levels,
'ranks': ranks,
'file_names': filenames,
'thread_num': thread_num
'thread_num': thread_num,
'buffer_size': buffer_size
})
startup_var.desc.set_dtypes(dtypes)
startup_var.persistable = True
main_prog_var = _copy_reader_var_(default_main_program().current_block(),
startup_var)
return monkey_patch_reader_methods(main_prog_var)
startup_reader.desc.set_dtypes(dtypes)
startup_reader.persistable = True
main_prog_reader = _copy_reader_var_(default_main_program().current_block(),
startup_reader)
if pass_num > 1:
main_prog_reader = multi_pass(
reader=main_prog_reader, pass_num=pass_num)
if for_parallel:
main_prog_reader = parallel(reader=main_prog_reader)
return monkey_patch_reader_methods(main_prog_reader)
def __create_decorated_reader__(op_type, reader, attrs):
def __create_shared_decorated_reader__(op_type, reader, attrs):
var_name = unique_name(op_type)
startup_blk = default_startup_program().current_block()
startup_var = startup_blk.create_var(name=var_name)
......@@ -360,22 +450,41 @@ def __create_decorated_reader__(op_type, reader, attrs):
return monkey_patch_reader_methods(main_prog_var)
def create_shuffle_reader(reader, buffer_size):
return __create_decorated_reader__('create_shuffle_reader', reader,
{'buffer_size': int(buffer_size)})
def __create_unshared_decorated_reader__(op_type, reader, attrs):
new_reader_name = unique_name(op_type)
main_blk = default_main_program().current_block()
new_reader = main_blk.create_var(name=new_reader_name)
main_blk.append_op(
type=op_type,
inputs={'UnderlyingReader': reader},
outputs={'Out': [new_reader]},
attrs=attrs)
new_reader.persistable = True
new_reader.stop_gradient = True
return monkey_patch_reader_methods(new_reader)
def shuffle(reader, buffer_size):
return __create_unshared_decorated_reader__(
'create_shuffle_reader', reader, {'buffer_size': int(buffer_size)})
def create_double_buffer_reader(reader, place=None):
def double_buffer(reader, place=None):
attrs = dict()
if place is not None:
attrs['place'] = str(place).upper()
return __create_decorated_reader__('create_double_buffer_reader', reader,
attrs)
return __create_unshared_decorated_reader__('create_double_buffer_reader',
reader, attrs)
def multi_pass(reader, pass_num):
return __create_shared_decorated_reader__(
'create_multi_pass_reader', reader, {'pass_num': int(pass_num)})
def create_multi_pass_reader(reader, pass_num):
return __create_decorated_reader__('create_multi_pass_reader', reader,
{'pass_num': int(pass_num)})
def parallel(reader):
return __create_shared_decorated_reader__('create_threaded_reader', reader,
{})
def read_file(file_obj):
......
......@@ -15,12 +15,13 @@
All layers just related to metric.
"""
import warnings
from ..layer_helper import LayerHelper
from ..initializer import Normal, Constant
from ..framework import Variable
from ..param_attr import ParamAttr
__all__ = ['accuracy']
__all__ = ['accuracy', 'auc']
def accuracy(input, label, k=1, correct=None, total=None):
......@@ -55,3 +56,37 @@ def accuracy(input, label, k=1, correct=None, total=None):
"Total": [total],
})
return acc_out
def auc(input, label, curve='ROC', num_thresholds=200):
warnings.warn(
"This interface not recommended, fluid.layers.auc compute the auc at every minibatch, \
but can not aggregate them and get the pass AUC, because pass \
auc can not be averaged with weighted from the minibatch auc value. \
Please use fluid.metrics.Auc, it can compute the auc value via Python natively, \
which can get every minibatch and every pass auc value.", Warning)
helper = LayerHelper("auc", **locals())
topk_out = helper.create_tmp_variable(dtype=input.dtype)
topk_indices = helper.create_tmp_variable(dtype="int64")
helper.append_op(
type="top_k",
inputs={"X": [input]},
outputs={"Out": [topk_out],
"Indices": [topk_indices]},
attrs={"k": k})
auc_out = helper.create_tmp_variable(dtype="float32")
if correct is None:
correct = helper.create_tmp_variable(dtype="int64")
if total is None:
total = helper.create_tmp_variable(dtype="int64")
helper.append_op(
type="accuracy",
inputs={
"Out": [topk_out],
"Indices": [topk_indices],
"Label": [label]
},
attrs={"curve": curve,
"num_thresholds": num_thresholds},
outputs={"AUC": [auc_out], })
return auc_out
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Fluid Metrics
The metrics are accomplished via Python natively.
"""
import numpy as np
import copy
import warnings
__all__ = [
'MetricBase',
'CompositeMetric',
'Accuracy',
'ChunkEvaluator',
'EditDistance',
'DetectionMAP',
'Auc',
]
def _is_numpy_(var):
return isinstance(var, (np.ndarray, np.generic))
def _is_number_(var):
return isinstance(var, int) or isinstance(var, float) or (isinstance(
var, np.ndarray) and var.shape == (1, ))
def _is_number_or_matrix_(var):
return _is_number_(var) or isinstance(var, np.ndarray)
class MetricBase(object):
"""
Base Class for all evaluators
Args:
name(str): The name of evaluator. such as, "accuracy". Used for generate
temporary variable name.
Interface:
Note(*) : the states is the attributes who not has _ prefix.
get_config(): print current states and configuration
reset(): clear the states. If the Metrics states type is not (int, float, np.ndarray),
Please override this method.
update(): update states at every minibatch
eval(): get metric evaluation in numpy type.
"""
def __init__(self, name, **kwargs):
self._name = str(name) if name != None else self.__class__.__name__
self._kwargs = kwargs if kwargs != None else dict()
self.reset()
def __str__(self):
return self._name
def reset(self):
"""
states is the attributes who not has _ prefix.
reset the states of metrics.
"""
states = {
attr: value
for attr, value in self.__dict__.iteritems()
if not attr.startswith("_")
}
for attr, value in states.iteritems():
if isinstance(value, int):
setattr(self, attr, 0)
elif isinstance(value, float):
setattr(self, attr, .0)
elif isinstance(value, (np.ndarray, np.generic)):
setattr(self, attr, np.zeros_like(value))
else:
setattr(self, attr, None)
def get_config(self):
states = {
attr: value
for attr, value in self.__dict__.iteritems()
if not attr.startswith("_")
}
config = copy.deepcopy(self._kwargs)
config.update({"name": self._name, "states": copy.deepcopy(states)})
return config
def update(self):
raise NotImplementedError()
def eval(self):
raise NotImplementedError()
class CompositeMetric(MetricBase):
"""
Compute multiple metrics in each minibatch.
for example, merge F1, accuracy, recall into one Metric.
"""
def __init__(self, name=None, **kwargs):
super(CompositeMetric, self).__init__(name, kwargs)
self._metrics = []
def add_metric(self, metric):
if not isinstance(metric, MetricBase):
raise ValueError("SubMetric should be inherit from MetricBase.")
self._metrics.append(metric)
def eval(self):
ans = []
for m in self._metrics:
ans.append(m.eval())
return ans
class Accuracy(MetricBase):
"""
Accumulate the accuracy from minibatches and compute the average accuracy
for every pass.
Args:
name: the metrics name
Example:
minibatch_accuracy = fluid.layers.accuracy(pred, label)
accuracy_evaluator = fluid.metrics.Accuracy()
for epoch in PASS_NUM:
accuracy_evaluator.reset()
for data in batches:
loss = exe.run(fetch_list=[cost, minibatch_accuracy])
accuracy_evaluator.update(value=minibatch_accuracy, weight=batches)
accuracy = accuracy_evaluator.eval()
"""
def __init__(self, name=None):
super(Accuracy, self).__init__(name)
self.value = .0
self.weight = .0
def update(self, value, weight):
if not _is_number_or_matrix_(value):
raise ValueError(
"The 'value' must be a number(int, float) or a numpy ndarray.")
if not _is_number_(weight):
raise ValueError("The 'weight' must be a number(int, float).")
self.value += value * weight
self.weight += weight
def eval(self):
if self.weight == 0:
raise ValueError(
"There is no data in Accuracy Metrics. Please check layers.accuracy output has added to Accuracy."
)
return self.value / self.weight
class ChunkEvalutor(MetricBase):
"""
Accumulate counter numbers output by chunk_eval from mini-batches and
compute the precision recall and F1-score using the accumulated counter
numbers.
"""
def __init__(self, name=None):
super(ChunkEvalutor, self).__init__(name)
self.num_infer_chunks = 0
self.num_label_chunks = 0
self.num_correct_chunks = 0
def update(self, num_infer_chunks, num_label_chunks, num_correct_chunks):
if not _is_number_or_matrix_(num_infer_chunks):
raise ValueError(
"The 'num_infer_chunks' must be a number(int, float) or a numpy ndarray."
)
if not _is_number_or_matrix_(num_label_chunks):
raise ValueError(
"The 'num_label_chunks' must be a number(int, float) or a numpy ndarray."
)
if not _is_number_or_matrix_(num_correct_chunks):
raise ValueError(
"The 'num_correct_chunks' must be a number(int, float) or a numpy ndarray."
)
self.num_infer_chunks += num_infer_chunks
self.num_label_chunks += num_label_chunks
self.num_correct_chunks += num_correct_chunks
def eval(self):
precision = float(
self.num_correct_chunks
) / self.num_infer_chunks if self.num_infer_chunks else 0
recall = float(self.num_correct_chunks
) / self.num_label_chunks if self.num_label_chunks else 0
f1_score = float(2 * precision * recall) / (
precision + recall) if self.num_correct_chunks else 0
return precision, recall, f1_score
class EditDistance(MetricBase):
"""
Accumulate edit distance sum and sequence number from mini-batches and
compute the average edit_distance and instance error of all batches.
Args:
name: the metrics name
Example:
edit_distance_metrics = fluid.layers.edit_distance(input, label)
distance_evaluator = fluid.metrics.EditDistance()
for epoch in PASS_NUM:
distance_evaluator.reset()
for data in batches:
loss = exe.run(fetch_list=[cost] + list(edit_distance_metrics))
distance_evaluator.update(*edit_distance_metrics)
distance, instance_error = distance_evaluator.eval()
In the above example:
'distance' is the average of the edit distance in a pass.
'instance_error' is the instance error rate in a pass.
"""
def __init__(self, name):
super(EditDistance, self).__init__(name)
self.total_distance = .0
self.seq_num = 0
self.instance_error = 0
def update(self, distances, seq_num):
if not _is_numpy_(distances):
raise ValueError("The 'distances' must be a numpy ndarray.")
if not _is_number_(seq_num):
raise ValueError("The 'seq_num' must be a number(int, float).")
seq_right_count = np.sum(distances == 0)
total_distance = np.sum(distances)
self.seq_num += seq_num
self.instance_error += seq_num - seq_right_count
self.total_distance += total_distance
def eval():
if self.seq_num == 0:
raise ValueError(
"There is no data in EditDistance Metric. Please check layers.edit_distance output has been added to EditDistance."
)
avg_distance = self.total_distance / self.seq_num
avg_instance_error = self.instance_error / self.seq_num
return avg_distance, avg_instance_error
class DetectionMAP(MetricBase):
"""
Calculate the detection mean average precision (mAP).
TODO (Dang Qingqing): update the following doc.
The general steps are as follows:
1. calculate the true positive and false positive according to the input
of detection and labels.
2. calculate mAP value, support two versions: '11 point' and 'integral'.
Please get more information from the following articles:
https://sanchom.wordpress.com/tag/average-precision/
https://arxiv.org/abs/1512.02325
"""
def __init__(self, name=None):
super(DetectionMAP, self).__init__(name)
# the current map value
self.value = .0
def update(self, value, weight):
if not _is_number_or_matrix_(value):
raise ValueError(
"The 'value' must be a number(int, float) or a numpy ndarray.")
if not _is_number_(weight):
raise ValueError("The 'weight' must be a number(int, float).")
self.value += value
self.weight += weight
def eval(self):
if self.weight == 0:
raise ValueError(
"There is no data in DetectionMAP Metrics. "
"Please check layers.detection_map output has added to DetectionMAP."
)
return self.value / self.weight
class Auc(MetricBase):
"""
Auc Metrics which adapts to binary classification.
Need to note that auc metrics compute the value via Python natively.
If you concern the speed, please use the fluid.layers.auc instead.
The `auc` function creates four local variables, `true_positives`,
`true_negatives`, `false_positives` and `false_negatives` that are used to
compute the AUC. To discretize the AUC curve, a linearly spaced set of
thresholds is used to compute pairs of recall and precision values. The area
under the ROC-curve is therefore computed using the height of the recall
values by the false positive rate, while the area under the PR-curve is the
computed using the height of the precision values by the recall.
Args:
name: metric name
curve: Specifies the name of the curve to be computed, 'ROC' [default] or
'PR' for the Precision-Recall-curve.
num_thresholds: The number of thresholds to use when discretizing the roc
curve.
"NOTE: only implement the ROC curve type via Python now."
"""
def __init__(self, name, curve='ROC', num_thresholds=200):
super(MetricBase, self).__init__(name, curve, num_thresholds)
self._curve = curve
self._num_thresholds = num_thresholds
self._epsilon = 1e-6
self.tp_list = np.ndarray((num_thresholds, ))
self.fn_list = np.ndarray((num_thresholds, ))
self.tn_list = np.ndarray((num_thresholds, ))
self.fp_list = np.ndarray((num_thresholds, ))
def update(self, labels, predictions, axis=1):
if not _is_numpy_(labels):
raise ValueError("The 'labels' must be a numpy ndarray.")
if not _is_numpy_(predictions):
raise ValueError("The 'predictions' must be a numpy ndarray.")
kepsilon = 1e-7 # to account for floating point imprecisions
thresholds = [(i + 1) * 1.0 / (num_thresholds - 1)
for i in range(num_thresholds - 2)]
thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon]
# caculate TP, FN, TN, FP count
for idx_thresh, thresh in enumerate(thresholds):
tp, fn, tn, fp = 0, 0, 0, 0
for i, lbl in enumerate(labels):
if lbl:
if predictions[i, 0] >= thresh:
tp += 1
else:
fn += 1
else:
if predictions[i, 0] >= thresh:
fp += 1
else:
tn += 1
tp_list[idx_thresh] += tp
fn_list[idx_thresh] += fn
tn_list[idx_thresh] += tn
fp_list[idx_thresh] += fp
def eval(self):
epsilon = self._epsilon
num_thresholds = self._num_thresholds
tpr = (tp_list.astype("float32") + epsilon) / (
tp_list + fn_list + epsilon)
fpr = fp_list.astype("float32") / (fp_list + tn_list + epsilon)
rec = (tp_list.astype("float32") + epsilon) / (
tp_list + fp_list + epsilon)
x = fpr[:num_thresholds - 1] - fpr[1:]
y = (tpr[:num_thresholds - 1] + tpr[1:]) / 2.0
auc_value = np.sum(x * y)
return auc_value
......@@ -100,9 +100,11 @@ class ParallelExecutor(object):
local_scopes = share_vars_from.executor.local_scopes(
) if share_vars_from else []
persistable_vars = [
self.persistable_vars = [
v.name
for v in filter(lambda var: var.persistable, main.list_vars())
for v in filter(lambda var: \
var.persistable and var.type != core.VarDesc.VarType.RAW,
main.list_vars())
]
self.executor = core.ParallelExecutor(
......@@ -113,7 +115,7 @@ class ParallelExecutor(object):
p.name for p in main.global_block().iter_parameters()
if not p.stop_gradient
]),
set(persistable_vars),
set(self.persistable_vars),
main.desc,
loss_name if loss_name else '',
scope,
......@@ -143,3 +145,6 @@ class ParallelExecutor(object):
self.executor.run(fetch_list, fetch_var_name, feed_tensor_dict)
arr = self.scope.find_var(fetch_var_name).get_lod_tensor_array()
return [arr[i] for i in range(len(arr))]
def bcast_params(self):
self.executor.bcast_params(set(self.persistable_vars))
......@@ -61,8 +61,12 @@ class TestMultipleReader(unittest.TestCase):
exe.run(fluid.default_startup_program())
batch_count = 0
while not data_files.eof():
img_val, = exe.run(fetch_list=[img])
while True:
try:
img_val, = exe.run(fetch_list=[img])
except fluid.core.EnforceNotMet as ex:
self.assertIn("There is no next data.", ex.message)
break
batch_count += 1
self.assertLessEqual(img_val.shape[0], self.batch_size)
data_files.reset()
......
......@@ -44,7 +44,7 @@ class TestMultipleReader(unittest.TestCase):
shapes=[(-1, 784), (-1, 1)],
lod_levels=[0, 0],
dtypes=['float32', 'int64'])
data_file = fluid.layers.create_multi_pass_reader(
data_file = fluid.layers.io.multi_pass(
reader=data_file, pass_num=self.pass_num)
img, label = fluid.layers.read_file(data_file)
......@@ -57,8 +57,12 @@ class TestMultipleReader(unittest.TestCase):
exe.run(fluid.default_startup_program())
batch_count = 0
while not data_file.eof():
img_val, = exe.run(fetch_list=[img])
while True:
try:
img_val, = exe.run(fetch_list=[img])
except fluid.core.EnforceNotMet as ex:
self.assertIn("There is no next data.", ex.message)
break
batch_count += 1
self.assertLessEqual(img_val.shape[0], self.batch_size)
data_file.reset()
......
......@@ -26,11 +26,14 @@ def simple_fc_net(use_feed):
img = fluid.layers.data(name='image', shape=[784], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
else:
reader = fluid.layers.open_recordio_file(
filename='./mnist.recordio',
reader = fluid.layers.open_files(
filenames=['./mnist.recordio'],
shapes=[[-1, 784], [-1, 1]],
lod_levels=[0, 0],
dtypes=['float32', 'int64'])
dtypes=['float32', 'int64'],
thread_num=1,
for_parallel=True)
reader = fluid.layers.io.double_buffer(reader)
img, label = fluid.layers.read_file(reader)
hidden = img
for _ in xrange(4):
......@@ -51,11 +54,14 @@ def fc_with_batchnorm(use_feed):
img = fluid.layers.data(name='image', shape=[784], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
else:
reader = fluid.layers.open_recordio_file(
filename='./mnist.recordio',
reader = fluid.layers.open_files(
filenames=['mnist.recordio'],
shapes=[[-1, 784], [-1, 1]],
lod_levels=[0, 0],
dtypes=['float32', 'int64'])
dtypes=['float32', 'int64'],
thread_num=1,
for_parallel=True)
reader = fluid.layers.io.double_buffer(reader)
img, label = fluid.layers.read_file(reader)
hidden = img
......
......@@ -65,8 +65,13 @@ class TestRecordIO(unittest.TestCase):
# train a pass
batch_id = 0
while not data_file.eof():
tmp, = exe.run(fetch_list=[avg_loss])
while True:
try:
tmp, = exe.run(fetch_list=[avg_loss])
except fluid.core.EnforceNotMet as ex:
self.assertIn("There is no next data.", ex.message)
break
avg_loss_np.append(tmp)
batch_id += 1
data_file.reset()
......@@ -74,8 +79,8 @@ class TestRecordIO(unittest.TestCase):
self.assertLess(avg_loss_np[-1], avg_loss_np[0])
def test_shuffle_reader(self):
self.test_main(decorator_callback=lambda reader: fluid.layers.create_shuffle_reader(reader, buffer_size=200))
self.test_main(decorator_callback=lambda reader: fluid.layers.io.shuffle(reader, buffer_size=200))
def test_double_buffer_reader(self):
self.test_main(decorator_callback=lambda reader: fluid.layers.create_double_buffer_reader(reader,
self.test_main(decorator_callback=lambda reader: fluid.layers.io.double_buffer(reader,
place='cuda:0' if fluid.core.is_compiled_with_cuda() else 'cpu'))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册