提交 3dd01823 编写于 作者: Y Yu Yang

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into feature/clean_matmul

...@@ -80,6 +80,8 @@ parser.add_argument( ...@@ -80,6 +80,8 @@ parser.add_argument(
type=str, type=str,
default="", default="",
help="Comma-separated list of hostname:port pairs") help="Comma-separated list of hostname:port pairs")
parser.add_argument(
"--profile", action='store_true', help="If set, profile a few steps.")
# Flags for defining the tf.train.Server # Flags for defining the tf.train.Server
parser.add_argument( parser.add_argument(
...@@ -183,8 +185,8 @@ def main(): ...@@ -183,8 +185,8 @@ def main():
start_time = time.time() start_time = time.time()
num_samples = 0 num_samples = 0
train_pass_acc.reset() train_pass_acc.reset()
for batch_id, data in enumerate(train_reader()):
ts = time.time() def run_step(batch_id, data):
img_data = np.array( img_data = np.array(
map(lambda x: x[0].reshape(data_shape), data)).astype( map(lambda x: x[0].reshape(data_shape), data)).astype(
"float32") "float32")
...@@ -196,14 +198,28 @@ def main(): ...@@ -196,14 +198,28 @@ def main():
feed={"pixel": img_data, feed={"pixel": img_data,
"label": y_data}, "label": y_data},
fetch_list=[avg_cost, batch_acc, batch_size]) fetch_list=[avg_cost, batch_acc, batch_size])
return loss, acc, b_size
if args.profile and args.task_index == 0:
# warmup.
for batch_id, data in enumerate(train_reader()):
if batch_id > 5: break
run_step(batch_id, data)
with profiler.profiler('All', 'total', '/tmp/profile_vgg'):
for batch_id, data in enumerate(train_reader()):
if batch_id > 5: break
run_step(batch_id, data)
for batch_id, data in enumerate(train_reader()):
ts = time.time()
loss, acc, b_size = run_step(batch_id, data)
iters += 1 iters += 1
num_samples += len(data) num_samples += len(data)
train_pass_acc.add(value=acc, weight=b_size) train_pass_acc.add(value=acc, weight=b_size)
print( print(
"Task:%d Pass = %d, Iters = %d, Loss = %f, Accuracy = %f, " "Pass = %d, Iters = %d, Loss = %f, Accuracy = %f, "
"Speed = %.2f img/s " % (args.task_index, pass_id, iters, "Speed = %.2f img/s" % (pass_id, iters, loss, acc,
loss, acc, len(data) / (time.time() - ts))
len(data) / (time.time() - ts))
) # The accuracy is the accumulation of batches, but not the current batch. ) # The accuracy is the accumulation of batches, but not the current batch.
pass_elapsed = time.time() - start_time pass_elapsed = time.time() - start_time
......
...@@ -42,7 +42,3 @@ Codistillation is a technique that tries to scale the training further. A few tr ...@@ -42,7 +42,3 @@ Codistillation is a technique that tries to scale the training further. A few tr
[3] Yonghui Wu, Mike Schuster, Zhifeng Chen, Quoc V Le, Mohammad Norouzi, Wolfgang Macherey, Maxim Krikun, Yuan Cao, Qin Gao, Klaus Macherey, et al. Google’s neural machine translation system: Bridging the gap between human and machine translation. [3] Yonghui Wu, Mike Schuster, Zhifeng Chen, Quoc V Le, Mohammad Norouzi, Wolfgang Macherey, Maxim Krikun, Yuan Cao, Qin Gao, Klaus Macherey, et al. Google’s neural machine translation system: Bridging the gap between human and machine translation.
[4] LARGE SCALE DISTRIBUTED NEURAL NETWORK TRAINING THROUGH ONLINE DISTILLATION [4] LARGE SCALE DISTRIBUTED NEURAL NETWORK TRAINING THROUGH ONLINE DISTILLATION
...@@ -143,7 +143,7 @@ OpDesc *BlockDesc::InsertOp(size_t index) { ...@@ -143,7 +143,7 @@ OpDesc *BlockDesc::InsertOp(size_t index) {
} }
void BlockDesc::RemoveOp(size_t s, size_t e) { void BlockDesc::RemoveOp(size_t s, size_t e) {
if (ops_.begin() + s == ops_.end() || ops_.begin() + e == ops_.end()) { if (ops_.begin() + s >= ops_.end() || ops_.begin() + e > ops_.end()) {
return; return;
} }
need_update_ = true; need_update_ = true;
......
...@@ -15,12 +15,14 @@ if(WITH_GPU) ...@@ -15,12 +15,14 @@ if(WITH_GPU)
dynload_cuda) dynload_cuda)
set(multi_devices_graph_builder_deps nccl_all_reduce_op_handle) set(multi_devices_graph_builder_deps nccl_all_reduce_op_handle)
nv_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope ddim dynload_cuda) nv_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope ddim dynload_cuda)
nv_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor dynload_cuda)
else() else()
set(multi_devices_graph_builder_deps) set(multi_devices_graph_builder_deps)
cc_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope ddim) cc_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope ddim)
cc_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor)
endif() endif()
cc_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor)
cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor) cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor)
cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle
......
...@@ -19,14 +19,12 @@ ...@@ -19,14 +19,12 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
BroadcastOpHandle::BroadcastOpHandle(const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places)
: local_scopes_(local_scopes), places_(places) {}
void BroadcastOpHandle::RunImpl() { void BroadcastOpHandle::RunImpl() {
// the input and output may have dummy var. if (places_.size() == 1) return;
VarHandle *in_var_handle;
// The input and output may have dummy vars.
VarHandle *in_var_handle;
{ {
auto in_var_handles = DynamicCast<VarHandle>(inputs_); auto in_var_handles = DynamicCast<VarHandle>(inputs_);
PADDLE_ENFORCE_EQ(in_var_handles.size(), 1, PADDLE_ENFORCE_EQ(in_var_handles.size(), 1,
...@@ -55,27 +53,97 @@ void BroadcastOpHandle::RunImpl() { ...@@ -55,27 +53,97 @@ void BroadcastOpHandle::RunImpl() {
Tensor &in_tensor = VariableVisitor::GetMutableTensor(in_var); Tensor &in_tensor = VariableVisitor::GetMutableTensor(in_var);
for (auto *out : out_var_handles) { // NOTE: The tensors' Place of input and output must be all on GPU or all on
if (*out == *in_var_handle) { // CPU.
for (auto *out_var_handle : out_var_handles) {
if (out_var_handle->IsTheSameVar(*in_var_handle)) {
continue; continue;
} }
auto t_out_p = out_var_handle->place_;
auto &out_p = out->place_; auto *out_var = var_scopes.at(out_var_handle->scope_idx_)
auto *out_var = var_scopes.at(out->scope_idx_)->FindVar(out->name_); ->FindVar(out_var_handle->name_);
PADDLE_ENFORCE_NOT_NULL(out_var); PADDLE_ENFORCE_NOT_NULL(out_var);
PADDLE_ENFORCE_EQ(out_p.which(), in_var_handle->place_.which(), if (platform::is_gpu_place(in_tensor.place())) {
"Places must be all on CPU or all on CUDA."); PADDLE_ENFORCE(platform::is_gpu_place(t_out_p),
"Places of input and output must be all on GPU.");
} else {
t_out_p = platform::CPUPlace();
}
VariableVisitor::ShareDimsAndLoD(*in_var, out_var); VariableVisitor::ShareDimsAndLoD(*in_var, out_var);
VariableVisitor::GetMutableTensor(out_var).mutable_data(out_p, VariableVisitor::GetMutableTensor(out_var).mutable_data(t_out_p,
in_tensor.type()); in_tensor.type());
}
if (platform::is_cpu_place(in_tensor.place())) {
for (auto *out_var_handle : out_var_handles) {
if (out_var_handle->IsTheSameVar(*in_var_handle)) {
continue;
}
auto &out_p = out_var_handle->place_;
auto *out_var = var_scopes.at(out_var_handle->scope_idx_)
->FindVar(out_var_handle->name_);
RunAndRecordEvent(out_p, [in_tensor, out_var] {
paddle::framework::TensorCopy(
in_tensor, platform::CPUPlace(),
&VariableVisitor::GetMutableTensor(out_var));
});
}
} else {
#ifdef PADDLE_WITH_CUDA
VarHandle *out_handle = nullptr;
int root_id = boost::get<platform::CUDAPlace>(in_tensor.place()).device;
std::vector<std::function<void()>> broadcast_calls;
for (auto out_var_handle : out_var_handles) {
Variable *out_var = var_scopes.at(out_var_handle->scope_idx_)
->FindVar(out_var_handle->name_);
int dst_id =
boost::get<platform::CUDAPlace>(out_var_handle->place_).device;
auto &nccl_ctx = nccl_ctxs_->at(dst_id);
void *send_recv_buffer = nullptr;
if (root_id == dst_id) {
send_recv_buffer = const_cast<void *>(in_tensor.data<void>());
out_handle = out_var_handle;
} else {
send_recv_buffer =
VariableVisitor::GetMutableTensor(out_var).mutable_data(
out_var_handle->place_);
}
int type = platform::ToNCCLDataType(in_tensor.type());
size_t numel = static_cast<size_t>(in_tensor.numel());
broadcast_calls.emplace_back(
[send_recv_buffer, numel, type, root_id, &nccl_ctx] {
PADDLE_ENFORCE(platform::dynload::ncclBcast(
send_recv_buffer, numel, static_cast<ncclDataType_t>(type),
root_id, nccl_ctx.comm_, nccl_ctx.stream()));
});
}
auto dev_ctx = dev_ctxes_.at(out_p); this->RunAndRecordEvent([&] {
RunAndRecordEvent(out_p, [in_tensor, out_var, dev_ctx, out_p] { {
paddle::framework::TensorCopy( platform::NCCLGroupGuard guard;
in_tensor, out_p, *(dev_ctx), for (auto &call : broadcast_calls) {
&VariableVisitor::GetMutableTensor(out_var)); call();
}
}
if (!out_handle->IsTheSameVar(*in_var_handle)) {
auto out_var = var_scopes.at(in_var_handle->scope_idx_)
->FindVar(out_var_handles[0]->name_);
paddle::framework::TensorCopy(
in_tensor, in_var_handle->place_,
*(dev_ctxes_.at(in_var_handle->place_)),
&VariableVisitor::GetMutableTensor(out_var));
}
}); });
#else
PADDLE_THROW("CUDA is not enabled.");
#endif
} }
} }
......
...@@ -24,14 +24,32 @@ ...@@ -24,14 +24,32 @@
#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/nccl_helper.h"
#endif
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
struct BroadcastOpHandle : public OpHandleBase { struct BroadcastOpHandle : public OpHandleBase {
public: public:
#ifdef PADDLE_WITH_CUDA
BroadcastOpHandle(const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
const platform::NCCLContextMap *nccl_ctxs)
: local_scopes_(local_scopes), places_(places), nccl_ctxs_(nccl_ctxs) {
if (nccl_ctxs_) {
for (auto &p_ctx : nccl_ctxs_->contexts_) {
dev_ctxes_[platform::CUDAPlace(p_ctx.first)] = p_ctx.second.ctx_.get();
}
}
}
#else
BroadcastOpHandle(const std::vector<Scope *> &local_scopes, BroadcastOpHandle(const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places); const std::vector<platform::Place> &places)
: local_scopes_(local_scopes), places_(places) {}
#endif
std::string Name() const override; std::string Name() const override;
...@@ -44,6 +62,9 @@ struct BroadcastOpHandle : public OpHandleBase { ...@@ -44,6 +62,9 @@ struct BroadcastOpHandle : public OpHandleBase {
private: private:
const std::vector<Scope *> &local_scopes_; const std::vector<Scope *> &local_scopes_;
const std::vector<platform::Place> &places_; const std::vector<platform::Place> &places_;
#ifdef PADDLE_WITH_CUDA
const platform::NCCLContextMap *nccl_ctxs_;
#endif
}; };
} // namespace details } // namespace details
} // namespace framework } // namespace framework
......
...@@ -35,15 +35,25 @@ struct TestBroadcastOpHandle { ...@@ -35,15 +35,25 @@ struct TestBroadcastOpHandle {
std::unique_ptr<OpHandleBase> op_handle_; std::unique_ptr<OpHandleBase> op_handle_;
std::vector<std::unique_ptr<VarHandleBase>> vars_; std::vector<std::unique_ptr<VarHandleBase>> vars_;
std::vector<p::Place> gpu_list_; std::vector<p::Place> gpu_list_;
bool use_gpu_;
#ifdef PADDLE_WITH_CUDA
std::unique_ptr<platform::NCCLContextMap> nccl_ctxs_;
#endif
void WaitAll() { void WaitAll() {
for (size_t j = 0; j < ctxs_.size(); ++j) { for (size_t j = 0; j < ctxs_.size(); ++j) {
ctxs_[j]->Wait(); ctxs_[j]->Wait();
} }
#ifdef PADDLE_WITH_CUDA
if (nccl_ctxs_) {
nccl_ctxs_->WaitAll();
}
#endif
} }
void InitCtxOnGpu(bool use_gpu) { void InitCtxOnGpu(bool use_gpu) {
if (use_gpu) { use_gpu_ = use_gpu;
if (use_gpu_) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
int count = p::GetCUDADeviceCount(); int count = p::GetCUDADeviceCount();
if (count <= 1) { if (count <= 1) {
...@@ -57,6 +67,7 @@ struct TestBroadcastOpHandle { ...@@ -57,6 +67,7 @@ struct TestBroadcastOpHandle {
gpu_list_.push_back(p); gpu_list_.push_back(p);
ctxs_.emplace_back(new p::CUDADeviceContext(p)); ctxs_.emplace_back(new p::CUDADeviceContext(p));
} }
nccl_ctxs_.reset(new platform::NCCLContextMap(gpu_list_));
#else #else
PADDLE_THROW("CUDA is not support."); PADDLE_THROW("CUDA is not support.");
#endif #endif
...@@ -67,6 +78,9 @@ struct TestBroadcastOpHandle { ...@@ -67,6 +78,9 @@ struct TestBroadcastOpHandle {
gpu_list_.push_back(p); gpu_list_.push_back(p);
ctxs_.emplace_back(new p::CPUDeviceContext(p)); ctxs_.emplace_back(new p::CPUDeviceContext(p));
} }
#ifdef PADDLE_WITH_CUDA
nccl_ctxs_.reset(nullptr);
#endif
} }
} }
...@@ -82,7 +96,21 @@ struct TestBroadcastOpHandle { ...@@ -82,7 +96,21 @@ struct TestBroadcastOpHandle {
} }
param_scopes_[input_scope_idx]->Var("input"); param_scopes_[input_scope_idx]->Var("input");
op_handle_.reset(new BroadcastOpHandle(local_scopes_, gpu_list_)); if (use_gpu_) {
#ifdef PADDLE_WITH_CUDA
op_handle_.reset(
new BroadcastOpHandle(local_scopes_, gpu_list_, nccl_ctxs_.get()));
#else
PADDLE_THROW("CUDA is not support.");
#endif
} else {
#ifdef PADDLE_WITH_CUDA
op_handle_.reset(
new BroadcastOpHandle(local_scopes_, gpu_list_, nccl_ctxs_.get()));
#else
op_handle_.reset(new BroadcastOpHandle(local_scopes_, gpu_list_));
#endif
}
auto* in_var_handle = auto* in_var_handle =
new VarHandle(1, input_scope_idx, "input", gpu_list_[input_scope_idx]); new VarHandle(1, input_scope_idx, "input", gpu_list_[input_scope_idx]);
...@@ -97,7 +125,9 @@ struct TestBroadcastOpHandle { ...@@ -97,7 +125,9 @@ struct TestBroadcastOpHandle {
op_handle_->AddInput(dummy_var_handle); op_handle_->AddInput(dummy_var_handle);
for (size_t j = 0; j < gpu_list_.size(); ++j) { for (size_t j = 0; j < gpu_list_.size(); ++j) {
op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get()); if (!use_gpu_) {
op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get());
}
VarHandle* out_var_handle = new VarHandle(2, j, "out", gpu_list_[j]); VarHandle* out_var_handle = new VarHandle(2, j, "out", gpu_list_[j]);
vars_.emplace_back(out_var_handle); vars_.emplace_back(out_var_handle);
op_handle_->AddOutput(out_var_handle); op_handle_->AddOutput(out_var_handle);
......
...@@ -25,6 +25,7 @@ GatherOpHandle::GatherOpHandle(const std::vector<Scope *> &local_scopes, ...@@ -25,6 +25,7 @@ GatherOpHandle::GatherOpHandle(const std::vector<Scope *> &local_scopes,
: local_scopes_(local_scopes), places_(places) {} : local_scopes_(local_scopes), places_(places) {}
void GatherOpHandle::RunImpl() { void GatherOpHandle::RunImpl() {
if (places_.size() == 1) return;
// the input and output may have dummy var. // the input and output may have dummy var.
auto in_var_handles = DynamicCast<VarHandle>(inputs_); auto in_var_handles = DynamicCast<VarHandle>(inputs_);
...@@ -35,7 +36,6 @@ void GatherOpHandle::RunImpl() { ...@@ -35,7 +36,6 @@ void GatherOpHandle::RunImpl() {
VarHandle *out_var_handle; VarHandle *out_var_handle;
{ {
auto out_var_handles = DynamicCast<VarHandle>(outputs_); auto out_var_handles = DynamicCast<VarHandle>(outputs_);
PADDLE_ENFORCE_EQ(out_var_handles.size(), 1, PADDLE_ENFORCE_EQ(out_var_handles.size(), 1,
"The number of output should be one."); "The number of output should be one.");
out_var_handle = out_var_handles.front(); out_var_handle = out_var_handles.front();
...@@ -50,68 +50,62 @@ void GatherOpHandle::RunImpl() { ...@@ -50,68 +50,62 @@ void GatherOpHandle::RunImpl() {
auto pre_in_var = auto pre_in_var =
var_scopes.at(in_0_handle->scope_idx_)->FindVar(in_0_handle->name_); var_scopes.at(in_0_handle->scope_idx_)->FindVar(in_0_handle->name_);
PADDLE_ENFORCE_NOT_NULL(pre_in_var); PADDLE_ENFORCE_NOT_NULL(pre_in_var);
PADDLE_ENFORCE(pre_in_var->IsType<framework::SelectedRows>(), PADDLE_ENFORCE(pre_in_var->IsType<framework::SelectedRows>(),
"Currently, gather_op only can gather SelectedRows."); "Currently, gather_op only can gather SelectedRows.");
auto pre_place = in_0_handle->place_;
PADDLE_ENFORCE_EQ(out_var_handle->place_.which(), pre_place.which(),
"The place of input and output should be the same.");
// Wait input done, this Wait is asynchronous operation // Wait input done, this Wait is asynchronous operation
WaitInputVarGenerated(in_var_handles); WaitInputVarGenerated(in_var_handles);
auto &pre_in_value = pre_in_var->Get<framework::SelectedRows>();
std::vector<int64_t> out_rows; std::vector<int64_t> out_rows;
std::vector<Tensor> in_tensors; std::vector<Tensor> in_tensors;
std::vector<platform::Place> in_places;
auto &pre_in = pre_in_var->Get<framework::SelectedRows>(); // Gather the inputs
// gather the inputs
for (auto *in_handle : in_var_handles) { for (auto *in_handle : in_var_handles) {
auto in_p = in_handle->place_;
in_places.push_back(in_p);
PADDLE_ENFORCE_EQ(in_p.which(), pre_place.which(),
"Places must be all on CPU or all on CUDA.");
auto *in_var = auto *in_var =
var_scopes.at(in_handle->scope_idx_)->FindVar(in_handle->name_); var_scopes.at(in_handle->scope_idx_)->FindVar(in_handle->name_);
auto &in_sr = in_var->Get<framework::SelectedRows>(); PADDLE_ENFORCE_NOT_NULL(in_var);
VariableVisitor::EnforceShapeAndDTypeEQ(*in_var, *pre_in_var);
PADDLE_ENFORCE_EQ(in_sr.value().type(), pre_in.value().type(), auto &in_sr_value = in_var->Get<framework::SelectedRows>();
"The type of input is not consistent.");
PADDLE_ENFORCE_EQ(pre_in.height(), in_sr.height(),
"The height of inputs is not consistent.");
PADDLE_ENFORCE_EQ(pre_in.GetCompleteDims(), in_sr.GetCompleteDims(),
"The dims of inputs is not consistent.");
auto &in_sr_rows = in_sr.rows(); auto &in_sr_rows = in_sr_value.rows();
out_rows.insert(out_rows.end(), in_sr_rows.begin(), in_sr_rows.end()); out_rows.insert(out_rows.end(), in_sr_rows.begin(), in_sr_rows.end());
in_tensors.emplace_back(in_sr_value.value());
in_tensors.emplace_back(in_sr.value());
} }
// write the output // NOTE: The Places of all input tensor must be all on CPU or all on GPU.
auto &out_place = out_var_handle->place_; platform::Place t_out_p = out_var_handle->place_;
auto out_scope_idx = out_var_handle->scope_idx_; if (platform::is_gpu_place(pre_in_value.place())) {
auto out_var = var_scopes.at(out_scope_idx)->FindVar(out_var_handle->name_); PADDLE_ENFORCE(platform::is_gpu_place(t_out_p),
"Places of input and output must be all on GPU.");
} else {
t_out_p = platform::CPUPlace();
}
auto out = out_var->GetMutable<framework::SelectedRows>(); auto out_var =
out->set_height(pre_in.height()); var_scopes.at(out_var_handle->scope_idx_)->FindVar(out_var_handle->name_);
out->set_rows(out_rows); PADDLE_ENFORCE_NOT_NULL(out_var);
auto out_value = out_var->GetMutable<framework::SelectedRows>();
out_value->set_height(pre_in_value.height());
out_value->set_rows(out_rows);
size_t rows = out_rows.size(); size_t rows = out_rows.size();
DDim out_dim = pre_in.GetCompleteDims(); DDim out_dim = pre_in_value.GetCompleteDims();
out_dim[0] = static_cast<int64_t>(rows); out_dim[0] = static_cast<int64_t>(rows);
out->mutable_value()->Resize(out_dim); out_value->mutable_value()->Resize(out_dim).mutable_data(
out->mutable_value()->mutable_data(out_place, pre_in.value().type()); t_out_p, pre_in_value.value().type());
Tensor *out_tensor = out->mutable_value(); Tensor *out_tensor = out_value->mutable_value();
// copy // copy
auto dev_ctx = dev_ctxes_[out_place]; auto dev_ctx = dev_ctxes_[out_var_handle->place_];
RunAndRecordEvent(out_place, [in_tensors, out_tensor, dev_ctx, out_place] { RunAndRecordEvent(out_var_handle->place_, [in_tensors, out_tensor, &dev_ctx,
t_out_p] {
int s = 0, e = 0; int s = 0, e = 0;
for (size_t j = 0; j < in_tensors.size(); ++j) { for (size_t j = 0; j < in_tensors.size(); ++j) {
e += in_tensors[j].dims()[0]; e += in_tensors[j].dims()[0];
auto sub_out = out_tensor->Slice(s, e); auto sub_out = out_tensor->Slice(s, e);
paddle::framework::TensorCopy(in_tensors[j], out_place, *(dev_ctx), paddle::framework::TensorCopy(in_tensors[j], t_out_p, *dev_ctx, &sub_out);
&sub_out);
s = e; s = e;
} }
}); });
......
...@@ -11,9 +11,11 @@ ...@@ -11,9 +11,11 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h" #include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
#include <utility>
#include "paddle/fluid/framework/details/broadcast_op_handle.h"
#include "paddle/fluid/framework/details/computation_op_handle.h" #include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/reduce_op_handle.h"
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h" #include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
#include "paddle/fluid/framework/details/send_op_handle.h" #include "paddle/fluid/framework/details/send_op_handle.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
...@@ -34,8 +36,8 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder( ...@@ -34,8 +36,8 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
const std::string &loss_var_name, const std::string &loss_var_name,
const std::unordered_set<std::string> &params, const std::unordered_set<std::string> &params,
const std::vector<Scope *> &local_scopes, bool use_default_grad_scale, const std::vector<Scope *> &local_scopes,
platform::NCCLContextMap *nccl_ctxs) platform::NCCLContextMap *nccl_ctxs, bool use_default_grad_scale)
: loss_var_name_(loss_var_name), : loss_var_name_(loss_var_name),
places_(places), places_(places),
local_scopes_(local_scopes), local_scopes_(local_scopes),
...@@ -105,6 +107,11 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(const OpDesc &op, ...@@ -105,6 +107,11 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(const OpDesc &op,
std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
const ProgramDesc &program) const { const ProgramDesc &program) const {
std::unordered_map<std::string, proto::VarType::Type> var_types;
for (auto *var : program.Block(0).AllVars()) {
var_types[var->Name()] = var->GetType();
}
auto graph = new SSAGraph(); auto graph = new SSAGraph();
SSAGraph &result = *graph; SSAGraph &result = *graph;
std::unordered_set<std::string> og_has_been_broadcast; std::unordered_set<std::string> og_has_been_broadcast;
...@@ -133,12 +140,17 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -133,12 +140,17 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
is_forwarding = false; is_forwarding = false;
} else { } else {
CreateComputationalOps(&result, *op, places_.size()); CreateComputationalOps(&result, *op, places_.size());
if (!is_forwarding) { if (!is_forwarding && places_.size() > 1) {
// Currently, we assume that once gradient is generated, it can be // Currently, we assume that once gradient is generated, it can be
// broadcast, and each gradient is only broadcast once. // broadcast, and each gradient is only broadcast once.
for (auto &og : op->OutputArgumentNames()) { for (auto &og : op->OutputArgumentNames()) {
if (IsParameterGradientOnce(og, &og_has_been_broadcast)) { if (IsParameterGradientOnce(og, &og_has_been_broadcast)) {
InsertNCCLAllReduceOp(&result, og); if (IsSparseGradient(var_types, og)) {
CreateReduceOp(&result, og, 0);
CreateBroadcastOp(&result, og, 0);
} else {
InsertNCCLAllReduceOp(&result, og);
}
} }
} }
} }
...@@ -165,6 +177,50 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -165,6 +177,50 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
return std::unique_ptr<SSAGraph>(graph); return std::unique_ptr<SSAGraph>(graph);
} }
bool MultiDevSSAGraphBuilder::IsSparseGradient(
const std::unordered_map<std::string, proto::VarType::Type> &var_types,
const std::string &og) const {
PADDLE_ENFORCE(var_types.count(og) != 0);
if (var_types.at(og) == proto::VarType::SELECTED_ROWS) {
return true;
}
return false;
}
void MultiDevSSAGraphBuilder::CreateBroadcastOp(SSAGraph *result,
const std::string &p_name,
size_t src_dev_id) const {
#ifdef PADDLE_WITH_CUDA
auto *op_handle = new BroadcastOpHandle(local_scopes_, places_, nccl_ctxs_);
#else
auto *op_handle = new BroadcastOpHandle(local_scopes_, places_);
#endif
result->ops_.emplace_back(op_handle);
auto *in = result->vars_.at(src_dev_id).at(p_name).back().get();
op_handle->AddInput(in);
for (size_t i = 0; i < places_.size(); ++i) {
auto &vars = result->vars_.at(i).at(p_name);
auto &p = places_[i];
auto *out_var = new VarHandle(vars.size(), i, p_name, p);
vars.emplace_back(out_var);
op_handle->AddOutput(out_var);
#ifndef ADDLE_WITH_CUDA
op_handle->SetDeviceContext(p,
platform::DeviceContextPool::Instance().Get(p));
#endif
}
}
void MultiDevSSAGraphBuilder::CreateComputationalOp(SSAGraph *result,
const OpDesc &op,
int dev_id) const {
result->ops_.emplace_back(
new ComputationOpHandle(op, local_scopes_[dev_id], places_[dev_id]));
CreateOpHandleIOs(result, op, dev_id);
}
OpDesc *MultiDevSSAGraphBuilder::GetSendOpDesc( OpDesc *MultiDevSSAGraphBuilder::GetSendOpDesc(
const ProgramDesc &program) const { const ProgramDesc &program) const {
for (auto *op : program.Block(0).AllOps()) { for (auto *op : program.Block(0).AllOps()) {
...@@ -174,7 +230,6 @@ OpDesc *MultiDevSSAGraphBuilder::GetSendOpDesc( ...@@ -174,7 +230,6 @@ OpDesc *MultiDevSSAGraphBuilder::GetSendOpDesc(
} }
return nullptr; return nullptr;
} }
void MultiDevSSAGraphBuilder::InsertNCCLAllReduceOp( void MultiDevSSAGraphBuilder::InsertNCCLAllReduceOp(
SSAGraph *result, const std::string &og) const { SSAGraph *result, const std::string &og) const {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
...@@ -247,6 +302,36 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(SSAGraph *result, ...@@ -247,6 +302,36 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(SSAGraph *result,
} }
} }
VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(SSAGraph *result,
const std::string &og,
int dst_dev_id) const {
#ifdef PADDLE_WITH_CUDA
result->ops_.emplace_back(
new ReduceOpHandle(local_scopes_, places_, nccl_ctxs_));
#else
result->ops_.emplace_back(new ReduceOpHandle(local_scopes_, places_));
#endif
auto *op_handle = result->ops_.back().get();
for (size_t i = 0; i < places_.size(); ++i) {
auto &vars = result->vars_[i][og];
#ifndef PADDLE_WITH_CUDA
auto &p = places_[i];
op_handle->SetDeviceContext(p,
platform::DeviceContextPool::Instance().Get(p));
#endif
PADDLE_ENFORCE(!vars.empty());
auto &prev_grad = vars.back();
op_handle->AddInput(prev_grad.get());
}
auto &vars = result->vars_[dst_dev_id][og];
auto var =
new VarHandle(vars.size() - 1, dst_dev_id, og, places_[dst_dev_id]);
vars.emplace_back(var);
op_handle->AddOutput(var);
return var;
}
void MultiDevSSAGraphBuilder::CreateSendOp(SSAGraph *result, void MultiDevSSAGraphBuilder::CreateSendOp(SSAGraph *result,
const OpDesc &op) const { const OpDesc &op) const {
auto &p = places_[0]; auto &p = places_[0];
...@@ -263,6 +348,7 @@ bool MultiDevSSAGraphBuilder::IsScaleLossOp(const OpDesc &op) const { ...@@ -263,6 +348,7 @@ bool MultiDevSSAGraphBuilder::IsScaleLossOp(const OpDesc &op) const {
return op.OutputArgumentNames().size() == 1 && return op.OutputArgumentNames().size() == 1 &&
op.OutputArgumentNames()[0] == GradVarName(loss_var_name_); op.OutputArgumentNames()[0] == GradVarName(loss_var_name_);
} }
} // namespace details } // namespace details
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -13,8 +13,8 @@ ...@@ -13,8 +13,8 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include <string> #include <string>
#include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/details/ssa_graph_builder.h" #include "paddle/fluid/framework/details/ssa_graph_builder.h"
...@@ -27,6 +27,7 @@ class NCCLContextMap; ...@@ -27,6 +27,7 @@ class NCCLContextMap;
namespace framework { namespace framework {
class Scope; class Scope;
namespace details { namespace details {
class MultiDevSSAGraphBuilder : public SSAGraphBuilder { class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
public: public:
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
...@@ -34,8 +35,8 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { ...@@ -34,8 +35,8 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
const std::string &loss_var_name, const std::string &loss_var_name,
const std::unordered_set<std::string> &params, const std::unordered_set<std::string> &params,
const std::vector<Scope *> &local_scopes, const std::vector<Scope *> &local_scopes,
bool skip_scale_loss, platform::NCCLContextMap *nccl_ctxs,
platform::NCCLContextMap *nccl_ctxs); bool use_default_grad_scale);
#else #else
MultiDevSSAGraphBuilder(const std::vector<platform::Place> &places, MultiDevSSAGraphBuilder(const std::vector<platform::Place> &places,
const std::string &loss_var_name, const std::string &loss_var_name,
...@@ -74,6 +75,10 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { ...@@ -74,6 +75,10 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
size_t num_places) const; size_t num_places) const;
void CreateScaleLossGradOp(SSAGraph *result) const; void CreateScaleLossGradOp(SSAGraph *result) const;
VarHandle *CreateReduceOp(SSAGraph *result, const std::string &og,
int dst_dev_id) const;
void CreateComputationalOp(SSAGraph *result, const OpDesc &op,
int dev_id) const;
bool IsParameterGradientOnce( bool IsParameterGradientOnce(
const std::string &og, const std::string &og,
...@@ -81,11 +86,18 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { ...@@ -81,11 +86,18 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
void InsertNCCLAllReduceOp(SSAGraph *result, const std::string &og) const; void InsertNCCLAllReduceOp(SSAGraph *result, const std::string &og) const;
void CreateBroadcastOp(SSAGraph *result, const std::string &p_name,
size_t src_dev_id) const;
/** /**
* Get send op in the global block of program. * Get send op in the global block of program.
* nullptr if not found. * nullptr if not found.
*/ */
OpDesc *GetSendOpDesc(const ProgramDesc &program) const; OpDesc *GetSendOpDesc(const ProgramDesc &program) const;
bool IsSparseGradient(
const std::unordered_map<std::string, proto::VarType::Type> &var_types,
const std::string &og) const;
}; };
} // namespace details } // namespace details
} // namespace framework } // namespace framework
......
...@@ -22,6 +22,7 @@ namespace framework { ...@@ -22,6 +22,7 @@ namespace framework {
namespace details { namespace details {
void ReduceOpHandle::RunImpl() { void ReduceOpHandle::RunImpl() {
if (places_.size() == 1) return;
// the input and output may have dummy var. // the input and output may have dummy var.
auto in_var_handles = DynamicCast<VarHandle>(inputs_); auto in_var_handles = DynamicCast<VarHandle>(inputs_);
...@@ -51,44 +52,48 @@ void ReduceOpHandle::RunImpl() { ...@@ -51,44 +52,48 @@ void ReduceOpHandle::RunImpl() {
// Wait input done, this Wait is asynchronous operation // Wait input done, this Wait is asynchronous operation
WaitInputVarGenerated(in_var_handles); WaitInputVarGenerated(in_var_handles);
auto pre_place = in_0_handle->place_;
std::vector<platform::Place> in_places;
auto pre_in_tensor = VariableVisitor::GetMutableTensor(pre_in_var);
for (auto *in_handle : in_var_handles) {
auto in_p = in_handle->place_;
PADDLE_ENFORCE_EQ(in_p.which(), pre_place.which(),
"Places must be all on CPU or all on CUDA.");
in_places.emplace_back(in_p);
// NOTE: The Places of all input tensor must be all on CPU or all on GPU.
std::vector<platform::Place> in_places; // used to get dev_ctx
for (auto *in_handle : in_var_handles) {
in_places.emplace_back(in_handle->place_);
auto in_var = auto in_var =
var_scopes.at(in_handle->scope_idx_)->FindVar(in_handle->name_); var_scopes.at(in_handle->scope_idx_)->FindVar(in_handle->name_);
PADDLE_ENFORCE_NOT_NULL(in_var); PADDLE_ENFORCE_NOT_NULL(in_var);
VariableVisitor::EnforceShapeAndDTypeEQ(*pre_in_var, *in_var);
auto in_tensor = VariableVisitor::GetMutableTensor(in_var);
PADDLE_ENFORCE_EQ(in_tensor.type(), pre_in_tensor.type(),
"The type of input is not consistent.");
} }
auto out_var = auto out_var =
var_scopes.at(out_var_handle->scope_idx_)->FindVar(out_var_handle->name_); var_scopes.at(out_var_handle->scope_idx_)->FindVar(out_var_handle->name_);
PADDLE_ENFORCE_NOT_NULL(out_var); PADDLE_ENFORCE_NOT_NULL(out_var);
// NOTE: The tensors' Place of input and output must be all on GPU or all on
// CPU.
auto in_p = VariableVisitor::GetMutableTensor(pre_in_var).place();
platform::Place t_out_p;
if (platform::is_gpu_place(in_p)) {
PADDLE_ENFORCE(platform::is_gpu_place(out_var_handle->place_),
"Places of input and output must be all on GPU.");
t_out_p = out_var_handle->place_;
} else {
t_out_p = platform::CPUPlace();
}
if (pre_in_var->IsType<framework::SelectedRows>()) { if (pre_in_var->IsType<framework::SelectedRows>()) {
std::vector<const SelectedRows *> in_selected_rows = std::vector<const SelectedRows *> in_selected_rows =
GetInputValues<SelectedRows>(in_var_handles, var_scopes); GetInputValues<SelectedRows>(in_var_handles, var_scopes);
GatherSelectedRows(in_selected_rows, in_places, dev_ctxes_, GatherSelectedRows(in_selected_rows, in_places, dev_ctxes_, t_out_p,
out_var_handle->place_,
out_var->GetMutable<framework::SelectedRows>()); out_var->GetMutable<framework::SelectedRows>());
} else { } else {
std::vector<const LoDTensor *> lod_tensors = std::vector<const LoDTensor *> lod_tensors =
GetInputValues<LoDTensor>(in_var_handles, var_scopes); GetInputValues<LoDTensor>(in_var_handles, var_scopes);
if (paddle::platform::is_cpu_place(pre_place)) { if (paddle::platform::is_cpu_place(lod_tensors[0]->place())) {
ReduceLoDTensor func(lod_tensors, ReduceLoDTensor func(lod_tensors,
out_var->GetMutable<framework::LoDTensor>()); out_var->GetMutable<framework::LoDTensor>());
VisitDataType(ToDataType(lod_tensors[0]->type()), func); VisitDataType(ToDataType(lod_tensors[0]->type()), func);
} else if (paddle::platform::is_gpu_place(pre_place)) { } else if (paddle::platform::is_gpu_place(lod_tensors[0]->place())) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
auto pre_in = pre_in_var->Get<framework::LoDTensor>(); auto pre_in = pre_in_var->Get<framework::LoDTensor>();
VariableVisitor::ShareDimsAndLoD(*pre_in_var, out_var); VariableVisitor::ShareDimsAndLoD(*pre_in_var, out_var);
...@@ -96,7 +101,7 @@ void ReduceOpHandle::RunImpl() { ...@@ -96,7 +101,7 @@ void ReduceOpHandle::RunImpl() {
out_var_handle->place_, pre_in.type()); out_var_handle->place_, pre_in.type());
auto out_p = out_var_handle->place_; auto out_p = out_var_handle->place_;
int root = boost::get<platform::CUDAPlace>(out_p).device; int root_id = boost::get<platform::CUDAPlace>(out_p).device;
std::vector<std::function<void()>> all_reduce_calls; std::vector<std::function<void()>> all_reduce_calls;
for (size_t i = 0; i < var_scopes.size(); ++i) { for (size_t i = 0; i < var_scopes.size(); ++i) {
auto &p = in_places[i]; auto &p = in_places[i];
...@@ -104,23 +109,23 @@ void ReduceOpHandle::RunImpl() { ...@@ -104,23 +109,23 @@ void ReduceOpHandle::RunImpl() {
int dev_id = boost::get<platform::CUDAPlace>(p).device; int dev_id = boost::get<platform::CUDAPlace>(p).device;
auto &nccl_ctx = nccl_ctxs_->at(dev_id); auto &nccl_ctx = nccl_ctxs_->at(dev_id);
auto stream = nccl_ctx.stream();
auto comm = nccl_ctx.comm_;
void *buffer = const_cast<void *>(lod_tensor.data<void>()); void *buffer = const_cast<void *>(lod_tensor.data<void>());
void *recvbuffer = nullptr; void *recvbuffer = nullptr;
if (root == dev_id) { if (root_id == dev_id) {
recvbuffer = recvbuffer =
out_var->GetMutable<framework::LoDTensor>()->mutable_data( out_var->GetMutable<framework::LoDTensor>()->mutable_data(
out_var_handle->place_); out_var_handle->place_);
} }
int type = platform::ToNCCLDataType(lod_tensor.type()); int type = platform::ToNCCLDataType(lod_tensor.type());
all_reduce_calls.emplace_back([=] { size_t numel = static_cast<size_t>(lod_tensor.numel());
PADDLE_ENFORCE(platform::dynload::ncclReduce( all_reduce_calls.emplace_back(
buffer, recvbuffer, static_cast<size_t>(lod_tensor.numel()), [buffer, recvbuffer, type, numel, root_id, &nccl_ctx] {
static_cast<ncclDataType_t>(type), ncclSum, root, comm, stream)); PADDLE_ENFORCE(platform::dynload::ncclReduce(
}); buffer, recvbuffer, numel, static_cast<ncclDataType_t>(type),
ncclSum, root_id, nccl_ctx.comm_, nccl_ctx.stream()));
});
} }
this->RunAndRecordEvent([&] { this->RunAndRecordEvent([&] {
...@@ -130,7 +135,7 @@ void ReduceOpHandle::RunImpl() { ...@@ -130,7 +135,7 @@ void ReduceOpHandle::RunImpl() {
} }
}); });
#else #else
PADDLE_THROW("CUDA is not support."); PADDLE_THROW("CUDA is not enabled.");
#endif #endif
} else { } else {
PADDLE_THROW("Place should be CPUPlace or CUDAPlace."); PADDLE_THROW("Place should be CPUPlace or CUDAPlace.");
......
...@@ -55,7 +55,7 @@ struct ReduceOpHandle : public OpHandleBase { ...@@ -55,7 +55,7 @@ struct ReduceOpHandle : public OpHandleBase {
std::string Name() const override; std::string Name() const override;
bool IsMultiDeviceTransfer() override { return false; }; bool IsMultiDeviceTransfer() override { return true; };
protected: protected:
void RunImpl() override; void RunImpl() override;
......
...@@ -62,7 +62,7 @@ struct VarHandle : public VarHandleBase { ...@@ -62,7 +62,7 @@ struct VarHandle : public VarHandleBase {
std::string name_; std::string name_;
platform::Place place_; platform::Place place_;
bool operator==(const VarHandle& o) const { bool IsTheSameVar(const VarHandle& o) const {
return o.generated_op_ == generated_op_ && o.name_ == name_ && return o.generated_op_ == generated_op_ && o.name_ == name_ &&
o.scope_idx_ == scope_idx_; o.scope_idx_ == scope_idx_;
} }
......
...@@ -88,6 +88,52 @@ void VariableVisitor::ShareDimsAndLoD(const Variable& src, Variable* trg) { ...@@ -88,6 +88,52 @@ void VariableVisitor::ShareDimsAndLoD(const Variable& src, Variable* trg) {
VisitVariable(src, &visitor); VisitVariable(src, &visitor);
} }
struct EnforceShapeAndDTypeEQVisitor {
const Variable* trg_;
void operator()(const LoDTensor& src) {
auto& tensor = trg_->Get<LoDTensor>();
PADDLE_ENFORCE_EQ(
src.place().which(), tensor.place().which(),
"The Places of the two Variable must be all on CPU or all on GPU.");
PADDLE_ENFORCE_EQ(src.type(), tensor.type(),
"The dtype of the two Variable is not equal.");
PADDLE_ENFORCE_EQ(src.dims(), tensor.dims(),
"The dims of the two Variable is not equal.");
PADDLE_ENFORCE_EQ(src.lod(), tensor.lod(),
"The lod of the two Variable is not equal.");
PADDLE_ENFORCE_EQ(src.layout(), tensor.layout(),
"The layout of the two Variable's tensor is not equal.");
}
void operator()(const SelectedRows& src) {
auto& selected_rows = trg_->Get<SelectedRows>();
PADDLE_ENFORCE_EQ(
src.place().which(), selected_rows.place().which(),
"The Places of the two Variable must be all on CPU or all on GPU.");
PADDLE_ENFORCE_EQ(src.value().type(), selected_rows.value().type(),
"The dtype of the two Variable is not equal.");
PADDLE_ENFORCE_EQ(src.value().layout(), selected_rows.value().layout(),
"The layout of the two Variable's tensor is not equal.");
PADDLE_ENFORCE_EQ(src.height(), selected_rows.height(),
"The height of the two Variable is not equal.");
PADDLE_ENFORCE_EQ(src.GetCompleteDims(), selected_rows.GetCompleteDims(),
"The dims of the two Variable is not equal.");
}
template <typename T>
void operator()(const T&) {
PADDLE_ENFORCE("EnforceShapeAndDTypeEQ is not supported by type %s",
typeid(T).name());
}
};
void VariableVisitor::EnforceShapeAndDTypeEQ(const Variable& var1,
const Variable& var2) {
EnforceShapeAndDTypeEQVisitor visitor{&var1};
VisitVariable(var2, &visitor);
}
} // namespace details } // namespace details
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -26,6 +26,9 @@ class VariableVisitor { ...@@ -26,6 +26,9 @@ class VariableVisitor {
static Tensor &GetMutableTensor(Variable *var); static Tensor &GetMutableTensor(Variable *var);
static void ShareDimsAndLoD(const Variable &src, Variable *trg); static void ShareDimsAndLoD(const Variable &src, Variable *trg);
static void EnforceShapeAndDTypeEQ(const Variable &var1,
const Variable &var2);
}; };
} // namespace details } // namespace details
......
...@@ -348,8 +348,12 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, ...@@ -348,8 +348,12 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
} }
} }
} }
platform::DeviceContextPool::Instance().Get(place_)->Wait();
if (create_vars && create_local_scope) { if (create_vars && create_local_scope) {
scope->DeleteScope(local_scope); scope->DeleteScope(local_scope);
} else {
// Delete the local scopes created in operators.
scope->DropKids();
} }
if (FLAGS_benchmark) { if (FLAGS_benchmark) {
VLOG(2) << "-------------------------------------------------------"; VLOG(2) << "-------------------------------------------------------";
......
...@@ -93,7 +93,7 @@ ParallelExecutor::ParallelExecutor( ...@@ -93,7 +93,7 @@ ParallelExecutor::ParallelExecutor(
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
details::MultiDevSSAGraphBuilder builder( details::MultiDevSSAGraphBuilder builder(
member_->places_, loss_var_name, params, member_->local_scopes_, member_->places_, loss_var_name, params, member_->local_scopes_,
use_default_grad_scale, member_->nccl_ctxs_.get()); member_->nccl_ctxs_.get(), use_default_grad_scale);
#else #else
details::MultiDevSSAGraphBuilder builder(member_->places_, loss_var_name, details::MultiDevSSAGraphBuilder builder(member_->places_, loss_var_name,
params, member_->local_scopes_, params, member_->local_scopes_,
......
nv_test(test_tensorrt SRCS test_tensorrt.cc DEPS dynload_cuda device_context dynamic_loader) nv_test(test_tensorrt SRCS test_tensorrt.cc DEPS dynload_cuda device_context dynamic_loader)
nv_test(test_tensorrt_engine SRCS test_engine.cc engine.cc DEPS dynload_cuda) nv_test(test_tensorrt_engine SRCS test_engine.cc engine.cc DEPS dynload_cuda)
nv_test(test_io_converter SRCS test_io_converter.cc io_converter.cc DEPS dynload_cuda dynamic_loader lod_tensor)
set(ENGINE_FILE ${CMAKE_CURRENT_SOURCE_DIR}/engine.cc) set(ENGINE_FILE ${CMAKE_CURRENT_SOURCE_DIR}/engine.cc)
add_subdirectory(convert) add_subdirectory(convert)
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/io_converter.h"
#include <cuda.h>
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace inference {
namespace tensorrt {
using platform::is_gpu_place;
using platform::is_cpu_place;
class DefaultInputConverter : public EngineInputConverter {
public:
DefaultInputConverter() {}
// NOTE out is GPU memory.
virtual void operator()(const LoDTensor& in, void* out,
size_t max_size) override {
PADDLE_ENFORCE(out != nullptr);
PADDLE_ENFORCE_LE(in.memory_size(), max_size);
const auto& place = in.place();
if (is_cpu_place(place)) {
PADDLE_ENFORCE(stream_ != nullptr);
PADDLE_ENFORCE_EQ(0,
cudaMemcpyAsync(out, in.data<float>(), in.memory_size(),
cudaMemcpyHostToDevice, *stream_));
} else if (is_gpu_place(place)) {
PADDLE_ENFORCE_EQ(0,
cudaMemcpyAsync(out, in.data<float>(), in.memory_size(),
cudaMemcpyHostToHost, *stream_));
} else {
PADDLE_THROW("Unknown device for converter");
}
cudaStreamSynchronize(*stream_);
}
};
REGISTER_TENSORRT_INPUT_CONVERTER(mul, DefaultInputConverter);
} // namespace tensorrt
} // namespace inference
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <unordered_map>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/inference/utils/singleton.h"
namespace paddle {
namespace inference {
namespace tensorrt {
using framework::LoDTensor;
/*
* Convert Input from Fluid to an Engine.
* TensorRT's ITensor follows row major, NCHW. Fluid is also row major, so in
* most cases just need to copy the data.
*/
class EngineInputConverter {
public:
EngineInputConverter() {}
virtual void operator()(const LoDTensor& in, void* out, size_t max_size) {}
void SetStream(cudaStream_t* stream) { stream_ = stream; }
static void Run(const std::string& in_op_type, const LoDTensor& in, void* out,
size_t max_size, cudaStream_t* stream) {
PADDLE_ENFORCE(stream != nullptr);
auto* converter = Registry<EngineInputConverter>::Lookup(in_op_type);
PADDLE_ENFORCE_NOT_NULL(converter);
converter->SetStream(stream);
(*converter)(in, out, max_size);
}
virtual ~EngineInputConverter() {}
protected:
cudaStream_t* stream_{nullptr};
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
#define REGISTER_TENSORRT_INPUT_CONVERTER(in_op_type__, Converter__) \
struct trt_input_##in_op_type__##_converter { \
trt_input_##in_op_type__##_converter() { \
::paddle::inference::Registry<EngineInputConverter>::Register< \
Converter__>(#in_op_type__); \
} \
}; \
trt_input_##in_op_type__##_converter trt_input_##in_op_type__##_converter__;
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/inference/tensorrt/io_converter.h"
#include <gtest/gtest.h>
namespace paddle {
namespace inference {
namespace tensorrt {
class EngineInputConverterTester : public ::testing::Test {
public:
void SetUp() override { tensor.Resize({10, 10}); }
framework::LoDTensor tensor;
};
TEST_F(EngineInputConverterTester, DefaultCPU) {
void* buffer;
tensor.mutable_data<float>(platform::CPUPlace());
ASSERT_EQ(cudaMalloc(&buffer, tensor.memory_size()), 0);
cudaStream_t stream;
EngineInputConverter::Run("mul", tensor, buffer, tensor.memory_size(),
&stream);
}
TEST_F(EngineInputConverterTester, DefaultGPU) {
void* buffer;
tensor.mutable_data<float>(platform::CUDAPlace());
ASSERT_EQ(cudaMalloc(&buffer, tensor.memory_size()), 0);
cudaStream_t stream;
EngineInputConverter::Run("mul", tensor, buffer, tensor.memory_size(),
&stream);
}
} // namespace tensorrt
} // namespace inference
} // namespace paddle
...@@ -24,6 +24,11 @@ function(inference_test TARGET_NAME) ...@@ -24,6 +24,11 @@ function(inference_test TARGET_NAME)
endforeach() endforeach()
endfunction(inference_test) endfunction(inference_test)
####################
# Inference tests here depend on fluid/tests/book. If users want to run
# individual test with ctest, they need to run tests in fluid/tests/book
# first to generate saved model.
####################
# This unittest is buggy! # This unittest is buggy!
#inference_test(fit_a_line) #inference_test(fit_a_line)
inference_test(image_classification ARGS vgg resnet) inference_test(image_classification ARGS vgg resnet)
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <unordered_map>
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace inference {
// NOTE not thread-safe.
template <typename T>
struct Singleton {
static T& Global() {
static T* x = new T;
return *x;
}
Singleton() = delete;
Singleton& operator=(const Singleton&) = delete;
};
/*
* An registor for any type.
* NOTE not thread-safe.
*/
template <typename ItemParent>
struct Registry {
static Registry& Global() {
static auto* x = new Registry<ItemParent>;
return *x;
}
template <typename ItemChild>
static void Register(const std::string& name) {
PADDLE_ENFORCE_EQ(items_.count(name), 0);
items_[name] = new ItemChild;
}
static ItemParent* Lookup(const std::string& name) {
auto it = items_.find(name);
if (it == items_.end()) return nullptr;
return it->second;
}
~Registry() {
for (auto& item : items_) {
delete item.second;
}
}
private:
Registry() = default;
static std::unordered_map<std::string, ItemParent*> items_;
};
template <typename ItemParent>
std::unordered_map<std::string, ItemParent*> Registry<ItemParent>::items_;
} // namespace inference
} // namespace paddle
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor;
class AdadeltaOp : public framework::OperatorWithKernel { class AdadeltaOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
...@@ -55,6 +56,12 @@ class AdadeltaOp : public framework::OperatorWithKernel { ...@@ -55,6 +56,12 @@ class AdadeltaOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("AvgSquaredGradOut", param_dim); ctx->SetOutputDim("AvgSquaredGradOut", param_dim);
ctx->SetOutputDim("AvgSquaredUpdateOut", param_dim); ctx->SetOutputDim("AvgSquaredUpdateOut", param_dim);
} }
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input_data_type =
framework::ToDataType(ctx.Input<Tensor>("Param")->type());
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
}; };
class AdadeltaOpMaker : public framework::OpProtoAndCheckerMaker { class AdadeltaOpMaker : public framework::OpProtoAndCheckerMaker {
......
...@@ -23,6 +23,7 @@ limitations under the License. */ ...@@ -23,6 +23,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor;
class AdagradOp : public framework::OperatorWithKernel { class AdagradOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
...@@ -56,6 +57,12 @@ class AdagradOp : public framework::OperatorWithKernel { ...@@ -56,6 +57,12 @@ class AdagradOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("ParamOut", param_dims); ctx->SetOutputDim("ParamOut", param_dims);
ctx->SetOutputDim("MomentOut", param_dims); ctx->SetOutputDim("MomentOut", param_dims);
} }
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto input_data_type =
framework::ToDataType(ctx.Input<Tensor>("Param")->type());
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
}; };
class AdagradOpMaker : public framework::OpProtoAndCheckerMaker { class AdagradOpMaker : public framework::OpProtoAndCheckerMaker {
......
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor;
class AdamOp : public framework::OperatorWithKernel { class AdamOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
...@@ -69,6 +70,12 @@ class AdamOp : public framework::OperatorWithKernel { ...@@ -69,6 +70,12 @@ class AdamOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("Moment1Out", param_dims); ctx->SetOutputDim("Moment1Out", param_dims);
ctx->SetOutputDim("Moment2Out", param_dims); ctx->SetOutputDim("Moment2Out", param_dims);
} }
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input_data_type =
framework::ToDataType(ctx.Input<Tensor>("Param")->type());
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
}; };
class AdamOpMaker : public framework::OpProtoAndCheckerMaker { class AdamOpMaker : public framework::OpProtoAndCheckerMaker {
......
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor;
class AdamaxOp : public framework::OperatorWithKernel { class AdamaxOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
...@@ -63,6 +64,12 @@ class AdamaxOp : public framework::OperatorWithKernel { ...@@ -63,6 +64,12 @@ class AdamaxOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("MomentOut", param_dims); ctx->SetOutputDim("MomentOut", param_dims);
ctx->SetOutputDim("InfNormOut", param_dims); ctx->SetOutputDim("InfNormOut", param_dims);
} }
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input_data_type =
framework::ToDataType(ctx.Input<Tensor>("Param")->type());
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
}; };
class AdamaxOpMaker : public framework::OpProtoAndCheckerMaker { class AdamaxOpMaker : public framework::OpProtoAndCheckerMaker {
......
...@@ -366,7 +366,8 @@ REGISTER_OP_KERNEL(conv2d_grad, CUDNN, plat::CUDAPlace, ...@@ -366,7 +366,8 @@ REGISTER_OP_KERNEL(conv2d_grad, CUDNN, plat::CUDAPlace,
REGISTER_OP_KERNEL(conv3d, CUDNN, plat::CUDAPlace, REGISTER_OP_KERNEL(conv3d, CUDNN, plat::CUDAPlace,
paddle::operators::CUDNNConvOpKernel<float>, paddle::operators::CUDNNConvOpKernel<float>,
paddle::operators::CUDNNConvOpKernel<double>); paddle::operators::CUDNNConvOpKernel<double>,
paddle::operators::CUDNNConvOpKernel<plat::float16>);
REGISTER_OP_KERNEL(conv3d_grad, CUDNN, plat::CUDAPlace, REGISTER_OP_KERNEL(conv3d_grad, CUDNN, plat::CUDAPlace,
paddle::operators::CUDNNConvGradOpKernel<float>, paddle::operators::CUDNNConvGradOpKernel<float>,
paddle::operators::CUDNNConvGradOpKernel<double>); paddle::operators::CUDNNConvGradOpKernel<double>);
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor;
class DecayedAdagradOp : public framework::OperatorWithKernel { class DecayedAdagradOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
...@@ -51,6 +52,12 @@ class DecayedAdagradOp : public framework::OperatorWithKernel { ...@@ -51,6 +52,12 @@ class DecayedAdagradOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("ParamOut", param_dims); ctx->SetOutputDim("ParamOut", param_dims);
ctx->SetOutputDim("MomentOut", param_dims); ctx->SetOutputDim("MomentOut", param_dims);
} }
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input_data_type =
framework::ToDataType(ctx.Input<Tensor>("Param")->type());
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
}; };
class DecayedAdagradOpMaker : public framework::OpProtoAndCheckerMaker { class DecayedAdagradOpMaker : public framework::OpProtoAndCheckerMaker {
......
...@@ -69,6 +69,10 @@ message VariableMessage { ...@@ -69,6 +69,10 @@ message VariableMessage {
bytes rows = 9; bytes rows = 9;
// Look up table block execution output variable name. // Look up table block execution output variable name.
string out_varname = 10; string out_varname = 10;
// If true, the ps server will start profiling, the ps
// server stops profiling and generates a profile to /tmp/profile_ps_*
// when profile switches from true to false.
bool profile = 11;
} }
message VoidMessage {} message VoidMessage {}
...@@ -23,6 +23,7 @@ limitations under the License. */ ...@@ -23,6 +23,7 @@ limitations under the License. */
#include "paddle/fluid/operators/detail/bytebuffer_stream.h" #include "paddle/fluid/operators/detail/bytebuffer_stream.h"
#include "paddle/fluid/operators/detail/proto_encoder_helper.h" #include "paddle/fluid/operators/detail/proto_encoder_helper.h"
#include "paddle/fluid/operators/detail/variable_response.h" #include "paddle/fluid/operators/detail/variable_response.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -45,6 +46,13 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, ...@@ -45,6 +46,13 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
void* payload = nullptr; void* payload = nullptr;
size_t payload_size; size_t payload_size;
ProtoEncodeHelper e(static_cast<char*>(buf), 1024); ProtoEncodeHelper e(static_cast<char*>(buf), 1024);
// Note: normally the profiler is enabled in 1 trainer, hence only
// 1 trainer returns true for ShouldSendProfileState(). It tells PS
// servers the trainer's profiling state so that PS can follow the
// trainer.
if (platform::ShouldSendProfileState()) {
e.WriteBool(VarMsg::kProfileFieldNumber, platform::IsProfileEnabled());
}
e.WriteString(VarMsg::kVarnameFieldNumber, name); e.WriteString(VarMsg::kVarnameFieldNumber, name);
if (var->IsType<framework::LoDTensor>()) { if (var->IsType<framework::LoDTensor>()) {
e.WriteUint64(VarMsg::kTypeFieldNumber, 0); e.WriteUint64(VarMsg::kTypeFieldNumber, 0);
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <string> #include <string>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/operators/detail/send_recv.pb.h" #include "paddle/fluid/operators/detail/send_recv.pb.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h" #include "paddle/fluid/operators/detail/sendrecvop_utils.h"
...@@ -427,7 +428,26 @@ int VariableResponse::Parse(Source* source) { ...@@ -427,7 +428,26 @@ int VariableResponse::Parse(Source* source) {
meta_.set_out_varname(temp); meta_.set_out_varname(temp);
break; break;
} }
case sendrecv::VariableMessage::kProfileFieldNumber: {
bool profiling;
if (!input.ReadRaw(reinterpret_cast<void*>(&profiling), 1)) {
return tag;
}
meta_.set_profile(profiling);
int64_t listener_id = platform::ListenerId();
if (listener_id <= 0) {
break;
}
if (profiling && !platform::IsProfileEnabled()) {
platform::EnableProfiler(platform::ProfilerState::kCPU);
} else if (!profiling && platform::IsProfileEnabled()) {
// TODO(panyx0718): Should we allow to customize file dir.
platform::DisableProfiler(
platform::EventSortingKey::kDefault,
string::Sprintf("/tmp/profile_ps_%lld", listener_id));
}
break;
}
default: { default: {
// Unknown tag, return unknown error. // Unknown tag, return unknown error.
return -1; return -1;
......
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor;
class FTRLOp : public framework::OperatorWithKernel { class FTRLOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
...@@ -53,6 +54,12 @@ class FTRLOp : public framework::OperatorWithKernel { ...@@ -53,6 +54,12 @@ class FTRLOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("SquaredAccumOut", param_dim); ctx->SetOutputDim("SquaredAccumOut", param_dim);
ctx->SetOutputDim("LinearAccumOut", param_dim); ctx->SetOutputDim("LinearAccumOut", param_dim);
} }
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input_data_type =
framework::ToDataType(ctx.Input<Tensor>("Param")->type());
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
}; };
class FTRLOpMaker : public framework::OpProtoAndCheckerMaker { class FTRLOpMaker : public framework::OpProtoAndCheckerMaker {
......
...@@ -18,6 +18,7 @@ limitations under the License. */ ...@@ -18,6 +18,7 @@ limitations under the License. */
#include <vector> #include <vector>
#include "paddle/fluid/operators/listen_and_serv_op.h" #include "paddle/fluid/operators/listen_and_serv_op.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -294,6 +295,8 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor, ...@@ -294,6 +295,8 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
void ListenAndServOp::RunImpl(const framework::Scope &scope, void ListenAndServOp::RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const { const platform::Place &dev_place) const {
// Mark this as PS that it should decide profiling by listening from trainer.
platform::SetProfileListener();
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(dev_place); auto &dev_ctx = *pool.Get(dev_place);
framework::Scope &recv_scope = scope.NewScope(); framework::Scope &recv_scope = scope.NewScope();
...@@ -328,9 +331,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -328,9 +331,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
rpc_service_->WaitServerReady(); rpc_service_->WaitServerReady();
// Write to a file of server selected port for python use. // Write to a file of server selected port for python use.
std::string file_path = std::string file_path = string::Sprintf("/tmp/paddle.%d.selected_port",
string::Sprintf("/tmp/paddle.%d.selected_port", static_cast<int>(::getpid()));
static_cast<int>(::getpid()));
SavePort(file_path); SavePort(file_path);
if (sync_mode) { if (sync_mode) {
RunSyncLoop(&executor, program, &recv_scope, prefetch_block); RunSyncLoop(&executor, program, &recv_scope, prefetch_block);
......
...@@ -46,8 +46,7 @@ class LoDResetKernel : public framework::OpKernel<T> { ...@@ -46,8 +46,7 @@ class LoDResetKernel : public framework::OpKernel<T> {
auto* lod = lod_t->data<int>(); auto* lod = lod_t->data<int>();
if (platform::is_gpu_place(ctx.GetPlace())) { if (platform::is_gpu_place(ctx.GetPlace())) {
framework::Tensor lod_cpu; framework::Tensor lod_cpu;
framework::TensorCopy(*lod_t, platform::CPUPlace(), framework::TensorCopySync(*lod_t, platform::CPUPlace(), &lod_cpu);
ctx.device_context(), &lod_cpu);
lod = lod_cpu.data<int>(); lod = lod_cpu.data<int>();
} }
level0 = std::vector<int>(lod, lod + lod_t->numel()); level0 = std::vector<int>(lod, lod + lod_t->numel());
......
...@@ -62,7 +62,7 @@ class LookupSparseTableOp : public framework::OperatorBase { ...@@ -62,7 +62,7 @@ class LookupSparseTableOp : public framework::OperatorBase {
auto w_t = w_var->GetMutable<framework::SelectedRows>(); auto w_t = w_var->GetMutable<framework::SelectedRows>();
std::vector<int64_t> keys; std::vector<int64_t> keys;
keys.resize(ids_t.numel()); keys.resize(ids_t.numel());
for (size_t i = 0; i < ids_t.numel(); ++i) { for (int64_t i = 0; i < ids_t.numel(); ++i) {
keys[i] = ids_t.data<int64_t>()[i]; keys[i] = ids_t.data<int64_t>()[i];
} }
......
...@@ -41,7 +41,7 @@ math_library(depthwise_conv) ...@@ -41,7 +41,7 @@ math_library(depthwise_conv)
math_library(gru_compute DEPS activation_functions math_function) math_library(gru_compute DEPS activation_functions math_function)
math_library(im2col) math_library(im2col)
math_library(lstm_compute DEPS activation_functions) math_library(lstm_compute DEPS activation_functions)
cc_library(blas SRCS blas.cc DEPS cblas framework_proto) cc_library(blas SRCS blas.cc DEPS cblas framework_proto device_context)
math_library(math_function DEPS blas) math_library(math_function DEPS blas)
math_library(maxouting) math_library(maxouting)
math_library(pooling) math_library(pooling)
......
...@@ -69,8 +69,8 @@ void testConcat() { ...@@ -69,8 +69,8 @@ void testConcat() {
} }
if (paddle::platform::is_gpu_place(Place())) { if (paddle::platform::is_gpu_place(Place())) {
paddle::framework::TensorCopy(input_a_cpu, Place(), *context, &input_a); paddle::framework::TensorCopySync(input_a_cpu, Place(), &input_a);
paddle::framework::TensorCopy(input_b_cpu, Place(), *context, &input_b); paddle::framework::TensorCopySync(input_b_cpu, Place(), &input_b);
} }
std::vector<paddle::framework::Tensor> input; std::vector<paddle::framework::Tensor> input;
...@@ -86,8 +86,8 @@ void testConcat() { ...@@ -86,8 +86,8 @@ void testConcat() {
int* out_ptr; int* out_ptr;
if (paddle::platform::is_gpu_place(Place())) { if (paddle::platform::is_gpu_place(Place())) {
paddle::framework::TensorCopy(out, paddle::platform::CPUPlace(), *context, paddle::framework::TensorCopySync(out, paddle::platform::CPUPlace(),
&out_cpu); &out_cpu);
out_ptr = out_cpu.data<int>(); out_ptr = out_cpu.data<int>();
} else { } else {
out_ptr = out.data<int>(); out_ptr = out.data<int>();
...@@ -142,8 +142,8 @@ void testConcat() { ...@@ -142,8 +142,8 @@ void testConcat() {
} }
if (paddle::platform::is_gpu_place(Place())) { if (paddle::platform::is_gpu_place(Place())) {
paddle::framework::TensorCopy(input_a_cpu, Place(), *context, &input_a); paddle::framework::TensorCopySync(input_a_cpu, Place(), &input_a);
paddle::framework::TensorCopy(input_b_cpu, Place(), *context, &input_b); paddle::framework::TensorCopySync(input_b_cpu, Place(), &input_b);
} }
input.clear(); input.clear();
...@@ -157,8 +157,8 @@ void testConcat() { ...@@ -157,8 +157,8 @@ void testConcat() {
PADDLE_ENFORCE_EQ(input_b.dims(), dim_b); PADDLE_ENFORCE_EQ(input_b.dims(), dim_b);
if (paddle::platform::is_gpu_place(Place())) { if (paddle::platform::is_gpu_place(Place())) {
paddle::framework::TensorCopy(out, paddle::platform::CPUPlace(), *context, paddle::framework::TensorCopySync(out, paddle::platform::CPUPlace(),
&out_cpu); &out_cpu);
out_ptr = out_cpu.data<int>(); out_ptr = out_cpu.data<int>();
} else { } else {
out_ptr = out.data<int>(); out_ptr = out.data<int>();
...@@ -215,8 +215,8 @@ void testConcat() { ...@@ -215,8 +215,8 @@ void testConcat() {
} }
if (paddle::platform::is_gpu_place(Place())) { if (paddle::platform::is_gpu_place(Place())) {
paddle::framework::TensorCopy(input_a_cpu, Place(), *context, &input_a); paddle::framework::TensorCopySync(input_a_cpu, Place(), &input_a);
paddle::framework::TensorCopy(input_b_cpu, Place(), *context, &input_b); paddle::framework::TensorCopySync(input_b_cpu, Place(), &input_b);
} }
input.clear(); input.clear();
...@@ -230,8 +230,8 @@ void testConcat() { ...@@ -230,8 +230,8 @@ void testConcat() {
PADDLE_ENFORCE_EQ(input_b.dims(), dim_b); PADDLE_ENFORCE_EQ(input_b.dims(), dim_b);
if (paddle::platform::is_gpu_place(Place())) { if (paddle::platform::is_gpu_place(Place())) {
paddle::framework::TensorCopy(out, paddle::platform::CPUPlace(), *context, paddle::framework::TensorCopySync(out, paddle::platform::CPUPlace(),
&out_cpu); &out_cpu);
out_ptr = out_cpu.data<int>(); out_ptr = out_cpu.data<int>();
} else { } else {
out_ptr = out.data<int>(); out_ptr = out.data<int>();
...@@ -290,8 +290,8 @@ void testConcat() { ...@@ -290,8 +290,8 @@ void testConcat() {
} }
if (paddle::platform::is_gpu_place(Place())) { if (paddle::platform::is_gpu_place(Place())) {
paddle::framework::TensorCopy(input_a_cpu, Place(), *context, &input_a); paddle::framework::TensorCopySync(input_a_cpu, Place(), &input_a);
paddle::framework::TensorCopy(input_b_cpu, Place(), *context, &input_b); paddle::framework::TensorCopySync(input_b_cpu, Place(), &input_b);
} }
input.clear(); input.clear();
...@@ -305,8 +305,8 @@ void testConcat() { ...@@ -305,8 +305,8 @@ void testConcat() {
PADDLE_ENFORCE_EQ(input_b.dims(), dim_b); PADDLE_ENFORCE_EQ(input_b.dims(), dim_b);
if (paddle::platform::is_gpu_place(Place())) { if (paddle::platform::is_gpu_place(Place())) {
paddle::framework::TensorCopy(out, paddle::platform::CPUPlace(), *context, paddle::framework::TensorCopySync(out, paddle::platform::CPUPlace(),
&out_cpu); &out_cpu);
out_ptr = out_cpu.data<int>(); out_ptr = out_cpu.data<int>();
} else { } else {
out_ptr = out.data<int>(); out_ptr = out.data<int>();
......
...@@ -59,9 +59,9 @@ void naive_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value, ...@@ -59,9 +59,9 @@ void naive_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
r_prev_state = value.prev_state_value[i]; r_prev_state = value.prev_state_value[i];
} }
op(r_value_in, r_value_ig, r_value_fg, r_value_og, r_prev_state, r_state, op(&r_value_in, &r_value_ig, &r_value_fg, &r_value_og, &r_prev_state,
r_state_atv, r_out, r_checkI, r_checkF, r_checkO, active_node, &r_state, &r_state_atv, &r_out, &r_checkI, &r_checkF, &r_checkO,
active_gate, active_state); active_node, active_gate, active_state);
value_in[i] = r_value_in; value_in[i] = r_value_in;
value_ig[i] = r_value_ig; value_ig[i] = r_value_ig;
...@@ -125,11 +125,11 @@ void naive_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value, ...@@ -125,11 +125,11 @@ void naive_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
r_prev_state = value.prev_state_value[i]; r_prev_state = value.prev_state_value[i];
} }
op(r_value_in, r_value_ig, r_value_fg, r_value_og, r_grad_in, r_grad_ig, op(&r_value_in, &r_value_ig, &r_value_fg, &r_value_og, &r_grad_in,
r_grad_fg, r_grad_og, r_prev_state, r_prev_state_grad, r_state, &r_grad_ig, &r_grad_fg, &r_grad_og, &r_prev_state, &r_prev_state_grad,
r_state_grad, r_state_atv, r_output_grad, r_checkI, r_checkF, r_checkO, &r_state, &r_state_grad, &r_state_atv, &r_output_grad, &r_checkI,
r_checkIGrad, r_checkFGrad, r_checkOGrad, active_node, active_gate, &r_checkF, &r_checkO, &r_checkIGrad, &r_checkFGrad, &r_checkOGrad,
active_state); active_node, active_gate, active_state);
grad_in[i] = r_grad_in; grad_in[i] = r_grad_in;
grad_ig[i] = r_grad_ig; grad_ig[i] = r_grad_ig;
...@@ -186,9 +186,9 @@ void avx_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value, ...@@ -186,9 +186,9 @@ void avx_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
r_prev_state = (reinterpret_cast<__m256 *>(value.prev_state_value))[i]; r_prev_state = (reinterpret_cast<__m256 *>(value.prev_state_value))[i];
} }
op(r_value_in, r_value_ig, r_value_fg, r_value_og, r_prev_state, r_state, op(&r_value_in, &r_value_ig, &r_value_fg, &r_value_og, &r_prev_state,
r_state_atv, r_out, r_checkI, r_checkF, r_checkO, active_node, &r_state, &r_state_atv, &r_out, &r_checkI, &r_checkF, &r_checkO,
active_gate, active_state); active_node, active_gate, active_state);
value_in[i] = r_value_in; value_in[i] = r_value_in;
value_ig[i] = r_value_ig; value_ig[i] = r_value_ig;
...@@ -258,11 +258,11 @@ void avx_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value, ...@@ -258,11 +258,11 @@ void avx_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
r_prev_state = (reinterpret_cast<__m256 *>(value.prev_state_value))[i]; r_prev_state = (reinterpret_cast<__m256 *>(value.prev_state_value))[i];
} }
op(r_value_in, r_value_ig, r_value_fg, r_value_og, r_grad_in, r_grad_ig, op(&r_value_in, &r_value_ig, &r_value_fg, &r_value_og, &r_grad_in,
r_grad_fg, r_grad_og, r_prev_state, r_prev_state_grad, r_state, &r_grad_ig, &r_grad_fg, &r_grad_og, &r_prev_state, &r_prev_state_grad,
r_state_grad, r_state_atv, r_output_grad, r_checkI, r_checkF, r_checkO, &r_state, &r_state_grad, &r_state_atv, &r_output_grad, &r_checkI,
r_checkIGrad, r_checkFGrad, r_checkOGrad, active_node, active_gate, &r_checkF, &r_checkO, &r_checkIGrad, &r_checkFGrad, &r_checkOGrad,
active_state); active_node, active_gate, active_state);
grad_in[i] = r_grad_in; grad_in[i] = r_grad_in;
grad_ig[i] = r_grad_ig; grad_ig[i] = r_grad_ig;
......
...@@ -70,9 +70,9 @@ __global__ void KeLstmForward(Op op, LstmMetaValue<T> value, int frame_size, ...@@ -70,9 +70,9 @@ __global__ void KeLstmForward(Op op, LstmMetaValue<T> value, int frame_size,
r_prev_state = value.prev_state_value[frame_idx]; r_prev_state = value.prev_state_value[frame_idx];
} }
op(r_value_in, r_value_ig, r_value_fg, r_value_og, r_prev_state, r_state, op(&r_value_in, &r_value_ig, &r_value_fg, &r_value_og, &r_prev_state,
r_state_atv, r_out, r_checkI, r_checkF, r_checkO, active_node, active_gate, &r_state, &r_state_atv, &r_out, &r_checkI, &r_checkF, &r_checkO,
active_state); active_node, active_gate, active_state);
value.gate_value[frame_idx] = r_value_in; value.gate_value[frame_idx] = r_value_in;
value.gate_value[frame_idx + frame_size] = r_value_ig; value.gate_value[frame_idx + frame_size] = r_value_ig;
...@@ -145,11 +145,11 @@ __global__ void KeLstmBackward(Op op, LstmMetaValue<T> value, ...@@ -145,11 +145,11 @@ __global__ void KeLstmBackward(Op op, LstmMetaValue<T> value,
r_prev_state = value.prev_state_value[frame_idx]; r_prev_state = value.prev_state_value[frame_idx];
} }
op(r_value_in, r_value_ig, r_value_fg, r_value_og, r_grad_in, r_grad_ig, op(&r_value_in, &r_value_ig, &r_value_fg, &r_value_og, &r_grad_in, &r_grad_ig,
r_grad_fg, r_grad_og, r_prev_state, r_prev_state_grad, r_state, &r_grad_fg, &r_grad_og, &r_prev_state, &r_prev_state_grad, &r_state,
r_state_grad, r_state_atv, r_output_grad, r_checkI, r_checkF, r_checkO, &r_state_grad, &r_state_atv, &r_output_grad, &r_checkI, &r_checkF,
r_checkIGrad, r_checkFGrad, r_checkOGrad, active_node, active_gate, &r_checkO, &r_checkIGrad, &r_checkFGrad, &r_checkOGrad, active_node,
active_state); active_gate, active_state);
grad.gate_grad[frame_idx] = r_grad_in; grad.gate_grad[frame_idx] = r_grad_in;
grad.gate_grad[frame_idx + frame_size] = r_grad_ig; grad.gate_grad[frame_idx + frame_size] = r_grad_ig;
......
...@@ -12,11 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,11 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once
#include <type_traits>
#include "paddle/fluid/operators/math/detail/activation_functions.h" #include "paddle/fluid/operators/math/detail/activation_functions.h"
#include "paddle/fluid/platform/hostdevice.h" #include "paddle/fluid/platform/hostdevice.h"
#include <type_traits>
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
...@@ -27,19 +27,19 @@ namespace forward { ...@@ -27,19 +27,19 @@ namespace forward {
template <class T> template <class T>
class lstm { class lstm {
public: public:
HOSTDEVICE void operator()(T &value_in, T &value_ig, T &value_fg, T &value_og, HOSTDEVICE void operator()(T *value_in, T *value_ig, T *value_fg, T *value_og,
T &prev_state, T &state, T &state_atv, T &output, T *prev_state, T *state, T *state_atv, T *output,
T &checkI, T &checkF, T &checkO, T *checkI, T *checkF, T *checkO,
ActivationType active_node, ActivationType active_node,
ActivationType active_gate, ActivationType active_gate,
ActivationType active_state) { ActivationType active_state) {
value_in = activation(value_in, active_node); *value_in = activation(*value_in, active_node);
value_ig = activation(value_ig + prev_state * checkI, active_gate); *value_ig = activation(*value_ig + (*prev_state) * (*checkI), active_gate);
value_fg = activation(value_fg + prev_state * checkF, active_gate); *value_fg = activation(*value_fg + (*prev_state) * (*checkF), active_gate);
state = value_in * value_ig + prev_state * value_fg; *state = (*value_in) * (*value_ig) + (*prev_state) * (*value_fg);
value_og = activation(value_og + state * checkO, active_gate); *value_og = activation(*value_og + (*state) * (*checkO), active_gate);
state_atv = activation(state, active_state); *state_atv = activation(*state, active_state);
output = value_og * state_atv; *output = (*value_og) * (*state_atv);
} }
#ifndef __NVCC__ #ifndef __NVCC__
#ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default #ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default
...@@ -48,27 +48,27 @@ class lstm { ...@@ -48,27 +48,27 @@ class lstm {
// Only float support AVX optimization // Only float support AVX optimization
static const bool avx = std::is_same<T, float>::value; static const bool avx = std::is_same<T, float>::value;
HOSTDEVICE void operator()(__m256 &value_in, __m256 &value_ig, HOSTDEVICE void operator()(__m256 *value_in, __m256 *value_ig,
__m256 &value_fg, __m256 &value_og, __m256 *value_fg, __m256 *value_og,
__m256 &prev_state, __m256 &state, __m256 *prev_state, __m256 *state,
__m256 &state_atv, __m256 &output, __m256 &checkI, __m256 *state_atv, __m256 *output, __m256 *checkI,
__m256 &checkF, __m256 &checkO, __m256 *checkF, __m256 *checkO,
ActivationType active_node, ActivationType active_node,
ActivationType active_gate, ActivationType active_gate,
ActivationType active_state) { ActivationType active_state) {
value_in = activation(value_in, active_node); *value_in = activation(*value_in, active_node);
value_ig = *value_ig = activation(
activation(_mm256_add_ps(value_ig, _mm256_mul_ps(prev_state, checkI)), _mm256_add_ps(*value_ig, _mm256_mul_ps(*prev_state, *checkI)),
active_gate); active_gate);
value_fg = *value_fg = activation(
activation(_mm256_add_ps(value_fg, _mm256_mul_ps(prev_state, checkF)), _mm256_add_ps(*value_fg, _mm256_mul_ps(*prev_state, *checkF)),
active_gate); active_gate);
state = _mm256_add_ps(_mm256_mul_ps(value_in, value_ig), *state = _mm256_add_ps(_mm256_mul_ps(*value_in, *value_ig),
_mm256_mul_ps(prev_state, value_fg)); _mm256_mul_ps(*prev_state, *value_fg));
value_og = activation(_mm256_add_ps(value_og, _mm256_mul_ps(state, checkO)), *value_og = activation(
active_gate); _mm256_add_ps(*value_og, _mm256_mul_ps(*state, *checkO)), active_gate);
state_atv = activation(state, active_state); *state_atv = activation(*state, active_state);
output = _mm256_mul_ps(value_og, state_atv); *output = _mm256_mul_ps(*value_og, *state_atv);
} }
#endif #endif
#endif #endif
...@@ -81,26 +81,29 @@ namespace backward { ...@@ -81,26 +81,29 @@ namespace backward {
template <class T> template <class T>
class lstm { class lstm {
public: public:
HOSTDEVICE void operator()(T &value_in, T &value_ig, T &value_fg, T &value_og, HOSTDEVICE void operator()(T *value_in, T *value_ig, T *value_fg, T *value_og,
T &grad_in, T &grad_ig, T &grad_fg, T &grad_og, T *grad_in, T *grad_ig, T *grad_fg, T *grad_og,
T &prev_state, T &prev_state_grad, T &state, T *prev_state, T *prev_state_grad, T *state,
T &state_grad, T &state_atv, T &output_grad, T *state_grad, T *state_atv, T *output_grad,
T &checkI, T &checkF, T &checkO, T &checkIGrad, T *checkI, T *checkF, T *checkO, T *checkIGrad,
T &checkFGrad, T &checkOGrad, T *checkFGrad, T *checkOGrad,
ActivationType active_node, ActivationType active_node,
ActivationType active_gate, ActivationType active_gate,
ActivationType active_state) { ActivationType active_state) {
grad_og = activation(output_grad * state_atv, value_og, active_gate); *grad_og =
state_grad += activation(output_grad * value_og, state_atv, active_state) + activation((*output_grad) * (*state_atv), *value_og, active_gate);
grad_og * checkO; *state_grad +=
grad_in = activation(state_grad * value_ig, value_in, active_node); activation((*output_grad) * (*value_og), *state_atv, active_state) +
grad_ig = activation(state_grad * value_in, value_ig, active_gate); (*grad_og) * (*checkO);
grad_fg = activation(state_grad * prev_state, value_fg, active_gate); *grad_in = activation((*state_grad) * (*value_ig), *value_in, active_node);
prev_state_grad = *grad_ig = activation((*state_grad) * (*value_in), *value_ig, active_gate);
grad_ig * checkI + grad_fg * checkF + state_grad * value_fg; *grad_fg =
checkIGrad = grad_ig * prev_state; activation((*state_grad) * (*prev_state), *value_fg, active_gate);
checkFGrad = grad_fg * prev_state; *prev_state_grad = (*grad_ig) * (*checkI) + (*grad_fg) * (*checkF) +
checkOGrad = grad_og * state; (*state_grad) * (*value_fg);
*checkIGrad = (*grad_ig) * (*prev_state);
*checkFGrad = (*grad_fg) * (*prev_state);
*checkOGrad = (*grad_og) * (*state);
} }
#ifndef __NVCC__ #ifndef __NVCC__
#ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default #ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default
...@@ -109,32 +112,33 @@ class lstm { ...@@ -109,32 +112,33 @@ class lstm {
// Only float support AVX optimization // Only float support AVX optimization
static const bool avx = std::is_same<T, float>::value; static const bool avx = std::is_same<T, float>::value;
HOSTDEVICE void operator()( HOSTDEVICE void operator()(
__m256 &value_in, __m256 &value_ig, __m256 &value_fg, __m256 &value_og, __m256 *value_in, __m256 *value_ig, __m256 *value_fg, __m256 *value_og,
__m256 &grad_in, __m256 &grad_ig, __m256 &grad_fg, __m256 &grad_og, __m256 *grad_in, __m256 *grad_ig, __m256 *grad_fg, __m256 *grad_og,
__m256 &prev_state, __m256 &prev_state_grad, __m256 &state, __m256 *prev_state, __m256 *prev_state_grad, __m256 *state,
__m256 &state_grad, __m256 &state_atv, __m256 &output_grad, __m256 *state_grad, __m256 *state_atv, __m256 *output_grad,
__m256 &checkI, __m256 &checkF, __m256 &checkO, __m256 &checkIGrad, __m256 *checkI, __m256 *checkF, __m256 *checkO, __m256 *checkIGrad,
__m256 &checkFGrad, __m256 &checkOGrad, ActivationType active_node, __m256 *checkFGrad, __m256 *checkOGrad, ActivationType active_node,
ActivationType active_gate, ActivationType active_state) { ActivationType active_gate, ActivationType active_state) {
grad_og = activation(_mm256_mul_ps(output_grad, state_atv), value_og, *grad_og = activation(_mm256_mul_ps(*output_grad, *state_atv), *value_og,
active_gate); active_gate);
state_grad = _mm256_add_ps(activation(_mm256_mul_ps(output_grad, value_og), *state_grad =
state_atv, active_state), _mm256_add_ps(activation(_mm256_mul_ps(*output_grad, *value_og),
state_grad); *state_atv, active_state),
state_grad = _mm256_add_ps(_mm256_mul_ps(grad_og, checkO), state_grad); *state_grad);
grad_in = *state_grad = _mm256_add_ps(_mm256_mul_ps(*grad_og, *checkO), *state_grad);
activation(_mm256_mul_ps(state_grad, value_ig), value_in, active_node); *grad_in = activation(_mm256_mul_ps(*state_grad, *value_ig), *value_in,
grad_ig = active_node);
activation(_mm256_mul_ps(state_grad, value_in), value_ig, active_gate); *grad_ig = activation(_mm256_mul_ps(*state_grad, *value_in), *value_ig,
grad_fg = activation(_mm256_mul_ps(state_grad, prev_state), value_fg, active_gate);
active_gate); *grad_fg = activation(_mm256_mul_ps(*state_grad, *prev_state), *value_fg,
prev_state_grad = _mm256_add_ps(_mm256_mul_ps(grad_ig, checkI), active_gate);
_mm256_mul_ps(grad_fg, checkF)); *prev_state_grad = _mm256_add_ps(_mm256_mul_ps(*grad_ig, *checkI),
prev_state_grad = _mm256_mul_ps(*grad_fg, *checkF));
_mm256_add_ps(_mm256_mul_ps(state_grad, value_fg), prev_state_grad); *prev_state_grad =
checkIGrad = _mm256_mul_ps(grad_ig, prev_state); _mm256_add_ps(_mm256_mul_ps(*state_grad, *value_fg), *prev_state_grad);
checkFGrad = _mm256_mul_ps(grad_fg, prev_state); *checkIGrad = _mm256_mul_ps(*grad_ig, *prev_state);
checkOGrad = _mm256_mul_ps(grad_og, state); *checkFGrad = _mm256_mul_ps(*grad_fg, *prev_state);
*checkOGrad = _mm256_mul_ps(*grad_og, *state);
} }
#endif #endif
#endif #endif
......
...@@ -41,7 +41,7 @@ void TestSequencePadding(const paddle::framework::LoD& lod, ...@@ -41,7 +41,7 @@ void TestSequencePadding(const paddle::framework::LoD& lod,
if (paddle::platform::is_cpu_place(*place)) { if (paddle::platform::is_cpu_place(*place)) {
seq = cpu_seq; seq = cpu_seq;
} else { } else {
TensorCopy(cpu_seq, *place, *context, &seq); TensorCopySync(cpu_seq, *place, &seq);
seq.set_lod(lod); seq.set_lod(lod);
} }
...@@ -64,7 +64,7 @@ void TestSequencePadding(const paddle::framework::LoD& lod, ...@@ -64,7 +64,7 @@ void TestSequencePadding(const paddle::framework::LoD& lod,
if (paddle::platform::is_cpu_place(*place)) { if (paddle::platform::is_cpu_place(*place)) {
cpu_seq_back = seq_back; cpu_seq_back = seq_back;
} else { } else {
TensorCopy(seq_back, paddle::platform::CPUPlace(), *context, &cpu_seq_back); TensorCopySync(seq_back, paddle::platform::CPUPlace(), &cpu_seq_back);
cpu_seq_back.set_lod(lod); cpu_seq_back.set_lod(lod);
} }
......
...@@ -33,7 +33,7 @@ class MultiplexGPUKernel : public framework::OpKernel<T> { ...@@ -33,7 +33,7 @@ class MultiplexGPUKernel : public framework::OpKernel<T> {
auto cols = ins[0]->numel() / rows; auto cols = ins[0]->numel() / rows;
// copy index to cpu // copy index to cpu
Tensor index_t_cpu; Tensor index_t_cpu;
TensorCopy(*ids, platform::CPUPlace(), ctx.device_context(), &index_t_cpu); TensorCopySync(*ids, platform::CPUPlace(), &index_t_cpu);
auto* index = index_t_cpu.data<int32_t>(); auto* index = index_t_cpu.data<int32_t>();
auto stream = ctx.cuda_device_context().stream(); auto stream = ctx.cuda_device_context().stream();
platform::CUDAPlace place = boost::get<platform::CUDAPlace>(ctx.GetPlace()); platform::CUDAPlace place = boost::get<platform::CUDAPlace>(ctx.GetPlace());
...@@ -69,7 +69,7 @@ class MultiplexGradGPUKernel : public framework::OpKernel<T> { ...@@ -69,7 +69,7 @@ class MultiplexGradGPUKernel : public framework::OpKernel<T> {
auto cols = ins[0]->numel() / rows; auto cols = ins[0]->numel() / rows;
// copy index to cpu // copy index to cpu
Tensor index_t_cpu; Tensor index_t_cpu;
TensorCopy(*ids, platform::CPUPlace(), ctx.device_context(), &index_t_cpu); TensorCopySync(*ids, platform::CPUPlace(), &index_t_cpu);
auto* index = index_t_cpu.data<int32_t>(); auto* index = index_t_cpu.data<int32_t>();
auto stream = ctx.cuda_device_context().stream(); auto stream = ctx.cuda_device_context().stream();
......
...@@ -174,7 +174,8 @@ REGISTER_OP_KERNEL(pool2d_grad, CUDNN, plat::CUDAPlace, ...@@ -174,7 +174,8 @@ REGISTER_OP_KERNEL(pool2d_grad, CUDNN, plat::CUDAPlace,
REGISTER_OP_KERNEL(pool3d, CUDNN, plat::CUDAPlace, REGISTER_OP_KERNEL(pool3d, CUDNN, plat::CUDAPlace,
ops::PoolCUDNNOpKernel<float>, ops::PoolCUDNNOpKernel<float>,
ops::PoolCUDNNOpKernel<double>); ops::PoolCUDNNOpKernel<double>,
ops::PoolCUDNNOpKernel<plat::float16>);
REGISTER_OP_KERNEL(pool3d_grad, CUDNN, plat::CUDAPlace, REGISTER_OP_KERNEL(pool3d_grad, CUDNN, plat::CUDAPlace,
ops::PoolCUDNNGradOpKernel<float>, ops::PoolCUDNNGradOpKernel<float>,
ops::PoolCUDNNGradOpKernel<double>); ops::PoolCUDNNGradOpKernel<double>);
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor;
class ProximalAdagradOp : public framework::OperatorWithKernel { class ProximalAdagradOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
...@@ -55,6 +56,12 @@ class ProximalAdagradOp : public framework::OperatorWithKernel { ...@@ -55,6 +56,12 @@ class ProximalAdagradOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("ParamOut", param_dim); ctx->SetOutputDim("ParamOut", param_dim);
ctx->SetOutputDim("MomentOut", param_dim); ctx->SetOutputDim("MomentOut", param_dim);
} }
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input_data_type =
framework::ToDataType(ctx.Input<Tensor>("Param")->type());
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
}; };
class ProximalAdagradOpMaker : public framework::OpProtoAndCheckerMaker { class ProximalAdagradOpMaker : public framework::OpProtoAndCheckerMaker {
......
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor;
class ProximalGDOp : public framework::OperatorWithKernel { class ProximalGDOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
...@@ -43,6 +44,12 @@ class ProximalGDOp : public framework::OperatorWithKernel { ...@@ -43,6 +44,12 @@ class ProximalGDOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("ParamOut", param_dim); ctx->SetOutputDim("ParamOut", param_dim);
} }
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input_data_type =
framework::ToDataType(ctx.Input<Tensor>("Param")->type());
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
}; };
class ProximalGDOpMaker : public framework::OpProtoAndCheckerMaker { class ProximalGDOpMaker : public framework::OpProtoAndCheckerMaker {
......
...@@ -224,7 +224,7 @@ __global__ void RowConvGradFilterImproved(const T *in, const T *dout, ...@@ -224,7 +224,7 @@ __global__ void RowConvGradFilterImproved(const T *in, const T *dout,
for (int offset = 16; offset > 0; for (int offset = 16; offset > 0;
offset = offset / 2) { // blockDim.x is 32. offset = offset / 2) { // blockDim.x is 32.
val += platform::__shfl_down_sync(mask, val, offset); val += platform::CudaShuffleDownSync(mask, val, offset);
} }
__syncthreads(); __syncthreads();
...@@ -284,7 +284,7 @@ __global__ void RowConvGradFilter(const T *in, const T *dout, int num_sequence, ...@@ -284,7 +284,7 @@ __global__ void RowConvGradFilter(const T *in, const T *dout, int num_sequence,
for (int offset = 16; offset > 0; for (int offset = 16; offset > 0;
offset = offset / 2) { // blockDim.x is 32. offset = offset / 2) { // blockDim.x is 32.
val += platform::__shfl_down_sync(mask, val, offset); val += platform::CudaShuffleDownSync(mask, val, offset);
} }
__syncthreads(); __syncthreads();
......
...@@ -66,13 +66,11 @@ class SequenceSliceOpKernel : public framework::OpKernel<T> { ...@@ -66,13 +66,11 @@ class SequenceSliceOpKernel : public framework::OpKernel<T> {
if (platform::is_gpu_place(ctx.GetPlace())) { if (platform::is_gpu_place(ctx.GetPlace())) {
offset_cpu.mutable_data<T>(offset->dims(), platform::CPUPlace()); offset_cpu.mutable_data<T>(offset->dims(), platform::CPUPlace());
framework::TensorCopy(*offset, platform::CPUPlace(), ctx.device_context(), framework::TensorCopySync(*offset, platform::CPUPlace(), &offset_cpu);
&offset_cpu);
offset_data = offset_cpu.data<int64_t>(); offset_data = offset_cpu.data<int64_t>();
length_cpu.mutable_data<T>(length->dims(), platform::CPUPlace()); length_cpu.mutable_data<T>(length->dims(), platform::CPUPlace());
framework::TensorCopy(*length, platform::CPUPlace(), ctx.device_context(), framework::TensorCopySync(*length, platform::CPUPlace(), &length_cpu);
&length_cpu);
length_data = length_cpu.data<int64_t>(); length_data = length_cpu.data<int64_t>();
} }
...@@ -127,13 +125,11 @@ class SequenceSliceGradOpKernel : public framework::OpKernel<T> { ...@@ -127,13 +125,11 @@ class SequenceSliceGradOpKernel : public framework::OpKernel<T> {
if (platform::is_gpu_place(ctx.GetPlace())) { if (platform::is_gpu_place(ctx.GetPlace())) {
offset_cpu.mutable_data<T>(offset->dims(), platform::CPUPlace()); offset_cpu.mutable_data<T>(offset->dims(), platform::CPUPlace());
framework::TensorCopy(*offset, platform::CPUPlace(), ctx.device_context(), framework::TensorCopySync(*offset, platform::CPUPlace(), &offset_cpu);
&offset_cpu);
offset_data = offset_cpu.data<int64_t>(); offset_data = offset_cpu.data<int64_t>();
length_cpu.mutable_data<T>(length->dims(), platform::CPUPlace()); length_cpu.mutable_data<T>(length->dims(), platform::CPUPlace());
framework::TensorCopy(*length, platform::CPUPlace(), ctx.device_context(), framework::TensorCopySync(*length, platform::CPUPlace(), &length_cpu);
&length_cpu);
length_data = length_cpu.data<int64_t>(); length_data = length_cpu.data<int64_t>();
} }
......
...@@ -241,7 +241,8 @@ __device__ __forceinline__ void BlockReduce(Pair<T>* sh_topk, int* maxid, ...@@ -241,7 +241,8 @@ __device__ __forceinline__ void BlockReduce(Pair<T>* sh_topk, int* maxid,
CREATE_SHFL_MASK(mask, true); CREATE_SHFL_MASK(mask, true);
if (maxid[0] / 32 == warp) { if (maxid[0] / 32 == warp) {
if (platform::__shfl_sync(mask, *beam, (maxid[0]) % 32, 32) == MaxLength) if (platform::CudaShuffleSync(mask, *beam, (maxid[0]) % 32, 32) ==
MaxLength)
break; break;
} }
} }
......
...@@ -18,34 +18,33 @@ limitations under the License. */ ...@@ -18,34 +18,33 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace platform { namespace platform {
// __shfl_down and __shfl have been deprecated as of CUDA 9.0.
#if CUDA_VERSION < 9000 #if CUDA_VERSION < 9000
template <typename T>
__forceinline__ __device__ T __shfl_down_sync(unsigned, T val, int delta) {
return __shfl_down(val, delta);
}
template <typename T>
__forceinline__ __device__ T __shfl_sync(unsigned, T val, int src_line,
int width) {
return __shfl(val, src_line, width);
}
#define CREATE_SHFL_MASK(mask, predicate) mask = 0u; #define CREATE_SHFL_MASK(mask, predicate) mask = 0u;
#else #else
#define FULL_WARP_MASK 0xFFFFFFFF #define FULL_WARP_MASK 0xFFFFFFFF
#define CREATE_SHFL_MASK(mask, predicate) \ #define CREATE_SHFL_MASK(mask, predicate) \
mask = __ballot_sync(FULL_WARP_MASK, (predicate)) mask = __ballot_sync(FULL_WARP_MASK, (predicate))
#endif
template <typename T> template <typename T>
__forceinline__ __device__ T __shfl_down_sync(unsigned mask, T val, int delta) { __forceinline__ __device__ T CudaShuffleDownSync(unsigned mask, T val,
return __shfl_down_sync(mask, val, delta); int delta, int width = 32) {
#if CUDA_VERSION < 9000
return __shfl_down(val, delta, width);
#else
return __shfl_down_sync(mask, val, delta, width);
#endif
} }
template <typename T> template <typename T>
__forceinline__ __device__ T __shfl_sync(unsigned mask, T val, int src_line, __forceinline__ __device__ T CudaShuffleSync(unsigned mask, T val, int src_line,
int width) { int width = 32) {
#if CUDA_VERSION < 9000
return __shfl(val, src_line, width);
#else
return __shfl_sync(mask, val, src_line, width); return __shfl_sync(mask, val, src_line, width);
}
#endif #endif
}
template <typename T> template <typename T>
__device__ T reduceSum(T val, int tid, int len) { __device__ T reduceSum(T val, int tid, int len) {
...@@ -61,7 +60,7 @@ __device__ T reduceSum(T val, int tid, int len) { ...@@ -61,7 +60,7 @@ __device__ T reduceSum(T val, int tid, int len) {
CREATE_SHFL_MASK(mask, tid < len); CREATE_SHFL_MASK(mask, tid < len);
for (int offset = warpSize / 2; offset > 0; offset /= 2) for (int offset = warpSize / 2; offset > 0; offset /= 2)
val += platform::__shfl_down_sync(mask, val, offset); val += platform::CudaShuffleDownSync(mask, val, offset);
if (tid < warpSize) shm[tid] = 0; if (tid < warpSize) shm[tid] = 0;
...@@ -75,7 +74,7 @@ __device__ T reduceSum(T val, int tid, int len) { ...@@ -75,7 +74,7 @@ __device__ T reduceSum(T val, int tid, int len) {
if (tid < warpSize) { if (tid < warpSize) {
val = shm[tid]; val = shm[tid];
for (int offset = warpSize / 2; offset > 0; offset /= 2) for (int offset = warpSize / 2; offset > 0; offset /= 2)
val += platform::__shfl_down_sync(mask, val, offset); val += platform::CudaShuffleDownSync(mask, val, offset);
} }
return val; return val;
} }
......
...@@ -13,12 +13,15 @@ See the License for the specific language governing permissions and ...@@ -13,12 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
#include <sys/time.h> #include <sys/time.h>
#include <time.h> #include <time.h>
#include <algorithm> #include <algorithm>
#include <iomanip> #include <iomanip>
#include <limits>
#include <map> #include <map>
#include <mutex> // NOLINT #include <mutex> // NOLINT
#include <random>
#include <string> #include <string>
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
#include <cuda.h> #include <cuda.h>
...@@ -33,6 +36,9 @@ namespace platform { ...@@ -33,6 +36,9 @@ namespace platform {
struct EventList; struct EventList;
static int64_t profiler_lister_id = 0;
static bool should_send_profile_state = false;
// The profiler state, the initial value is ProfilerState::kDisabled // The profiler state, the initial value is ProfilerState::kDisabled
static ProfilerState g_state = ProfilerState::kDisabled; static ProfilerState g_state = ProfilerState::kDisabled;
// The thread local event list only can be accessed by the specific thread // The thread local event list only can be accessed by the specific thread
...@@ -219,13 +225,12 @@ void EnableProfiler(ProfilerState state) { ...@@ -219,13 +225,12 @@ void EnableProfiler(ProfilerState state) {
PADDLE_ENFORCE(state != ProfilerState::kDisabled, PADDLE_ENFORCE(state != ProfilerState::kDisabled,
"Can't enbale profling, since the input state is ", "Can't enbale profling, since the input state is ",
"ProfilerState::kDisabled"); "ProfilerState::kDisabled");
PADDLE_ENFORCE(g_state == ProfilerState::kDisabled, if (state == g_state) {
"The profiling state should be disabled when calling ", return;
"EnableProfiler.");
g_state = state;
if (g_state == ProfilerState::kAll) {
GetDeviceTracer()->Enable();
} }
g_state = state;
should_send_profile_state = true;
GetDeviceTracer()->Enable();
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
if (g_state == ProfilerState::kCUDA) { if (g_state == ProfilerState::kCUDA) {
// Generate some dummy events first to reduce the startup overhead. // Generate some dummy events first to reduce the startup overhead.
...@@ -435,8 +440,7 @@ void ParseEvents(const std::vector<std::vector<Event>>& events, ...@@ -435,8 +440,7 @@ void ParseEvents(const std::vector<std::vector<Event>>& events,
void DisableProfiler(EventSortingKey sorted_key, void DisableProfiler(EventSortingKey sorted_key,
const std::string& profile_path) { const std::string& profile_path) {
PADDLE_ENFORCE(g_state != ProfilerState::kDisabled, if (g_state == ProfilerState::kDisabled) return;
"Can't disable profiling, since it's not starting.");
// Mark the profiling stop. // Mark the profiling stop.
Mark("_stop_profiler_", nullptr); Mark("_stop_profiler_", nullptr);
...@@ -444,12 +448,25 @@ void DisableProfiler(EventSortingKey sorted_key, ...@@ -444,12 +448,25 @@ void DisableProfiler(EventSortingKey sorted_key,
ParseEvents(all_events, sorted_key); ParseEvents(all_events, sorted_key);
ResetProfiler(); ResetProfiler();
DeviceTracer* tracer = GetDeviceTracer(); DeviceTracer* tracer = GetDeviceTracer();
if (g_state == ProfilerState::kAll && tracer && tracer->IsEnabled()) { if (tracer->IsEnabled()) {
tracer->Disable(); tracer->Disable();
tracer->GenProfile(profile_path); tracer->GenProfile(profile_path);
} }
g_state = ProfilerState::kDisabled; g_state = ProfilerState::kDisabled;
should_send_profile_state = true;
}
bool IsProfileEnabled() { return g_state != ProfilerState::kDisabled; }
bool ShouldSendProfileState() { return should_send_profile_state; }
void SetProfileListener() {
std::mt19937 rng;
rng.seed(std::random_device()());
std::uniform_int_distribution<std::mt19937::result_type> dist6(
1, std::numeric_limits<int64_t>::max());
profiler_lister_id = dist6(rng);
} }
int64_t ListenerId() { return profiler_lister_id; }
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
...@@ -114,5 +114,13 @@ void ResetProfiler(); ...@@ -114,5 +114,13 @@ void ResetProfiler();
void DisableProfiler(EventSortingKey sorted_key, void DisableProfiler(EventSortingKey sorted_key,
const std::string& profile_path); const std::string& profile_path);
// Test if the profiler is currently enabled.
bool IsProfileEnabled();
// Whether the trainer should send profiling state to PS.
bool ShouldSendProfileState();
// Mark current process as PS by assigning a lister id.
void SetProfileListener();
int64_t ListenerId();
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
...@@ -40,16 +40,14 @@ import backward ...@@ -40,16 +40,14 @@ import backward
import regularizer import regularizer
import average import average
import metrics import metrics
import transpiler
from param_attr import ParamAttr, WeightNormParamAttr from param_attr import ParamAttr, WeightNormParamAttr
from data_feeder import DataFeeder from data_feeder import DataFeeder
from core import LoDTensor, CPUPlace, CUDAPlace, CUDAPinnedPlace from core import LoDTensor, CPUPlace, CUDAPlace, CUDAPinnedPlace
from distribute_transpiler import DistributeTranspiler from transpiler import DistributeTranspiler, SimpleDistributeTranspiler, InferenceTranspiler, memory_optimize, release_memory
from distribute_transpiler_simple import SimpleDistributeTranspiler
from concurrency import (Go, make_channel, channel_send, channel_recv, from concurrency import (Go, make_channel, channel_send, channel_recv,
channel_close, Select) channel_close, Select)
from inference_transpiler import InferenceTranspiler
import clip import clip
from memory_optimization_transpiler import memory_optimize, release_memory
import profiler import profiler
import unique_name import unique_name
import recordio_writer import recordio_writer
...@@ -58,7 +56,7 @@ from parallel_executor import ParallelExecutor ...@@ -58,7 +56,7 @@ from parallel_executor import ParallelExecutor
Tensor = LoDTensor Tensor = LoDTensor
__all__ = framework.__all__ + executor.__all__ + concurrency.__all__ +\ __all__ = framework.__all__ + executor.__all__ + concurrency.__all__ +\
trainer.__all__ + inferencer.__all__ + [ trainer.__all__ + inferencer.__all__ + transpiler.__all__ + [
'io', 'io',
'initializer', 'initializer',
'layers', 'layers',
...@@ -76,11 +74,6 @@ __all__ = framework.__all__ + executor.__all__ + concurrency.__all__ +\ ...@@ -76,11 +74,6 @@ __all__ = framework.__all__ + executor.__all__ + concurrency.__all__ +\
'WeightNormParamAttr', 'WeightNormParamAttr',
'DataFeeder', 'DataFeeder',
'clip', 'clip',
'SimpleDistributeTranspiler',
'DistributeTranspiler',
'InferenceTranspiler',
'memory_optimize',
'release_memory',
'profiler', 'profiler',
'unique_name', 'unique_name',
'recordio_writer', 'recordio_writer',
......
...@@ -13,7 +13,9 @@ ...@@ -13,7 +13,9 @@
# limitations under the License. # limitations under the License.
import core import core
import framework
import executor
import io
__all__ = ['Inferencer', ] __all__ = ['Inferencer', ]
...@@ -29,6 +31,15 @@ class Inferencer(object): ...@@ -29,6 +31,15 @@ class Inferencer(object):
# 4. load params from param_path into scope # 4. load params from param_path into scope
self.scope = core.Scope() self.scope = core.Scope()
self.place = place self.place = place
self.startup_program = framework.Program()
# TODO: generate the startup_program with network_func
exe = executor.Executor(place)
exe.run(self.startup_program, scope=self.scope)
if param_path:
# load params from param_path into scope
io.load_persistables(exe, dirname=param_path)
def infer(self, inputs): def infer(self, inputs):
# run self.program # run self.program
......
...@@ -251,7 +251,7 @@ class EditDistance(MetricBase): ...@@ -251,7 +251,7 @@ class EditDistance(MetricBase):
self.instance_error += seq_num - seq_right_count self.instance_error += seq_num - seq_right_count
self.total_distance += total_distance self.total_distance += total_distance
def eval(): def eval(self):
if self.seq_num == 0: if self.seq_num == 0:
raise ValueError( raise ValueError(
"There is no data in EditDistance Metric. Please check layers.edit_distance output has been added to EditDistance." "There is no data in EditDistance Metric. Please check layers.edit_distance output has been added to EditDistance."
...@@ -280,6 +280,7 @@ class DetectionMAP(MetricBase): ...@@ -280,6 +280,7 @@ class DetectionMAP(MetricBase):
super(DetectionMAP, self).__init__(name) super(DetectionMAP, self).__init__(name)
# the current map value # the current map value
self.value = .0 self.value = .0
self.weight = .0
def update(self, value, weight): def update(self, value, weight):
if not _is_number_or_matrix_(value): if not _is_number_or_matrix_(value):
...@@ -340,8 +341,8 @@ class Auc(MetricBase): ...@@ -340,8 +341,8 @@ class Auc(MetricBase):
raise ValueError("The 'predictions' must be a numpy ndarray.") raise ValueError("The 'predictions' must be a numpy ndarray.")
kepsilon = 1e-7 # to account for floating point imprecisions kepsilon = 1e-7 # to account for floating point imprecisions
thresholds = [(i + 1) * 1.0 / (num_thresholds - 1) thresholds = [(i + 1) * 1.0 / (self._num_thresholds - 1)
for i in range(num_thresholds - 2)] for i in range(self._num_thresholds - 2)]
thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon] thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon]
# caculate TP, FN, TN, FP count # caculate TP, FN, TN, FP count
...@@ -358,19 +359,20 @@ class Auc(MetricBase): ...@@ -358,19 +359,20 @@ class Auc(MetricBase):
fp += 1 fp += 1
else: else:
tn += 1 tn += 1
tp_list[idx_thresh] += tp self.tp_list[idx_thresh] += tp
fn_list[idx_thresh] += fn self.fn_list[idx_thresh] += fn
tn_list[idx_thresh] += tn self.tn_list[idx_thresh] += tn
fp_list[idx_thresh] += fp self.fp_list[idx_thresh] += fp
def eval(self): def eval(self):
epsilon = self._epsilon epsilon = self._epsilon
num_thresholds = self._num_thresholds num_thresholds = self._num_thresholds
tpr = (tp_list.astype("float32") + epsilon) / ( tpr = (self.tp_list.astype("float32") + epsilon) / (
tp_list + fn_list + epsilon) self.tp_list + self.fn_list + epsilon)
fpr = fp_list.astype("float32") / (fp_list + tn_list + epsilon) fpr = self.fp_list.astype("float32") / (
rec = (tp_list.astype("float32") + epsilon) / ( self.fp_list + self.tn_list + epsilon)
tp_list + fp_list + epsilon) rec = (self.tp_list.astype("float32") + epsilon) / (
self.tp_list + self.fp_list + epsilon)
x = fpr[:num_thresholds - 1] - fpr[1:] x = fpr[:num_thresholds - 1] - fpr[1:]
y = (tpr[:num_thresholds - 1] + tpr[1:]) / 2.0 y = (tpr[:num_thresholds - 1] + tpr[1:]) / 2.0
......
...@@ -43,7 +43,7 @@ class ParallelExecutor(object): ...@@ -43,7 +43,7 @@ class ParallelExecutor(object):
training. training.
allow_op_delay(bool, default False): Whether to delay and buffer allow_op_delay(bool, default False): Whether to delay and buffer
some operators together for scheduling or not, which may some operators together for scheduling or not, which may
improve performance in some cases, defalut False. improve performance in some cases, default False.
share_vars_from(ParallelExecutor, default None): If provied, share_vars_from(ParallelExecutor, default None): If provied,
it will share variables from the specified ParallelExecutor. it will share variables from the specified ParallelExecutor.
use_default_grad_scale(bool, default True): If set True, a default use_default_grad_scale(bool, default True): If set True, a default
...@@ -93,7 +93,7 @@ class ParallelExecutor(object): ...@@ -93,7 +93,7 @@ class ParallelExecutor(object):
if use_cuda: if use_cuda:
# Experiments on se-resnext shows that too many threads hurt # Experiments on se-resnext shows that too many threads hurt
# performance. Worth tunning for other models in the future. # performance. Worth tunning for other models in the future.
num_threads = len(self._places) num_threads = len(self._places) * 2
else: else:
num_threads = min( num_threads = min(
len(self._places) * 2, multiprocessing.cpu_count()) len(self._places) * 2, multiprocessing.cpu_count())
...@@ -130,6 +130,7 @@ class ParallelExecutor(object): ...@@ -130,6 +130,7 @@ class ParallelExecutor(object):
local_scopes, local_scopes,
allow_op_delay, allow_op_delay,
use_default_grad_scale) use_default_grad_scale)
self.scope = scope self.scope = scope
def run(self, fetch_list, feed=None, feed_dict=None): def run(self, fetch_list, feed=None, feed_dict=None):
......
...@@ -70,9 +70,11 @@ def conv3d_forward_naive(input, filter, group, conv_param): ...@@ -70,9 +70,11 @@ def conv3d_forward_naive(input, filter, group, conv_param):
class TestConv3dOp(OpTest): class TestConv3dOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "conv3d"
self.use_cudnn = False self.use_cudnn = False
self.dtype = np.float32
self.init_kernel_type()
self.init_group() self.init_group()
self.init_op_type()
self.init_dilation() self.init_dilation()
self.init_test_case() self.init_test_case()
...@@ -80,20 +82,24 @@ class TestConv3dOp(OpTest): ...@@ -80,20 +82,24 @@ class TestConv3dOp(OpTest):
'stride': self.stride, 'stride': self.stride,
'pad': self.pad, 'pad': self.pad,
'dilations': self.dilations, 'dilations': self.dilations,
'use_cudnn': self.use_cudnn,
'data_format': 'AnyLayout' # TODO(dzhwinter) : should be fix latter 'data_format': 'AnyLayout' # TODO(dzhwinter) : should be fix latter
} }
input = np.random.random(self.input_size).astype("float32")
filter = np.random.random(self.filter_size).astype("float32") input = np.random.random(self.input_size).astype(self.dtype)
filter = np.random.random(self.filter_size).astype(self.dtype)
output = conv3d_forward_naive(input, filter, self.groups, output = conv3d_forward_naive(input, filter, self.groups,
conv3d_param).astype("float32") conv3d_param).astype(self.dtype)
self.inputs = {'Input': input, 'Filter': filter} self.inputs = {
'Input': OpTest.np_dtype_to_fluid_dtype(input),
'Filter': OpTest.np_dtype_to_fluid_dtype(filter)
}
self.attrs = { self.attrs = {
'strides': self.stride, 'strides': self.stride,
'paddings': self.pad, 'paddings': self.pad,
'groups': self.groups, 'groups': self.groups,
'dilations': self.dilations 'dilations': self.dilations,
'use_cudnn': self.use_cudnn
} }
self.outputs = {'Output': output} self.outputs = {'Output': output}
...@@ -108,6 +114,8 @@ class TestConv3dOp(OpTest): ...@@ -108,6 +114,8 @@ class TestConv3dOp(OpTest):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
if self.dtype == np.float16:
return
if self.testcudnn(): if self.testcudnn():
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
self.check_grad_with_place( self.check_grad_with_place(
...@@ -120,6 +128,8 @@ class TestConv3dOp(OpTest): ...@@ -120,6 +128,8 @@ class TestConv3dOp(OpTest):
set(['Input', 'Filter']), 'Output', max_relative_error=0.03) set(['Input', 'Filter']), 'Output', max_relative_error=0.03)
def test_check_grad_no_filter(self): def test_check_grad_no_filter(self):
if self.dtype == np.float16:
return
if self.testcudnn(): if self.testcudnn():
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
self.check_grad_with_place( self.check_grad_with_place(
...@@ -135,6 +145,8 @@ class TestConv3dOp(OpTest): ...@@ -135,6 +145,8 @@ class TestConv3dOp(OpTest):
no_grad_set=set(['Filter'])) no_grad_set=set(['Filter']))
def test_check_grad_no_input(self): def test_check_grad_no_input(self):
if self.dtype == np.float16:
return
if self.testcudnn(): if self.testcudnn():
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
self.check_grad_with_place( self.check_grad_with_place(
...@@ -163,8 +175,8 @@ class TestConv3dOp(OpTest): ...@@ -163,8 +175,8 @@ class TestConv3dOp(OpTest):
def init_group(self): def init_group(self):
self.groups = 1 self.groups = 1
def init_op_type(self): def init_kernel_type(self):
self.op_type = "conv3d" pass
class TestCase1(TestConv3dOp): class TestCase1(TestConv3dOp):
...@@ -235,34 +247,90 @@ class TestWithDilation(TestConv3dOp): ...@@ -235,34 +247,90 @@ class TestWithDilation(TestConv3dOp):
self.groups = 3 self.groups = 3
#----------------Conv3dCUDNN----------------
class TestCUDNN(TestConv3dOp): class TestCUDNN(TestConv3dOp):
def init_op_type(self): def init_kernel_type(self):
self.use_cudnn = True self.use_cudnn = True
self.op_type = "conv3d"
class TestFP16CUDNN(TestConv3dOp):
def init_kernel_type(self):
self.use_cudnn = True
self.dtype = np.float16
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=2e-2)
class TestWithGroup1CUDNN(TestWithGroup1): class TestWithGroup1CUDNN(TestWithGroup1):
def init_op_type(self): def init_kernel_type(self):
self.use_cudnn = True self.use_cudnn = True
self.op_type = "conv3d"
class TestFP16WithGroup1CUDNN(TestWithGroup1):
def init_kernel_type(self):
self.use_cudnn = True
self.dtype = np.float16
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=2e-2)
class TestWithGroup2CUDNN(TestWithGroup2): class TestWithGroup2CUDNN(TestWithGroup2):
def init_op_type(self): def init_kernel_type(self):
self.use_cudnn = True self.use_cudnn = True
self.op_type = "conv3d"
class TestFP16WithGroup2CUDNN(TestWithGroup2):
def init_kernel_type(self):
self.use_cudnn = True
self.dtype = np.float16
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=2e-2)
class TestWith1x1CUDNN(TestWith1x1): class TestWith1x1CUDNN(TestWith1x1):
def init_op_type(self): def init_kernel_type(self):
self.use_cudnn = True self.use_cudnn = True
self.op_type = "conv3d"
class TestFP16With1x1CUDNN(TestWith1x1):
def init_kernel_type(self):
self.use_cudnn = True
self.dtype = np.float16
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=2e-2)
class TestWithInput1x1Filter1x1CUDNN(TestWithInput1x1Filter1x1): class TestWithInput1x1Filter1x1CUDNN(TestWithInput1x1Filter1x1):
def init_op_type(self): def init_kernel_type(self):
self.use_cudnn = True self.use_cudnn = True
self.op_type = "conv3d"
class TestFP16WithInput1x1Filter1x1CUDNN(TestWithInput1x1Filter1x1):
def init_kernel_type(self):
self.use_cudnn = True
self.dtype = np.float16
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=2e-2)
# FIXME(typhoonzero): find a way to determine if # FIXME(typhoonzero): find a way to determine if
......
...@@ -280,7 +280,7 @@ class TestMNIST(TestParallelExecutorBase): ...@@ -280,7 +280,7 @@ class TestMNIST(TestParallelExecutorBase):
fluid.recordio_writer.convert_reader_to_recordio_file( fluid.recordio_writer.convert_reader_to_recordio_file(
'./mnist.recordio', reader, feeder) './mnist.recordio', reader, feeder)
def test_simple_fc(self): def check_simple_fc_convergence(self):
self.check_network_convergence(simple_fc_net) self.check_network_convergence(simple_fc_net)
self.check_network_convergence(simple_fc_net, allow_op_delay=True) self.check_network_convergence(simple_fc_net, allow_op_delay=True)
...@@ -290,7 +290,10 @@ class TestMNIST(TestParallelExecutorBase): ...@@ -290,7 +290,10 @@ class TestMNIST(TestParallelExecutorBase):
simple_fc_net, feed_dict={"image": img, simple_fc_net, feed_dict={"image": img,
"label": label}) "label": label})
def test_simple_fc_parallel_accuracy(self): def test_simple_fc(self):
self.check_simple_fc_convergence()
def check_simple_fc_parallel_accuracy(self):
img = numpy.zeros(shape=[32, 784], dtype='float32') img = numpy.zeros(shape=[32, 784], dtype='float32')
label = numpy.ones(shape=[32, 1], dtype='int64') label = numpy.ones(shape=[32, 1], dtype='int64')
single_first_loss, single_last_loss = self.check_network_convergence( single_first_loss, single_last_loss = self.check_network_convergence(
...@@ -311,7 +314,10 @@ class TestMNIST(TestParallelExecutorBase): ...@@ -311,7 +314,10 @@ class TestMNIST(TestParallelExecutorBase):
for p_l in parallel_last_loss: for p_l in parallel_last_loss:
self.assertAlmostEquals(p_l, single_last_loss[0], delta=1e-6) self.assertAlmostEquals(p_l, single_last_loss[0], delta=1e-6)
def test_batchnorm_fc(self): def test_simple_fc_parallel_accuracy(self):
self.check_simple_fc_parallel_accuracy()
def check_batchnorm_fc_convergence(self):
self.check_network_convergence(fc_with_batchnorm) self.check_network_convergence(fc_with_batchnorm)
img = numpy.zeros(shape=[32, 784], dtype='float32') img = numpy.zeros(shape=[32, 784], dtype='float32')
label = numpy.ones(shape=[32, 1], dtype='int64') label = numpy.ones(shape=[32, 1], dtype='int64')
...@@ -319,6 +325,9 @@ class TestMNIST(TestParallelExecutorBase): ...@@ -319,6 +325,9 @@ class TestMNIST(TestParallelExecutorBase):
fc_with_batchnorm, feed_dict={"image": img, fc_with_batchnorm, feed_dict={"image": img,
"label": label}) "label": label})
def test_batchnorm_fc(self):
self.check_batchnorm_fc_convergence()
class TestResnet(TestParallelExecutorBase): class TestResnet(TestParallelExecutorBase):
# @classmethod # @classmethod
...@@ -339,7 +348,7 @@ class TestResnet(TestParallelExecutorBase): ...@@ -339,7 +348,7 @@ class TestResnet(TestParallelExecutorBase):
# fluid.recordio_writer.convert_reader_to_recordio_file( # fluid.recordio_writer.convert_reader_to_recordio_file(
# "./flowers.recordio", reader, feeder, compressor=fluid.core.RecordIOWriter.Compressor.NoCompress) # "./flowers.recordio", reader, feeder, compressor=fluid.core.RecordIOWriter.Compressor.NoCompress)
def test_resnet(self): def check_resnet_convergence(self):
import functools import functools
batch_size = 2 batch_size = 2
self.check_network_convergence( self.check_network_convergence(
...@@ -348,6 +357,9 @@ class TestResnet(TestParallelExecutorBase): ...@@ -348,6 +357,9 @@ class TestResnet(TestParallelExecutorBase):
iter=20, iter=20,
batch_size=batch_size) batch_size=batch_size)
def test_resnet(self):
self.check_resnet_convergence()
class ModelHyperParams(object): class ModelHyperParams(object):
# Dictionary size for source and target language. This model directly uses # Dictionary size for source and target language. This model directly uses
...@@ -510,7 +522,7 @@ class TestTransformer(TestParallelExecutorBase): ...@@ -510,7 +522,7 @@ class TestTransformer(TestParallelExecutorBase):
class ParallelExecutorTestingDuringTraining(unittest.TestCase): class ParallelExecutorTestingDuringTraining(unittest.TestCase):
def test_parallel_testing(self): def check_network_convergence(self):
main = fluid.Program() main = fluid.Program()
startup = fluid.Program() startup = fluid.Program()
with fluid.program_guard(main, startup): with fluid.program_guard(main, startup):
...@@ -550,6 +562,9 @@ class ParallelExecutorTestingDuringTraining(unittest.TestCase): ...@@ -550,6 +562,9 @@ class ParallelExecutorTestingDuringTraining(unittest.TestCase):
"Train loss: " + str(train_loss) + "\n Test loss:" + "Train loss: " + str(train_loss) + "\n Test loss:" +
str(test_loss)) str(test_loss))
def test_parallel(self):
self.check_network_convergence()
import paddle.dataset.conll05 as conll05 import paddle.dataset.conll05 as conll05
import paddle.fluid as fluid import paddle.fluid as fluid
...@@ -568,21 +583,26 @@ embedding_name = 'emb' ...@@ -568,21 +583,26 @@ embedding_name = 'emb'
def db_lstm(word, predicate, ctx_n2, ctx_n1, ctx_0, ctx_p1, ctx_p2, mark, def db_lstm(word, predicate, ctx_n2, ctx_n1, ctx_0, ctx_p1, ctx_p2, mark,
**ignored): is_sparse, **ignored):
# 8 features # 8 features
predicate_embedding = fluid.layers.embedding( predicate_embedding = fluid.layers.embedding(
input=predicate, input=predicate,
is_sparse=is_sparse,
size=[pred_dict_len, word_dim], size=[pred_dict_len, word_dim],
dtype='float32', dtype='float32',
param_attr='vemb') param_attr='vemb')
mark_embedding = fluid.layers.embedding( mark_embedding = fluid.layers.embedding(
input=mark, size=[mark_dict_len, mark_dim], dtype='float32') input=mark,
is_sparse=is_sparse,
size=[mark_dict_len, mark_dim],
dtype='float32')
word_input = [word, ctx_n2, ctx_n1, ctx_0, ctx_p1, ctx_p2] word_input = [word, ctx_n2, ctx_n1, ctx_0, ctx_p1, ctx_p2]
emb_layers = [ emb_layers = [
fluid.layers.embedding( fluid.layers.embedding(
size=[word_dict_len, word_dim], size=[word_dict_len, word_dim],
is_sparse=is_sparse,
input=x, input=x,
param_attr=fluid.ParamAttr( param_attr=fluid.ParamAttr(
name=embedding_name, trainable=False)) for x in word_input name=embedding_name, trainable=False)) for x in word_input
...@@ -632,7 +652,7 @@ def db_lstm(word, predicate, ctx_n2, ctx_n1, ctx_0, ctx_p1, ctx_p2, mark, ...@@ -632,7 +652,7 @@ def db_lstm(word, predicate, ctx_n2, ctx_n1, ctx_0, ctx_p1, ctx_p2, mark,
class TestCRFModel(unittest.TestCase): class TestCRFModel(unittest.TestCase):
def test_all(self): def check_network_convergence(self, is_sparse):
main = fluid.Program() main = fluid.Program()
startup = fluid.Program() startup = fluid.Program()
with fluid.program_guard(main, startup): with fluid.program_guard(main, startup):
...@@ -652,6 +672,7 @@ class TestCRFModel(unittest.TestCase): ...@@ -652,6 +672,7 @@ class TestCRFModel(unittest.TestCase):
name='ctx_p2_data', shape=[1], dtype='int64', lod_level=1) name='ctx_p2_data', shape=[1], dtype='int64', lod_level=1)
mark = fluid.layers.data( mark = fluid.layers.data(
name='mark_data', shape=[1], dtype='int64', lod_level=1) name='mark_data', shape=[1], dtype='int64', lod_level=1)
feature_out = db_lstm(**locals()) feature_out = db_lstm(**locals())
target = fluid.layers.data( target = fluid.layers.data(
name='target', shape=[1], dtype='int64', lod_level=1) name='target', shape=[1], dtype='int64', lod_level=1)
...@@ -694,3 +715,9 @@ class TestCRFModel(unittest.TestCase): ...@@ -694,3 +715,9 @@ class TestCRFModel(unittest.TestCase):
print map(numpy.array, print map(numpy.array,
pe.run(feed=feeder.feed(cur_batch), pe.run(feed=feeder.feed(cur_batch),
fetch_list=[avg_cost.name]))[0] fetch_list=[avg_cost.name]))[0]
def test_update_sparse_parameter(self):
self.check_network_convergence(is_sparse=True)
def test_update_dense_parameter(self):
self.check_network_convergence(is_sparse=False)
...@@ -90,20 +90,22 @@ def avg_pool3D_forward_naive(x, ...@@ -90,20 +90,22 @@ def avg_pool3D_forward_naive(x,
class TestPool3d_Op(OpTest): class TestPool3d_Op(OpTest):
def setUp(self): def setUp(self):
self.op_type = "pool3d"
self.use_cudnn = False self.use_cudnn = False
self.dtype = np.float32
self.init_test_case() self.init_test_case()
self.init_global_pool() self.init_global_pool()
self.init_op_type() self.init_kernel_type()
self.init_pool_type() self.init_pool_type()
self.init_ceil_mode() self.init_ceil_mode()
if self.global_pool: if self.global_pool:
self.paddings = [0 for _ in range(len(self.paddings))] self.paddings = [0 for _ in range(len(self.paddings))]
input = np.random.random(self.shape).astype("float32") input = np.random.random(self.shape).astype(self.dtype)
output = self.pool3D_forward_naive(input, self.ksize, self.strides, output = self.pool3D_forward_naive(input, self.ksize, self.strides,
self.paddings, self.global_pool, self.paddings, self.global_pool,
self.ceil_mode).astype("float32") self.ceil_mode).astype(self.dtype)
self.inputs = {'X': input} self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(input)}
self.attrs = { self.attrs = {
'strides': self.strides, 'strides': self.strides,
...@@ -116,7 +118,7 @@ class TestPool3d_Op(OpTest): ...@@ -116,7 +118,7 @@ class TestPool3d_Op(OpTest):
'data_format': 'AnyLayout' # TODO(dzhwinter) : should be fix latter 'data_format': 'AnyLayout' # TODO(dzhwinter) : should be fix latter
} }
self.outputs = {'Out': output.astype('float32')} self.outputs = {'Out': output}
def testcudnn(self): def testcudnn(self):
return core.is_compiled_with_cuda() and self.use_cudnn return core.is_compiled_with_cuda() and self.use_cudnn
...@@ -129,6 +131,8 @@ class TestPool3d_Op(OpTest): ...@@ -129,6 +131,8 @@ class TestPool3d_Op(OpTest):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
if self.dtype == np.float16:
return
if self.testcudnn() and self.pool_type != "max": if self.testcudnn() and self.pool_type != "max":
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
self.check_grad_with_place( self.check_grad_with_place(
...@@ -142,8 +146,8 @@ class TestPool3d_Op(OpTest): ...@@ -142,8 +146,8 @@ class TestPool3d_Op(OpTest):
self.strides = [1, 1, 1] self.strides = [1, 1, 1]
self.paddings = [0, 0, 0] self.paddings = [0, 0, 0]
def init_op_type(self): def init_kernel_type(self):
self.op_type = "pool3d" pass
def init_pool_type(self): def init_pool_type(self):
self.pool_type = "avg" self.pool_type = "avg"
...@@ -158,15 +162,11 @@ class TestPool3d_Op(OpTest): ...@@ -158,15 +162,11 @@ class TestPool3d_Op(OpTest):
class TestCase1(TestPool3d_Op): class TestCase1(TestPool3d_Op):
def init_test_case(self): def init_test_case(self):
self.op_type = "pool3d"
self.shape = [2, 3, 7, 7, 7] self.shape = [2, 3, 7, 7, 7]
self.ksize = [3, 3, 3] self.ksize = [3, 3, 3]
self.strides = [1, 1, 1] self.strides = [1, 1, 1]
self.paddings = [0, 0, 0] self.paddings = [0, 0, 0]
def init_op_type(self):
self.op_type = "pool3d"
def init_pool_type(self): def init_pool_type(self):
self.pool_type = "avg" self.pool_type = "avg"
self.pool3D_forward_naive = avg_pool3D_forward_naive self.pool3D_forward_naive = avg_pool3D_forward_naive
...@@ -182,9 +182,6 @@ class TestCase2(TestPool3d_Op): ...@@ -182,9 +182,6 @@ class TestCase2(TestPool3d_Op):
self.strides = [1, 1, 1] self.strides = [1, 1, 1]
self.paddings = [1, 1, 1] self.paddings = [1, 1, 1]
def init_op_type(self):
self.op_type = "pool3d"
def init_pool_type(self): def init_pool_type(self):
self.pool_type = "avg" self.pool_type = "avg"
self.pool3D_forward_naive = avg_pool3D_forward_naive self.pool3D_forward_naive = avg_pool3D_forward_naive
...@@ -194,27 +191,18 @@ class TestCase2(TestPool3d_Op): ...@@ -194,27 +191,18 @@ class TestCase2(TestPool3d_Op):
class TestCase3(TestPool3d_Op): class TestCase3(TestPool3d_Op):
def init_op_type(self):
self.op_type = "pool3d"
def init_pool_type(self): def init_pool_type(self):
self.pool_type = "max" self.pool_type = "max"
self.pool3D_forward_naive = max_pool3D_forward_naive self.pool3D_forward_naive = max_pool3D_forward_naive
class TestCase4(TestCase1): class TestCase4(TestCase1):
def init_op_type(self):
self.op_type = "pool3d"
def init_pool_type(self): def init_pool_type(self):
self.pool_type = "max" self.pool_type = "max"
self.pool3D_forward_naive = max_pool3D_forward_naive self.pool3D_forward_naive = max_pool3D_forward_naive
class TestCase5(TestCase2): class TestCase5(TestCase2):
def init_op_type(self):
self.op_type = "pool3d"
def init_pool_type(self): def init_pool_type(self):
self.pool_type = "max" self.pool_type = "max"
self.pool3D_forward_naive = max_pool3D_forward_naive self.pool3D_forward_naive = max_pool3D_forward_naive
...@@ -222,39 +210,105 @@ class TestCase5(TestCase2): ...@@ -222,39 +210,105 @@ class TestCase5(TestCase2):
#--------------------test pool3d-------------------- #--------------------test pool3d--------------------
class TestCUDNNCase1(TestPool3d_Op): class TestCUDNNCase1(TestPool3d_Op):
def init_op_type(self): def init_kernel_type(self):
self.use_cudnn = True self.use_cudnn = True
self.op_type = "pool3d"
class TestFP16CUDNNCase1(TestPool3d_Op):
def init_kernel_type(self):
self.use_cudnn = True
self.dtype = np.float16
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=1e-3)
class TestCUDNNCase2(TestCase1): class TestCUDNNCase2(TestCase1):
def init_op_type(self): def init_kernel_type(self):
self.use_cudnn = True self.use_cudnn = True
self.op_type = "pool3d"
class TestFP16CUDNNCase2(TestCase1):
def init_kernel_type(self):
self.use_cudnn = True
self.dtype = np.float16
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=1e-3)
class TestCUDNNCase3(TestCase2): class TestCUDNNCase3(TestCase2):
def init_op_type(self): def init_kernel_type(self):
self.use_cudnn = True self.use_cudnn = True
self.op_type = "pool3d"
class TestFP16CUDNNCase3(TestCase2):
def init_kernel_type(self):
self.use_cudnn = True
self.dtype = np.float16
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=1e-3)
class TestCUDNNCase4(TestCase3): class TestCUDNNCase4(TestCase3):
def init_op_type(self): def init_kernel_type(self):
self.use_cudnn = True self.use_cudnn = True
self.op_type = "pool3d"
class TestFP16CUDNNCase4(TestCase3):
def init_kernel_type(self):
self.use_cudnn = True
self.dtype = np.float16
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=1e-3)
class TestCUDNNCase5(TestCase4): class TestCUDNNCase5(TestCase4):
def init_op_type(self): def init_kernel_type(self):
self.use_cudnn = True self.use_cudnn = True
self.op_type = "pool3d"
class TestFP16CUDNNCase5(TestCase4):
def init_kernel_type(self):
self.use_cudnn = True
self.dtype = np.float16
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=1e-3)
class TestCUDNNCase6(TestCase5): class TestCUDNNCase6(TestCase5):
def init_op_type(self): def init_kernel_type(self):
self.use_cudnn = True self.use_cudnn = True
self.op_type = "pool3d"
class TestFP16CUDNNCase6(TestCase5):
def init_kernel_type(self):
self.use_cudnn = True
self.dtype = np.float16
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=1e-3)
class TestCeilModeCase1(TestCUDNNCase1): class TestCeilModeCase1(TestCUDNNCase1):
......
...@@ -12,14 +12,17 @@ ...@@ -12,14 +12,17 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import core import core
import framework import framework
import executor import executor
import data_feeder import data_feeder
import contextlib import contextlib
import io
# optimizer is same as the parameter of Trainer.__init__. Rename it to opt_module # optimizer is same as the parameter of Trainer.__init__. Rename it to opt_module
import optimizer as opt_module import optimizer as opt_module
from transpiler import distribute_transpiler
__all__ = [ __all__ = [
'Trainer', 'Trainer',
...@@ -76,22 +79,60 @@ class Trainer(object): ...@@ -76,22 +79,60 @@ class Trainer(object):
raise TypeError( raise TypeError(
"The optimizer should be an instance of Optimizer") "The optimizer should be an instance of Optimizer")
optimizer.minimize(loss) optimize_ops, params_grads = optimizer.minimize(loss)
self.place = Trainer._check_and_get_place(place) self.place = Trainer._check_and_get_place(place)
self.dist_transpile_if_necessary(optimize_ops, params_grads)
# 2. move the default_main_program to self.program and run the # 2. move the default_main_program to self.program and run the
# default_startup program on an empty core.Scope() # default_startup program on an empty core.Scope()
# Run startup program # Run startup program
exe = executor.Executor(place) with self._prog_and_scope_guard():
exe.run(self.startup_program, scope=self.scope) exe = executor.Executor(place)
exe.run(self.startup_program)
if param_path: if param_path:
# load params from param_path into scope # load params from param_path into scope
# TODO(yuyang): This depends on parameters implementation. io.load_persistables(exe, dirname=param_path)
pass
def dist_transpile_if_necessary(self, optimize_ops, params_grads):
# TODO(helin): support distributed training if "PADDLE_TRAINING_ROLE" not in os.environ:
return
# the port of all pservers, needed by both trainer and pserver
port = os.getenv("PADDLE_PSERVER_PORT", "6174")
# comma separated ips of all pservers, needed by trainer and
# pserver
pserver_ips = os.getenv("PADDLE_PSERVER_IPS", "")
eplist = []
for ip in pserver_ips.split(","):
eplist.append(':'.join([ip, port]))
pserver_endpoints = ",".join(eplist)
# total number of workers/trainers in the job, needed by
# trainer and pserver
trainers = int(os.getenv("PADDLE_TRAINERS"))
# the IP of the local machine, needed by pserver only
current_endpoint = os.getenv("PADDLE_CURRENT_IP", "") + ":" + port
# the unique trainer id, starting from 0, needed by trainer
# only
trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
# the role, should be either PSERVER or TRAINER
training_role = os.getenv("PADDLE_TRAINING_ROLE")
with self._prog_and_scope_guard():
t = distribute_transpiler.DistributeTranspiler()
t.transpile(
trainer_id, pservers=pserver_endpoints, trainers=trainers)
if training_role == "PSERVER":
self.train_program = t.get_pserver_program(current_endpoint)
self.startup_program = t.get_startup_program(current_endpoint,
self.train_program)
elif training_role == "TRAINER":
self.train_program = t.get_trainer_program()
else:
raise ValueError(
'TRAINING_ROLE environment variable must be either TRAINER or PSERVER'
)
def train(self, def train(self,
num_epochs, num_epochs,
...@@ -117,6 +158,13 @@ class Trainer(object): ...@@ -117,6 +158,13 @@ class Trainer(object):
raise NotImplementedError( raise NotImplementedError(
"Parallel Executor version of trainer is not implemented") "Parallel Executor version of trainer is not implemented")
training_role = os.getenv("PADDLE_TRAINING_ROLE", "")
if training_role == "PSERVER":
with self._prog_and_scope_guard():
exe = executor.Executor(self.place)
exe.run()
return
self._train_by_executor(num_epochs, event_handler, reader, feed_order) self._train_by_executor(num_epochs, event_handler, reader, feed_order)
def test(self, reader): def test(self, reader):
...@@ -124,7 +172,9 @@ class Trainer(object): ...@@ -124,7 +172,9 @@ class Trainer(object):
def save_params(self, param_path): def save_params(self, param_path):
# reference: save_persistables in io.py # reference: save_persistables in io.py
pass exe = executor.Executor(self.place)
io.save_persistables(
exe, dirname=param_path, main_program=self.startup_program)
@staticmethod @staticmethod
def _check_and_get_place(place): def _check_and_get_place(place):
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from distribute_transpiler import DistributeTranspiler
from inference_transpiler import InferenceTranspiler
from memory_optimization_transpiler import memory_optimize, release_memory
from distribute_transpiler_simple import SimpleDistributeTranspiler
__all__ = [
"DistributeTranspiler", "InferenceTranspiler", "SimpleDistributeTranspiler",
"memory_optimize", "release_memory"
]
...@@ -17,9 +17,8 @@ from __future__ import print_function ...@@ -17,9 +17,8 @@ from __future__ import print_function
import math import math
import distributed_splitter as splitter import distributed_splitter as splitter
import framework from .. import core
from framework import Program, default_main_program, Variable, Parameter from ..framework import Program, default_main_program, Variable, Parameter
from . import core
LOOKUP_TABLE_TYPE = "lookup_table" LOOKUP_TABLE_TYPE = "lookup_table"
LOOKUP_TABLE_GRAD_TYPE = "lookup_table_grad" LOOKUP_TABLE_GRAD_TYPE = "lookup_table_grad"
...@@ -135,6 +134,16 @@ def split_dense_variable(var_list, ...@@ -135,6 +134,16 @@ def split_dense_variable(var_list,
return blocks return blocks
def delete_ops(block, ops):
try:
start = list(block.ops).index(ops[0])
end = list(block.ops).index(ops[-1])
[block.remove_op(start) for _ in xrange(end - start + 1)]
except Exception, e:
raise e
block.program.sync_with_cpp()
class DistributeTranspiler: class DistributeTranspiler:
def transpile(self, def transpile(self,
trainer_id, trainer_id,
...@@ -317,7 +326,7 @@ class DistributeTranspiler: ...@@ -317,7 +326,7 @@ class DistributeTranspiler:
def get_trainer_program(self): def get_trainer_program(self):
# remove optimize ops and add a send op to main_program # remove optimize ops and add a send op to main_program
self.delete_ops(self.origin_program.global_block(), self.optimize_ops) delete_ops(self.origin_program.global_block(), self.optimize_ops)
# FIXME(typhoonzero): serialize once will fix error occurs when clone. # FIXME(typhoonzero): serialize once will fix error occurs when clone.
self.origin_program.__str__() self.origin_program.__str__()
return self.origin_program return self.origin_program
...@@ -601,7 +610,7 @@ class DistributeTranspiler: ...@@ -601,7 +610,7 @@ class DistributeTranspiler:
attrs={"axis": 0}) attrs={"axis": 0})
# delete lookup_table_op # delete lookup_table_op
self.delete_ops(program.global_block(), [op]) delete_ops(program.global_block(), [op])
# break for loop # break for loop
break break
...@@ -1164,12 +1173,3 @@ class DistributeTranspiler: ...@@ -1164,12 +1173,3 @@ class DistributeTranspiler:
in_name.startswith("beta2_pow_acc"): in_name.startswith("beta2_pow_acc"):
return True return True
return False return False
def delete_ops(self, block, ops):
try:
start = list(block.ops).index(ops[0])
end = list(block.ops).index(ops[-1])
[block.remove_op(start) for _ in xrange(end - start + 1)]
except Exception, e:
raise e
block.program.sync_with_cpp()
...@@ -12,10 +12,8 @@ ...@@ -12,10 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import framework from ..framework import Program, default_main_program, Parameter, Variable
from framework import Program, default_main_program, Parameter, Variable from ..layer_helper import LayerHelper
import optimizer
from layer_helper import LayerHelper
def hash_name_to_server(params_grads, pserver_endpoints): def hash_name_to_server(params_grads, pserver_endpoints):
......
...@@ -13,9 +13,9 @@ ...@@ -13,9 +13,9 @@
# limitations under the License. # limitations under the License.
import numpy as np import numpy as np
from framework import Program from .. import core
from executor import global_scope from ..framework import Program
from . import core from ..executor import global_scope
class InferenceTranspiler: class InferenceTranspiler:
......
...@@ -13,11 +13,9 @@ ...@@ -13,11 +13,9 @@
# limitations under the License. # limitations under the License.
from collections import defaultdict from collections import defaultdict
import framework from .. import core
from framework import Program, default_main_program, Parameter, Variable from ..framework import Program, default_main_program, Parameter, Variable
import backward from ..backward import _rename_arg_
from backward import _rename_arg_
from . import core
dtype_to_size = { dtype_to_size = {
core.VarDesc.VarType.FP16: 2, core.VarDesc.VarType.FP16: 2,
......
...@@ -22,7 +22,11 @@ import paddle.fluid.proto.profiler.profiler_pb2 as profiler_pb2 ...@@ -22,7 +22,11 @@ import paddle.fluid.proto.profiler.profiler_pb2 as profiler_pb2
parser = argparse.ArgumentParser(description=__doc__) parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument( parser.add_argument(
'--profile_path', type=str, default='', help='Input profile file name.') '--profile_path',
type=str,
default='',
help='Input profile file name. If there are multiple file, the format '
'should be trainer1=file1,trainer2=file2,ps=file3')
parser.add_argument( parser.add_argument(
'--timeline_path', type=str, default='', help='Output timeline file name.') '--timeline_path', type=str, default='', help='Output timeline file name.')
args = parser.parse_args() args = parser.parse_args()
...@@ -108,8 +112,8 @@ class _ChromeTraceFormatter(object): ...@@ -108,8 +112,8 @@ class _ChromeTraceFormatter(object):
class Timeline(object): class Timeline(object):
def __init__(self, profile_pb): def __init__(self, profile_dict):
self._profile_pb = profile_pb self._profile_dict = profile_dict
self._pid = 0 self._pid = 0
self._devices = dict() self._devices = dict()
self._chrome_trace = _ChromeTraceFormatter() self._chrome_trace = _ChromeTraceFormatter()
...@@ -120,35 +124,37 @@ class Timeline(object): ...@@ -120,35 +124,37 @@ class Timeline(object):
return cur_pid return cur_pid
def _allocate_pids(self): def _allocate_pids(self):
for event in self._profile_pb.events: for k, profile_pb in self._profile_dict.iteritems():
if event.type == profiler_pb2.Event.CPU: for event in profile_pb.events:
if (event.device_id, "CPU") not in self._devices: if event.type == profiler_pb2.Event.CPU:
pid = self._allocate_pid() if (k, event.device_id, "CPU") not in self._devices:
self._devices[(event.device_id, "CPU")] = pid pid = self._allocate_pid()
self._chrome_trace.emit_pid("cpu:block:%d" % self._devices[(k, event.device_id, "CPU")] = pid
(event.device_id), pid) self._chrome_trace.emit_pid("%s:cpu:block:%d" %
elif event.type == profiler_pb2.Event.GPUKernel: (k, event.device_id), pid)
if (event.device_id, "GPUKernel") not in self._devices: elif event.type == profiler_pb2.Event.GPUKernel:
pid = self._allocate_pid() if (k, event.device_id, "GPUKernel") not in self._devices:
self._devices[(event.device_id, "GPUKernel")] = pid pid = self._allocate_pid()
self._chrome_trace.emit_pid("gpu:%d" % (event.device_id), self._devices[(k, event.device_id, "GPUKernel")] = pid
pid) self._chrome_trace.emit_pid("%s:gpu:%d" %
(k, event.device_id), pid)
def _allocate_events(self): def _allocate_events(self):
for event in self._profile_pb.events: for k, profile_pb in self._profile_dict.iteritems():
if event.type == profiler_pb2.Event.CPU: for event in profile_pb.events:
type = "CPU" if event.type == profiler_pb2.Event.CPU:
elif event.type == profiler_pb2.Event.GPUKernel: type = "CPU"
type = "GPUKernel" elif event.type == profiler_pb2.Event.GPUKernel:
pid = self._devices[(event.device_id, type)] type = "GPUKernel"
args = {'name': event.name} pid = self._devices[(k, event.device_id, type)]
if event.memcopy.bytes > 0: args = {'name': event.name}
args = {'mem_bytes': event.memcopy.bytes} if event.memcopy.bytes > 0:
# TODO(panyx0718): Chrome tracing only handles ms. However, some args = {'mem_bytes': event.memcopy.bytes}
# ops takes micro-seconds. Hence, we keep the ns here. # TODO(panyx0718): Chrome tracing only handles ms. However, some
self._chrome_trace.emit_region( # ops takes micro-seconds. Hence, we keep the ns here.
event.start_ns, (event.end_ns - event.start_ns) / 1.0, pid, self._chrome_trace.emit_region(
event.sub_device_id, 'Op', event.name, args) event.start_ns, (event.end_ns - event.start_ns) / 1.0, pid,
event.sub_device_id, 'Op', event.name, args)
def generate_chrome_trace(self): def generate_chrome_trace(self):
self._allocate_pids() self._allocate_pids()
...@@ -163,11 +169,23 @@ timeline_path = '/tmp/timeline' ...@@ -163,11 +169,23 @@ timeline_path = '/tmp/timeline'
if args.timeline_path: if args.timeline_path:
timeline_path = args.timeline_path timeline_path = args.timeline_path
with open(profile_path, 'r') as f: profile_paths = profile_path.split(',')
profile_s = f.read() profile_dict = dict()
profile_pb = profiler_pb2.Profile() if len(profile_path) == 1:
profile_pb.ParseFromString(profile_s) with open(profile_path, 'r') as f:
profile_s = f.read()
tl = Timeline(profile_pb) profile_pb = profiler_pb2.Profile()
profile_pb.ParseFromString(profile_s)
profile_dict['trainer'] = profile_pb
else:
for profile_path in profile_paths:
k, v = profile_path.split('=')
with open(v, 'r') as f:
profile_s = f.read()
profile_pb = profiler_pb2.Profile()
profile_pb.ParseFromString(profile_s)
profile_dict[k] = profile_pb
tl = Timeline(profile_dict)
with open(timeline_path, 'w') as f: with open(timeline_path, 'w') as f:
f.write(tl.generate_chrome_trace()) f.write(tl.generate_chrome_trace())
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册