提交 b9c28df9 编写于 作者: T typhoonzero

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into multigpumultinode

......@@ -36,7 +36,8 @@ MESSAGE(STATUS "Set ${MKLDNN_INSTALL_DIR}/lib to runtime path")
SET(CMAKE_INSTALL_RPATH_USE_LINK_PATH TRUE)
SET(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_RPATH}" "${MKLDNN_INSTALL_DIR}/lib")
INCLUDE_DIRECTORIES(${MKLDNN_INC_DIR})
INCLUDE_DIRECTORIES(${MKLDNN_INC_DIR}) # For MKLDNN code to include internal headers.
INCLUDE_DIRECTORIES(${THIRD_PARTY_PATH}/install) # For Paddle code to include mkldnn.h
IF(${CBLAS_PROVIDER} STREQUAL "MKLML")
SET(MKLDNN_DEPENDS ${MKLML_PROJECT})
......
# Problem
# Kernel Hint Design
## Problem
In PaddlePaddle's [Design](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/switch_kernel.md), one Operator may have multiple kernels. Users may have some personal preference to choose a certain type of kernel for an operator, such as `force_cpu` to choose a CPU kernel, `use_cudnn` to choose a CUDNN kernel, we need to provide a way for users to do this.
In the current design, we use KernelType to describe one kernel.
......
# Background
# Kernel Selection
## Background
Every operator has many kernels because there are multiple data types, places, data layout, library type that Fluid supports. We use the `OpKernelType ` to describe kernel types that operators can hold.
The `OpKernelType ` is as follows:
......
Install and Build
=================
install and Compile
==========
.. _install_steps:
Install Steps
++++++++
PaddlePaddle provides various methods of installation for many different users
You can choose either pip or Docker to complete your install:
Focus on Deep Learning Model Development
-----------------
PaddlePaddle provides lots of packages of python wheel , that pip can install:
.. toctree::
:maxdepth: 1
:maxdepth: 1
pip_install_en.rst
docker_install_en.rst
pip_install_en.rst
Build from Source
-----------------
This is the most convenient way of installation. Please choose the right installation package with machine configure and system.
Follow the Bottom Frame
----------
PaddlePaddle also supports installation using Docker. Please refer to the tutorial below:
.. toctree::
:maxdepth: 1
docker_install_en.rst
.. warning::
We recommend running PaddlePaddle in Docker. This method has the following advantages:
We recommend to directly install via above installation steps, you'll only need to build PaddlePaddle from source when you need a modifed binary.
- Does not require installation of third-party dependencies.
- Easy to share runtime environment.
.. toctree::
Lastly, users can also compile and install PaddlePaddle from source code. The instructions are below:
.. toctree::
:maxdepth: 1
build_from_source_en.md
build_from_source_en.rst
.. warning::
One caveat with this approach is that developers will have to download, compile and install all third-party dependencies. Thus this process of installation is more time consuming.
FAQ
++++++++++
-----------
For any problems during installation, please refer to the page below for answers:
:ref:`常见问题解答 <install_faq>`
If the problem still persists, you are welcome to seek assistance from the PaddlePaddle community:
`FAQ <http://www.paddlepaddle.org/docs/develop/documentation/zh/faq/build_and_install/index_en.html>`_
`创建issue <https://github.com/PaddlePaddle/Paddle/issues/new>`_
......@@ -105,7 +105,7 @@ static void BuildVar(const std::string& param_name,
TEST(Operator, CPUtoGPU) {
using namespace paddle::framework;
using namespace paddle::platform;
InitDevices();
InitDevices(true);
paddle::framework::Scope scope;
paddle::platform::CPUPlace cpu_place;
......
......@@ -80,7 +80,11 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
auto graph = new SSAGraph();
SSAGraph &result = *graph;
std::unordered_set<std::string> og_has_been_broadcast;
result.vars_.resize(places_.size());
// We cannot invoke resize. It is a bug of GCC 4.8
result.vars_ = std::vector<
std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>>(
places_.size());
bool is_forwarding = true;
for (auto *op : program.Block(0).AllOps()) {
......@@ -180,15 +184,16 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
if (vars.empty()) { // This device has no data. continue.
continue;
}
auto *prev_grad = &vars[vars.size() - 1];
op_handle->AddInput(prev_grad);
auto &prev_grad = vars[vars.size() - 1];
op_handle->AddInput(prev_grad.get());
auto &var = vars[vars.size()];
var.place_ = p;
var.name_ = og;
var.version_ = vars.size() - 1;
vars.emplace_back(new VarHandle);
auto &var = vars.back();
var->place_ = p;
var->name_ = og;
var->version_ = vars.size() - 1;
op_handle->AddOutput(&var);
op_handle->AddOutput(var.get());
}
#else
PADDLE_ENFORCE("Not implemented");
......
......@@ -16,6 +16,8 @@
#include <map>
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/details/var_handle.h"
......@@ -24,7 +26,9 @@ namespace framework {
namespace details {
struct SSAGraph {
std::vector<std::unordered_map<std::string, std::map<int, VarHandle>>> vars_;
std::vector<
std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>>
vars_;
// aux variables to represent dependency. Useful to resolve data hazard.
std::unordered_set<std::unique_ptr<VarHandleBase>> dep_vars_;
std::vector<std::unique_ptr<OpHandleBase>> ops_;
......
......@@ -27,8 +27,8 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(SSAGraph *graph) {
auto it_old = name_pair.second.rbegin();
++it_old;
for (; it_old != name_pair.second.rend(); it_new = it_old, ++it_old) {
auto *write_op = it_new->second.generated_op_;
auto &read_ops = it_old->second.pending_ops_;
auto *write_op = (*it_new)->generated_op_;
auto &read_ops = (*it_old)->pending_ops_;
for (auto *read_op : read_ops) {
// Manually add a dependency var from read_op to write_op;
......@@ -54,14 +54,15 @@ VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle(
auto &var_holder = var_holders[each_var_name];
VarHandle *var = nullptr;
if (var_holder.empty()) {
var_holder.emplace_back(new VarHandle);
auto &init_var = var_holder[0];
init_var.place_ = place;
init_var.name_ = each_var_name;
init_var.generated_op_ = nullptr;
init_var.version_ = 0;
var = &init_var;
init_var->place_ = place;
init_var->name_ = each_var_name;
init_var->generated_op_ = nullptr;
init_var->version_ = 0;
var = init_var.get();
} else {
var = &var_holder.rbegin()->second;
var = var_holder.rbegin()->get();
}
return var;
}
......@@ -72,11 +73,12 @@ void SSAGraphBuilder::CreateOpOutput(SSAGraph *graph, OpHandleBase *op_handle,
size_t place_offset) {
auto &vars = graph->vars_[place_offset][each_var_name];
size_t version = vars.size();
auto &var = vars[version];
var.version_ = version;
var.name_ = each_var_name;
var.place_ = place;
op_handle->AddOutput(&var);
vars.emplace_back(new VarHandle());
auto &var = vars.back();
var->version_ = version;
var->name_ = each_var_name;
var->place_ = place;
op_handle->AddOutput(var.get());
}
template <typename Callback>
......@@ -84,7 +86,7 @@ void IterAllVar(const SSAGraph &graph, Callback callback) {
for (auto &each : graph.vars_) {
for (auto &pair1 : each) {
for (auto &pair2 : pair1.second) {
callback(pair2.second);
callback(*pair2);
}
}
}
......
......@@ -69,7 +69,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
for (auto &var_map : graph_->vars_) {
for (auto &name_pair : var_map) {
for (auto &version_pair : name_pair.second) {
InsertPendingVar(version_pair.second);
InsertPendingVar(*version_pair);
}
}
}
......@@ -95,7 +95,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
for (auto &var_map : graph_->vars_) {
auto it = var_map.find(fetch_var_name);
if (it != var_map.end()) {
fetched_vars[fetch_var_name].push_back(&it->second.rbegin()->second);
fetched_vars[fetch_var_name].push_back(it->second.rbegin()->get());
}
}
}
......
......@@ -64,7 +64,7 @@ void InitP2P(int count) {
#endif
}
void InitDevices() {
void InitDevices(bool init_p2p) {
/*Init all avaiable devices by default */
std::vector<platform::Place> places;
......@@ -85,7 +85,9 @@ void InitDevices() {
for (int i = 0; i < count; ++i) {
places.emplace_back(platform::CUDAPlace(i));
}
InitP2P(count);
if (init_p2p) {
InitP2P(count);
}
platform::DeviceContextPool::Init(places);
}
......
......@@ -24,7 +24,7 @@ void InitGflags(std::vector<std::string> &argv);
void InitGLOG(const std::string &prog_name);
void InitDevices();
void InitDevices(bool init_p2p);
} // namespace framework
} // namespace paddle
......@@ -21,7 +21,7 @@ TEST(InitDevices, CPU) {
using paddle::platform::DeviceContextPool;
#ifndef PADDLE_WITH_CUDA
InitDevices();
InitDevices(true);
DeviceContextPool& pool = DeviceContextPool::Instance();
ASSERT_EQ(pool.size(), 1U);
#endif
......@@ -33,7 +33,7 @@ TEST(InitDevices, CUDA) {
#ifdef PADDLE_WITH_CUDA
int count = paddle::platform::GetCUDADeviceCount();
InitDevices();
InitDevices(true);
DeviceContextPool& pool = DeviceContextPool::Instance();
ASSERT_EQ(pool.size(), 1U + static_cast<unsigned>(count));
#endif
......
......@@ -30,7 +30,7 @@ __global__ void test(size_t* a, int size) {
}
TEST(LoD, data) {
paddle::framework::InitDevices();
paddle::framework::InitDevices(true);
paddle::framework::LoD lod{{0, 1, 2}};
lod.push_back({0, 2, 4, 5});
......@@ -46,7 +46,7 @@ TEST(LoD, data) {
}
TEST(LoDTensor, LoDInGPU) {
paddle::framework::InitDevices();
paddle::framework::InitDevices(true);
paddle::framework::LoDTensor lod_tensor;
paddle::platform::CUDAPlace place(0);
......
......@@ -72,7 +72,7 @@ REGISTER_OP_WITHOUT_GRADIENT(test_operator,
paddle::framework::OpWithoutKernelCheckerMaker);
TEST(OperatorBase, all) {
paddle::framework::InitDevices();
paddle::framework::InitDevices(true);
paddle::framework::proto::OpDesc op_desc;
op_desc.set_type("test_operator");
BuildVar("input", {"IN1"}, op_desc.add_inputs());
......@@ -198,7 +198,7 @@ REGISTER_OP_CPU_KERNEL(op_with_kernel,
// test with single input
TEST(OpKernel, all) {
paddle::framework::InitDevices();
paddle::framework::InitDevices(true);
paddle::framework::proto::OpDesc op_desc;
op_desc.set_type("op_with_kernel");
BuildVar("x", {"IN1"}, op_desc.add_inputs());
......@@ -228,7 +228,7 @@ REGISTER_OP_CPU_KERNEL(op_multi_inputs_with_kernel,
TEST(OpKernel, multi_inputs) {
using namespace paddle::framework;
paddle::framework::InitDevices();
paddle::framework::InitDevices(true);
proto::OpDesc op_desc;
op_desc.set_type("op_multi_inputs_with_kernel");
......@@ -269,7 +269,7 @@ class OperatorClone : public paddle::framework::OperatorBase {
};
TEST(Operator, Clone) {
paddle::framework::InitDevices();
paddle::framework::InitDevices(true);
OperatorClone a("ABC", paddle::framework::VariableNameMap{},
paddle::framework::VariableNameMap{},
paddle::framework::AttributeMap{});
......
......@@ -85,9 +85,9 @@ ProgramDesc::ProgramDesc(const std::string &binary_str) {
}
const std::vector<std::string> ProgramDesc::GetFeedTargetNames() {
BlockDesc *global_block = blocks_[0].get();
auto &global_block = Block(0);
std::vector<std::string> feed_target_names;
for (auto *op : global_block->AllOps()) {
for (auto *op : global_block.AllOps()) {
if (op->Type() == kFeedOpType) {
feed_target_names.insert(feed_target_names.begin(), op->Output("Out")[0]);
}
......@@ -96,9 +96,9 @@ const std::vector<std::string> ProgramDesc::GetFeedTargetNames() {
}
const std::vector<std::string> ProgramDesc::GetFetchTargetNames() {
BlockDesc *global_block = blocks_[0].get();
auto &global_block = Block(0);
std::vector<std::string> fetch_target_names;
for (auto *op : global_block->AllOps()) {
for (auto *op : global_block.AllOps()) {
if (op->Type() == kFetchOpType) {
fetch_target_names.push_back(op->Input("X")[0]);
}
......@@ -106,5 +106,43 @@ const std::vector<std::string> ProgramDesc::GetFetchTargetNames() {
return fetch_target_names;
}
void ProgramDesc::SetFeedHolderName(const std::string &feed_holder_name) {
auto *global_block = MutableBlock(0);
int index = 0;
for (auto *op : global_block->AllOps()) {
if (op->Type() == kFeedOpType) {
// Unify the input's name of all feed_ops to feed_holder_name
global_block->RemoveVar(op->Input("X")[0]);
op->SetInput("X", {feed_holder_name});
op->SetAttr("col", {index});
op->CheckAttrs();
index++;
}
}
auto *feed_holder = global_block->Var(feed_holder_name);
feed_holder->SetType(proto::VarType::FEED_MINIBATCH);
feed_holder->SetPersistable(true);
}
void ProgramDesc::SetFetchHolderName(const std::string &fetch_holder_name) {
auto *global_block = MutableBlock(0);
int index = 0;
for (auto *op : global_block->AllOps()) {
if (op->Type() == kFetchOpType) {
// Unify the output's name of all fetch_ops to fetch_holder_name
global_block->RemoveVar(op->Output("Out")[0]);
op->SetOutput("Out", {fetch_holder_name});
op->SetAttr("col", {index});
op->CheckAttrs();
index++;
}
}
auto *fetch_holder = global_block->Var(fetch_holder_name);
fetch_holder->SetType(proto::VarType::FETCH_LIST);
fetch_holder->SetPersistable(true);
}
} // namespace framework
} // namespace paddle
......@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/framework.pb.h"
......@@ -52,9 +53,26 @@ class ProgramDesc {
proto::ProgramDesc *Proto();
// The output variable of feed_op is referenced as feed_target.
// This function is used to collect the output variable's name of all
// feed_ops.
const std::vector<std::string> GetFeedTargetNames();
// The input variable of fetch_op is referenced as fetch_target.
// This function is used to collect the input variable's name of all
// fetch_ops.
const std::vector<std::string> GetFetchTargetNames();
// The input variable of feed_op that holds input Tensor provided by users is
// referenced as feed_holder.
// This function is used to change or unify the feed_holder variables' name.
void SetFeedHolderName(const std::string &feed_holder_name);
// The output variable of fetch_op that holds output Tensor needed by users is
// referenced as fetch_holder.
// This function is used to change or unify the fetch_holder variables' name.
void SetFetchHolderName(const std::string &fetch_holder_name);
private:
proto::ProgramDesc desc_;
......
......@@ -12,6 +12,7 @@ limitations under the License. */
#include "gflags/gflags.h"
#include "gtest/gtest.h"
#include "paddle/fluid/inference/tests/test_helper.h"
#include "paddle/fluid/inference/tests/test_multi_thread_helper.h"
DEFINE_string(dirname, "", "Directory of the inference model.");
......@@ -26,32 +27,63 @@ TEST(inference, fit_a_line) {
// 0. Call `paddle::framework::InitDevices()` initialize all the devices
// In unittests, this is done in paddle/testing/paddle_gtest_main.cc
paddle::framework::LoDTensor input;
// The second dim of the input tensor should be 13
// The input data should be >= 0
int64_t batch_size = 10;
SetupTensor<float>(&input, {batch_size, 13}, static_cast<float>(0),
static_cast<float>(10));
std::vector<paddle::framework::LoDTensor*> cpu_feeds;
cpu_feeds.push_back(&input);
for (int num_threads : {1, 2}) {
std::vector<std::vector<paddle::framework::LoDTensor*>> cpu_feeds;
cpu_feeds.resize(num_threads);
for (int i = 0; i < num_threads; ++i) {
auto* input = new paddle::framework::LoDTensor();
// The second dim of the input tensor should be 13
// The input data should be >= 0
int64_t batch_size = 10;
SetupTensor<float>(input, {batch_size, 13}, static_cast<float>(0),
static_cast<float>(10));
cpu_feeds[i].push_back(input);
}
paddle::framework::LoDTensor output1;
std::vector<paddle::framework::LoDTensor*> cpu_fetchs1;
cpu_fetchs1.push_back(&output1);
std::vector<std::vector<paddle::framework::LoDTensor*>> cpu_fetchs1;
cpu_fetchs1.resize(num_threads);
for (int i = 0; i < num_threads; ++i) {
auto* output = new paddle::framework::LoDTensor();
cpu_fetchs1[i].push_back(output);
}
// Run inference on CPU
TestInference<paddle::platform::CPUPlace>(dirname, cpu_feeds, cpu_fetchs1);
LOG(INFO) << output1.dims();
// Run inference on CPU
LOG(INFO) << "--- CPU Runs (num_threads: " << num_threads << "): ---";
if (num_threads == 1) {
TestInference<paddle::platform::CPUPlace>(dirname, cpu_feeds[0],
cpu_fetchs1[0]);
} else {
TestMultiThreadInference<paddle::platform::CPUPlace>(
dirname, cpu_feeds, cpu_fetchs1, num_threads);
}
#ifdef PADDLE_WITH_CUDA
paddle::framework::LoDTensor output2;
std::vector<paddle::framework::LoDTensor*> cpu_fetchs2;
cpu_fetchs2.push_back(&output2);
std::vector<std::vector<paddle::framework::LoDTensor*>> cpu_fetchs2;
cpu_fetchs2.resize(num_threads);
for (int i = 0; i < num_threads; ++i) {
auto* output = new paddle::framework::LoDTensor();
cpu_fetchs2[i].push_back(output);
}
// Run inference on CUDA GPU
TestInference<paddle::platform::CUDAPlace>(dirname, cpu_feeds, cpu_fetchs2);
LOG(INFO) << output2.dims();
// Run inference on CUDA GPU
LOG(INFO) << "--- GPU Runs (num_threads: " << num_threads << "): ---";
if (num_threads == 1) {
TestInference<paddle::platform::CUDAPlace>(dirname, cpu_feeds[0],
cpu_fetchs2[0]);
} else {
TestMultiThreadInference<paddle::platform::CUDAPlace>(
dirname, cpu_feeds, cpu_fetchs2, num_threads);
}
CheckError<float>(output1, output2);
for (int i = 0; i < num_threads; ++i) {
CheckError<float>(*cpu_fetchs1[i][0], *cpu_fetchs2[i][0]);
delete cpu_fetchs2[i][0];
}
#endif
for (int i = 0; i < num_threads; ++i) {
delete cpu_feeds[i][0];
delete cpu_fetchs1[i][0];
}
} // num_threads-loop
}
......@@ -25,7 +25,8 @@ limitations under the License. */
template <typename T>
void SetupTensor(paddle::framework::LoDTensor* input,
paddle::framework::DDim dims, T lower, T upper) {
std::mt19937 rng(100); // An arbitrarily chosen but fixed seed.
static unsigned int seed = 100;
std::mt19937 rng(seed++);
std::uniform_real_distribution<double> uniform_dist(0, 1);
T* input_ptr = input->mutable_data<T>(dims, paddle::platform::CPUPlace());
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <map>
#include <string>
#include <thread> // NOLINT
#include <vector>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/inference/io.h"
void ThreadedRunInference(
const std::unique_ptr<paddle::framework::ProgramDesc>& inference_program,
paddle::framework::Executor* executor, paddle::framework::Scope* scope,
const int thread_id,
const std::vector<paddle::framework::LoDTensor*>& cpu_feeds,
const std::vector<paddle::framework::LoDTensor*>& cpu_fetchs) {
auto copy_program = std::unique_ptr<paddle::framework::ProgramDesc>(
new paddle::framework::ProgramDesc(*inference_program));
std::string feed_holder_name = "feed_" + paddle::string::to_string(thread_id);
std::string fetch_holder_name =
"fetch_" + paddle::string::to_string(thread_id);
copy_program->SetFeedHolderName(feed_holder_name);
copy_program->SetFetchHolderName(fetch_holder_name);
// 3. Get the feed_target_names and fetch_target_names
const std::vector<std::string>& feed_target_names =
copy_program->GetFeedTargetNames();
const std::vector<std::string>& fetch_target_names =
copy_program->GetFetchTargetNames();
// 4. Prepare inputs: set up maps for feed targets
std::map<std::string, const paddle::framework::LoDTensor*> feed_targets;
for (size_t i = 0; i < feed_target_names.size(); ++i) {
// Please make sure that cpu_feeds[i] is right for feed_target_names[i]
feed_targets[feed_target_names[i]] = cpu_feeds[i];
}
// 5. Define Tensor to get the outputs: set up maps for fetch targets
std::map<std::string, paddle::framework::LoDTensor*> fetch_targets;
for (size_t i = 0; i < fetch_target_names.size(); ++i) {
fetch_targets[fetch_target_names[i]] = cpu_fetchs[i];
}
// 6. Run the inference program
executor->Run(*copy_program, scope, feed_targets, fetch_targets, true,
feed_holder_name, fetch_holder_name);
}
template <typename Place>
void TestMultiThreadInference(
const std::string& dirname,
const std::vector<std::vector<paddle::framework::LoDTensor*>>& cpu_feeds,
const std::vector<std::vector<paddle::framework::LoDTensor*>>& cpu_fetchs,
const int num_threads) {
// 1. Define place, executor, scope
auto place = Place();
auto executor = paddle::framework::Executor(place);
auto* scope = new paddle::framework::Scope();
// 2. Initialize the inference_program and load parameters
std::unique_ptr<paddle::framework::ProgramDesc> inference_program =
paddle::inference::Load(executor, *scope, dirname);
std::vector<std::thread*> threads;
for (int i = 0; i < num_threads; ++i) {
threads.push_back(new std::thread(
ThreadedRunInference, std::ref(inference_program), &executor, scope, i,
std::ref(cpu_feeds[i]), std::ref(cpu_fetchs[i])));
}
for (int i = 0; i < num_threads; ++i) {
threads[i]->join();
delete threads[i];
}
delete scope;
}
......@@ -662,14 +662,3 @@ REGISTER_OP(swish, ops::ActivationOp, ops::SwishOpMaker, swish_grad,
ops::grad_functor<double>>);
FOR_EACH_KERNEL_FUNCTOR(REGISTER_ACTIVATION_CPU_KERNEL);
REGISTER_OP_CPU_KERNEL(relu,
ops::ActivationKernel<paddle::platform::CPUDeviceContext,
ops::ReluFunctor<float>>,
ops::ActivationKernel<paddle::platform::CPUDeviceContext,
ops::ReluFunctor<double>>);
REGISTER_OP_CPU_KERNEL(
relu_grad, ops::ActivationGradKernel<paddle::platform::CPUDeviceContext,
ops::ReluGradFunctor<float>>,
ops::ActivationGradKernel<paddle::platform::CPUDeviceContext,
ops::ReluGradFunctor<double>>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
......@@ -17,31 +14,19 @@ limitations under the License. */
#include "paddle/fluid/platform/float16.h"
namespace ops = paddle::operators;
#define REGISTER_ACTIVATION_CUDA_KERNEL(act_type, functor, grad_functor) \
REGISTER_OP_CUDA_KERNEL( \
act_type, ops::ActivationKernel<paddle::platform::CUDADeviceContext, \
ops::functor<float>>, \
ops::ActivationKernel<paddle::platform::CUDADeviceContext, \
ops::functor<double>>); \
REGISTER_OP_CUDA_KERNEL( \
act_type##_grad, \
ops::ActivationGradKernel<paddle::platform::CUDADeviceContext, \
ops::grad_functor<float>>, \
ops::ActivationGradKernel<paddle::platform::CUDADeviceContext, \
namespace plat = paddle::platform;
#define REGISTER_ACTIVATION_CUDA_KERNEL(act_type, functor, grad_functor) \
REGISTER_OP_CUDA_KERNEL( \
act_type, \
ops::ActivationKernel<plat::CUDADeviceContext, ops::functor<float>>, \
ops::ActivationKernel<plat::CUDADeviceContext, ops::functor<double>>, \
ops::ActivationKernel<plat::CUDADeviceContext, \
ops::functor<plat::float16>>); \
REGISTER_OP_CUDA_KERNEL( \
act_type##_grad, ops::ActivationGradKernel<plat::CUDADeviceContext, \
ops::grad_functor<float>>, \
ops::ActivationGradKernel<plat::CUDADeviceContext, \
ops::grad_functor<double>>);
FOR_EACH_KERNEL_FUNCTOR(REGISTER_ACTIVATION_CUDA_KERNEL);
REGISTER_OP_CUDA_KERNEL(
relu, ops::ActivationKernel<paddle::platform::CUDADeviceContext,
ops::ReluFunctor<float>>,
ops::ActivationKernel<paddle::platform::CUDADeviceContext,
ops::ReluFunctor<double>>,
ops::ActivationKernel<paddle::platform::CUDADeviceContext,
ops::ReluFunctor<paddle::platform::float16>>);
REGISTER_OP_CUDA_KERNEL(
relu_grad, ops::ActivationGradKernel<paddle::platform::CUDADeviceContext,
ops::ReluGradFunctor<float>>,
ops::ActivationGradKernel<paddle::platform::CUDADeviceContext,
ops::ReluGradFunctor<double>>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
......@@ -15,9 +12,11 @@ limitations under the License. */
#pragma once
#include <utility>
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/platform/float16.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
......@@ -338,11 +337,25 @@ struct Sine {
HOSTDEVICE T operator()(const T& val) const { return sin(val); }
};
template <>
struct Sine<platform::float16> {
HOSTDEVICE platform::float16 operator()(const platform::float16& val) const {
return platform::float16(sin(static_cast<float>(val)));
}
};
template <typename T>
struct Cosine {
HOSTDEVICE T operator()(const T& val) const { return cos(val); }
};
template <>
struct Cosine<platform::float16> {
HOSTDEVICE platform::float16 operator()(const platform::float16& val) const {
return platform::float16(cos(static_cast<float>(val)));
}
};
// cosine'(x) = -sin(x)
template <typename T>
struct CosGradFunctor : public BaseActivationFunctor<T> {
......@@ -826,6 +839,7 @@ struct SwishGradFunctor : public BaseActivationFunctor<T> {
__macro(sigmoid, SigmoidFunctor, SigmoidGradFunctor); \
__macro(logsigmoid, LogSigmoidFunctor, LogSigmoidGradFunctor); \
__macro(exp, ExpFunctor, ExpGradFunctor); \
__macro(relu, ReluFunctor, ReluGradFunctor); \
__macro(tanh, TanhFunctor, TanhGradFunctor); \
__macro(softshrink, SoftShrinkFunctor, SoftShrinkGradFunctor); \
__macro(sqrt, SqrtFunctor, SqrtGradFunctor); \
......
......@@ -13,14 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/transform.h"
#ifdef __NVCC__
#include <cuda.h>
#include <thrust/iterator/iterator_adaptor.h>
#include "paddle/fluid/platform/cuda_helper.h"
constexpr int ELEMWISE_MAX_BLOCK_DIM = 1024;
#endif
......@@ -43,35 +44,35 @@ namespace operators {
*/
inline void get_mid_dims(const framework::DDim& x_dims,
const framework::DDim& y_dims, const int axis,
int& pre, int& n, int& post) {
pre = 1;
n = 1;
post = 1;
int* pre, int* n, int* post) {
*pre = 1;
*n = 1;
*post = 1;
for (int i = 0; i < axis; ++i) {
pre *= x_dims[i];
(*pre) *= x_dims[i];
}
for (int i = 0; i < y_dims.size(); ++i) {
PADDLE_ENFORCE_EQ(x_dims[i + axis], y_dims[i],
"Broadcast dimension mismatch.");
n *= y_dims[i];
(*n) *= y_dims[i];
}
for (int i = axis + y_dims.size(); i < x_dims.size(); ++i) {
post *= x_dims[i];
(*post) *= x_dims[i];
}
}
inline void trim_trailing_singular_dims(framework::DDim& dims) {
inline void trim_trailing_singular_dims(framework::DDim* dims) {
// Remove trailing dimensions of size 1 for y
auto actual_dims_size = dims.size();
auto actual_dims_size = dims->size();
for (; actual_dims_size != 0; --actual_dims_size) {
if (dims[actual_dims_size - 1] != 1) break;
if ((*dims)[actual_dims_size - 1] != 1) break;
}
if (actual_dims_size != dims.size()) {
auto actual_dims = framework::vectorize(dims);
if (actual_dims_size != dims->size()) {
auto actual_dims = framework::vectorize(*dims);
actual_dims.resize(actual_dims_size);
dims = framework::make_ddim(actual_dims);
*dims = framework::make_ddim(actual_dims);
}
}
......@@ -159,7 +160,7 @@ class RowwiseTransformIterator<T, platform::CUDADeviceContext>
RowwiseTransformIterator<T, platform::CUDADeviceContext>, const T*>
super_t;
HOSTDEVICE RowwiseTransformIterator(const T* x, int n)
: super_t(x), begin_(x), n_(n){};
: super_t(x), begin_(x), n_(n) {}
friend class thrust::iterator_core_access;
private:
......@@ -179,7 +180,7 @@ class MidWiseTransformIterator<T, platform::CUDADeviceContext>
MidWiseTransformIterator<T, platform::CUDADeviceContext>, const T*>
super_t;
HOSTDEVICE MidWiseTransformIterator(const T* x, int n, int post)
: super_t(x), begin_(x), n_(n), post_(post){};
: super_t(x), begin_(x), n_(n), post_(post) {}
friend class thrust::iterator_core_access;
private:
......@@ -333,6 +334,55 @@ static void ElemwiseGradBroadcast1CPU(const T* x, const T* y, const T* out,
}
}
#ifdef __NVCC__
// __shfl_down has been deprecated as of CUDA 9.0.
#if CUDA_VERSION < 9000
template <typename T>
__forceinline__ __device__ T __shfl_down_sync(unsigned, T val, int delta) {
return __shfl_down(val, delta);
}
#define CREATE_SHFL_MASK(mask, predicate) mask = 0u;
#else
#define FULL_WARP_MASK 0xFFFFFFFF
#define CREATE_SHFL_MASK(mask, predicate) \
mask = __ballot_sync(FULL_WARP_MASK, (predicate))
#endif
template <typename T>
__device__ T reduceSum(T val, int tid, int len) {
// TODO(zcd): The warp size should be taken from the
// parameters of the GPU but not specified as 32 simply.
// To make the reduceSum more efficiently,
// I use Warp-Level Parallelism and assume the Warp size
// is 32 which may be different for different GPU,
// but most card's warp size is 32.
__shared__ T shm[32];
const int warpSize = 32;
unsigned mask = 0u;
CREATE_SHFL_MASK(mask, tid < len);
for (int offset = warpSize / 2; offset > 0; offset /= 2)
val += __shfl_down_sync(mask, val, offset);
if (tid < warpSize) shm[tid] = 0;
__syncthreads();
if (tid % warpSize == 0) {
shm[tid / warpSize] = val;
}
CREATE_SHFL_MASK(mask, tid < warpSize);
if (tid < warpSize) {
val = shm[tid];
for (int offset = warpSize / 2; offset > 0; offset /= 2)
val += __shfl_down_sync(mask, val, offset);
}
return val;
}
template <typename T, typename DX_OP, typename DY_OP>
static __global__ void ElemwiseGradBroadcast1CUDAKernel(
const T* x, const T* y, const T* out, const T* dout, int h, int w,
......@@ -355,7 +405,7 @@ static __global__ void ElemwiseGradBroadcast1CUDAKernel(
if (dy) {
h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
val = platform::reduceSum(val, tid, h);
val = reduceSum(val, tid, h);
if (threadIdx.x == 0) {
dy[j] = val;
}
......@@ -432,7 +482,7 @@ static __global__ void ElemwiseGradBroadcast2CUDAKernel(
if (dy) {
int h = pre * post;
h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
val = platform::reduceSum(val, tid, h);
val = reduceSum(val, tid, h);
if (threadIdx.x == 0) {
dy[j] = val;
}
......@@ -472,11 +522,11 @@ void ElemwiseGradCompute(const framework::ExecutionContext& ctx,
auto y_dim = y.dims();
axis = (axis == -1 ? x_dim.size() - y_dim.size() : axis);
trim_trailing_singular_dims(y_dim);
trim_trailing_singular_dims(&y_dim);
axis = (y_dim.size() == 0) ? x_dim.size() : axis;
int pre, n, post;
get_mid_dims(x_dim, y_dim, axis, pre, n, post);
get_mid_dims(x_dim, y_dim, axis, &pre, &n, &post);
if (post == 1) {
int h = pre;
int w = n;
......@@ -514,7 +564,7 @@ void ElemwiseGradCompute(const framework::ExecutionContext& ctx,
}
}
}
};
}
template <typename DeviceContext, typename T, typename functor,
typename broadcastfunctor, typename broadcast2functor>
......@@ -543,11 +593,11 @@ void ElementwiseGradCompute(const framework::ExecutionContext& ctx,
}
axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis);
trim_trailing_singular_dims(y_dims);
trim_trailing_singular_dims(&y_dims);
axis = (y_dims.size() == 0) ? x_dims.size() : axis;
int pre, n, post;
get_mid_dims(x_dims, y_dims, axis, pre, n, post);
get_mid_dims(x_dims, y_dims, axis, &pre, &n, &post);
if (post == 1) {
broadcastfunctor f;
......@@ -582,11 +632,11 @@ void ElementwiseComputeEx(const framework::ExecutionContext& ctx,
axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis);
PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(),
"Axis should be in range [0, x_dims)");
trim_trailing_singular_dims(y_dims);
trim_trailing_singular_dims(&y_dims);
axis = (y_dims.size() == 0) ? x_dims.size() : axis;
int pre, n, post;
get_mid_dims(x_dims, y_dims, axis, pre, n, post);
get_mid_dims(x_dims, y_dims, axis, &pre, &n, &post);
if (post == 1) {
functor.RunRowWise(n, pre);
return;
......
......@@ -62,53 +62,5 @@ CUDA_ATOMIC_WRAPPER(Add, double) {
}
#endif
// __shfl_down has been deprecated as of CUDA 9.0.
#if CUDA_VERSION < 9000
template <typename T>
__forceinline__ __device__ T __shfl_down_sync(unsigned, T val, int delta) {
return __shfl_down(val, delta);
}
#define CREATE_SHFL_MASK(mask, predicate) mask = 0u;
#else
#define FULL_WARP_MASK 0xFFFFFFFF
#define CREATE_SHFL_MASK(mask, predicate) \
mask = __ballot_sync(FULL_WARP_MASK, (predicate))
#endif
template <typename T>
__device__ T reduceSum(T val, int tid, int len) {
// TODO(zcd): The warp size should be taken from the
// parameters of the GPU but not specified as 32 simply.
// To make the reduceSum more efficiently,
// I use Warp-Level Parallelism and assume the Warp size
// is 32 which may be different for different GPU,
// but most card's warp size is 32.
__shared__ T shm[32];
const int warpSize = 32;
unsigned mask = 0u;
CREATE_SHFL_MASK(mask, tid < len);
for (int offset = warpSize / 2; offset > 0; offset /= 2)
val += __shfl_down_sync(mask, val, offset);
if (tid < warpSize) shm[tid] = 0;
__syncthreads();
if (tid % warpSize == 0) {
shm[tid / warpSize] = val;
}
CREATE_SHFL_MASK(mask, tid < warpSize);
if (tid < warpSize) {
val = shm[tid];
for (int offset = warpSize / 2; offset > 0; offset /= 2)
val += __shfl_down_sync(mask, val, offset);
}
return val;
}
} // namespace platform
} // namespace paddle
......@@ -8,10 +8,14 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/platform/device_context.h"
#include <string>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/memory/memory.h"
namespace paddle {
namespace platform {
......
......@@ -8,11 +8,12 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/dynload/cublas.h"
......
......@@ -11,11 +11,12 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "gtest/gtest.h"
#include "paddle/fluid/platform/device_context.h"
#include <vector>
#include "glog/logging.h"
#include "gtest/gtest.h"
TEST(Device, Init) {
using paddle::platform::DeviceContext;
......
......@@ -11,15 +11,19 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/platform/device_tracer.h"
#include <google/protobuf/text_format.h>
#include <deque>
#include <fstream>
#include <map>
#include <mutex>
#include <mutex> // NOLINT
#include <numeric>
#include <thread>
#include <string>
#include <thread> // NOLINT
#include <vector>
#include "glog/logging.h"
#include "google/protobuf/text_format.h"
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/string/printf.h"
......@@ -123,7 +127,7 @@ void DisableActivity() {
void CUPTIAPI bufferRequested(uint8_t **buffer, size_t *size,
size_t *maxNumRecords) {
uint8_t *buf = (uint8_t *)malloc(kBufSize + kAlignSize);
uint8_t *buf = reinterpret_cast<uint8_t *>(malloc(kBufSize + kAlignSize));
*size = kBufSize;
*buffer = ALIGN_BUFFER(buf, kAlignSize);
*maxNumRecords = 0;
......
......@@ -11,8 +11,10 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include "paddle/fluid/platform/dynload/cupti.h"
#include "paddle/fluid/platform/profiler.pb.h"
......
......@@ -1003,6 +1003,46 @@ HOSTDEVICE inline float16 exp(const float16& a) {
return float16(::expf(static_cast<float>(a)));
}
template <>
HOSTDEVICE inline float16 log(const float16& a) {
return float16(::logf(static_cast<float>(a)));
}
template <>
HOSTDEVICE inline float16 tanh(const float16& a) {
return float16(::tanhf(static_cast<float>(a)));
}
template <>
HOSTDEVICE inline float16 sqrt(const float16& a) {
return float16(::sqrtf(static_cast<float>(a)));
}
template <>
HOSTDEVICE inline float16 ceil(const float16& a) {
return float16(::ceilf(static_cast<float>(a)));
}
template <>
HOSTDEVICE inline float16 floor(const float16& a) {
return float16(::floorf(static_cast<float>(a)));
}
template <>
HOSTDEVICE inline float16 round(const float16& a) {
return float16(::roundf(static_cast<float>(a)));
}
template <>
HOSTDEVICE inline float16 pow(const float16& a, const float16& b) {
return float16(::powf(static_cast<float>(a), static_cast<float>(b)));
}
template <>
HOSTDEVICE inline float16 abs(const float16& a) {
return float16(::fabs(static_cast<float>(a)));
}
} // namespace numext
} // namespace Eigen
......@@ -11,11 +11,11 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <mkldnn.hpp>
#include <vector>
#include "mkldnn/include/mkldnn.hpp"
#include "paddle/fluid/framework/operator.h"
namespace paddle {
......
......@@ -15,8 +15,11 @@ limitations under the License. */
#include "paddle/fluid/platform/profiler.h"
#include <sys/time.h>
#include <time.h>
#include <algorithm>
#include <iomanip>
#include <map>
#include <mutex> // NOLINT
#include <string>
#ifdef PADDLE_WITH_CUDA
#include <cuda.h>
#endif // PADDLE_WITH_CUDA
......@@ -28,10 +31,10 @@ limitations under the License. */
namespace paddle {
namespace platform {
struct EventList;
// The profiler state, the initial value is ProfilerState::kDisabled
static ProfilerState g_state = ProfilerState::kDisabled;
// To record which timer the profiler used, CUDA or CPU.
static std::string g_profiler_place = "";
// The thread local event list only can be accessed by the specific thread
// The thread index of each thread
static thread_local int32_t g_thread_id;
......@@ -45,6 +48,39 @@ static std::list<std::shared_ptr<EventList>> g_all_event_lists;
// The thread local event list only can be accessed by the specific thread
static thread_local std::shared_ptr<EventList> g_event_list;
struct EventList {
constexpr static size_t kMB = 1024 * 1024;
constexpr static size_t kEventBlockSize = 16 * kMB;
constexpr static size_t kEventSize = sizeof(Event);
constexpr static size_t kEventAlign = alignof(Event);
constexpr static size_t kNumBlock =
kEventBlockSize /
((kEventSize + kEventAlign - 1) / kEventAlign * kEventAlign);
template <typename... Args>
void Record(Args&&... args) {
if (event_blocks.empty() || event_blocks.front().size() == kNumBlock) {
event_blocks.emplace_front();
event_blocks.front().reserve(kNumBlock);
}
event_blocks.front().emplace_back(std::forward<Args>(args)...);
}
std::vector<Event> Reduce() {
std::vector<Event> result;
for (auto& block : event_blocks) {
result.insert(result.begin(), std::make_move_iterator(block.begin()),
std::make_move_iterator(block.end()));
}
event_blocks.clear();
return result;
}
void Clear() { event_blocks.clear(); }
std::forward_list<std::vector<Event>> event_blocks;
};
inline uint64_t GetTimeInNsec() {
using clock = std::conditional<std::chrono::high_resolution_clock::is_steady,
std::chrono::high_resolution_clock,
......@@ -60,9 +96,9 @@ inline uint64_t PosixInNsec() {
return 1000 * (static_cast<uint64_t>(tv.tv_sec) * 1000000 + tv.tv_usec);
}
Event::Event(EventKind kind, std::string name, uint32_t thread_id,
Event::Event(EventType type, std::string name, uint32_t thread_id,
const DeviceContext* dev_ctx)
: kind_(kind), name_(name), thread_id_(thread_id), has_cuda_(false) {
: type_(type), name_(name), thread_id_(thread_id), has_cuda_(false) {
#ifdef PADDLE_WITH_CUDA
has_cuda_ = dev_ctx ? platform::is_gpu_place(dev_ctx->GetPlace()) : false;
if (has_cuda_) {
......@@ -76,17 +112,7 @@ Event::Event(EventKind kind, std::string name, uint32_t thread_id,
cpu_ns_ = GetTimeInNsec();
}
std::string Event::kind() const {
switch (kind_) {
case EventKind::kMark:
return "mark";
case EventKind::kPushRange:
return "push";
case EventKind::kPopRange:
return "pop";
}
PADDLE_THROW("Unknown EventKind.");
}
const EventType& Event::type() const { return type_; }
double Event::CpuElapsedMs(const Event& e) const {
return (e.cpu_ns_ - cpu_ns_) / (1000000.0);
......@@ -129,15 +155,15 @@ inline EventList& GetEventList() {
}
void Mark(const std::string& name, const DeviceContext* dev_ctx) {
GetEventList().Record(EventKind::kMark, name, g_thread_id, dev_ctx);
GetEventList().Record(EventType::kMark, name, g_thread_id, dev_ctx);
}
void PushEvent(const std::string& name, const DeviceContext* dev_ctx) {
GetEventList().Record(EventKind::kPushRange, name, g_thread_id, dev_ctx);
GetEventList().Record(EventType::kPushRange, name, g_thread_id, dev_ctx);
}
void PopEvent(const std::string& name, const DeviceContext* dev_ctx) {
GetEventList().Record(EventKind::kPopRange, name, g_thread_id, dev_ctx);
GetEventList().Record(EventType::kPopRange, name, g_thread_id, dev_ctx);
}
RecordEvent::RecordEvent(const std::string& name, const DeviceContext* dev_ctx)
......@@ -197,12 +223,7 @@ void EnableProfiler(ProfilerState state) {
"The profiling state should be disabled when calling ",
"EnableProfiler.");
g_state = state;
if (g_state == ProfilerState::kCUDA) {
g_profiler_place = "CUDA";
} else if (g_state == ProfilerState::kCPU) {
g_profiler_place = "CPU";
} else {
g_profiler_place = "All";
if (g_state == ProfilerState::kAll) {
GetDeviceTracer()->Enable();
}
#ifdef PADDLE_WITH_CUDA
......@@ -240,27 +261,63 @@ std::vector<std::vector<Event>> GetAllEvents() {
return result;
}
void DisableProfiler(EventSortingKey sorted_key,
const std::string& profile_path) {
PADDLE_ENFORCE(g_state != ProfilerState::kDisabled,
"Can't disable profiling, since it's not starting.");
// Mark the profiling stop.
Mark("_stop_profiler_", nullptr);
g_state = ProfilerState::kDisabled;
// The information of each event given in the profiling report
struct EventItem {
std::string name;
int calls;
double total_time;
double min_time;
double max_time;
double ave_time;
};
// Print results
void PrintProfiler(const std::vector<std::vector<EventItem>>& events_table,
const std::string& sorted_domain, const size_t name_width,
const size_t data_width) {
// Output header information
std::cout << "\n------------------------->"
<< " Profiling Report "
<< "<-------------------------\n\n";
std::string place;
if (g_state == ProfilerState::kCPU) {
place = "CPU";
} else if (g_state == ProfilerState::kCUDA) {
place = "CUDA";
} else if (g_state == ProfilerState::kAll) {
place = "All";
} else {
PADDLE_THROW("Invalid profiler state");
}
std::vector<std::vector<Event>> all_events = GetAllEvents();
ParseEvents(all_events, sorted_key);
ResetProfiler();
DeviceTracer* tracer = GetDeviceTracer();
if (g_profiler_place == "All" && tracer && tracer->IsEnabled()) {
tracer->Disable();
tracer->GenProfile(profile_path);
std::cout << "Place: " << place << std::endl;
std::cout << "Time unit: ms" << std::endl;
std::cout << "Sorted by " << sorted_domain
<< " in descending order in the same thread\n\n";
// Output events table
std::cout.setf(std::ios::left);
std::cout << std::setw(name_width) << "Event" << std::setw(data_width)
<< "Calls" << std::setw(data_width) << "Total"
<< std::setw(data_width) << "Min." << std::setw(data_width)
<< "Max." << std::setw(data_width) << "Ave." << std::endl;
for (size_t i = 0; i < events_table.size(); ++i) {
for (size_t j = 0; j < events_table[i].size(); ++j) {
const EventItem& event_item = events_table[i][j];
std::cout << std::setw(name_width) << event_item.name
<< std::setw(data_width) << event_item.calls
<< std::setw(data_width) << event_item.total_time
<< std::setw(data_width) << event_item.min_time
<< std::setw(data_width) << event_item.max_time
<< std::setw(data_width) << event_item.ave_time << std::endl;
}
}
std::cout << std::endl;
}
void ParseEvents(std::vector<std::vector<Event>>& events,
EventSortingKey sorted_by) {
if (g_profiler_place == "") return;
// Parse the event list and output the profiling report
void ParseEvents(const std::vector<std::vector<Event>>& events,
EventSortingKey sorted_by = EventSortingKey::kDefault) {
if (g_state == ProfilerState::kDisabled) return;
std::string sorted_domain;
std::function<bool(const EventItem&, const EventItem&)> sorted_func;
......@@ -307,9 +364,9 @@ void ParseEvents(std::vector<std::vector<Event>>& events,
std::unordered_map<std::string, int> event_idx;
for (size_t j = 0; j < events[i].size(); j++) {
if (events[i][j].kind() == "push") {
if (events[i][j].type() == EventType::kPushRange) {
pushed_events.push_back(events[i][j]);
} else if (events[i][j].kind() == "pop") {
} else if (events[i][j].type() == EventType::kPopRange) {
std::list<Event>::reverse_iterator rit = pushed_events.rbegin();
while (rit != pushed_events.rend() &&
rit->name() != events[i][j].name()) {
......@@ -317,10 +374,10 @@ void ParseEvents(std::vector<std::vector<Event>>& events,
}
if (rit != pushed_events.rend()) {
double event_time =
(g_profiler_place == "CUDA" || g_profiler_place == "All")
? rit->CudaElapsedMs(events[i][j])
: rit->CpuElapsedMs(events[i][j]);
double event_time = (g_state == ProfilerState::kCUDA ||
g_state == ProfilerState::kAll)
? rit->CudaElapsedMs(events[i][j])
: rit->CpuElapsedMs(events[i][j]);
std::string event_name =
"thread" + std::to_string(rit->thread_id()) + "::" + rit->name();
......@@ -376,35 +433,22 @@ void ParseEvents(std::vector<std::vector<Event>>& events,
PrintProfiler(events_table, sorted_domain, max_name_width + 4, 12);
}
void PrintProfiler(std::vector<std::vector<EventItem>>& events_table,
std::string& sorted_domain, const size_t name_width,
const size_t data_width) {
// Output header information
std::cout << "\n------------------------->"
<< " Profiling Report "
<< "<-------------------------\n\n";
std::cout << "Place: " << g_profiler_place << std::endl;
std::cout << "Time unit: ms" << std::endl;
std::cout << "Sorted by " << sorted_domain
<< " in descending order in the same thread\n\n";
// Output events table
std::cout.setf(std::ios::left);
std::cout << std::setw(name_width) << "Event" << std::setw(data_width)
<< "Calls" << std::setw(data_width) << "Total"
<< std::setw(data_width) << "Min." << std::setw(data_width)
<< "Max." << std::setw(data_width) << "Ave." << std::endl;
for (size_t i = 0; i < events_table.size(); ++i) {
for (size_t j = 0; j < events_table[i].size(); ++j) {
EventItem& event_item = events_table[i][j];
std::cout << std::setw(name_width) << event_item.name
<< std::setw(data_width) << event_item.calls
<< std::setw(data_width) << event_item.total_time
<< std::setw(data_width) << event_item.min_time
<< std::setw(data_width) << event_item.max_time
<< std::setw(data_width) << event_item.ave_time << std::endl;
}
void DisableProfiler(EventSortingKey sorted_key,
const std::string& profile_path) {
PADDLE_ENFORCE(g_state != ProfilerState::kDisabled,
"Can't disable profiling, since it's not starting.");
// Mark the profiling stop.
Mark("_stop_profiler_", nullptr);
std::vector<std::vector<Event>> all_events = GetAllEvents();
ParseEvents(all_events, sorted_key);
ResetProfiler();
DeviceTracer* tracer = GetDeviceTracer();
if (g_state == ProfilerState::kAll && tracer && tracer->IsEnabled()) {
tracer->Disable();
tracer->GenProfile(profile_path);
}
std::cout << std::endl;
g_state = ProfilerState::kDisabled;
}
} // namespace platform
......
......@@ -15,7 +15,7 @@ limitations under the License. */
#pragma once
#include <forward_list>
#include <list>
#include <mutex>
#include <string>
#include <vector>
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/profiler.pb.h"
......@@ -23,16 +23,16 @@ limitations under the License. */
namespace paddle {
namespace platform {
enum EventKind { kMark, kPushRange, kPopRange };
enum EventType { kMark, kPushRange, kPopRange };
class Event {
public:
// The DeviceContext is used to get the cuda stream.
// If CPU profiling mode, can pass nullptr.
Event(EventKind kind, std::string name, uint32_t thread_id,
Event(EventType type, std::string name, uint32_t thread_id,
const DeviceContext* dev_ctx);
std::string kind() const;
const EventType& type() const;
std::string name() const { return name_; }
uint32_t thread_id() const { return thread_id_; }
bool has_cuda() const { return has_cuda_; }
......@@ -46,7 +46,7 @@ class Event {
double CudaElapsedMs(const Event& e) const;
private:
EventKind kind_;
EventType type_;
std::string name_;
uint32_t thread_id_;
int64_t cpu_ns_;
......@@ -57,39 +57,6 @@ class Event {
#endif
};
struct EventList {
constexpr static size_t kMB = 1024 * 1024;
constexpr static size_t kEventBlockSize = 16 * kMB;
constexpr static size_t kEventSize = sizeof(Event);
constexpr static size_t kEventAlign = alignof(Event);
constexpr static size_t kNumBlock =
kEventBlockSize /
((kEventSize + kEventAlign - 1) / kEventAlign * kEventAlign);
template <typename... Args>
void Record(Args&&... args) {
if (event_blocks.empty() || event_blocks.front().size() == kNumBlock) {
event_blocks.emplace_front();
event_blocks.front().reserve(kNumBlock);
}
event_blocks.front().emplace_back(std::forward<Args>(args)...);
}
std::vector<Event> Reduce() {
std::vector<Event> result;
for (auto& block : event_blocks) {
result.insert(result.begin(), std::make_move_iterator(block.begin()),
std::make_move_iterator(block.end()));
}
event_blocks.clear();
return result;
}
void Clear() { event_blocks.clear(); }
std::forward_list<std::vector<Event>> event_blocks;
};
enum ProfilerState {
kDisabled, // disabled state
kCPU, // CPU profiling state
......@@ -136,16 +103,6 @@ struct RecordThread {
// event_lists, event_lists[i][j] represents the j-th Event of i-th thread.
std::vector<std::vector<Event>> GetAllEvents();
// The information of each event given in the profiling report
struct EventItem {
std::string name;
int calls;
double total_time;
double min_time;
double max_time;
double ave_time;
};
// Candidate keys to sort the profiling report
enum EventSortingKey { kDefault, kCalls, kTotal, kMin, kMax, kAve };
......@@ -158,14 +115,5 @@ void ResetProfiler();
void DisableProfiler(EventSortingKey sorted_key,
const std::string& profile_path);
// Parse the event list and output the profiling report
void ParseEvents(std::vector<std::vector<Event>>&,
EventSortingKey sorted_by = EventSortingKey::kDefault);
// Print results
void PrintProfiler(std::vector<std::vector<EventItem>>& events_table,
std::string& sorted_domain, const size_t name_width,
const size_t data_width);
} // namespace platform
} // namespace paddle
......@@ -13,22 +13,23 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/platform/profiler.h"
#include <string>
#ifdef PADDLE_WITH_CUDA
#include "cuda_runtime.h"
#include <cuda_runtime.h>
#endif
#include "gtest/gtest.h"
TEST(Event, CpuElapsedTime) {
using paddle::platform::Event;
using paddle::platform::EventKind;
using paddle::platform::EventType;
Event start_event(EventKind::kPushRange, "test", 0, nullptr);
Event start_event(EventType::kPushRange, "test", 0, nullptr);
EXPECT_TRUE(start_event.has_cuda() == false);
int counter = 0;
while (counter != 1000) {
counter++;
}
Event stop_event(EventKind::kPopRange, "test", 0, nullptr);
Event stop_event(EventType::kPopRange, "test", 0, nullptr);
EXPECT_GT(start_event.CpuElapsedMs(stop_event), 0);
}
......@@ -38,16 +39,16 @@ TEST(Event, CudaElapsedTime) {
using paddle::platform::CUDADeviceContext;
using paddle::platform::CUDAPlace;
using paddle::platform::Event;
using paddle::platform::EventKind;
using paddle::platform::EventType;
DeviceContext* dev_ctx = new CUDADeviceContext(CUDAPlace(0));
Event start_event(EventKind::kPushRange, "test", 0, dev_ctx);
Event start_event(EventType::kPushRange, "test", 0, dev_ctx);
EXPECT_TRUE(start_event.has_cuda() == true);
int counter = 0;
while (counter != 1000) {
counter++;
}
Event stop_event(EventKind::kPopRange, "test", 0, dev_ctx);
Event stop_event(EventType::kPopRange, "test", 0, dev_ctx);
EXPECT_GT(start_event.CudaElapsedMs(stop_event), 0);
}
#endif
......@@ -55,7 +56,7 @@ TEST(Event, CudaElapsedTime) {
TEST(RecordEvent, RecordEvent) {
using paddle::platform::DeviceContext;
using paddle::platform::Event;
using paddle::platform::EventKind;
using paddle::platform::EventType;
using paddle::platform::RecordEvent;
using paddle::platform::ProfilerState;
using paddle::platform::EventSortingKey;
......
......@@ -465,7 +465,8 @@ All parameter, weight, gradient are variables in Paddle.
m.def("init_gflags", framework::InitGflags);
m.def("init_glog", framework::InitGLOG);
m.def("init_devices", &framework::InitDevices);
m.def("init_devices",
[](bool init_p2p) { framework::InitDevices(init_p2p); });
m.def("is_compiled_with_cuda", IsCompiledWithCUDA);
#ifdef PADDLE_WITH_CUDA
......
......@@ -41,6 +41,6 @@ int main(int argc, char** argv) {
paddle::memory::Used(paddle::platform::CUDAPlace(0));
#endif
paddle::framework::InitDevices();
paddle::framework::InitDevices(true);
return RUN_ALL_TESTS();
}
......@@ -85,6 +85,8 @@ def __bootstrap__():
import core
import os
in_test = 'unittest' in sys.modules
try:
num_threads = int(os.getenv('OMP_NUM_THREADS', '1'))
except ValueError:
......@@ -109,8 +111,11 @@ def __bootstrap__():
core.init_gflags([sys.argv[0]] +
["--tryfromenv=" + ",".join(read_env_flags)])
core.init_glog(sys.argv[0])
core.init_devices()
# don't init_p2p when in unittest to save time.
core.init_devices(not in_test)
# TODO(panyx0718): Avoid doing complex initialization logic in __init__.py.
# Consider paddle.init(args) or paddle.main(args)
layers.monkey_patch_variable()
__bootstrap__()
......@@ -278,11 +278,21 @@ class DistributeTranspiler:
# we don't need to create them when grad arrives.
# change client side var name to origin name by
# removing ".trainer_%d" suffix
suff_idx = v.name.find(".trainer_")
if suff_idx >= 0:
orig_var_name = v.name[:suff_idx]
else:
orig_var_name = v.name
# NOTE: single_trainer_var must be created for multi-trainer
# case to merge grads from multiple trainers
single_trainer_var = \
pserver_program.global_block().create_var(
name=orig_var_name,
persistable=True,
type=v.type,
dtype=v.dtype,
shape=v.shape)
if self.trainers > 1:
for trainer_id in xrange(self.trainers):
var = pserver_program.global_block().create_var(
......@@ -293,12 +303,6 @@ class DistributeTranspiler:
shape=v.shape)
recv_inputs.append(var)
else:
single_trainer_var = pserver_program.global_block().create_var(
name=orig_var_name,
persistable=True,
type=v.type,
dtype=v.dtype,
shape=v.shape)
recv_inputs.append(single_trainer_var)
# step3
......
......@@ -102,7 +102,7 @@ if '${WITH_FLUID_ONLY}'== 'OFF':
package_data['py_paddle']=['*.py','_swig_paddle.so']
package_dir={
'': '${CMAKE_CURRENT_SOURCE_DIR}',
'': '${PADDLE_BINARY_DIR}/python',
# The paddle.fluid.proto will be generated while compiling.
# So that package points to other directory.
'paddle.fluid.proto.profiler': '${PADDLE_BINARY_DIR}/paddle/fluid/platform',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册